In [None]:
# default_exp core

In [None]:
#export
from wong.imports import *

In [None]:
#hide
from nbdev.showdoc import *

# Core
> All the basic functions and classes

## Convolutional Operations

In [None]:
#export
from enum import Enum

In [None]:
#export
class OprtType(Enum):
    "Operator types.`Nothing` means not any operator there."
    Nothing = 0
    Conv2d  = 1
    ReLU = 2
    BatchNorm2d = 3

In [None]:
#export
def conv_unit(ni:int, no:int, seq:tuple, ks:int=3, stride:int=1, groups:int=1, zero_bn:bool=False, act_inplace=True):
    """
    The basic convolutional operation, which is combination of operators such as conv, bn, relu, etc.
    """
    unit = []
    has_conv = False # if has conv operator
    for e in seq:
        if e == OprtType.Nothing:  # None operator
            continue
        elif e == OprtType.Conv2d:  # conv operator
            has_conv = True
            unit += [nn.Conv2d(ni, no, ks, stride=stride, padding=ks//2, groups=groups, bias=False)]
        elif e == OprtType.ReLU:  # relu operator
            unit += [nn.ReLU(inplace=act_inplace)]  # in folded resnet, inplace has to be false
        elif e == OprtType.BatchNorm2d:  # bn operator
            if has_conv: # if has conv operator
                bn = nn.BatchNorm2d(no)  # bn operator's `ni` equal to 'no' of conv op
                nn.init.constant_(bn.weight, 0. if zero_bn else 1.) # zero bn only after conv
                unit += [bn]
            else:  # if has not conv operator
                unit += [nn.BatchNorm2d(ni)] # bn operator's `ni` equal to 'ni' of conv op
    return nn.Sequential(*unit)


Parameters:
- `ni` : number of input channels
- `no` : number of output channels
- `seq` : sequence of operators, a tuple of `OprtType` operator types
- `ks` : kernel size of conv operator
- `stride` : stride size of conv operator
- `groups` : number of groups of conv operator
- `zero_bn` : does initialize zero value for weight of batch norm operator
- `act_inplace` : does do the activations in-place.

Return:
- a nn.Sequential object



In [None]:
#export
"several customized conv units"
relu_conv_bn = partial(conv_unit, seq = (OprtType.ReLU, OprtType.Conv2d, OprtType.BatchNorm2d))  # Relu-->Conv-->BN
conv_bn_relu = partial(conv_unit, seq = (OprtType.Conv2d, OprtType.BatchNorm2d, OprtType.ReLU))  # Conv-->BN-->Relu
bn_relu_conv = partial(conv_unit, seq = (OprtType.BatchNorm2d, OprtType.ReLU, OprtType.Conv2d))  # BN-->Relu-->Conv
relu_conv = partial(conv_unit, seq = (OprtType.ReLU, OprtType.Conv2d, OprtType.Nothing))  # Relu-->Conv
conv_bn = partial(conv_unit, seq = (OprtType.Conv2d, OprtType.BatchNorm2d, OprtType.Nothing))  # Conv-->BN

In [None]:
#export
def resnet_basicblock(ni, no, nh, stride:int=1):
    """
    Basic Unit in Residual Networks
    
    Reference:
    Deep Residual Learning for Image Recognition:
    https://arxiv.org/abs/1512.03385
    """
    return nn.Sequential(*[*relu_conv_bn(ni, nh, stride=stride).children()], 
                         *[*relu_conv_bn(nh, no).children()])

def resnet_bottleneck(ni, no, nh, stride:int=1, groups:int=1, zero_bn=True):
    """
    Bottleneck Unit in Residual Networks
    
    Reference:
    Deep Residual Learning for Image Recognition:
    https://arxiv.org/abs/1512.03385
    """
    return nn.Sequential(*[*relu_conv_bn(ni, nh, ks=1).children()],
                         *[*relu_conv_bn(nh, nh, stride=stride, groups=groups).children()],
                         *[*relu_conv_bn(nh, no, ks=1, zero_bn=zero_bn).children()])

In [None]:
#export
def xception(ni:int, no:int, nh:int, ks:int=3, stride:int=1, zero_bn:bool=False):
    """
    Basic unit in xception networks.
    
    Reference:
    Xception: Deep Learning with Depthwise Separable Convolutions:
    https://arxiv.org/abs/1610.02357
    """
    return nn.Sequential(*[*relu_conv(ni, nh, ks=ks, stride=stride, groups=ni).children()],
                        *[*conv_bn(nh, no, ks=1, zero_bn=zero_bn).children()]
                        )

## Stemming Stage

In [None]:
#export
def resnet_stem(ni:int=3, no:int=64, deep_stem:bool=False):
    if deep_stem:
        stem = nn.Sequential(*[*conv_bn_relu(ni, no, stride=2).children()],  #downsample
                             *[*conv_bn_relu(no, no, stride=1).children()],
                             *[*conv_bn_relu(no, no, stride=1).children()],
                     )
    else:
        stem = nn.Sequential(nn.Conv2d(ni, no, kernel_size=7, stride=2, padding=3, bias=False),
                             nn.BatchNorm2d(no),
                             nn.ReLU(inplace=True))
    stem = nn.Sequential(*[*stem.children()],
                         nn.MaxPool2d(kernel_size=3, stride=2, padding=1))    
    return stem

In [None]:
resnet_stem(), resnet_stem(deep_stem=True)

(Sequential(
   (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
   (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (2): ReLU(inplace=True)
   (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
 ), Sequential(
   (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
   (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (2): ReLU()
   (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
   (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (5): ReLU()
   (6): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
   (7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (8): ReLU()
   (9): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
 ))

## Identity mapping


In [None]:
#export
class IdentityMapping2(nn.Module):
    """
    Identity Mapping between input and output, four cases:
    1.  stride == 1 and ni == no
        input == output
    2.  stride == 1 and ni != no
        1x1 conv and bn
    3.  stride == 2 and ni == no
        maxpool or avgpool
    4.  stride == 2 and ni != no
        (maxpool or avgpool) and 1x1 conv and bn
        
    """
    def __init__(self, ni:int, no:int, stride:int=1, pooling_type:str='max'):
        super(IdentityMapping, self).__init__()
        assert stride == 1 or stride == 2
        assert pooling_type == 'max' or pooling_type == 'avg'
        unit = []
        if stride == 2:
            if pooling_type == 'max':
                downsample = nn.MaxPool2d(kernel_size=3, stride=stride, padding=1)
            elif pooling_type == 'avg':
                downsample = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1)
            unit.append(downsample)
        if ni != no:
            unit += conv_bn(ni, no, ks=1).children()  #, zero_bn=False
        self.unit = nn.Sequential(*unit)
    def forward(self, x):
        out = self.unit(x)
        return out

In [None]:
#export
class IdentityMapping(nn.Module):
    """ Identity mapping of ResNet.        
    """
    def __init__(self, ni:int, no:int, stride:int=1):
        super(IdentityMapping, self).__init__()
        assert stride == 1 or stride == 2
        unit = []
        if not (ni == no and stride == 1):
            unit += conv_bn(ni, no, ks=1, stride=stride).children()  #, zero_bn=False
        self.unit = nn.Sequential(*unit)
    def forward(self, x):
        out = self.unit(x)
        return out

In [None]:
IdentityMapping(16, 32, stride=2), IdentityMapping(16, 32, stride=1), \
IdentityMapping(16, 16, stride=1), IdentityMapping(16, 16, stride=2)

(IdentityMapping(
   (unit): Sequential(
     (0): Conv2d(16, 32, kernel_size=(1, 1), stride=(2, 2), bias=False)
     (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   )
 ), IdentityMapping(
   (unit): Sequential(
     (0): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
     (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   )
 ), IdentityMapping(
   (unit): Sequential()
 ), IdentityMapping(
   (unit): Sequential(
     (0): Conv2d(16, 16, kernel_size=(1, 1), stride=(2, 2), bias=False)
     (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   )
 ))

## Classifier

In [None]:
#export
class Classifier(nn.Module):
    """
    Usually work as the final operator for image processing (classification, object detection, etc.)
    
    Including:
    an average pooling op, which downsampling image resolution to 1x1.
    a linear op, which perform classification.
    """
    def __init__(self, ni, no):
        super(Classifier, self).__init__()
        self.adaptivepool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(ni, no)
        
    def forward(self, x):
        out = self.adaptivepool(x)  # out tensor (N, ni, 1, 1)
        out = out.view(out.size(0), -1)  # out tensor (N, ni)
        out = self.fc(out)  # out tensor (N, no)
        return out

## Help functions

In [None]:
#export
def init_cnn(m):
    "copy from https://github.com/fastai/fastai/blob/master/fastai/vision/models/xresnet.py"
    if getattr(m, 'bias', None) is not None: nn.init.constant_(m.bias, 0)
    if isinstance(m, (nn.Conv2d,nn.Linear)): nn.init.kaiming_normal_(m.weight)
    for l in m.children(): init_cnn(l)


In [None]:
#export
def num_params(net:nn.Module):
    "Number of parameters of a neural network"
    num_params = 0
    for name, param in net.named_parameters():
        num = torch.prod(torch.tensor(param.size()))
        num_params += num
        # print(name, param.size(), num)
    return num_params