In [None]:
# default_exp foldnet

In [None]:
#export
from wong.imports import *
from wong.core import *
from wong.config import cfg, assert_cfg


In [None]:
from fastcore.all import *  # test_eq

# FoldNet
> a folded ResNet

Aggregate to enough units for folded net:
1. Unit : unit operator
2. ni : number of input channels for `Unit`
3. fold : folding length
4. stride : across stage or not
5. **kwargs : arguments to `Unit`



In [None]:
# #export
# class FoldBlock(nn.Module):
#     "Basic block of folded ResNet"
#     def __init__(self, Unit:nn.Module, ni:int, fold:int, stride:int=1, **kwargs):
#         super(FoldBlock, self).__init__()
#         self.ni, self.fold, self.stride = ni, fold, stride
#         units = []
#         for i in range(max(1,fold-1)):
#             units += [Unit(ni, stride=1, **kwargs)]
#         self.units = nn.ModuleList(units)
        
#     def forward(self, *xs):
#         xs = list(xs)
#         if self.fold==1:
#             xs[0] = xs[0] + self.units[0](xs[0])
#             return xs
#         for i in range(self.fold-1):
#             xs[i+1] = xs[i] + self.units[i](xs[i+1])
#         xs.reverse()
#         return xs

In [None]:
#export
class FoldBlock(nn.Module):
    "Basic block of folded ResNet"
    def __init__(self, Unit:nn.Module, ni:int, fold:int, stride:int=1, **kwargs):
        super(FoldBlock, self).__init__()
        self.ni, self.fold, self.stride = ni, fold, stride
        units = []
        for i in range(max(1,fold-1)):
            units += [Unit(ni, stride=1, **kwargs)]
        self.units = nn.ModuleList(units)
        
    def forward(self, *xs):
        xs = list(xs)
        if self.fold==1:
            xs[0] = xs[0] + self.units[0](xs[0])
            return xs
        for i in range(self.fold-1):
            xs[i+1] = xs[i+1] + self.units[i](xs[i])
        xs.reverse()
        return xs

In [None]:
m = FoldBlock(mbconv, 16, 4, nh=32)

xs = [torch.randn(2,16,32,32)] * 4

xs2 = m(*xs)

len(xs2), [x.shape for x in xs2]

(4,
 [torch.Size([2, 16, 32, 32]),
  torch.Size([2, 16, 32, 32]),
  torch.Size([2, 16, 32, 32]),
  torch.Size([2, 16, 32, 32])])

In [None]:
#export
class ExpandBlock(nn.Module):
    "Expand block of folded ResNet"
    def __init__(self, Unit:nn.Module, ni:int, fold1:int, fold2:int, stride:int=1, **kwargs):
        super(ExpandBlock, self).__init__()
        self.ni, self.fold1, self.fold2, self.stride = ni, fold1, fold2, stride
        units = []
        for i in range(fold2 - fold1):
            units += [Unit(ni, stride=1, **kwargs)]
        self.units = nn.ModuleList(units)
        if stride == 2:
            self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
    def forward(self, *xs):
        xs = list(xs)
        if self.stride == 2:
            for i in range(len(xs)):
                xs[i] = self.pool(xs[i])
        if self.fold2 <= self.fold1:
            return xs[:self.fold2]
        xs.reverse()
        for i in range(self.fold2 - self.fold1):
            xs.append(self.units[i](xs[-1]) + xs[-1])
        xs.reverse()
        return xs

In [None]:
m = ExpandBlock(mbconv, 16, fold1=4, fold2=3, stride=2, nh=32)

In [None]:
xs2 = m(*xs)

In [None]:
len(xs2), [x.shape for x in xs2]

(3,
 [torch.Size([2, 16, 16, 16]),
  torch.Size([2, 16, 16, 16]),
  torch.Size([2, 16, 16, 16])])

In [None]:
#export
class FoldNet(nn.Module):
    "A folded resnet, using Expand."
    def __init__(self, Stem, Unit, folds:tuple, ni:int, num_nodes:tuple,
                 bottle_scale:int=1, first_downsample:bool=False, tail_all:bool=True,
                 c_in:int=3, c_out:int=10, **kwargs):
        super(FoldNet, self).__init__()
        num_stages = len(num_nodes)
        nh = int(ni * bottle_scale)
        strides = [1 if i==0 and not first_downsample else 2 for i in range(num_stages)]
        folds = [1] + folds #[fold*exp**i for i in range(num_stages)]
        
        self.stem = Stem(c_in, no=ni) # , deep_stem
        
        units = []
        for i, (nu, stride) in enumerate(zip(num_nodes, strides)):
            for j in range(nu):
                if j == 0: # the first node(layer) of each stage
                    units += [ExpandBlock(Unit, ni, fold1 = folds[i], fold2=folds[i+1], stride=stride, nh=nh, **kwargs)]
                else:
                    units += [FoldBlock(Unit, ni, fold=folds[i+1], stride=1, nh=nh, **kwargs)]
                    
        self.units = nn.ModuleList(units)
        
        if tail_all:
            self.classifier = Classifier(ni*folds[-1], c_out) #
        else:
            self.classifier = Classifier(ni, c_out)
        self.folds = folds
        self.num_nodes = num_nodes
        self.tail_all = tail_all
        init_cnn(self)
        
    def forward(self, x):
        x = self.stem(x)
        xs = [x] #self.init(x)
        for unit in self.units:
            xs = unit(*xs)
        if self.tail_all:
            x = torch.cat(xs,1)
        else:
            x = xs[0]

        x = self.classifier(x)
        return x
        

In [None]:
num_nodes = [8,9,9,9]
folds = [4,4,4,4]
model = FoldNet(Stem=conv_bn, Unit=mbconv, folds=folds, ni=64, num_nodes=num_nodes, bottle_scale=1, tail_all=True, ks=3, c_out=100)
num_params(model)

tensor(906148)

In [None]:
x = torch.randn(2,3,64,64)

In [None]:
with torch.autograd.set_detect_anomaly(True):
    out = model(x)
    out.mean().backward()

## Calculate number of params in FoldNet

In [None]:
#export
def num_units(folds, nodes):
    "calculate the number of all units in the backbone of FoldNet."
    num_units = (folds[0] - 1) * nodes[0]
    for i in range(len(folds)-1):
#         print(num_units)
        num_units += (folds[i+1]-1)*(nodes[i+1]-1) + max(0, folds[i+1]-folds[i])
    return(num_units)

Suppose the folding length of stage $i$ is $d_i$, then the number of units per FoldNet block is $d_i-1$.Suppose the number of blocks per stage is $b_i$, then the number of units of stage $i$ equal to $(d_i-1) * (b_i-1)$ for all the stages except the first stage, since a `ExpandBlock` with none unit start at each stage except the first stage.

Suppose $n$ stages exist, then the number of all units in the backbone of FoldNet is:
\begin{equation}
(d_0-1) * b_0 + \sum_1^{n-1} (d_i-1) * (b_i-1)
\end{equation}

In [None]:
#export
def cal_num_params(Stem, Unit, folds, nodes, ni, bottle_scale, tail_all, c_out):
    "calcuate the number of all params of FoldNet, according to hyper-params."
    m0 = Stem(3, no=ni)
    m1 = Unit(ni, nh=ni*bottle_scale)
    m2 = Classifier(ni*folds[-1] if tail_all else ni, c_out)
    return num_params(m0) + num_params(m1) * num_units(folds, nodes) + num_params(m2)

In [None]:
#folds = [fold]*4
cal_num_params(Stem=conv_bn, Unit=mbconv, folds=folds, nodes=num_nodes, ni=64, bottle_scale=1, tail_all=True, c_out=100)

tensor(906148)

## Misc

In [None]:
#hide
class InitBlock(nn.Module):
    "Init block of folded ResNet"
    def __init__(self, Unit:nn.Module, ni:int, fold:int, stride:int=1, **kwargs):
        super(InitBlock, self).__init__()
        self.ni, self.fold = ni, fold
        units = []
        for i in range(fold-1):
            units += [Unit(ni, stride=stride, **kwargs)]
        self.units = nn.ModuleList(units)
        
    def forward(self, x):
        xs = [x]
        for i in range(self.fold-1):
            xs += [xs[i] + self.units[i](xs[i])]
        xs.reverse()
        return xs

In [None]:
#hide
m = InitBlock(mbconv, 16, 4, stride=1, nh=32)
x = torch.randn(2,16,32,32)
xs = m(x)
len(xs), xs[0].shape

(4, torch.Size([2, 16, 32, 32]))

In [None]:
#hide
# try inner fold (fold before BN), may fail
class FoldBlock2(nn.Module):
    "Basic block of folded ResNet"
    def __init__(self, Unit:nn.Module, ni:int, fold:int, stride:int=1, **kwargs):
        super(FoldBlock2, self).__init__()
        self.ni, self.fold, self.stride = ni, fold, stride
        units = []
        aggregates = []
        for i in range(fold-1):
            units += [Unit(ni, stride=1, **kwargs)]
            aggregates += [conv_bn(ni, ks=1, zero_bn=False)]
        self.units = nn.ModuleList(units)
        self.aggregates = nn.ModuleList(aggregates)
        
    def forward(self, *xs):
        xs = list(xs)
        for i in range(self.fold-1):
            xs[i+1] = xs[i+1] + self.units[i](xs[i])
        for i in range(self.fold-1):
            xs[i+1] = self.aggregates[i](xs[i+1])
        xs.reverse()
        return xs

In [None]:
#hide
# try particular transition block, may fail
class TransitionBlock(nn.Module):
    "Transition block of folded ResNet"
    def __init__(self, Unit:nn.Module, ni:int, no:int, fold:int, stride:int=1, **kwargs):
        super(TransitionBlock, self).__init__()
        self.ni, self.no, self.fold, self.stride = ni, no, fold, stride
        units = []
        idmappings = []
        for i in range(fold-1):
            if i==0:
                units += [Unit(ni, no=no, stride=stride, **kwargs)]
            else:
                units += [Unit(ni=no, no=no, stride=1, **kwargs)]
            idmappings += [IdentityMappingMaxPool(ni, no=no, stride=stride)]
        self.units = nn.ModuleList(units)
        self.idmappings = nn.ModuleList(idmappings)
        self.idmapping0 = IdentityMappingMaxPool(ni, no=no, stride=stride)
        
    def forward(self, *xs):
        xs = list(xs)
        for i in range(self.fold-1):
            xs[i+1] = self.idmappings[i](xs[i+1]) + self.units[i](xs[i])
        xs[0] = self.idmapping0(xs[0])
        xs.reverse()
        return xs

In [None]:
#hide
m = TransitionBlock(mbconv, ni=16, no=32, fold=4, stride=2, nh=32)

[x.shape for x in xs]

xs2 = m(*xs)

len(xs2), [x.shape for x in xs2]

isinstance(m, TransitionBlock), isinstance(m, ExpandBlock)

(True, False)

In [None]:
#hide
# try particular transition block, may fail
class ResNetXTransition(nn.Module):
    "A folded resnet using Transition."
    def __init__(self, Stem, Unit, fold:int, nis:tuple, num_nodes:tuple,
                 bottle_scale:int=1, first_downsample:bool=False, tail_all:bool=True,
                 c_in:int=3, c_out:int=10, **kwargs):
        super(ResNetXTransition, self).__init__()
        num_stages = len(num_nodes)
        nhs = [int(ni * bottle_scale) for ni in nis]
        strides = [1 if i==0 and not first_downsample else 2 for i in range(num_stages)]
        
        self.stem = Stem(c_in, no=nis[0]) # , deep_stem
        self.expand = ExpandBlock(Unit, ni=nis[0], fold1=1, fold2=fold, stride=strides[0], nh=nhs[0], **kwargs)
        
        units = []
        for i, (nu, stride) in enumerate(zip(num_nodes, strides)):
            if i != 0:
                units += [TransitionBlock(Unit, ni=nis[i-1], no=nis[i], fold=fold, stride=stride, nh=nhs[i-1], **kwargs)]
            for j in range(nu):
#                 if j == 0: # the first node(layer) of each stage
#                     units += [ExpandBlock(Unit, ni, fold1 = folds[i], fold2=folds[i+1], stride=stride, nh=nh, **kwargs)]
#                 else:
                units += [FoldBlock(Unit, ni=nis[i], fold=fold, stride=1, nh=nhs[i], **kwargs)]
                    
        self.units = nn.ModuleList(units)
        
        if tail_all:
            self.classifier = Classifier(nis[-1]*fold, c_out) #
        else:
            self.classifier = Classifier(nis[-1], c_out)
        self.fold = fold
        self.num_nodes = num_nodes
        self.tail_all = tail_all
        init_cnn(self)
        
    def forward(self, x):
        x = self.stem(x)
        xs = self.expand(*[x])
        for unit in self.units:
            xs = unit(*xs)
        if self.tail_all:
            x = torch.cat(xs,1)
        else:
            x = xs[0]

        x = self.classifier(x)
        return x
        

In [None]:
#hide
num_nodes = [4,4,4,4]
fold = 4
nis = [32,64,96,128]
bottle_scale = 6
model = ResNetXTransition(Stem=conv_bn, Unit=mbconv, fold=fold, nis=nis, num_nodes=num_nodes,
                          bottle_scale=bottle_scale, tail_all=False, ks=3, c_out=100)
num_params(model)

tensor(5620548)