In [None]:
#default_exp unet

# U-Net
> We expose the functionality for setting up the dynamic U-Net using our encoder and decoder packages.

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
#export
import numpy as np
import torch
import torch.nn as nn

from dynamic_unet.decoder import ConcatLayer, ConvLayer, DecoderConnect, LambdaLayer, upconv2x2

The crux of constructing the decoder happens in the `setup_decoder` function call below, and consequently in the `construct_decoder`. The details are hard to extract from the code below, so we can break it down as follows (tracing the code in the `setup_decoder` function first.

### Getting Shapes Using Hooks

We're going to gather the input size and output size of a tensor to any layer in the ResNet encoder network with a name that has the prefix "layer". To do so, we'll use hooks. Specifically, a hook is a closure, i.e. a function that's passed as an argument when registering a hook for a specific layer in our network. You can see here that
* `shape_hook` is the function passed when registering a hook
* We `register_forward_hook` for any `child` layer of our network that has a name that `startswith` layer, e.g. `self.layer0`, `self.layer1`, and so on.

Note the specification for `shape_hook`, and generally for the function passed to `register_forward_hook` - it will have access to the input and output of the layer we are calling `register_forward_hook` for (note that input and output can be tuples here). In our case, we only care about their shapes, as we'll need the shape to determine the shape of the decoder's input, and accordingly the number of filters the convolutional layers need to output in the previous layer.

Accordingly, we'll take those shapes, and add them to our `input_sizes` and `output_sizes` array, to keep track of the input and output shapes as the network processes an input. To actually populate these arrays, we have to do exctly that - process an input. Thus, we'll make a dummy input (in the code, `test_input`) that we pass through our encoder, and after it finishes processing that input, our `input_sizes` and `output_sizes` array will be populated!

### Constructing the Decoder

Now that we have the input and output sizes of any tensors passing through the blocks of our ResNet encoder, we can construct our decoder level by level. To do so, we'll just look at the following things:
* How much we need to upsample the size of the image (determined by looking at the ratio of the input image size and the output image size)
* What the difference in channels between the input and output of the corresponding encoder level are (determined by looking at the ratio of channels between input and output)

Looking at both of these gives us a sense of the operation we need to do to reverse what the encoder did. Specifically, we can abide by the following assumptions when constructing the decoder:
* The shape of the input to the level of the decoder we're working on will be the same as the shape of the output of the corresponding level of the encoder
* The shape of the output of this level of the decoder will be the same shape as the shape of the input of the corresponding level of the encoder

With these assumptions in mind, and using the details above for constructing the operations for each level of the decoder, we can just use case work for actually constructing the decoder, depending on whether we're looking at the last layer of the encoder, one of the middle layers, or the first layer of the encoder.

Since we're starting from the inputs and outputs of the first layer of the encoder, we add on the constructed layers as we inspect the shapes of the inputs and outputs of the encoder, and then reverse the list of constructed layers when finalizing the decoder architecture, to ensure that we go from the last output shape of the encoder to the first input shape of the encoder, which is (generally) what we want to output for segmentation. (This doesn't necessarily have to be true, in which case, a 1x1 convolution is added at the end of the decoder to get the right number of output channels, specified in the constructing of the class as `num_output_channels`.

Note that we maintain the decoder as a list of modules, i.e. an `nn.ModuleList`. This is an intentional choice, as we'll need to perform the operations of our network in sequence by level, as each level requires getting the corresponding output of the encoder, and processing it alongside the corresponding input of the decoder.

### Model Forward Using Hooks

The last part of setting up our dynamic U-Net architecture is to specify the `forward` function. In order to do so, we need to keep track of the outputs of each level of our encoder. Since we've encompassed the encoder as one module when constructing our U-Net, the easiest way to get the outputs for each level of the encoder is to just use hooks again.

The setup for these hooks is very similar to how we set up the shape hooks above, but instead, we only keep track of the outputs, and we want the actual output tensor, not the shape. This is encompassed in the `encoder_output_hook` hook in the forward function below. Again, we register the hook for all layers in our encoder that have name starting with "layer".

To actually use these outputs, we only need to keep track of the corresponding input we are passing into the current level of the decoder. This becomes convenient to do since we left the decoder as an `nn.ModuleList`, so we need only iterate over the encoder outputs and the corresponding layer of the decoder that they'll be passed into with the corresponding input to the decoder. This is encompassed in the following loop in the `forward` function:

```
prev_output = None
for reo, rdl in zip(reversed(encoder_outputs), self.decoder):
    if prev_output is not None:
        prev_output = rdl({0: reo, 1: prev_output})
    else:
        prev_output = rdl(reo)
```

Note how that for the first layer of the decoder (the one that ties with the last layer of the encoder), there's no previous output. This is because the first layer of the decoder has the additional pathway (seen at the bottom of the U-Net architecture figure) that is concatenated with the output of the last layer of the encoder. On the other hand, for all other layers, the encoder output (`reo`) and the decoder input (`prev_output`) are concatenated together in a single pathway (explicitly, via the `ConcatLayer` forward function).

In [None]:
#export
class DynamicUNet(nn.Module):
    def __init__(self, encoder, input_size=(224, 224), num_output_channels=None, verbose=0):
        super(DynamicUNet, self).__init__()
        self.encoder = encoder
        self.verbose = verbose
        self.input_size = input_size
        self.num_input_channels = 3  # This must be 3 because we're using a ResNet encoder
        self.num_output_channels = num_output_channels
        
        self.decoder = self.setup_decoder()
        
    def forward(self, x):
        encoder_outputs = []
        def encoder_output_hook(self, input, output):
            encoder_outputs.append(output)

        handles = [
            child.register_forward_hook(encoder_output_hook) for name, child in self.encoder.named_children()
            if name.startswith('layer')
        ]

        try:
            self.encoder(x)
        finally:
            if self.verbose >= 1:
                print("Removing all forward handles")
            for handle in handles:
                handle.remove()

        prev_output = None
        for reo, rdl in zip(reversed(encoder_outputs), self.decoder):
            if prev_output is not None:
                prev_output = rdl({0: reo, 1: prev_output})
            else:
                prev_output = rdl(reo)
        return prev_output
                
    def setup_decoder(self):
        input_sizes = []
        output_sizes = []
        def shape_hook(self, input, output):
            input_sizes.append(input[0].shape)
            output_sizes.append(output.shape)

        handles = [
            child.register_forward_hook(shape_hook) for name, child in self.encoder.named_children()
            if name.startswith('layer')
        ]    

        self.encoder.eval()
        test_input = torch.randn(1, self.num_input_channels, *self.input_size)
        try:
            self.encoder(test_input)
        finally:
            if self.verbose >= 1:
                print("Removing all shape hook handles")
            for handle in handles:
                handle.remove()
        decoder = self.construct_decoder(input_sizes, output_sizes, num_output_channels=self.num_output_channels)
        return decoder
        
    def construct_decoder(self, input_sizes, output_sizes, num_output_channels=None):
        decoder_layers = []
        for layer_index, (input_size, output_size) in enumerate(zip(input_sizes, output_sizes)):
            upsampling_size_factor = int(input_size[-1] / output_size[-1])
            upsampling_channel_factor = input_size[-3] / output_size[-3]
            next_layer = []
            bs, c, h, w = input_size
            ops = []
            if layer_index == len(input_sizes) - 1:
                last_layer_ops = DecoderConnect(output_size[-3], output_size[2:])
                last_layer_ops_input = torch.randn(*output_size)
                last_layer_concat_ops_output = last_layer_ops(last_layer_ops_input)
                next_layer.extend([last_layer_ops])
                if upsampling_size_factor > 1 or upsampling_channel_factor != 1:
                    last_layer_concat_upconv_op = upconv2x2(output_size[-3], input_size[-3], size=input_size[2:])
                    last_layer_concat_upconv_op_output = nn.Sequential(*last_layer_concat_upconv_op)(
                        last_layer_concat_ops_output
                    )
                    next_layer.extend(last_layer_concat_upconv_op)
            elif layer_index == 0:
                first_layer_concat_ops = [
                    ConcatLayer(),
                    ConvLayer(output_size[-3] * 2, output_size[-3] * 2, kernel_size=1),
                    *upconv2x2(
                        output_size[-3] * 2,
                        output_size[-3],
                        size=[dim * upsampling_size_factor for dim in output_size[2:]]
                    ),
                    ConvLayer(output_size[-3], output_size[-3], kernel_size=3),
                    ConvLayer(
                        output_size[-3],
                        input_size[-3] if self.num_output_channels is None else self.num_output_channels,
                        kernel_size=1
                    ),
                ]
                first_layer_concat_ops_output = nn.Sequential(*first_layer_concat_ops)(
                    {0: torch.randn(*output_size), 1: torch.randn(*output_size)}
                )
                next_layer.extend(first_layer_concat_ops)
            else:
                middle_layer_concat_ops = [
                    ConcatLayer(),
                    ConvLayer(output_size[-3] * 2, output_size[-3] * 2, kernel_size=1),
                    ConvLayer(output_size[-3] * 2, output_size[-3], kernel_size=3),
                    ConvLayer(output_size[-3], output_size[-3], kernel_size=3)
                ]
                middle_layer_concat_ops_output = nn.Sequential(*middle_layer_concat_ops)(
                    {0: torch.randn(*output_size), 1: torch.randn(*output_size)}
                )
                next_layer.extend(middle_layer_concat_ops)
                if upsampling_size_factor > 1 or upsampling_channel_factor != 1:
                    middle_layer_concat_upconv_op = upconv2x2(output_size[-3], input_size[-3], size=input_size[2:])
                    middle_layer_concat_upconv_op_output = nn.Sequential(*middle_layer_concat_upconv_op)(
                        middle_layer_concat_ops_output
                    )
                    next_layer.extend(middle_layer_concat_upconv_op)
            decoder_layers.append(nn.Sequential(*next_layer))
        return nn.ModuleList(reversed(decoder_layers))

Now, we can construct our network.

In [None]:
#example
from dynamic_unet.encoder import resnet34 

model = DynamicUNet(resnet34(), num_output_channels=32, input_size=(360, 480))

In [None]:
#example
model

DynamicUNet(
  (encoder): ResNetEncoder(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer0): Sequential(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (layer1): Sequential(
      (0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (1): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d