In [None]:
#default_exp classification.modelling.body

In [None]:
# hide
import warnings
warnings.filterwarnings("ignore")

In [None]:
# hide
from nbdev.showdoc import *
from nbdev.export import *
from nbdev.imports import Config as NbdevConfig

nbdev_path = str(NbdevConfig().path("nbs_path")/'data')
nbdev_path

'/Users/ayushman/Desktop/lightning_cv/nbs/data'

# Model Body for Image Classification
> Convenince functions to prepare a Model for Vision applications

In [None]:
# export
from typing import *

import timm
import torch
from torch import nn

import re
from omegaconf import DictConfig
from fastcore.all import use_kwargs_dict

from lightning_cv.core.layers import *
from lightning_cv.core.common import Registry
from lightning_cv.core.layers import ActivationCatalog

In [None]:
# hide
from omegaconf import OmegaConf
from fastcore.all import *
from lightning_cv.core.layers import Mish

## Cut a pretrained model

In [None]:
# export
def _is_pool_type(l): 
    return re.search(r'Pool[123]d$', l.__class__.__name__)

In [None]:
#hide
m = nn.Sequential(nn.AdaptiveAvgPool2d(5), nn.Linear(2,3), nn.Conv2d(2,3,1), nn.MaxPool3d(5))
test_eq([bool(_is_pool_type(m_)) for m_ in m.children()], [True,False,False,True])

By default, the LightningCV library cuts a pretrained model at the pooling layer (Similar to the Fastai Library). This function helps detecting it.

In [None]:
# export
def has_pool_type(m):
    "Return `True` if `m` is a pooling layer or has one in its children"
    if _is_pool_type(m): return True
    for l in m.children():
        if has_pool_type(l): return True
    return False

In [None]:
m = nn.Sequential(nn.AdaptiveAvgPool2d(5), nn.Linear(2,3), nn.Conv2d(2,3,1), nn.MaxPool3d(5))
assert has_pool_type(m)
test_eq([has_pool_type(m_) for m_ in m.children()], [True,False,False,True])

In [None]:
#export
def create_body(model: nn.Module, cut: Optional[Union[int, Callable]] = None):
    "Cut off the body of a `model` as determined by `cut`"
    if cut is None:
        ll = list(enumerate(model.children()))
        cut = next(i for i,o in reversed(ll) if has_pool_type(o))
    if isinstance(cut, int):      
        return nn.Sequential(*list(model.children())[:cut])
    elif callable(cut): 
        return cut(model)
    else:
        raise NamedError("cut must be either integer or a function")

In [None]:
tst = nn.Sequential(nn.Conv2d(3,5,3), nn.BatchNorm2d(5), nn.AvgPool2d(1), nn.Linear(3,4))
m = create_body(tst)
test_eq(len(m), 2)

m = create_body(tst, cut=3)
test_eq(len(m), 3)

In [None]:
#hide
tst = timm.create_model("resnet18", pretrained=False, num_classes=0, global_pool='')
m = create_body(tst)
test_eq(len(m), 8)

m = create_body(tst, cut=-2)
test_eq(len(m), 8)

In [None]:
# export
class CnnBody(nn.Module):
    "default `nn.Module` to create a body for vision applications from `timm`"
    
    @use_kwargs_dict(keep=True, pretrained=False, num_classes=0, global_pool="")
    def __init__(self, model_name: str, cut=None, act_layer: str=None, **kwargs):
        super(CnnBody, self).__init__()
        # for different activation funtions
        # if act_layer is None then the default activations func will be used
        if act_layer is not None:
            act_layer = ActivationCatalog.get(act_layer)
        
        net = timm.create_model(model_name, act_layer=act_layer, **kwargs)
        self._cfg = net.default_cfg
        
        # prepare body
        self.net = create_body(net, cut)
        
    def forward(self, xb):
        return self.net(xb)
    
    @classmethod
    def from_config(cls, config: DictConfig):
        "create from a `Omegaconf/ Hydra` config"
        return cls(**config)
    
    @property
    def default_cfg(self):
        # this default_cfg is usefull incase you want to use TestTimePool from timm
        return self._cfg
    
    @default_cfg.setter
    def default_cfg(self, x: Dict):
        self._cfg = x

In [None]:
show_doc(CnnBody)

<h2 id="CnnBody" class="doc_header"><code>class</code> <code>CnnBody</code><a href="" class="source_link" style="float:right">[source]</a></h2>

> <code>CnnBody</code>(**`model_name`**:`str`, **`cut`**=*`None`*, **`act_layer`**:`str`=*`None`*, **`pretrained`**=*`False`*, **`num_classes`**=*`0`*, **`global_pool`**=*`''`*, **\*\*`kwargs`**) :: `Module`

default `nn.Module` to create a body for vision applications from `timm`

In [None]:
m1  = timm.create_model("resnet18", pretrained=True, act_layer=None)
m2  = timm.create_model("resnet18", pretrained=True, act_layer=None, global_pool='', num_classes=0)
tst = CnnBody(model_name="resnet18", cut=-2, act_layer=None, pretrained=True)


with torch.no_grad():
    i  = torch.randn(2, 3, 299, 299)
    o1 = m1.forward_features(i)
    o2 = m2(i)
    o3 = tst(i)

test_eq(o1, o3)
test_eq(o2, o3)
test_eq(m1.default_cfg, tst.default_cfg)
test_eq(m2.default_cfg, tst.default_cfg)

> Note: You can use the `act_layer` argument to change the activation layer of the `CnnBody`. `act_layer` is a string which corresponds to an `obj` in the `ActivationCatalog`. If you are using an activation func that is not in the `ActivationCatalog` be sure to register the `obj`. Also timm requires that the activation func should have a `inplace` argument.

In [None]:
m1  = timm.create_model("resnet18", pretrained=True, act_layer=Mish)
m2  = timm.create_model("resnet18", pretrained=True, act_layer=Mish, global_pool='', num_classes=0)
tst = CnnBody(model_name="resnet18", cut=-2, act_layer="Mish", pretrained=True)


with torch.no_grad():
    i  = torch.randn(2, 3, 299, 299)
    o1 = m1.forward_features(i)
    o2 = m2(i)
    o3 = tst(i)

test_eq(o1, o3)
test_eq(o2, o3)
test_eq(m1.default_cfg, tst.default_cfg)
test_eq(m2.default_cfg, tst.default_cfg)

In [None]:
# hide
tst

CnnBody(
  (net): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): Mish()
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): Mish()
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act2): Mish()
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    

## ModelBody Registery

In [None]:
# export
ModelBody = Registry("CNN_Body")
ModelBody.register(CnnBody)

In [None]:
# hide-input
ModelBody

Registry of CNN_Body:
╒═════════╤════════════════════════════╕
│ Names   │ Objects                    │
╞═════════╪════════════════════════════╡
│ CnnBody │ <class '__main__.CnnBody'> │
╘═════════╧════════════════════════════╛

## Helpers 

In [None]:
# export
def create_cnn_body(cfg: DictConfig) -> nn.Module:
    "instante an obj from ModelBody registery using lightning_cv config"
    body = ModelBody.get(cfg.MODEL.BODY.NAME)
    body = body.from_config(cfg.MODEL.BODY.ARGUMENTS)
    return body

In [None]:
from lightning_cv.config import get_cfg

cfg = get_cfg()
print(OmegaConf.to_yaml(cfg.MODEL.BODY))

NAME: CnnBody
ARGUMENTS:
  model_name: resnet18
  cut: -2
  act_layer: null
  pretrained: true



In [None]:
tst = create_cnn_body(cfg)
m1  = timm.create_model("resnet18", pretrained=True, act_layer=None)
m2  = timm.create_model("resnet18", pretrained=True, act_layer=None, global_pool='', num_classes=0)


with torch.no_grad():
    i  = torch.randn(2, 3, 299, 299)
    o1 = m1.forward_features(i)
    o2 = m2(i)
    o3 = tst(i)

test_eq(o1, o3)
test_eq(o2, o3)
test_eq(m1.default_cfg, tst.default_cfg)
test_eq(m2.default_cfg, tst.default_cfg)

In [None]:
# for a different activation
cfg.MODEL.BODY.ARGUMENTS.act_layer = "Mish"
tst = create_cnn_body(cfg)
m1  = timm.create_model("resnet18", pretrained=True, act_layer=Mish)
m2  = timm.create_model("resnet18", pretrained=True, act_layer=Mish, global_pool='', num_classes=0)


with torch.no_grad():
    i  = torch.randn(2, 3, 299, 299)
    o1 = m1.forward_features(i)
    o2 = m2(i)
    o3 = tst(i)

test_eq(o1, o3)
test_eq(o2, o3)
test_eq(m1.default_cfg, tst.default_cfg)
test_eq(m2.default_cfg, tst.default_cfg)

> Note: For `create_cnn_body` to work your `obj` must be registerd in the `ModelBody` registery and the `obj` must have a `from_config` `classmethod`.

In [None]:
#hide
notebook2script()

Converted 00_config.ipynb.
Converted 00a_core.common.ipynb.
Converted 00b_core.data_utils.ipynb.
Converted 00c_core.optim.ipynb.
Converted 00d_core.schedules.ipynb.
Converted 00e_core.layers.ipynb.
Converted 01a_classification.data.transforms.ipynb.
Converted 01b_classification.data.datasets.ipynb.
Converted 01c_classification.modelling.body.ipynb.
Converted index.ipynb.
