In [1]:
#default_exp decoder

# Decoder
> Functions to set up the dynamic decoders for use in the dynamic U-Net architecture. In particular, we utilize PyTorch hooks here.

In [1]:
#hide
from nbdev.showdoc import *

In [12]:
#export
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F

We need to automatically construct the Decoder using the architecture given in the encoder. To do so, we'll define some helper layers as `nn.Module`s.

* `ConvLayer` is just a general form of a convolution, ReLU, and batch normalization layer in sequence, with some empirical bets practices (e.g. initializing using $\frac{1}{\sqrt{5}}$ for all the weights in the convolutional layer, as per the FastAI course).
* `ConcatLayer` is just a thin wrapper on the `torch.cat` function that concatenates all inputs along the channel dimension, assuming inputs are image batches, i.e. they have shape (batch size, num channels, height, width).
* `LambdaLayer` is just a thin wrapper of a generic lambda function
* `upconv2x2` is a utility function for setting up convolutions that upsample an image. As mentioned above, in the U-Net architecture, we first concatenate the encoder output with the corresponding decoder input, so that when we upsample an image (i.e. from $(h, w)$ in size to $(2h, 2w)$ in size), we always have 2 times the amount of information (in this case, from having two times the number of channels). Accordingly, we will always convolve using an atrous convolution (where we dilate the kernel, rather than inserting 0s in the input to the convolutional layer), followed by the actual upsampling operation (using bilinear upsampling).

Note: these functions remain exposed for now, but the goal is to not need them to be exposed.

In [3]:
#export
class ConvLayer(nn.Module):
    def __init__(self, num_inputs, num_filters, bn=True, kernel_size=3, stride=1,
                 padding=None, transpose=False, dilation=1):
        super(ConvLayer, self).__init__()
        if padding is None:
            padding = (kernel_size-1)//2 if transpose is not None else 0
        if transpose:
            self.layer = nn.ConvTranspose2d(num_inputs, num_filters, kernel_size=kernel_size,
                                            stride=stride, padding=padding, dilation=dilation)
        else:
            self.layer = nn.Conv2d(num_inputs, num_filters, kernel_size=kernel_size,
                                   stride=stride, padding=padding)
        nn.init.kaiming_uniform_(self.layer.weight, a=np.sqrt(5))
        self.bn_layer = nn.BatchNorm2d(num_filters) if bn else None

    def forward(self, x):
        out = self.layer(x)
        out = F.relu(out)
        return out if self.bn_layer is None else self.bn_layer(out)
    
class ConcatLayer(nn.Module):
    def forward(self, x, dim=1):
        return torch.cat(list(x.values()), dim=dim)
    
class LambdaLayer(nn.Module):
    def __init__(self, f):
        super(LambdaLayer, self).__init__()
        self.f = f

    def forward(self, x):
        return self.f(x)

def upconv2x2(inplanes, outplanes, size=None, stride=1):
    if size is not None:
        return [
            ConvLayer(inplanes, outplanes, kernel_size=2, dilation=2, stride=stride),
            nn.Upsample(size=size, mode='bilinear', align_corners=True)
        ] 
    else:
        return [
            ConvLayer(inplanes, outplanes, kernel_size=2, dilation=2, stride=stride),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        ]

Some specifics in how the decoder is coordinated (here, the first layer means the input encoding layer of the encoder, and the last layer indicates the last layer in the encoder). These details are not super important, and are probably understandable if you inspect the U-Net architecture image more closely.

* The first layer's output passed along, concatenated, fed through `conv3x3` before upsampling, then fed through a regular `conv3x3` two times, then a `conv1x1` to output the right number of channels for segmentation output
* The middle layers output all are passed along, concatenated, and fed through a `conv3x3` that first halves number of channels to upsample, then a regular `conv3x3`
* The last layer output's takes two pathways:
    - Going down in the figure, the output goes through: max-pool (2x2), conv3x3, conv3x3, upconv2x2. These operations are encompassed in the `DecoderConnect` class.
    - Going across, assed across and concatenated to the result of above step
    
    
Again, these details don't particularly matter, unless you're implementing the architecture yourself. The important point is that upsampling always happens after a concatenation of the encoder's output with the corresponding input to the corresponding level of the decoder.

In [4]:
#export
class DecoderConnect(nn.Module):
    def __init__(self, inplanes, output_size):
        super(DecoderConnect, self).__init__()
        self.bottom_process = nn.Sequential(
            ConvLayer(inplanes, inplanes * 2, kernel_size=3),
            ConvLayer(inplanes * 2, inplanes * 2, kernel_size=3),
            *upconv2x2(inplanes * 2, inplanes, size=output_size)
        )
        self.concat_process = nn.Sequential(
            ConcatLayer(),
            ConvLayer(inplanes * 2, inplanes * 2, kernel_size=1),
            ConvLayer(inplanes * 2, inplanes, kernel_size=3),
            ConvLayer(inplanes, inplanes, kernel_size=3)
        )
        
    def forward(self, x):
        decoder_input = self.bottom_process(x)
        return self.concat_process({0: x, 1: decoder_input})

As a quick sanity check, we can initialize this to make sure everything is in order.

In [10]:
#example
model = DecoderConnect(512, (7, 7))

In [11]:
#example
model

DecoderConnect(
  (bottom_process): Sequential(
    (0): ConvLayer(
      (layer): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn_layer): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): ConvLayer(
      (layer): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn_layer): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): ConvLayer(
      (layer): Conv2d(1024, 512, kernel_size=(2, 2), stride=(1, 1))
      (bn_layer): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (3): Upsample(size=(7, 7), mode=bilinear)
  )
  (concat_process): Sequential(
    (0): ConcatLayer()
    (1): ConvLayer(
      (layer): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
      (bn_layer): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): ConvLayer(
      (layer): Con