In [None]:
#default_exp classification.modelling.body

In [None]:
# hide
import warnings
warnings.filterwarnings("ignore")

In [None]:
# hide
from nbdev.showdoc import *
from nbdev.export import *
from nbdev.imports import Config as NbdevConfig

nbdev_path = str(NbdevConfig().path("nbs_path")/'data')
nbdev_path

'/Users/ayushman/Desktop/lightning_cv/nbs/data'

# Model Body for Image Classification
> Convenince functions to prepare a Model for Vision applications

In [None]:
# export
from typing import *
import importlib

import timm
import torch
from torch import nn

import re
from omegaconf import DictConfig
from fastcore.all import use_kwargs_dict

from torchvision import models

from lightning_cv.core.layers import *
from lightning_cv.core.common import Registry
from lightning_cv.core.layers import ActivationCatalog

In [None]:
# hide
from omegaconf import OmegaConf
from fastcore.all import *
from lightning_cv.core.layers import Mish

## Cut a pretrained model

In [None]:
# export
def _is_pool_type(l): 
    return re.search(r'Pool[123]d$', l.__class__.__name__)

In [None]:
#hide
m = nn.Sequential(nn.AdaptiveAvgPool2d(5), nn.Linear(2,3), nn.Conv2d(2,3,1), nn.MaxPool3d(5))
test_eq([bool(_is_pool_type(m_)) for m_ in m.children()], [True,False,False,True])

By default, the LightningCV library cuts a pretrained model at the pooling layer (Similar to the Fastai Library). This function helps detecting it.

In [None]:
# export
def has_pool_type(m):
    "Return `True` if `m` is a pooling layer or has one in its children"
    if _is_pool_type(m): return True
    for l in m.children():
        if has_pool_type(l): return True
    return False

In [None]:
m = nn.Sequential(nn.AdaptiveAvgPool2d(5), nn.Linear(2,3), nn.Conv2d(2,3,1), nn.MaxPool3d(5))
assert has_pool_type(m)
test_eq([has_pool_type(m_) for m_ in m.children()], [True,False,False,True])

In [None]:
#export
def create_body(model: nn.Module, cut: Optional[Union[int, Callable]] = None):
    "Cut off the body of a `model` as determined by `cut`"
    if cut is None:
        ll = list(enumerate(model.children()))
        cut = next(i for i,o in reversed(ll) if has_pool_type(o))
    if isinstance(cut, int):      
        return nn.Sequential(*list(model.children())[:cut])
    elif callable(cut): 
        return cut(model)
    else:
        raise NamedError("cut must be either integer or a function")

In [None]:
tst = nn.Sequential(nn.Conv2d(3,5,3), nn.BatchNorm2d(5), nn.AvgPool2d(1), nn.Linear(3,4))
m = create_body(tst)
test_eq(len(m), 2)

m = create_body(tst, cut=3)
test_eq(len(m), 3)

In [None]:
#hide
tst = timm.create_model("resnet18", pretrained=False, num_classes=0, global_pool='')
m = create_body(tst)
test_eq(len(m), 8)

m = create_body(tst, cut=-2)
test_eq(len(m), 8)

In [None]:
# export
class TimmCnnBody(nn.Module):
    "default `nn.Module` to create a body for vision applications from `timm`"
    
    @use_kwargs_dict(keep=True, pretrained=False, num_classes=0, global_pool="")
    def __init__(self, model_name: str, cut=None, act_layer: str=None, **kwargs):
        super(TimmCnnBody, self).__init__()
        # for different activation funtions
        # if act_layer is None then the default activations func will be used
        if act_layer is not None:
            act_layer = ActivationCatalog.get(act_layer)
        
        net = timm.create_model(model_name, act_layer=act_layer, **kwargs)
        
        # prepare body
        self.net = create_body(net, cut)
        
    def forward(self, xb):
        return self.net(xb)
    
    @classmethod
    def from_config(cls, config: DictConfig):
        "create from a `Omegaconf/ Hydra` config"
        return cls(**config)

In [None]:
show_doc(TimmCnnBody)

<h2 id="TimmCnnBody" class="doc_header"><code>class</code> <code>TimmCnnBody</code><a href="" class="source_link" style="float:right">[source]</a></h2>

> <code>TimmCnnBody</code>(**`model_name`**:`str`, **`cut`**=*`None`*, **`act_layer`**:`str`=*`None`*, **`pretrained`**=*`False`*, **`num_classes`**=*`0`*, **`global_pool`**=*`''`*, **\*\*`kwargs`**) :: `Module`

default `nn.Module` to create a body for vision applications from `timm`

In [None]:
m1  = timm.create_model("resnet18", pretrained=True, act_layer=None)
m2  = timm.create_model("resnet18", pretrained=True, act_layer=None, global_pool='', num_classes=0)
tst = TimmCnnBody(model_name="resnet18", cut=-2, act_layer=None, pretrained=True)


with torch.no_grad():
    i  = torch.randn(2, 3, 299, 299)
    o1 = m1.forward_features(i)
    o2 = m2(i)
    o3 = tst(i)

test_eq(o1, o3)
test_eq(o2, o3)

> Note: You can use the `act_layer` argument to change the activation layer of the `CnnBody`. `act_layer` is a string which corresponds to an `obj` in the `ActivationCatalog`. If you are using an activation func that is not in the `ActivationCatalog` be sure to register the `obj`. Also timm requires that the activation func should have a `inplace` argument.

In [None]:
m1  = timm.create_model("resnet18", pretrained=True, act_layer=Mish)
m2  = timm.create_model("resnet18", pretrained=True, act_layer=Mish, global_pool='', num_classes=0)
tst = TimmCnnBody(model_name="resnet18", cut=-2, act_layer="Mish", pretrained=True)


with torch.no_grad():
    i  = torch.randn(2, 3, 299, 299)
    o1 = m1.forward_features(i)
    o2 = m2(i)
    o3 = tst(i)

test_eq(o1, o3)
test_eq(o2, o3)

In [None]:
m1  = timm.create_model("resnet18", pretrained=True, act_layer=Mish, in_chans=1)
m2  = timm.create_model("resnet18", pretrained=True, act_layer=Mish, global_pool='', num_classes=0, in_chans=1)
tst = TimmCnnBody(model_name="resnet18", cut=-2, act_layer="Mish", pretrained=True, in_chans=1)


with torch.no_grad():
    i  = torch.randn(2, 1, 299, 299)
    o1 = m1.forward_features(i)
    o2 = m2(i)
    o3 = tst(i)

test_eq(o1, o3)
test_eq(o2, o3)

In [None]:
# export
def _get_first_layer(m):
    "Access first layer of a model"
    c,p,n = m,None,None  # child, parent, name
    for n in next(m.named_parameters())[0].split('.')[:-1]:
        p,c=c,getattr(c,n)
    return c,p,n

In [None]:
#export
def _update_first_layer(model, n_in):
    "Change first layer based on number of input channels"
    if n_in == 3: return
    first_layer, parent, name = _get_first_layer(model)
    assert isinstance(first_layer, nn.Conv2d), f'Change of input channels only supported with Conv2d, found {first_layer.__class__.__name__}'
    assert getattr(first_layer, 'in_channels') == 3, f'Unexpected number of input channels, found {getattr(first_layer, "in_channels")} while expecting 3'
    params = {attr:getattr(first_layer, attr) for attr in 'out_channels kernel_size stride padding dilation groups padding_mode'.split()}
    params['bias'] = getattr(first_layer, 'bias') is not None
    params['in_channels'] = n_in
    new_layer = nn.Conv2d(**params)
    setattr(parent, name, new_layer)

In [None]:
# export
class TorchvisionBody(nn.Module):
    "default `nn.Module` to create a body for vision applications from `torchvision.models`"
    
    def __init__(self, model_name: str, in_chans: int = 3, pretrained: bool = True, cut=None):
        super(TorchvisionBody, self).__init__()
        
        module = importlib.import_module(f'torchvision.models')
        model  = getattr(module, model_name)(pretrained=pretrained)
        
        _update_first_layer(model, n_in=in_chans)
        self.net = create_body(model, cut)
        
    def forward(self, xb):
        return self.net(xb)
    
    @classmethod
    def from_config(cls, config: DictConfig):
        "create from a `Omegaconf/ Hydra` config"
        return cls(**config)

In [None]:
# hide
m1 = TorchvisionBody("resnet18", cut=-2)
m2 = TorchvisionBody("resnet18", cut=None, in_chans=3)

with torch.no_grad():
    i  = torch.randn(2, 3, 299, 299)
    o1 = m1(i)
    o2 = m2(i)

test_eq(o1, o2)

## ModelBody Registery

In [None]:
# export
ModelBody = Registry("CNN_Body")
ModelBody.register(TimmCnnBody)
ModelBody.register(TorchvisionBody)

In [None]:
# hide-input
ModelBody

Registry of CNN_Body:
╒═════════════════╤════════════════════════════════════╕
│ Names           │ Objects                            │
╞═════════════════╪════════════════════════════════════╡
│ TimmCnnBody     │ <class '__main__.TimmCnnBody'>     │
├─────────────────┼────────────────────────────────────┤
│ TorchvisionBody │ <class '__main__.TorchvisionBody'> │
╘═════════════════╧════════════════════════════════════╛

In [None]:
# export
def create_cnn_body(cfg: DictConfig) -> nn.Module:
    "instante an obj from ModelBody registery using lightning_cv config"
    body = ModelBody.get(cfg.MODEL.BODY.NAME)
    body = body.from_config(cfg.MODEL.BODY.ARGUMENTS)
    return body

In [None]:
from lightning_cv.config import get_cfg

cfg = get_cfg(strict=False)
print(OmegaConf.to_yaml(cfg.MODEL.BODY))

NAME: TimmCnnBody
ARGUMENTS:
  model_name: resnet18
  cut: -2
  act_layer: null
  pretrained: true



In [None]:
tst = create_cnn_body(cfg)
m1  = timm.create_model("resnet18", pretrained=True, act_layer=None)
m2  = timm.create_model("resnet18", pretrained=True, act_layer=None, global_pool='', num_classes=0)


with torch.no_grad():
    i  = torch.randn(2, 3, 299, 299)
    o1 = m1.forward_features(i)
    o2 = m2(i)
    o3 = tst(i)

test_eq(o1, o3)
test_eq(o2, o3)

In [None]:
# for a different activation
cfg.MODEL.BODY.ARGUMENTS.act_layer = "Mish"

tst = create_cnn_body(cfg)
m1  = timm.create_model("resnet18", pretrained=True, act_layer=Mish)
m2  = timm.create_model("resnet18", pretrained=True, act_layer=Mish, global_pool='', num_classes=0)


with torch.no_grad():
    i  = torch.randn(2, 3, 299, 299)
    o1 = m1.forward_features(i)
    o2 = m2(i)
    o3 = tst(i)

test_eq(o1, o3)
test_eq(o2, o3)

In [None]:
# for a torchvision models

# First we need to update the config
cfg.MODEL.BODY.NAME = "TorchvisionBody"

arguments = dict(model_name="resnet18", pretrained=True, in_chans=3, cut=None)
OmegaConf.update(cfg.MODEL.BODY, key="ARGUMENTS", value=arguments)

tst = create_cnn_body(cfg)
m1  = TorchvisionBody("resnet18", pretrained=True, in_chans=3, cut=None)

with torch.no_grad():
    i  = torch.randn(2, 3, 299, 299)
    o1 = m1(i)
    o2 = tst(i)

test_eq(o1, o2)

> Note: For `create_cnn_body` to work your `obj` must be registerd in the `ModelBody` registery and the `obj` must have a `from_config` `classmethod`.

In [None]:
#hide
notebook2script()

Converted 00_config.ipynb.
Converted 00a_core.common.ipynb.
Converted 00b_core.data_utils.ipynb.
Converted 00c_core.optim.ipynb.
Converted 00d_core.schedules.ipynb.
Converted 00e_core.layers.ipynb.
Converted 01a_classification.data.transforms.ipynb.
Converted 01b_classification.data.datasets.ipynb.
Converted 01c_classification.modelling.body.ipynb.
Converted index.ipynb.
