In [None]:
# default_exp classification.modelling.backbones

In [None]:
# hide
%load_ext nb_black
%load_ext autoreload
%autoreload 2

<IPython.core.display.Javascript object>

In [None]:
# hide
import warnings

from nbdev.export import *
from nbdev.showdoc import *

warnings.filterwarnings("ignore")

<IPython.core.display.Javascript object>

# Backbones 
> Backbones/feature extractors for use in Image Classification Tasks

In [None]:
# export
# @TODO: Add support for VisionTransformer Backbone

<IPython.core.display.Javascript object>

In [None]:
# export
import abc
import logging
import re
from typing import *

import torch
from fastcore.all import store_attr, use_kwargs_dict
from timm import create_model
from timm.models import ResNet
from torch import nn

from gale.core.classes import GaleModule
from gale.core.nn import ACTIVATION_REGISTRY
from gale.core.nn.shape_spec import ShapeSpec
from gale.core.nn.utils import set_bn_eval, trainable_params
from gale.core.utils.structures import IMAGE_CLASSIFIER_BACKBONES

_logger = logging.getLogger(__name__)

<IPython.core.display.Javascript object>

In [None]:
# export
_all_ = ["IMAGE_CLASSIFIER_BACKBONES"]

<IPython.core.display.Javascript object>

In [None]:
import copy
from dataclasses import dataclass, field

from fastcore.test import *
from omegaconf import MISSING, DictConfig, OmegaConf

<IPython.core.display.Javascript object>

## Utility functions -

In [None]:
# export
# funtions taken from: https://github.com/fastai/fastai/blob/master/fastai/vision/learner.py#L76
def _is_pool_type(l: nn.Module):
    return re.search(r"Pool[123]d$", l.__class__.__name__)


def has_pool_type(m: nn.Module):
    "Return `True` if `m` is a pooling layer or has one in its children"
    if _is_pool_type(m):
        return True
    for l in m.children():
        if has_pool_type(l):
            return True
    return False

<IPython.core.display.Javascript object>

In [None]:
# export
def prepare_backbone(model: nn.Module, cut=None):
    "Cut off the body of a typically pretrained `model` as determined by `cut`"
    if cut is None:
        ll = list(enumerate(model.children()))
        cut = next(i for i, o in reversed(ll) if has_pool_type(o))
    if isinstance(cut, int):
        return nn.Sequential(*list(model.children())[:cut])
    elif callable(cut):
        return cut(model)
    else:
        raise NamedError("cut must be either integer or a function")

<IPython.core.display.Javascript object>

In [None]:
# fmt: off
tst = nn.Sequential(nn.Conv2d(3, 5, 3), nn.BatchNorm2d(5), nn.AvgPool2d(1), nn.Linear(3, 4))
# fmt: on

m = prepare_backbone(tst)
test_eq(len(m), 2)

m = prepare_backbone(tst, cut=3)
test_eq(len(m), 3)

m = prepare_backbone(tst, cut=-1)
test_eq(len(m), 3)

<IPython.core.display.Javascript object>

In [None]:
# export
def filter_weight_decay(
    model: nn.Module,
    lr: float,
    weight_decay: float = 1e-5,
    skip_list=(),
) -> List[Dict]:
    """
    Filter out bias, bn and other 1d params from weight decay.
    Modified from: [timm](https://github.com/rwightman/pytorch-image-models/blob/e8a64fb88108b592da192e98054095b1ee25e96e/timm/optim/optim_factory.py)
    """
    decay = []
    no_decay = []
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue  # frozen weights
        if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
            no_decay.append(param)
        else:
            decay.append(param)
    return [
        {"params": no_decay, "weight_decay": 0.0, "lr": lr},
        {"params": decay, "weight_decay": weight_decay, "lr": lr},
    ]

<IPython.core.display.Javascript object>

In [None]:
# export
class ImageClassificationBackbone(GaleModule, metaclass=abc.ABCMeta):
    """
    Abstract class for ImageClassification BackBones
    """

    def __init__(self):
        """
        The `__init__` method of any subclass can specify its own set of arguments.
        """
        super().__init__()

    def get_lrs(self) -> List:
        """
        Returns a List containining the Lrs' for
        each parameter group. This is required to build schedulers
        like `torch.optim.lr_scheduler.OneCycleScheduler` which needs
        the max lrs' for all the Param Groups.
        """
        lrs = []

        for p in self.build_param_dicts():
            lrs.append(p["lr"])
        return lrs

    @abc.abstractmethod
    def output_shape(self) -> ShapeSpec:
        """
        Returns the output shape. For most backbones
        this means it will contain the channels in the
        output layer.
        """
        pass

<IPython.core.display.Javascript object>

In [None]:
show_doc(ImageClassificationBackbone)
show_doc(ImageClassificationBackbone.__init__)

<h2 id="ImageClassificationBackbone" class="doc_header"><code>class</code> <code>ImageClassificationBackbone</code><a href="" class="source_link" style="float:right">[source]</a></h2>

> <code>ImageClassificationBackbone</code>() :: [`GaleModule`](/gale/core.classes.html#GaleModule)

Abstract class for ImageClassification BackBones

<h4 id="ImageClassificationBackbone.__init__" class="doc_header"><code>ImageClassificationBackbone.__init__</code><a href="__main__.py#L7" class="source_link" style="float:right">[source]</a></h4>

> <code>ImageClassificationBackbone.__init__</code>()

The `__init__` method of any subclass can specify its own set of arguments.

<IPython.core.display.Javascript object>

In [None]:
show_doc(ImageClassificationBackbone.get_lrs)
show_doc(ImageClassificationBackbone.output_shape)

<h4 id="ImageClassificationBackbone.get_lrs" class="doc_header"><code>ImageClassificationBackbone.get_lrs</code><a href="__main__.py#L13" class="source_link" style="float:right">[source]</a></h4>

> <code>ImageClassificationBackbone.get_lrs</code>()

Returns a List containining the Lrs' for
each parameter group. This is required to build schedulers
like `torch.optim.lr_scheduler.OneCycleScheduler` which needs
the max lrs' for all the Param Groups.

<h4 id="ImageClassificationBackbone.output_shape" class="doc_header"><code>ImageClassificationBackbone.output_shape</code><a href="__main__.py#L26" class="source_link" style="float:right">[source]</a></h4>

> <code>ImageClassificationBackbone.output_shape</code>()

Returns the output shape. For most backbones
this means it will contain the channels in the
output layer.

<IPython.core.display.Javascript object>

In [None]:
# export
# @IMAGE_CLASSIFIER_BACKBONES.register()
class TimmBackboneBase(ImageClassificationBackbone):
    "Create a model from `timm` and converts it into a Image Classification Backbone"

    @use_kwargs_dict(
        keep=True,
        pretrained=True,
        drop_block_rate=None,
        drop_path_rate=None,
        bn_tf=False,
    )
    def __init__(
        self,
        model_name: str,
        input_shape: ShapeSpec,
        act: str = None,
        lr: float = 1e-03,
        wd: float = 0,
        freeze_bn: bool = False,
        freeze_at: int = False,
        filter_wd: bool = False,
        **kwargs,
    ):
        super(TimmBackboneBase, self).__init__()

        store_attr("lr, wd, filter_wd, input_shape")

        if act is not None:
            act = ACTIVATION_REGISTRY.get(act)

        # fmt: off
        model = create_model(model_name, act_layer=act, global_pool="", num_classes=0, 
                             in_chans=input_shape.channels, **kwargs)
        # fmt: on

        # save some of the input information from timm models
        self.num_features = model.num_features
        self.timm_model_cfg = model.default_cfg
        self.model = prepare_backbone(model)

        if not freeze_at:
            self.unfreeze()
        else:
            self.freeze_to(freeze_at)

        if freeze_bn:
            set_bn_eval(self.model)

    def forward(self, xb: torch.Tensor) -> torch.Tensor:
        return self.model(xb)

    def build_param_dicts(self):
        if self.filter_wd:
            ps = filter_weight_decay(self.model, lr=self.lr, weight_decay=self.wd)
        else:
            # fmt: off
            ps = [{"params": trainable_params(self.model),"lr": self.lr,"weight_decay": self.wd}]
            # fmt: on
        return ps

    def output_shape(self) -> ShapeSpec:
        return ShapeSpec(self.num_features, None, None)

<IPython.core.display.Javascript object>

In [None]:
show_doc(TimmBackboneBase)

<h2 id="TimmBackboneBase" class="doc_header"><code>class</code> <code>TimmBackboneBase</code><a href="" class="source_link" style="float:right">[source]</a></h2>

> <code>TimmBackboneBase</code>(**`model_name`**:`str`, **`input_shape`**:`ShapeSpec`, **`act`**:`str`=*`None`*, **`lr`**:`float`=*`0.001`*, **`wd`**:`float`=*`0`*, **`freeze_bn`**:`bool`=*`False`*, **`freeze_at`**:`int`=*`False`*, **`filter_wd`**:`bool`=*`False`*, **`pretrained`**=*`True`*, **`drop_block_rate`**=*`None`*, **`drop_path_rate`**=*`None`*, **`bn_tf`**=*`False`*, **\*\*`kwargs`**) :: [`ImageClassificationBackbone`](/gale/classification.modelling.backbones.html#ImageClassificationBackbone)

Create a model from `timm` and converts it into a Image Classification Backbone

<IPython.core.display.Javascript object>

Arguments to `TimmBackboneBase`:
- `input_shape` (ShapeSpec): Shape of the Inputs
- `model_name` (str): name of model to instantiate.
- `act` (str): name of the activation function to use. If None uses the default activations else the name must be in `ACTIVATION_REGISTRY`.
- `lr` (float): learning rate for the modules.
- `wd` (float): weight decay for the modules.
- `freeze_bn` (bool): freeze the batch normalization layers of the model.
- `freeze_at` (int): freeze the layers of the backbone upto `freeze_at`, false means train all.
- `filter_wd` (bool): Filter out bias, bn from weight_decay.
- `pretrained` (bool): load pretrained ImageNet-1k weights if true.
- `drop_block_rate` (float): Drop block rate
- `drop_path_rate` (float): Drop_path_rate
- `bn_tf` (bool): Use Tensorflow BatchNorm defaults for models that support it.
- `kwargs` (optional): Optional kwargs passed onto `timm.create_model()`

In [None]:
input_shape = ShapeSpec(channels=3, height=255, width=255)
bk = TimmBackboneBase(model_name="resnet18", pretrained=True, input_shape=input_shape)
m = create_model("resnet18")

i = torch.randn(2, 3, 224, 224)
o1 = bk(i)
test_eq(o1.shape, torch.Size([2, 512, 7, 7]))
test_eq(bk.output_shape().channels, m.num_features)

<IPython.core.display.Javascript object>

### Instantiation using config

The config for `TimmBackboneBaseConfig` is going to look like this. We need to convert the dataclass to the Omegaconf config file and then we can use `from_config_dict` method to instantiate our class.

In [None]:
@dataclass
class TimmBackboneBaseConfig:
    """
    Base config file for `TimmBackboneBase`
    """

    model_name: str = MISSING
    act: Optional[str] = None
    lr: Any = 1e-03
    wd: Any = 0.0
    freeze_bn: bool = False
    freeze_at: Any = False
    filter_wd: bool = False
    pretrained: bool = True
    drop_block_rate: Optional[float] = None
    drop_path_rate: Optional[float] = None
    bn_tf: bool = False


# create a config to instantiate the same backbone as above
conf = TimmBackboneBaseConfig(model_name="resnet18", pretrained=True)
conf = OmegaConf.structured(conf)

# we need to explicitely pass in the input_shape argument
m = TimmBackboneBase.from_config_dict(conf, input_shape=input_shape)

o2 = m(i)
test_eq(o2.shape, torch.Size([2, 512, 7, 7]))

test_eq(o1.data, o2.data)

<IPython.core.display.Javascript object>

In [None]:
# export
# fmt: off
# @IMAGE_CLASSIFIER_BACKBONES.register()
class ResNetBackbone(ImageClassificationBackbone):
    """
    A Backbone for ResNet based models from timm. Note: this class
    does supports all the models listed [here](https://github.com/rwightman/pytorch-image-models/blob/e8a64fb88108b592da192e98054095b1ee25e96e/timm/models/resnet.py)
    """

    @use_kwargs_dict(
        keep=True,
        pretrained=True,  
        drop_block_rate=None,  
        drop_path_rate=None,  
        bn_tf=False,  
    )
    
    def __init__(
        self,
        model_name: str,
        input_shape: ShapeSpec,
        act: str = None,
        lr: float = 1e-03,
        lr_div: float = 10,
        wd: float = 0,
        freeze_at: int = 0,
        **kwargs
    ):  
        super(ResNetBackbone, self).__init__()
        store_attr("freeze_at, wd, lr, lr_div, input_shape", self)
        
        if act is not None:
            act = ACTIVATION_REGISTRY.get(act)

        model = create_model(model_name, act_layer=act, global_pool="", num_classes=0, 
                             in_chans=input_shape.channels, **kwargs)
        
        assert isinstance(model, ResNet), "ResNetBackbone supports only ResNet models"
        # save some of the input information from timm models
        self.num_features = model.num_features
        self.timm_model_cfg = model.default_cfg
        
        # break up the model
        self.stem = nn.Sequential(model.conv1, model.bn1, model.act1, model.maxpool)
        self.stages = nn.Sequential(model.layer1, model.layer2, model.layer3, model.layer4)
        
        self.prepare_model()

        
    def forward(self, xb: torch.Tensor) -> torch.Tensor:
        out = self.stem(xb)
        return self.stages(out)
        
    def build_param_dicts(self) -> Any:
        p0 = {"params": trainable_params(self.stem), "lr": self.lr/self.lr_div, "weight_decay": self.wd}
        p1 = {"params": trainable_params(self.stages[0:2]), "lr": self.lr/self.lr_div, "weight_decay": self.wd}
        p2 = {"params": trainable_params(self.stages[2:]), "lr": self.lr, "weight_decay": self.wd}
        return [p0, p1, p2]
        
    
    def freeze_block(self, m: nn.Module):
        """
        Make this block `m` not trainable.
        This method sets all parameters to `requires_grad=False`,
        and convert all BatchNorm Layers in eval mode
        """
        for p in m.parameters():
            p.requires_grad = False
        set_bn_eval(m)
        
    def prepare_model(self):
        """
        Freeze the first several stages of the ResNet. Commonly used in fine-tuning.
        """
        if self.freeze_at >= 1:
            self.freeze_block(self.stem)
        for idx, stage in enumerate(self.stages, start=2):
            if self.freeze_at >= idx:
                for block in stage.children():
                    self.freeze_block(block)
    
    def output_shape(self) -> ShapeSpec:
        return ShapeSpec(self.num_features, None, None)
# fmt: on

<IPython.core.display.Javascript object>

In [None]:
show_doc(ResNetBackbone)
show_doc(ResNetBackbone.prepare_model)
show_doc(ResNetBackbone.freeze_block)

<h2 id="ResNetBackbone" class="doc_header"><code>class</code> <code>ResNetBackbone</code><a href="" class="source_link" style="float:right">[source]</a></h2>

> <code>ResNetBackbone</code>(**`model_name`**:`str`, **`input_shape`**:`ShapeSpec`, **`act`**:`str`=*`None`*, **`lr`**:`float`=*`0.001`*, **`lr_div`**:`float`=*`10`*, **`wd`**:`float`=*`0`*, **`freeze_at`**:`int`=*`0`*, **`pretrained`**=*`True`*, **`drop_block_rate`**=*`None`*, **`drop_path_rate`**=*`None`*, **`bn_tf`**=*`False`*, **\*\*`kwargs`**) :: [`ImageClassificationBackbone`](/gale/classification.modelling.backbones.html#ImageClassificationBackbone)

A Backbone for ResNet based models from timm. Note: this class
does supports all the models listed [here](https://github.com/rwightman/pytorch-image-models/blob/e8a64fb88108b592da192e98054095b1ee25e96e/timm/models/resnet.py)

<h4 id="ResNetBackbone.prepare_model" class="doc_header"><code>ResNetBackbone.prepare_model</code><a href="__main__.py#L71" class="source_link" style="float:right">[source]</a></h4>

> <code>ResNetBackbone.prepare_model</code>()

Freeze the first several stages of the ResNet. Commonly used in fine-tuning.

<h4 id="ResNetBackbone.freeze_block" class="doc_header"><code>ResNetBackbone.freeze_block</code><a href="__main__.py#L61" class="source_link" style="float:right">[source]</a></h4>

> <code>ResNetBackbone.freeze_block</code>(**`m`**:`Module`)

Make this block `m` not trainable.
This method sets all parameters to `requires_grad=False`,
and convert all BatchNorm Layers in eval mode

<IPython.core.display.Javascript object>

Arguments to `ResNetBackbone`:
- `input_shape` (ShapeSpec): Shape of the Inputs
- `model_name` (str): name of model to instantiate.
- `act` (str): name of the activation function to use. If None uses the default activations else the name must be in `ACTIVATION_REGISTRY`.
- `lr` (float): learning rate for the modules.
- `lr_div` (int, float): factor for discriminative lrs.   
- `wd` (float): weight decay for the modules.
- `freeze_at` (int): Freeze the first several stages of the ResNet. Commonly used in fine-tuning. `1` means freezing the stem. `2` means freezing the stem and one residual stage, etc.
- `pretrained` (bool): load pretrained ImageNet-1k weights if true.
- `drop_block_rate` (float): Drop block rate.
- `drop_path_rate` (float): Drop path rate.
- `bn_tf` (bool): Use Tensorflow BatchNorm defaults for models that support it.
- `kwargs` (optional): Optional kwargs passed onto `timm.create_model()`

### Instantiation using config

In [None]:
@dataclass
class ResNetBackboneConfig:
    """
    Base config file for `ResNetBackbone`
    """

    model_name: str = MISSING
    act: Optional[str] = None
    lr: Any = 1e-03
    lr_div: Any = 10
    wd: Any = 0.0
    freeze_at: int = 0
    pretrained: bool = True
    drop_block_rate: Optional[float] = None
    drop_path_rate: Optional[float] = None
    bn_tf: bool = False


# create config from OmegaConf using `ResNetBackboneConfig` dataclass
conf = OmegaConf.structured(ResNetBackboneConfig(model_name="resnet34"))
# instantiate cls from config
m = ResNetBackbone.from_config_dict(conf, input_shape=input_shape)

<IPython.core.display.Javascript object>

## Export -

In [None]:
# hide
notebook2script()

Converted 00_core.utils.logger.ipynb.
Converted 00a_core.utils.visualize.ipynb.
Converted 00b_core.utils.structures.ipynb.
Converted 01_core.nn.utils.ipynb.
Converted 01a_core.nn.losses.ipynb.
Converted 02_core.nn.optim.optimizers.ipynb.
Converted 02a_core.nn.optim.lr_schedulers.ipynb.
Converted 03_core.config.ipynb.
Converted 03a_core.classes.ipynb.
Converted 04_classification.modelling.backbones.ipynb.
Converted 04a_classification.modelling.heads.ipynb.
Converted 04b_classification.modelling.meta_arch.general.ipynb.
Converted 04b_classification.modelling.meta_arch.vit.ipynb.
Converted 05_collections.pandas.ipynb.
Converted 06a_collections.callbacks.notebook.ipynb.
Converted 06b_collections.callbacks.ema.ipynb.
Converted index.ipynb.


<IPython.core.display.Javascript object>