In [None]:
# hide
# skip
!git clone https://github.com/benihime91/gale # install gale on colab
!pip install -e "gale[dev]"

In [None]:
# default_exp classification.model.backbones

In [None]:
# hide
%load_ext nb_black
%load_ext autoreload
%autoreload 2
%matplotlib inline

<IPython.core.display.Javascript object>

In [None]:
# hide
import warnings

from nbdev.export import *
from nbdev.showdoc import *
from timm.utils import *

warnings.filterwarnings("ignore")

setup_default_logging()

<IPython.core.display.Javascript object>

# Backbones 
> Backbones/feature extractors for use in Image Classification Tasks

In [None]:
# export
import logging
import re
from collections import namedtuple
from dataclasses import dataclass
from typing import *

import timm
import torch
from fastcore.all import store_attr, use_kwargs_dict
from omegaconf import MISSING
from timm.models import ResNet
from torch import nn

from gale.core_classes import BasicModule
from gale.torch_utils import build_discriminative_lrs, set_bn_eval, trainable_params
from gale.utils.activs import ACTIVATION_REGISTRY
from gale.utils.shape_spec import ShapeSpec
from gale.utils.structures import IMAGE_CLASSIFIER_BACKBONES

_logger = logging.getLogger(__name__)

<IPython.core.display.Javascript object>

In [None]:
# export
_all_ = ["IMAGE_CLASSIFIER_BACKBONES"]

<IPython.core.display.Javascript object>

In [None]:
# hide
from omegaconf import MISSING, DictConfig, OmegaConf
from fastcore.test import *

<IPython.core.display.Javascript object>

## Utils function

In [None]:
# export
def _is_pool_type(l: nn.Module) -> bool:
    """
    True if `l` is a pooling layer.
    From: https://github.com/fastai/fastai/blob/master/fastai/vision/learner.py#L76
    """
    return re.search(r"Pool[123]d$", l.__class__.__name__)


def has_pool_type(m: nn.Module) -> bool:
    """
    Return `True` if `m` is a pooling layer or has one in its children
    From: https://github.com/fastai/fastai/blob/master/fastai/vision/learner.py#L76
    """
    if _is_pool_type(m):
        return True
    for l in m.children():
        if has_pool_type(l):
            return True
    return False

<IPython.core.display.Javascript object>

In [None]:
# export
def prepare_backbone(model: nn.Module, cut=None):
    "Cut off the body of a typically pretrained `model` as determined by `cut`"
    if cut is None:
        ll = list(enumerate(model.children()))
        cut = next(i for i, o in reversed(ll) if has_pool_type(o))
    if isinstance(cut, int):
        return nn.Sequential(*list(model.children())[:cut])
    elif callable(cut):
        return cut(model)
    else:
        raise NamedError("cut must be either integer or a function")

<IPython.core.display.Javascript object>

In [None]:
# fmt: off
tst = nn.Sequential(nn.Conv2d(3, 5, 3), nn.BatchNorm2d(5), nn.AvgPool2d(1), nn.Linear(3, 4))

m = prepare_backbone(tst)
test_eq(len(m), 2)

m = prepare_backbone(tst, cut=3)
test_eq(len(m), 3)

m = prepare_backbone(tst, cut=-1)
test_eq(len(m), 3)

<IPython.core.display.Javascript object>

In [None]:
# export
def filter_weight_decay(
    model: nn.Module,
    lr: float,
    weight_decay: float = 1e-5,
    skip_list=(),
) -> List[Dict]:
    """
    Filter out bias, bn and other 1d params from weight decay.
    Modified from: https://github.com/rwightman/pytorch-image-models/timm/optim/optim_factory.py
    """
    decay = []
    no_decay = []
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue  # frozen weights
        if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
            no_decay.append(param)
        else:
            decay.append(param)
    return [
        {"params": no_decay, "weight_decay": 0.0, "lr": lr},
        {"params": decay, "weight_decay": weight_decay, "lr": lr},
    ]

<IPython.core.display.Javascript object>

## ImageClassificationBackbone -

In [None]:
# export
class ImageClassificationBackbone(BasicModule):
    """
    Abstract class for ImageClassification BackBones
    """

    _hypers = namedtuple("hypers", field_names=["lr", "wd"])

    def filter_params(self, parameters: List[Dict]):
        """Filters any empty paramter groups in `p`"""
        pgs_filterd = []

        for group in parameters:
            if group["params"] == []:
                pass
            else:
                pgs_filterd += [group]
        return pgs_filterd

    @property
    def hypers(self) -> Tuple:
        """
        Returns list of parameters like `lr` and `wd`
        for each param group
        """
        lrs = []
        wds = []

        for p in self.build_param_dicts():
            lrs.append(p["lr"])
            wds.append(p["weight_decay"])
        return self._hypers(lrs, wds)

    def output_shape(self) -> ShapeSpec:
        """
        Returns the output shape. For most backbones
        this means it will contain the channels in the
        output layer.
        """
        pass

<IPython.core.display.Javascript object>

Some `meta_arch`'s in gale require backbones and for image classsification all the backbones should inherit from `ImageClassificationBackbone` 

In [None]:
class TstModule(ImageClassificationBackbone):
    def __init__(self):
        super(TstModule, self).__init__()
        layers = [nn.Linear(3, 4), nn.Linear(4, 5)]
        self.layers = nn.Sequential(*layers)

    def forward(self, o):
        return self.layers(o)

    def output_shape(self):
        return ShapeSpec(4, None, None)

    def build_param_dicts(self):
        p0 = {"params": self.layers[0].parameters(), "lr": 1e-06, "weight_decay": 0.001}
        p1 = {"params": self.layers[1].parameters(), "lr": 1e-03, "weight_decay": 0.1}
        return [p0, p1]


tst = TstModule()

<IPython.core.display.Javascript object>

### Properties-

In [None]:
show_doc(ImageClassificationBackbone.hypers)

<h4 id="ImageClassificationBackbone.hypers" class="doc_header"><code>ImageClassificationBackbone.hypers</code><a href="" class="source_link" style="float:right">[source]</a></h4>

Returns list of parameters like `lr` and `wd`
for each param group

<IPython.core.display.Javascript object>

In [None]:
test_eq(tst.hypers.lr, [1e-06, 1e-03])
test_eq(tst.hypers.wd, [0.001, 0.1])

<IPython.core.display.Javascript object>

In [None]:
show_doc(ImageClassificationBackbone.output_shape)

<h4 id="ImageClassificationBackbone.output_shape" class="doc_header"><code>ImageClassificationBackbone.output_shape</code><a href="__main__.py#L34" class="source_link" style="float:right">[source]</a></h4>

> <code>ImageClassificationBackbone.output_shape</code>()

Returns the output shape. For most backbones
this means it will contain the channels in the
output layer.

<IPython.core.display.Javascript object>

In [None]:
assert tst.output_shape() == ShapeSpec(4, None, None)

<IPython.core.display.Javascript object>

In [None]:
show_doc(ImageClassificationBackbone.filter_params)

<h4 id="ImageClassificationBackbone.filter_params" class="doc_header"><code>ImageClassificationBackbone.filter_params</code><a href="__main__.py#L9" class="source_link" style="float:right">[source]</a></h4>

> <code>ImageClassificationBackbone.filter_params</code>(**`parameters`**:`List`\[`Dict`\])

Filters any empty paramter groups in `p`

<IPython.core.display.Javascript object>

## TimmBackboneBase -

In [None]:
# export
class TimmBackboneBase(ImageClassificationBackbone):
    "Create a model from `timm` and converts it into a Image Classification Backbone"

    @use_kwargs_dict(
        keep=True,
        pretrained=True,
        drop_block_rate=None,
        drop_path_rate=None,
        bn_tf=False,
    )
    def __init__(
        self,
        model_name: str,
        input_shape: ShapeSpec,
        act: str = None,
        lr: float = 1e-03,
        wd: float = 0,
        freeze_bn: bool = False,
        freeze_at: int = False,
        filter_wd: bool = False,
        **kwargs,
    ):
        super(TimmBackboneBase, self).__init__()

        store_attr("lr, wd, filter_wd, input_shape")

        if act is not None:
            act = ACTIVATION_REGISTRY.get(act)

        model = timm.create_model(
            model_name,
            act_layer=act,
            global_pool="",
            num_classes=0,
            in_chans=input_shape.channels,
            **kwargs,
        )

        # save some of information from timm models
        self.num_features = model.num_features
        self.timm_model_cfg = model.default_cfg
        self.model = prepare_backbone(model)

        if not freeze_at:
            self.unfreeze()
        else:
            self.freeze_to(freeze_at)

        if freeze_bn:
            set_bn_eval(self.model)

    def forward(self, xb: torch.Tensor) -> torch.Tensor:
        return self.model(xb)

    def build_param_dicts(self) -> List:
        if self.filter_wd:
            ps = filter_weight_decay(self.model, lr=self.lr, weight_decay=self.wd)
        else:
            ps = {
                "params": trainable_params(self.model),
                "lr": self.lr,
                "weight_decay": self.wd,
            }
            ps = [ps]

        return self.filter_params(ps)

    def output_shape(self) -> ShapeSpec:
        return ShapeSpec(self.num_features, None, None)

<IPython.core.display.Javascript object>

This class provides a simple way to load a model from timm using all it's arguments. It then cuts the model at the pooling layer before the classifier of the model .ie., we keep the feature extractor the feature extractor is converted to the backbone. You can optionally choose to partially or fully freeze the parameters groups of the backbone using `freeze_at`. `freeze_bn` sets the BatchNorm layers of the model to eval & if `filter_wd` then the weight decay is not applied to bias and other 1d paramters of the backbone.

`TimmBackboneBase.build_param_dics()` is responsible to building the parameters of the model. Currently it returns the `trainable_params` of the model with `lr` and `wd`. The paramters are filterd with `wd` if `filter_wd`. For more advanced options you should probably override this method.

**Arguments to `TimmBackboneBase`:**
- `input_shape` (ShapeSpec): Shape of the Inputs
- `model_name` (str): name of model to instantiate.
- `act` (str): name of the activation function to use. If None uses the default activations else the name must be in `ACTIVATION_REGISTRY`.
- `lr` (float): learning rate for the modules.
- `wd` (float): weight decay for the modules.
- `freeze_bn` (bool): freeze the batch normalization layers of the model.
- `freeze_at` (int): freeze the layers of the backbone upto `freeze_at`, false means train all.
- `filter_wd` (bool): Filter out bias, bn from weight_decay.
- `pretrained` (bool): load pretrained ImageNet-1k weights if true.
- `drop_block_rate` (float): Drop block rate
- `drop_path_rate` (float): Drop_path_rate
- `bn_tf` (bool): Use Tensorflow BatchNorm defaults for models that support it.
- `kwargs` (optional): Optional kwargs passed onto `timm.create_model()`

In [None]:
input_shape = ShapeSpec(channels=3, height=255, width=255)
bk = TimmBackboneBase(model_name="resnet18", pretrained=True, input_shape=input_shape)
m = timm.create_model("resnet18")

i = torch.randn(2, 3, 224, 224)
o1 = bk(i)
test_eq(o1.shape, torch.Size([2, 512, 7, 7]))
test_eq(bk.output_shape().channels, m.num_features)

Loading pretrained weights from url (https://download.pytorch.org/models/resnet18-5c106cde.pth)


<IPython.core.display.Javascript object>

### Dataclass

In [None]:
# export
@dataclass
class TimmBackboneDataClass:
    """
    Base config file for `TimmBackboneBase`. You need to pass in a
    `model_name` the opter parameters are optional.
    """

    model_name: str = MISSING
    act: Optional[str] = None
    lr: Any = 1e-03
    wd: Any = 0.0
    freeze_bn: bool = False
    freeze_at: Any = False
    filter_wd: bool = False
    pretrained: bool = True
    drop_block_rate: Optional[float] = None
    drop_path_rate: Optional[float] = None
    bn_tf: bool = False

<IPython.core.display.Javascript object>

The config for `TimmBackboneBaseConfig` is going to look like this. We need to convert the dataclass to the Omegaconf config file and then we can use `from_config_dict` method to instantiate our class ...

In [None]:
# create a config to instantiate the same backbone as above
conf = TimmBackboneDataClass(model_name="resnet18", pretrained=True)
conf = OmegaConf.structured(conf)

# we need to explicitely pass in the input_shape argument
m = TimmBackboneBase.from_config_dict(conf, input_shape=input_shape)

o2 = m(i)
test_eq(o2.shape, torch.Size([2, 512, 7, 7]))

test_eq(o1.data, o2.data)

Loading pretrained weights from url (https://download.pytorch.org/models/resnet18-5c106cde.pth)


<IPython.core.display.Javascript object>

## ResNetBackbone-

In [None]:
# export
class ResNetBackbone(ImageClassificationBackbone):
    """
    A Backbone for ResNet based models from timm. Note: this class
    does supports all the models listed
    [here](https://github.com/rwightman/pytorch-image-models/blob/e8a64fb88108b592da192e98054095b1ee25e96e/timm/models/resnet.py)
    """

    @use_kwargs_dict(
        keep=True,
        pretrained=True,
        drop_block_rate=0.0,
        drop_path_rate=0.0,
    )
    def __init__(
        self,
        model_name: str,
        input_shape: ShapeSpec,
        act: str = None,
        lr: float = 1e-03,
        wd: float = 1e-02,
        lr_div: float = 100,
        freeze_at: int = 0,
        freeze_bn: bool = False,
        **kwargs,
    ):
        super(ResNetBackbone, self).__init__()
        store_attr("freeze_at, wd, lr, lr_div, input_shape, freeze_bn")

        if act is not None:
            act = ACTIVATION_REGISTRY.get(act)

        model = timm.create_model(
            model_name,
            act_layer=act,
            global_pool="",
            num_classes=0,
            in_chans=input_shape.channels,
            **kwargs,
        )

        assert isinstance(model, ResNet), "ResNetBackbone supports only ResNet models"
        # save some of the input information from timm models
        self.num_features = model.num_features
        self.timm_model_cfg = model.default_cfg

        # break up the model
        # the stem for the resnet model consists of a convolutional block, norm, act, pool
        stem = nn.Sequential(model.conv1, model.bn1, model.act1, model.maxpool)

        # stages will consisit of the remaining 4 layers
        stages = [model.layer1, model.layer2, model.layer3, model.layer4]
        stages = nn.Sequential(*stages)

        # creat the module
        self.resnet = nn.Sequential(stem, stages)
        self.prepare_model(self.resnet)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.resnet(x)

    def build_param_dicts(self) -> Any:
        # model split according to https://github.com/fastai/fastai/blob/master/fastai/vision/learner.py
        p0 = {
            "params": trainable_params(self.resnet[0])
            + trainable_params(self.resnet[1][:3]),
            "weight_decay": self.wd,
        }
        p1 = {"params": trainable_params(self.resnet[1][3:]), "weight_decay": self.wd}
        ps = [p0, p1]
        ps, _ = build_discriminative_lrs(ps, self.lr, self.lr / self.lr_div)
        return self.filter_params(ps)

    def freeze_block(self, m: nn.Module):
        """
        Make this block `m` not trainable.
        """
        for p in m.parameters():
            p.requires_grad = False
        m.eval()

    def prepare_model(self, m: nn.Module):
        """
        Freeze the first several stages of the `ResNet`. Commonly used in fine-tuning.
        """
        if self.freeze_at >= 1:
            _logger.debug("Freezing stem")
            # freeze the stem of the model
            self.freeze_block(m[0])

        # freeze the blocks of the model according to
        # freeze_at
        for idx, stage in enumerate(m[1], start=2):
            if self.freeze_at >= idx:
                _logger.debug(f"Freezing ResBlock {idx - 2 }")
                for block in stage.children():
                    self.freeze_block(block)

        if self.freeze_bn:
            set_bn_eval(m)

    def output_shape(self) -> ShapeSpec:
        return ShapeSpec(self.num_features, None, None)

<IPython.core.display.Javascript object>

**Arguments to `ResNetBackbone`**:
- `input_shape` (ShapeSpec): Shape of the Inputs
- `model_name` (str): name of model to instantiate.
- `act` (str): name of the activation function to use. If None uses the default activations else the name must be in `ACTIVATION_REGISTRY`.
- `lr` (float): learning rate for the modules.
- `lr_div` (int, float): factor for discriminative lrs.   
- `wd` (float): weight decay for the modules.
- `freeze_at` (int): Freeze the first several stages of the ResNet. Commonly used in fine-tuning. `1` means freezing the stem. `2` means freezing the stem and one residual stage, etc.
- `pretrained` (bool): load pretrained ImageNet-1k weights if true.
- `drop_block_rate` (float): Drop block rate.
- `drop_path_rate` (float): Drop path rate.
- `bn_tf` (bool): Use Tensorflow BatchNorm defaults for models that support it.
- `kwargs` (optional): Optional kwargs passed onto `timm.create_model()`

`ResNetBackbone` is a `ImageClassificationBackbone` class that is resposible to converting ResNet based models into a appropriate backbone for Image Classification tasks. 

Note that each resnet block at 1 stem and 4 convolutional blocks in the model. You can freeze some or all of these blocks by setting `freeze_at`. If `0` then the whole model is traininable. `1` freezes only the stem, `2` freezes the stem and a block and so on. We also train the ResNet model using discriminative Lr's for finetuning.
So the 3 and 4 blocks are trained with a learning rate of `lr` and the stem, 1 block, 2 block are trained with learning rates `lr`/`lr_div`. Weight decay `wd` is applied to the whole model.

In [None]:
show_doc(ResNetBackbone.prepare_model)

<h4 id="ResNetBackbone.prepare_model" class="doc_header"><code>ResNetBackbone.prepare_model</code><a href="__main__.py#L82" class="source_link" style="float:right">[source]</a></h4>

> <code>ResNetBackbone.prepare_model</code>(**`m`**:`Module`)

Freeze the first several stages of the `ResNet`. Commonly used in fine-tuning.

<IPython.core.display.Javascript object>

In [None]:
show_doc(ResNetBackbone.freeze_block)

<h4 id="ResNetBackbone.freeze_block" class="doc_header"><code>ResNetBackbone.freeze_block</code><a href="__main__.py#L74" class="source_link" style="float:right">[source]</a></h4>

> <code>ResNetBackbone.freeze_block</code>(**`m`**:`Module`)

Make this block `m` not trainable.

<IPython.core.display.Javascript object>

### Dataclass

This class can be instantiated from a config as follows - 

In [None]:
# export
@dataclass
class ResNetBackboneDataClass:
    """
    Base config file for `ResNetBackbone`
    """

    model_name: str = MISSING
    act: Optional[str] = None
    lr: Any = 1e-03
    lr_div: Any = 10
    wd: Any = 0.0
    freeze_at: int = 0
    pretrained: bool = True
    drop_block_rate: Optional[float] = None
    drop_path_rate: Optional[float] = None
    bn_tf: bool = False

<IPython.core.display.Javascript object>

In [None]:
# create config from OmegaConf using `ResNetBackboneConfig` dataclass
conf = OmegaConf.structured(ResNetBackboneDataClass(model_name="resnet34"))
# instantiate cls from config
m = ResNetBackbone.from_config_dict(conf, input_shape=input_shape)

Loading pretrained weights from url (https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth)


<IPython.core.display.Javascript object>

## Export -

In [None]:
# hide
notebook2script("04_classification.models.backbones.ipynb")

Converted 04_classification.models.backbones.ipynb.


<IPython.core.display.Javascript object>