In [None]:
# default_exp classification.modelling.heads

In [None]:
# hide
%load_ext nb_black
%load_ext autoreload
%autoreload 2

<IPython.core.display.Javascript object>

In [None]:
# hide
import warnings

from nbdev.export import *
from nbdev.showdoc import *

warnings.filterwarnings("ignore")

<IPython.core.display.Javascript object>

# Heads
> A head is a regular `torch.nn.Module` that can be attached to a backbone.


> Note: For Image Classification a, `head` typically contains a pooling layer along with the classifier

In [None]:
# export
from typing import *

import torch
from fastcore.all import L, ifnone, store_attr
from timm.models.layers.classifier import _create_fc, _create_pool
from torch import nn

from gale.classification.modelling.backbones import filter_weight_decay
from gale.core.classes import GaleModule
from gale.core.logging import setup_logger
from gale.core.nn import ACTIVATION_REGISTRY
from gale.core.nn.shape_spec import ShapeSpec
from gale.core.nn.utils import trainable_params
from gale.core.structures import IMAGE_CLASSIFIER_HEADS

<IPython.core.display.Javascript object>

In [None]:
# export
_logger = setup_logger()

<IPython.core.display.Javascript object>

In [None]:
# export
class ImageClassificationHead(GaleModule):
    """
    Abstract class for ImageClassification Heads
    """

    def __init__(self):
        """
        The `__init__` method of any subclass can specify its own set of arguments.
        """
        super().__init__()

    def get_lrs(self) -> List:
        """
        Returns a List containining the Lrs' for
        each parameter group. This is required to build schedulers
        like `torch.optim.lr_scheduler.OneCycleScheduler` which needs
        the max lrs' for all the Param Groups.
        """
        lrs = []

        for p in self.build_param_dicts():
            lrs.append(p["lr"])
        return lrs

<IPython.core.display.Javascript object>

In [None]:
show_doc(ImageClassificationHead.get_lrs)

<h4 id="ImageClassificationHead.get_lrs" class="doc_header"><code>ImageClassificationHead.get_lrs</code><a href="__main__.py#L13" class="source_link" style="float:right">[source]</a></h4>

> <code>ImageClassificationHead.get_lrs</code>()

Returns a List containining the Lrs' for
each parameter group. This is required to build schedulers
like `torch.optim.lr_scheduler.OneCycleScheduler` which needs
the max lrs' for all the Param Groups.

<IPython.core.display.Javascript object>

In [None]:
# export
@IMAGE_CLASSIFIER_HEADS.register()
class FullyConnectedHead(ImageClassificationHead):
    """
    Classifier head w/ configurable global pooling and dropout.
    From - https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/classifier.py
    """

    def __init__(
        self,
        input_shape: ShapeSpec,
        num_classes: int,
        pool_type: str = "avg",
        drop_rate: float = 0.0,
        use_conv: bool = False,
        lr: float = 2e-03,
        wd: float = 0,
        filter_wd: bool = False,
    ):
        super(FullyConnectedHead, self).__init__()
        self.drop_rate = drop_rate
        in_planes = input_shape.channels
        # fmt: off
        self.global_pool, num_pooled_features = _create_pool(in_planes, num_classes, pool_type, use_conv=use_conv)
        # fmt: on
        self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)
        self.flatten_after_fc = use_conv and pool_type

        store_attr("lr, wd, filter_wd")

    def forward(self, x):
        x = self.global_pool(x)
        if self.drop_rate:
            x = F.dropout(x, p=float(self.drop_rate), training=self.training)
        x = self.fc(x)
        return x

    def build_param_dicts(self) -> Any:
        if self.filter_wd:
            ps = filter_weight_decay(self, lr=self.lr, weight_decay=self.wd)
        else:
            # fmt: off
            ps = [{"params": trainable_params(self),"lr": self.lr,"weight_decay": self.wd}]
            # fmt: on
        return ps

<IPython.core.display.Javascript object>

Arguments to `FullyConnectedHead`:
- `input_shape` (ShapeSpec): input shape
- `num_classes` (int): Number of classes for the head.
- `pool_type` (str): The pooling layer to use. Check [here](https://github.com/rwightman/pytorch-image-models/blob/9a1bd358c7e998799eed88b29842e3c9e5483e34/timm/models/layers/adaptive_avgmax_pool.py#L79).
-  `drop_rate` (float): If >0.0 then applies dropout between the pool_layer and the fc layer.
- `use_conv` (bool): Use a convolutional layer as the final fc layer.
- `lr` (float): Learning rate for the modules.
- `wd` (float): Weight decay for the modules.
- `filter_wd` (bool): Filter out `bias`, `bn` from `weight_decay`.

In [None]:
input_shape = ShapeSpec(channels=512)
tst = FullyConnectedHead(input_shape, 10)
tst

FullyConnectedHead(
  (global_pool): SelectAdaptivePool2d (pool_type=avg, flatten=True)
  (fc): Linear(in_features=512, out_features=10, bias=True)
)

<IPython.core.display.Javascript object>

In [None]:
# hide
o = tst(torch.randn(2, 512, 2, 2))
o.shape

torch.Size([2, 10])

<IPython.core.display.Javascript object>

In [None]:
# export
@IMAGE_CLASSIFIER_HEADS.register()
class FastaiHead(ImageClassificationHead):
    """
    Model head that takes `in_planes` features, runs through `lin_ftrs`, and out `num_classes` classes.


    From -
    https://github.com/fastai/fastai/blob/8b1da8765fc07f1232c20fa8dc5e909d2835640c/fastai/vision/learner.py#L76
    """

    def __init__(
        self,
        input_shape: ShapeSpec,
        num_classes: int,
        act: str = "ReLU",
        lin_ftrs: Optional[List] = None,
        ps: Union[List, int] = 0.5,
        concat_pool: bool = True,
        first_bn: bool = True,
        bn_final: bool = False,
        lr: float = 2e-03,
        wd: float = 0,
        filter_wd: bool = False,
    ):
        super(FastaiHead, self).__init__()
        in_planes = input_shape.channels
        pool = "catavgmax" if concat_pool else "avg"
        pool, nf = _create_pool(in_planes, num_classes, pool, use_conv=False)

        # fmt: off
        lin_ftrs = [nf, 512, num_classes] if lin_ftrs is None else [nf] + lin_ftrs + [num_classes]
        # fmt: on

        bns = [first_bn] + [True] * len(lin_ftrs[1:])

        ps = L(ps)

        if len(ps) == 1:
            ps = [ps[0] / 2] * (len(lin_ftrs) - 2) + ps

        act = ifnone(act, "ReLU")
        # fmt: off
        actns = [ACTIVATION_REGISTRY.get(act)(inplace=True)] * (len(lin_ftrs) - 2) + [None]
        if bn_final:
            actns[-1] = ACTIVATION_REGISTRY.get(act)(inplace=True)
        # fmt: on

        self.layers = [pool]

        for ni, no, bn, p, actn in zip(lin_ftrs[:-1], lin_ftrs[1:], bns, ps, actns):
            self.layers += nn.Sequential(
                nn.BatchNorm1d(ni), nn.Dropout(p), nn.Linear(ni, no, bias=not bns), actn
            )

        if bn_final:
            self.layers.append(nn.BatchNorm1d(lin_ftrs[-1], momentum=0.01))
        self.layers = nn.Sequential(*[l for l in self.layers if l is not None])

        store_attr("lr, wd, filter_wd")

    def forward(self, xb: torch.Tensor) -> Any:
        return self.layers(xb)

    def build_param_dicts(self) -> Any:
        if self.filter_wd:
            ps = filter_weight_decay(self.layers, lr=self.lr, weight_decay=self.wd)
        else:
            # fmt: off
            ps = [{"params": trainable_params(self.layers),"lr": self.lr,"weight_decay": self.wd}]
            # fmt: on
        return ps

<IPython.core.display.Javascript object>

The head begins with `AdaptiveConcatPool2d` if `concat_pool=True` otherwise, it uses traditional average pooling. Then it uses a Flatten layer before going on blocks of `BatchNorm`, `Dropout` and `Linear` layers.

Those blocks start at `in_planes`, then every element of `lin_ftrs` (defaults to [512]) and end at `num_classes`. `ps` is a list of probabilities used for the dropouts (if you only pass 1, it will use half the value then that value as many times as necessary).

Arguments to `FastaiHead`:
- `input_shape` (ShapeSpec): input shape
- `num_classes` (int): Number of classes for the head.
- `act` (str): name of the activation function to use. If None uses the default activations else the name must be in ACTIVATION_REGISTRY. Activation layers are used after every block (`BatchNorm`, `Dropout` and `Linear` layers) if it is not the last block.
- `lin_ftrs` (List): Features of the Linear layers. (defaults to [512])
- `ps` (List): list of probabilities used for the dropouts.
- `concat_pool` (bool): Wether to use `AdaptiveConcatPool2d` or `AdaptiveAveragePool2d`.
- `first_bn` (bool): BatchNorm Layer after pool.
- `bn_final` (bool): Final Layer is BatchNorm.
- `lr` (float): Learning rate for the modules.
- `wd` (float): Weight decay for the modules.
- `filter_wd` (bool): Filter out `bias`, `bn` from `weight_decay`.

In [None]:
input_shape = ShapeSpec(channels=512)
tst = FastaiHead(input_shape=input_shape, num_classes=10)
tst

FastaiHead(
  (layers): Sequential(
    (0): SelectAdaptivePool2d (pool_type=catavgmax, flatten=True)
    (1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): Dropout(p=0.25, inplace=False)
    (3): Linear(in_features=1024, out_features=512, bias=False)
    (4): ReLU(inplace=True)
    (5): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): Dropout(p=0.5, inplace=False)
    (7): Linear(in_features=512, out_features=10, bias=False)
  )
)

<IPython.core.display.Javascript object>

In [None]:
# hide
o = tst(torch.randn(2, 512, 2, 2))
o.shape

torch.Size([2, 10])

<IPython.core.display.Javascript object>

### Instantiation using config :

In [None]:
from dataclasses import dataclass, field
from omegaconf import OmegaConf, DictConfig, MISSING

<IPython.core.display.Javascript object>

In [None]:
@dataclass
class HeadConf:
    num_classes: int = MISSING
    act: str = "ReLU"
    lin_ftrs: Optional[List] = None
    ps: Any = 0.5
    concat_pool: bool = True
    first_bn: bool = True
    bn_final: bool = False
    lr: float = 0.002
    wd: float = 0
    filter_wd: bool = False


conf = OmegaConf.structured(HeadConf(num_classes=10))

tst = FastaiHead.from_config_dict(conf, input_shape=input_shape)
tst

FastaiHead(
  (layers): Sequential(
    (0): SelectAdaptivePool2d (pool_type=catavgmax, flatten=True)
    (1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): Dropout(p=0.25, inplace=False)
    (3): Linear(in_features=1024, out_features=512, bias=False)
    (4): ReLU(inplace=True)
    (5): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): Dropout(p=0.5, inplace=False)
    (7): Linear(in_features=512, out_features=10, bias=False)
  )
)

<IPython.core.display.Javascript object>

## Export -

In [None]:
# hide
notebook2script()

Converted 00_core.logging.ipynb.
Converted 00a_core.structures.ipynb.
Converted 00b_core.visualize.ipynb.
Converted 01_core.nn.utils.ipynb.
Converted 01a_core.nn.losses.ipynb.
Converted 02_core.nn.optim.optimizers.ipynb.
Converted 02a_core.nn.optim.lr_schedulers.ipynb.
Converted 03_core.config.ipynb.
Converted 03a_core.classes.ipynb.
Converted 04_classification.modelling.backbones.ipynb.
Converted 04a_classification.modelling.heads.ipynb.
Converted 04b_classification.modelling.meta_arch.ipynb.
Converted 05_collections.pandas.ipynb.
Converted 06a_collections.callbacks.notebook.ipynb.
Converted 06b_collections.callbacks.ema.ipynb.
Converted index.ipynb.


<IPython.core.display.Javascript object>