In [None]:
#default_exp core.layers

In [None]:
# hide
import warnings
warnings.filterwarnings("ignore")

In [None]:
# hide
from nbdev.showdoc import *
from nbdev.export import *
from nbdev.imports import Config as NbdevConfig

nbdev_path = str(NbdevConfig().path("nbs_path")/'data')
nbdev_path

'/Users/ayushman/Desktop/lightning_cv/nbs/data'

# Layers
> Custom layers and basic functions to grab them

In [None]:
# export
from enum import Enum
from fastcore.all import delegates
from functools import partial

import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import Module
from torch.jit import script

from lightning_cv.core.utils.common import Registry

In [None]:
# hide
from fastcore.all import *

## Basic manipulations and resize

In [None]:
# export
class Identity(Module):
    "Do nothing at all"
    
    def __init__(self):
        super(Identity, self).__init__()
        
    def forward(self, x):
        return x

In [None]:
test_eq(Identity()(1), 1)

## Pooling layers

In [None]:
# export
class AdaptiveConcatPool2d(Module):
    """
    Layer that concats `AdaptiveAvgPool2d` and `AdaptiveMaxPool2d`.  
    From : https://github.com/fastai/fastai/blob/master/fastai/layers.py
    """
    def __init__(self, size=None):
        super(AdaptiveConcatPool2d, self).__init__()
        self.size = size or 1
        self.ap = nn.AdaptiveAvgPool2d(self.size)
        self.mp = nn.AdaptiveMaxPool2d(self.size)
    
    def forward(self, x): 
        return torch.cat([self.mp(x), self.ap(x)], 1)

In [None]:
tst = AdaptiveConcatPool2d()
x = torch.randn(10,5,4,4)
test_eq(tst(x).shape, [10,10,1,1])
max1 = torch.max(x,    dim=2, keepdim=True)[0]
maxp = torch.max(max1, dim=3, keepdim=True)[0]
test_eq(tst(x)[:,:5], maxp)
test_eq(tst(x)[:,5:], x.mean(dim=[2,3], keepdim=True))
tst = AdaptiveConcatPool2d(2)
test_eq(tst(x).shape, [10,10,2,2])

## BatchNorm layers

In [None]:
# export
NormType = Enum('NormType', 'Batch BatchZero Weight Spectral Instance InstanceZero')

In [None]:
# export
def _get_norm(prefix, nf, ndim=2, zero=False, **kwargs):
    "Norm layer with `nf` features and `ndim` initialized depending on `norm_type`."
    assert 1 <= ndim <= 3
    bn = getattr(nn, f"{prefix}{ndim}d")(nf, **kwargs)
    if bn.affine:
        bn.bias.data.fill_(1e-3)
        bn.weight.data.fill_(0. if zero else 1.)
    return bn

In [None]:
#export
@delegates(nn.BatchNorm2d)
def BatchNorm(nf, ndim=2, norm_type=NormType.Batch, **kwargs):
    """
    BatchNorm layer with `nf` features and `ndim` initialized depending on `norm_type`.  
    From : https://github.com/fastai/fastai/blob/master/fastai/layers.py
    """
    return _get_norm('BatchNorm', nf, ndim, zero=norm_type==NormType.BatchZero, **kwargs)

In [None]:
with torch.no_grad():
    tst = BatchNorm(15)
    assert isinstance(tst, nn.BatchNorm2d)
    test_eq(tst.weight, torch.ones(15))
    tst = BatchNorm(15, norm_type=NormType.BatchZero)
    test_eq(tst.weight, torch.zeros(15))
    tst = BatchNorm(15, ndim=1)
    assert isinstance(tst, nn.BatchNorm1d)
    tst = BatchNorm(15, ndim=3)
    assert isinstance(tst, nn.BatchNorm3d)

In [None]:
# export
class LinBnDrop(nn.Sequential):
    """
    Module grouping `BatchNorm1d`, `Dropout` and `Linear` layers.
    From : https://github.com/fastai/fastai/blob/master/fastai/layers.py
    """
    
    def __init__(self, n_in, n_out, bn=True, p=0., act=None, lin_first=False):
        layers = [BatchNorm(n_out if lin_first else n_in, ndim=1)] if bn else []
        
        if p != 0: 
            layers.append(nn.Dropout(p))
        
        lin = [nn.Linear(n_in, n_out, bias=not bn)]
        
        if act is not None: 
            lin.append(act)
        
        layers = lin+layers if lin_first else layers+lin
        
        super().__init__(*layers)

In [None]:
with torch.no_grad():
    tst = LinBnDrop(10, 20)
    mods = list(tst.children())
    test_eq(len(mods), 2)
    assert isinstance(mods[0], nn.BatchNorm1d)
    assert isinstance(mods[1], nn.Linear)

    tst = LinBnDrop(10, 20, p=0.1)
    mods = list(tst.children())
    test_eq(len(mods), 3)
    assert isinstance(mods[0], nn.BatchNorm1d)
    assert isinstance(mods[1], nn.Dropout)
    assert isinstance(mods[2], nn.Linear)

    tst = LinBnDrop(10, 20, act=nn.ReLU(), lin_first=True)
    mods = list(tst.children())
    test_eq(len(mods), 3)
    assert isinstance(mods[0], nn.Linear)
    assert isinstance(mods[1], nn.ReLU)
    assert isinstance(mods[2], nn.BatchNorm1d)

    tst = LinBnDrop(10, 20, bn=False)
    mods = list(tst.children())
    test_eq(len(mods), 1)
    assert isinstance(mods[0], nn.Linear)

## Activations

In [None]:
#export
#hide
# Mish Activation Funtion
# Souce code : https://github.com/fastai/fastai/blob/master/fastai/layers.py
@script
def _mish_jit_fwd(x): 
    return x.mul(torch.tanh(F.softplus(x)))

@script
def _mish_jit_bwd(x, grad_output):
    x_sigmoid = torch.sigmoid(x)
    x_tanh_sp = F.softplus(x).tanh()
    return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))

class MishJitAutoFn(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return _mish_jit_fwd(x)

    @staticmethod
    def backward(ctx, grad_output):
        x = ctx.saved_variables[0]
        return _mish_jit_bwd(x, grad_output)

def mish(x): 
    return MishJitAutoFn.apply(x)

In [None]:
#export
class Mish(Module):
    "Mish Activation function"
    def __init__(self, inplace=True):
        # NOTE: inplace does nothing it is for compatibility with `timm`
        super(Mish, self).__init__()
        
    def forward(self, x): 
        return MishJitAutoFn.apply(x)

## Registery of Common Activation Functions -



To add a activation function to the registery simply do :


```python
SomeActivation() # activation function
ACTIVATION_REGISTERY.register(SomeActivation)
# accesss it via
act_func = ACTIVATION_REGISTERY.get("SomeActivation")
```

In [None]:
# export
ACTIVATION_REGISTERY = Registry("ACTIVATIONS")
ACTIVATION_REGISTERY.__doc__ = "Registery of Activation Functions"
ACTIVATION_REGISTERY.register(Mish)
ACTIVATION_REGISTERY.register(torch.nn.LeakyReLU)
ACTIVATION_REGISTERY.register(torch.nn.ReLU)
ACTIVATION_REGISTERY.register(torch.nn.GELU)
ACTIVATION_REGISTERY.register(torch.nn.Sigmoid)
ACTIVATION_REGISTERY.register(torch.nn.SiLU)
ACTIVATION_REGISTERY.register(torch.nn.Tanh)
ACTIVATION_REGISTERY.register(torch.nn.LogSoftmax)
ACTIVATION_REGISTERY.register(torch.nn.Softmax)

In [None]:
#hide-input
print(ACTIVATION_REGISTERY)

Registry of ACTIVATIONS:
╒════════════╤══════════════════════════════════════════════════╕
│ Names      │ Objects                                          │
╞════════════╪══════════════════════════════════════════════════╡
│ Mish       │ <class '__main__.Mish'>                          │
├────────────┼──────────────────────────────────────────────────┤
│ LeakyReLU  │ <class 'torch.nn.modules.activation.LeakyReLU'>  │
├────────────┼──────────────────────────────────────────────────┤
│ ReLU       │ <class 'torch.nn.modules.activation.ReLU'>       │
├────────────┼──────────────────────────────────────────────────┤
│ GELU       │ <class 'torch.nn.modules.activation.GELU'>       │
├────────────┼──────────────────────────────────────────────────┤
│ Sigmoid    │ <class 'torch.nn.modules.activation.Sigmoid'>    │
├────────────┼──────────────────────────────────────────────────┤
│ SiLU       │ <class 'torch.nn.modules.activation.SiLU'>       │
├────────────┼─────────────────────────────────────

## Features

In [None]:
#export
@torch.no_grad()
def num_features_model(m, ch_int: int = 3):
    "Return the number of output features for `m`."
    sz = 32
    while True:
        try:
            x = torch.zeros((8, ch_int, sz, sz))
            dummy_out = m.eval()(x)
            return dummy_out.shape[1]
        except Exception as e:
            sz *= 2
            if sz > 2048: raise e

In [None]:
m = nn.Sequential(nn.Conv2d(3,5,3), nn.Conv2d(5,11,3))
test_eq(num_features_model(m, ch_int=3), 11)

## Model Init

In [None]:
# export
# hide
# Note : Functions are taken directly from : https://github.com/fastai/fastai/blob/master/fastai/torch_core.py

In [None]:
# export
def requires_grad(m):
    "Check if the first parameter of `m` requires grad or not"
    ps = list(m.parameters())
    return ps[0].requires_grad if len(ps) > 0 else False

In [None]:
tst = nn.Linear(4,5)
assert requires_grad(tst)
for p in tst.parameters(): p.requires_grad_(False)
assert not requires_grad(tst)

In [None]:
#export
def init_default(m, func=nn.init.kaiming_normal_):
    "Initialize `m` weights with `func` and set `bias` to 0."
    if func:
        if hasattr(m, 'weight'): func(m.weight)
        if hasattr(m, 'bias') and hasattr(m.bias, 'data'): m.bias.data.fill_(0.)
    return m

In [None]:
with torch.no_grad():
    tst = nn.Linear(4,5)
    tst.weight.data.uniform_(-1,1)
    tst.bias.data.uniform_(-1,1)
    tst = init_default(tst, func = lambda x: x.data.fill_(1.))
    test_eq(tst.weight, torch.ones(5,4))
    test_eq(tst.bias, torch.zeros(5))

In [None]:
#export
norm_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d, nn.LayerNorm,)

def cond_init(m, func):
    "Apply `init_default` to `m` unless it's a batchnorm module"
    if (not isinstance(m, norm_types)) and requires_grad(m): init_default(m, func)

In [None]:
with torch.no_grad():
    tst = nn.Linear(4,5)
    tst.weight.data.uniform_(-1,1)
    tst.bias.data.uniform_(-1,1)
    cond_init(tst, func = lambda x: x.data.fill_(1.))
    test_eq(tst.weight, torch.ones(5,4))
    test_eq(tst.bias, torch.zeros(5))

    tst = nn.BatchNorm2d(5)
    init = [tst.weight.clone(), tst.bias.clone()]
    cond_init(tst, func = lambda x: x.data.fill_(1.))
    test_eq(tst.weight, init[0])
    test_eq(tst.bias, init[1])

In [None]:
#export
def apply_leaf(m, f):
    "Apply `f` to children of `m`."
    c = m.children()
    if isinstance(m, nn.Module): f(m)
    for l in c: apply_leaf(l,f)

In [None]:
with torch.no_grad():
    tst = nn.Sequential(nn.Linear(4,5), nn.Sequential(nn.Linear(4,5), nn.Linear(4,5)))
    apply_leaf(tst, partial(init_default, func=lambda x: x.data.fill_(1.)))
    for l in [tst[0], *tst[1]]: test_eq(l.weight, torch.ones(5,4))
    for l in [tst[0], *tst[1]]: test_eq(l.bias,   torch.zeros(5))

In [None]:
#export
def apply_init(m, func=nn.init.kaiming_normal_):
    "Initialize all non-batchnorm layers of `m` with `func`."
    apply_leaf(m, partial(cond_init, func=func))

In [None]:
with torch.no_grad():
    tst = nn.Sequential(nn.Linear(4,5), nn.Sequential(nn.Linear(4,5), nn.BatchNorm1d(5)))
    init = [tst[1][1].weight.clone(), tst[1][1].bias.clone()]
    apply_init(tst, func=lambda x: x.data.fill_(1.))
    for l in [tst[0], tst[1][0]]: test_eq(l.weight, torch.ones(5,4))
    for l in [tst[0], tst[1][0]]: test_eq(l.bias,   torch.zeros(5))
    test_eq(tst[1][1].weight, init[0])
    test_eq(tst[1][1].bias,   init[1])

In [None]:
# export
bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)

def set_bn_eval(m: Module):
    "Set bn layers in eval mode for all recursive children of `m`."
    for l in m.children():
        if isinstance(l, bn_types) and not next(l.parameters()).requires_grad:
            l.eval()
            for param in l.parameters(): 
                param.requires_grad = False
        set_bn_eval(l)

## Model Paramters

In [None]:
#export
def trainable_params(m):
    "Return all trainable parameters of `m`"
    return [p for p in m.parameters() if p.requires_grad]

In [None]:
#export
def params(m):
    "Return all parameters of `m`"
    return [p for p in m.parameters()]

In [None]:
with torch.no_grad():
    m = nn.Linear(4,5)
    test_eq(trainable_params(m), [m.weight, m.bias])
    m.weight.requires_grad_(False)
    test_eq(trainable_params(m), [m.bias])

In [None]:
#hide
notebook2script()

Converted 00_config.ipynb.
Converted 00a_core.common.ipynb.
Converted 00b_core.data_utils.ipynb.
Converted 00c_core.optim.ipynb.
Converted 00d_core.schedules.ipynb.
Converted 00e_core.layers.ipynb.
Converted 01a_classification.data.transforms.ipynb.
Converted 01b_classification.data.datasets.ipynb.
Converted 01c_classification.modelling.body.ipynb.
Converted index.ipynb.
