# Import

In [1]:
import torch
import torch.nn
import torchvision

import pandas as pd

from IPython.core.debugger import set_trace

# Utility

In [2]:
def leaf_modules(model): # TODO: does pytorch have a built-in for this?
    for m in model.modules():
        if len(list(m.children())) == 0:
            yield m

# Model summary

First thing is to try to write a model summary with hooks

In [3]:
model = torchvision.models.vgg16(pretrained=True)

In [4]:
def _hook(module, x, y): 
    module.sp_x = x[0].shape[1:]
    module.sp_y = y.shape[1:]

In [5]:
hooks = []
for m in leaf_modules(model):
    hooks.append(m.register_forward_hook(_hook))

In [6]:
sz = (300, 400, 3)

In [7]:
X = torch.zeros(1, sz[2], *sz[0:2])
X.shape

torch.Size([1, 3, 300, 400])

In [8]:
y = model.forward(X)

In [9]:
y.shape

torch.Size([1, 1000])

In [10]:
df = pd.DataFrame(columns=['name', 'in_shape', 'out_shape', 'num_params'])
for m in leaf_modules(model):
    dict_row = {'name': str(m).split('(')[0],
                'in_shape': m.sp_x,
                'out_shape': m.sp_y}
    if hasattr(m, 'weight') and hasattr(m, 'bias'):
        dict_row['num_params'] = m.weight.numel() + m.bias.numel()
    df = df.append(dict_row, ignore_index=True)
# Get summary row
df['num_params'] = df['num_params'].astype(float) # pandas has trouble with int + nans
row_total = df.sum(numeric_only=True, skipna=True)
row_total.name='Total' # Replaces numeric index
df = df.append(row_total)

In [11]:
df

Unnamed: 0,name,in_shape,out_shape,num_params
0,Conv2d,"(3, 300, 400)","(64, 300, 400)",1792.0
1,ReLU,"(64, 300, 400)","(64, 300, 400)",
2,Conv2d,"(64, 300, 400)","(64, 300, 400)",36928.0
3,ReLU,"(64, 300, 400)","(64, 300, 400)",
4,MaxPool2d,"(64, 300, 400)","(64, 150, 200)",
5,Conv2d,"(64, 150, 200)","(128, 150, 200)",73856.0
6,ReLU,"(128, 150, 200)","(128, 150, 200)",
7,Conv2d,"(128, 150, 200)","(128, 150, 200)",147584.0
8,ReLU,"(128, 150, 200)","(128, 150, 200)",
9,MaxPool2d,"(128, 150, 200)","(128, 75, 100)",


In [12]:
for hook in hooks:
    hook.remove()

# Dynamic UNET

Get encoder somehow... fastai finds the first module with "Pool" in it and cuts the rest off

In [13]:
def conv_layer(ic, oc, ks, s, p):
    return torch.nn.Sequential(torch.nn.Conv2d(ic, oc, kernel_size=ks, stride=s, padding=p, bias=False),
                               torch.nn.BatchNorm2d(oc),
                               torch.nn.ReLU(inplace=True))

In [14]:
def up_conv_layer(ic, oc, ks, s, p, sz):
    return torch.nn.Sequential(conv_layer(ic, oc, ks, s, p),
                               torch.nn.Upsample(size=sz, mode='bilinear', align_corners=True))

In [15]:
class DynamicUNET(torch.nn.Module):
    def __init__(self, encoder, sz, oc):
        super().__init__()
        # Given an encoder, input size, and output channels we can generate a UNET
        self.encoder = encoder
        self.sz = sz
        self.oc = oc
        
        # Use hooks to store the input right before the size changes (i.e. before the max pool)
        # This might not work quite as intended for encoders with stride-2 convolution
        def _hook(module, x, y):
            if x[0].shape[2:] != y.shape[2:]:
                module.x = x[0]
                
        hooks = []
        for m in leaf_modules(encoder):
            hooks.append(m.register_forward_hook(_hook))
        self.hooks = hooks
                
        # Do dummy pass
        X = torch.zeros(1, self.sz[2], *self.sz[0:2])
        Xs = self._forward_encoder(X)
                           
        # Get decoder
        decoder = []
        sp_p = (0, Xs[0].shape[2], Xs[0].shape[3]) # Shape "prev"
        for X1, X2 in zip(Xs[:-1], Xs[1:]):
            sp_e = X1.shape[1:]                    # Shape "encoder"
            sp_n = X2.shape[1:]                    # Shape "next"
            decoder.append(up_conv_layer(sp_e[0]+sp_p[0], sp_n[0], 3, 1, 1, sp_n[1:]))
            sp_p = sp_n
        self.decoder = torch.nn.Sequential(*decoder)
        
        # Get last conv - since there is no upsampling for last convolution
        self.last_conv = torch.nn.Conv2d(Xs[-1].shape[1]+sp_p[0], self.oc, 3, 1, 1)
        
    def _forward_encoder(self, X):
        X = self.encoder.forward(X)
        # Get hooked outputs
        Xs = []
        for m in leaf_modules(encoder):
            if hasattr(m, 'x'):
                Xs.append(m.x)
        Xs.append(X); Xs.reverse() # These typically need to be accessed in reverse order
        return Xs
        
    def forward(self, X):
        Xs = self._forward_encoder(X)
        X_p = Xs[0][:,0:0,:,:] # Empty, but same size and dimension as last X
        for X, m in zip(Xs[:-1], self.decoder):
            X_p = m.forward(torch.cat([X, X_p], dim=1))
        return self.last_conv(torch.cat([Xs[-1], X_p], dim=1))
        
    def __del__(self):
        for hook in self.hooks:
            hook.remove()

In [16]:
encoder = torchvision.models.vgg16().features 
sz = (300, 400, 3)
out_channels = 5

In [17]:
model = DynamicUNET(encoder, sz, out_channels)

In [18]:
model

DynamicUNET(
  (encoder): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dil

In [19]:
X = torch.zeros(1, sz[2], *sz[0:2])
y_hat = model.forward(X)

In [20]:
y_hat.shape

torch.Size([1, 5, 300, 400])