In [1]:
import sys
sys.path.append('zero')

In [2]:
#export
from imports import *
from core import *

### Residual Network

In [3]:
#export
class ResStage(nn.Module):
    """
    Stage in a residual network, usually the units in a residual network are divided into
    stages according to feature (image) resolution.
    
    Parameters:
    -----------
    ni : number of input channels of the stage, 本stage的入channel数
    no : number of output channels of the stage, 本stage的出channel数
    nh : number of hidden channels of basic units in the stage, 内部channel数
    nu : number of basic units in the stage, unit数
    stride : stride size of conv op in First unit
    Unit : class of the basic unit, Unit class has calling format:
        Unit(ni:int, no:int, nh:int, stride:int=1, **kwargs)
    
    """
    def __init__(self, ni:int, no:int, nh:int, nu:int, stride:int, Unit:nn.Module, **kwargs):
        super(ResStage, self).__init__()
        # the first unit, stride size determine if downsample or not
        self.unit0 = Unit(ni, no, nh, stride=stride, **kwargs) 
        self.idmapping0 = IdentityMapping(ni, no, stride=stride) 
        units = []
        for i in range(nu - 1):
            units += [Unit(no, no, nh, stride=1, **kwargs)] #resnet_bottleneck
        self.units = nn.ModuleList(units)
            
    def forward(self, x):
        x = self.unit0(x) + self.idmapping0(x)
        for i, unit in enumerate(self.units):
            x = unit(x) + x
        return x
        

In [4]:
#export
def resnet_stem(ni:int=3, nf:int=64):
    """Stem stage in resnet
    """
    return nn.Sequential(*[*conv_bn_relu(ni, 32, stride=2).children()],  #downsample
                         *[*conv_bn_relu(32, 32, stride=1).children()],
                         *[*conv_bn_relu(32, nf, stride=1).children()],
                         nn.MaxPool2d(kernel_size=3, stride=2, padding=1)  #downsample
                        )

In [5]:
#export
class ResNet(nn.Sequential):
    """
    Residual Network
    
    Parameters:
    -----------
    nhs : number of hidden channels for all stages
    nos : number of output channels for all stages
    nus : number of units of all stages
    strides : stride sizes of all stages
    Stage : class of the stages, Stage class has calling format:
        Stage(ni:int, no:int, nh:int, nu:int, stride:int, Unit:nn.Module, **kwargs)
    Unit : class of the basic units
    c_in : number of input channels of the whole network
    c_out : number of output channels (features) of the whole network
    kwargs : additional args to Unit class
    """
    def __init__(self, nhs, nos, nus, strides, Stage:nn.Module, Unit:nn.Module,
                 c_in:int=3, c_out:int=1000, **kwargs):
        super(ResNet, self).__init__()
        stem = resnet_stem(c_in, nhs[0])
        stages = []
        ni = nhs[0]
        for i in range(len(nhs)):
            stages += [Stage(ni, nos[i], nhs[i], nus[i], strides[i], Unit, **kwargs)]
            ni = nos[i]
        classifier = Classifier(nos[-1], c_out)
        super().__init__(
            stem,
            *stages,
            classifier
        )
        init_cnn(self)
        
        
def resnet50(c_in:int=3, c_out:int=1000):
    return ResNet(nhs = [64, 128, 256, 512], nos = [256, 512, 1024, 2048],
                  nus = [3,4,6,3], strides = [1,2,2,2], Stage = ResStage,
                  Unit = resnet_bottleneck,
                  c_in=c_in, c_out=c_out)

def resnet101(c_in:int=3, c_out:int=1000):
    return ResNet(nhs = [64, 128, 256, 512], nos = [256, 512, 1024, 2048],
                  nus = [3,4,23,3], strides = [1,2,2,2], Stage = ResStage, 
                  Unit = resnet_bottleneck,
                  c_in=c_in, c_out=c_out)

In [6]:
model = resnet50()

In [7]:
model

ResNet(
  (0): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
    (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace)
    (6): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace)
    (9): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  )
  (1): ResStage(
    (unit0): Sequential(
      (0): ReLU(inplace)
      (1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): ReLU(inplace)
      (4): Conv2d(64, 64, kernel_

### our folded stage classes

In [None]:
#export
class DualStage(nn.Module):
    """
    Stage in a residual network, usually the units in a residual network are divided into
    stages according to feature (image) resolution.
    
    Parameters:
    -----------
    ni : number of input channels of the stage, 本stage的入channel数
    no : number of output channels of the stage, 本stage的出channel数
    nh : number of hidden channels of basic units in the stage, 内部channel数
    nu : number of basic units in the stage, unit数
    stride : stride size of conv op in First unit
    Unit : class of the basic unit, Unit class has calling format:
        Unit(ni:int, no:int, nh:int, stride:int=1, **kwargs)
    
    """
    def __init__(self, ni:int, no:int, nh:int, nu:int, stride:int, Unit:nn.Module, **kwargs):
        super(ResStage, self).__init__()
        # the first unit, stride size determine if downsample or not
        self.unit0 = Unit(ni, no, nh, stride=stride, **kwargs) 
        self.idmapping0 = IdentityMapping(ni, no, stride=stride) 
        units = []
        for i in range(nu - 1):
            units += [Unit(no, no, nh, stride=1, **kwargs)] #resnet_bottleneck
        self.units = nn.ModuleList(units)
            
    def forward(self, x):
        x = self.unit0(x) + self.idmapping0(x)
        for i, unit in enumerate(self.units):
            x = unit(x) + x
        return x
        