In [None]:
#default_exp lookbacknet

In [None]:
#export
from wong.imports import *
from wong.core import *


In [None]:
#export
class LookbackBlock(nn.Module):
    "Basic block of lookback ResNet"
    def __init__(self, Unit:nn.Module, ni:int, fold:int, stride:int=1, **kwargs):
        super(LookbackBlock, self).__init__()
        self.ni, self.fold, self.stride = ni, fold, stride
        units = []
        for i in range(fold):
            units += [Unit(ni, stride=1, **kwargs)]
        self.units = nn.ModuleList(units)
        
    def forward(self, *xs):
        xs = list(xs)
        for i in range(self.fold):
            xs[i] = xs[i] + self.units[i](xs[i-1])
        return xs

In [None]:
fold=2
m = LookbackBlock(mbconv, ni=16, fold=fold, nh=32)

xs = [torch.randn(2,16,32,32)] * fold

xs2 = m(*xs)

len(xs2), [x.shape for x in xs2]

(2, [torch.Size([2, 16, 32, 32]), torch.Size([2, 16, 32, 32])])

In [None]:
#export
class ExpandBlock(nn.Module):
    "Expand block of lookback ResNet"
    def __init__(self, Unit:nn.Module, ni:int, fold1:int, fold2:int, stride:int=1, **kwargs):
        super(ExpandBlock, self).__init__()
        self.ni, self.fold1, self.fold2, self.stride = ni, fold1, fold2, stride
        units = []
        for i in range(fold2 - fold1):
            units += [Unit(ni, stride=1, **kwargs)]
        self.units = nn.ModuleList(units)
        if stride == 2:
            self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
    def forward(self, *xs):
        xs = list(xs)
        if self.stride == 2:
            for i in range(len(xs)):
                xs[i] = self.pool(xs[i])
        if self.fold2 <= self.fold1:
            return xs[:self.fold2]
        for i in range(self.fold2 - self.fold1):
            xs.append(self.units[i](xs[-1]) + xs[-1])
        return xs

In [None]:
m = ExpandBlock(mbconv, ni=16, fold1=fold, fold2=3, stride=2, nh=32)

In [None]:
xs2 = m(*xs)

In [None]:
len(xs2), [x.shape for x in xs2]

(3,
 [torch.Size([2, 16, 16, 16]),
  torch.Size([2, 16, 16, 16]),
  torch.Size([2, 16, 16, 16])])

In [None]:
#export
class LookbackNet(nn.Module):
    "A lookback resnet, using Expand."
    def __init__(self, Stem, Unit, folds:tuple, ni:int, num_nodes:tuple,
                 bottle_scale:int=1, first_downsample:bool=False, tail_all:bool=True,
                 c_in:int=3, c_out:int=10, **kwargs):
        super(LookbackNet, self).__init__()
        num_stages = len(num_nodes)
        nh = int(ni * bottle_scale)
        strides = [1 if i==0 and not first_downsample else 2 for i in range(num_stages)]
        folds = [1] + folds #[fold*exp**i for i in range(num_stages)]
        
        self.stem = Stem(c_in, no=ni) # , deep_stem
        
        units = []
        for i, (nu, stride) in enumerate(zip(num_nodes, strides)):
            for j in range(nu):
                if j == 0: # the first node(layer) of each stage
                    units += [ExpandBlock(Unit, ni, fold1 = folds[i], fold2=folds[i+1], stride=stride, nh=nh, **kwargs)]
                else:
                    units += [LookbackBlock(Unit, ni, fold=folds[i+1], stride=1, nh=nh, **kwargs)]
                    
        self.units = nn.ModuleList(units)
        
        if tail_all:
            self.classifier = Classifier(ni*folds[-1], c_out) #
        else:
            self.classifier = Classifier(ni, c_out)
        self.folds = folds
        self.num_nodes = num_nodes
        self.tail_all = tail_all
        init_cnn(self)
        
    def forward(self, x):
        x = self.stem(x)
        xs = [x]
        for unit in self.units:
            xs = unit(*xs)
        if self.tail_all:
            x = torch.cat(xs,1)
        else:
            x = xs[0]

        x = self.classifier(x)
        return x
        

In [None]:
num_nodes = [8,9,9,9]
folds = [4,4,4,4]
ni = 64
bottle_scale = 1
model = LookbackNet(Stem=conv_bn, Unit=mbconv, folds=folds, ni=ni, num_nodes=num_nodes, 
                bottle_scale=bottle_scale, tail_all=True, ks=3, c_out=100, zero_bn=True)
num_params(model)

tensor(1189860)

In [None]:
x = torch.randn(2,3,64,64)

In [None]:
with torch.autograd.set_detect_anomaly(True):
    out = model(x)
    out.mean().backward()

In [None]:
model

LookbackNet(
  (stem): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (units): ModuleList(
    (0): ExpandBlock(
      (units): ModuleList(
        (0): Sequential(
          (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64, bias=False)
          (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (5): ReLU(inplace=True)
          (6): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): Sequential(
          (0): Conv2d(64, 64, kernel_size=(