In [None]:
# default_exp resnetx2

In [None]:
#export
from wong.imports import *
from wong.core import *
from wong.config import cfg, assert_cfg


In [None]:
from fastcore.all import *  # test_eq

# ResNetX2
> a folded ResNet

Aggregate to enough units for folded net:
1. Unit : unit operator
2. ni : number of input channels for `Unit`
3. fold : folding length
4. stride : across stage or not
5. **kwargs : arguments to `Unit`



In [None]:
#export
class FoldBlock(nn.Module):
    "Basic block of folded ResNet"
    def __init__(self, Unit:nn.Module, ni:int, fold:int, stride:int=1, **kwargs):
        super(FoldBlock, self).__init__()
        self.ni, self.fold, self.stride = ni, fold, stride
        units = []
        for i in range(fold-1):
            units += [Unit(ni, stride=1, **kwargs)]
        self.units = nn.ModuleList(units)
        
    def forward(self, *xs):
        xs = list(xs)
        for i in range(self.fold-1):
            xs[i+1] = xs[i+1] + self.units[i](xs[i])
        xs.reverse()
        return xs

In [None]:
m = FoldBlock(mbconv, 16, 4, nh=32)

In [None]:
xs = [torch.randn(2,16,32,32)] * 4

In [None]:
xs2 = m(*xs)

In [None]:
len(xs2), [x.shape for x in xs2]

(4,
 [torch.Size([2, 16, 32, 32]),
  torch.Size([2, 16, 32, 32]),
  torch.Size([2, 16, 32, 32]),
  torch.Size([2, 16, 32, 32])])

In [None]:
#export
class ExpandBlock(nn.Module):
    "Expand block of folded ResNet"
    def __init__(self, Unit:nn.Module, ni:int, fold1:int, fold2:int, stride:int=1, **kwargs):
        super(ExpandBlock, self).__init__()
        self.ni, self.fold1, self.fold2, self.stride = ni, fold1, fold2, stride
        units = []
        for i in range(fold2 - fold1):
            units += [Unit(ni, stride=1, **kwargs)]
        self.units = nn.ModuleList(units)
        if stride == 2:
            self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
    def forward(self, *xs):
        xs = list(xs)
        if self.stride == 2:
            for i in range(len(xs)):
                xs[i] = self.pool(xs[i])
        if self.fold2 <= self.fold1:
            return xs[:self.fold2]
        xs.reverse()
        for i in range(self.fold2 - self.fold1):
            xs.append(self.units[i](xs[-1]) + xs[-1])
        xs.reverse()
        return xs

In [None]:
m = ExpandBlock(mbconv, 16, fold1=4, fold2=3, stride=2, nh=32)

In [None]:
xs2 = m(*xs)

In [None]:
len(xs2), [x.shape for x in xs2]

(3,
 [torch.Size([2, 16, 16, 16]),
  torch.Size([2, 16, 16, 16]),
  torch.Size([2, 16, 16, 16])])

In [None]:
#export
class ResNetX2(nn.Module):
    "A folded resnet."
    def __init__(self, Stem, Unit, folds:tuple, ni:int, num_nodes:tuple,
                 bottle_scale:int=1, first_downsample:bool=False, tail_all:bool=True,
                 c_in:int=3, c_out:int=10, **kwargs):
        super(ResNetX2, self).__init__()
        num_stages = len(num_nodes)
        nh = int(ni * bottle_scale)
        strides = [1 if i==0 and not first_downsample else 2 for i in range(num_stages)]
        folds = [1] + folds #[fold*exp**i for i in range(num_stages)]
        
        self.stem = Stem(c_in, no=ni) # , deep_stem
        #self.init = InitBlock(Unit, ni, fold, nh=nh)
        
        units = []
        for i, (nu, stride) in enumerate(zip(num_nodes, strides)):
            for j in range(nu):
                if j == 0: # the first node(layer) of each stage
                    units += [ExpandBlock(Unit, ni, fold1 = folds[i], fold2=folds[i+1], stride=stride, nh=nh, **kwargs)]
                else:
                    units += [FoldBlock(Unit, ni, fold=folds[i+1], stride=1, nh=nh, **kwargs)]
                    
        self.units = nn.ModuleList(units)
        
        if tail_all:
            self.classifier = Classifier(ni*folds[-1], c_out) #
        else:
            self.classifier = Classifier(ni, c_out)
        self.folds = folds
        self.num_nodes = num_nodes
        self.tail_all = tail_all
        init_cnn(self)
        
    def forward(self, x):
        x = self.stem(x)
        xs = [x] #self.init(x)
        for unit in self.units:
            xs = unit(*xs)
        if self.tail_all:
            x = torch.cat(xs,1)
        else:
            x = xs[0]

        x = self.classifier(x)
        return x
        

In [None]:
num_nodes = [2,3,3,3]
folds = [4,4,4,4]
model = ResNetX2(Stem=conv_bn, Unit=mbconv, folds=folds, ni=64, num_nodes=num_nodes, bottle_scale=1, tail_all=False, ks=3, c_out=100)
num_params(model)

tensor(228004)

In [None]:
x = torch.randn(2,3,64,64)

In [None]:
with torch.autograd.set_detect_anomaly(True):
    out = model(x)
    out.mean().backward()

## Misc

In [None]:
class InitBlock(nn.Module):
    "Init block of folded ResNet"
    def __init__(self, Unit:nn.Module, ni:int, fold:int, stride:int=1, **kwargs):
        super(InitBlock, self).__init__()
        self.ni, self.fold = ni, fold
        units = []
        for i in range(fold-1):
            units += [Unit(ni, stride=stride, **kwargs)]
        self.units = nn.ModuleList(units)
        
    def forward(self, x):
        xs = [x]
        for i in range(self.fold-1):
            xs += [xs[i] + self.units[i](xs[i])]
        xs.reverse()
        return xs

In [None]:
m = InitBlock(mbconv, 16, 4, stride=1, nh=32)
x = torch.randn(2,16,32,32)
xs = m(x)
len(xs), xs[0].shape

(4, torch.Size([2, 16, 32, 32]))