# UNet Visualization

In this notebook, we'll be experimenting with the UNet, trying to visualize what happens at each layer. I'll be hijacking the MONAI UNet architecture to output some nice images.

In [1]:
import torch
import torch.nn as nn

In [3]:
# modelled after the original UNet
class UNet(nn.Module):
    # This conv/relu combination results in no change in dimension for full image restoration
    def conv_relu(self, in_channels, out_channels, kernel_size=3, padding=1, padding_mode='reflect'):
        return nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                padding=padding,
                padding_mode=padding_mode
            ),
            nn.ReLU()
        )
    
    # This transpose doubles the dimensions
    def conv_transpose(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1):
        return nn.ConvTranspose2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            output_padding=output_padding
        )
    
    def first_block(self, in_channels, out_channels):
        return nn.Sequential(
            self.conv_relu(in_channels, out_channels),
            self.conv_relu(out_channels, out_channels)
        )
    
    # Output: (x-4)/2
    def contract_block(self, in_channels, out_channels):
        # Testing: adding BatchNorm2d(out_channels) after ReLU layers
        return nn.Sequential(
            nn.MaxPool2d(kernel_size=2),
            self.conv_relu(in_channels, out_channels),
            self.conv_relu(out_channels, out_channels)
        )
    
    def bottleneck_block(self, in_channels, mid_channels, out_channels):
        return nn.Sequential(
            nn.MaxPool2d(kernel_size=2),
            self.conv_relu(in_channels, mid_channels),
            self.conv_relu(mid_channels, mid_channels),
            self.conv_transpose(mid_channels, out_channels)
        )
        
    # Output: (x-4)*2
    def expand_block(self, in_channels, mid_channels, out_channels):
        return nn.Sequential(
            self.conv_relu(in_channels, mid_channels),
            self.conv_relu(mid_channels, mid_channels),
            self.conv_transpose(mid_channels, out_channels)
        )

    def final_block(self, in_channels, mid_channels, out_channels):
        return nn.Sequential(
            self.conv_relu(in_channels, mid_channels),
            self.conv_relu(mid_channels, mid_channels),
            nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=1)
        )
    
    def __init__(self):
        super().__init__()
        self.contraction = nn.ModuleList([
            # 288
            self.first_block(1, 64),
            # 288
            self.contract_block(64, 128),
            # 144
            self.contract_block(128, 256),
            # 72
            self.contract_block(256, 512),
            # 36
        ])
        
        self.bottleneck = self.bottleneck_block(512, 1024, 512)
        
        self.expansion = nn.ModuleList([
            # 36
            self.expand_block(1024, 512, 256),
            # 72
            self.expand_block(512, 256, 128),
            # 144
            self.expand_block(256, 128, 64),
            # 288
            self.final_block(128, 64, 2)
            # 288
        ])
        
        self.contraction_outputs = []

    def forward(self, image):
        for layer in self.contraction:
            image = layer(image)
            self.contraction_outputs.append(image)
        
        image = self.bottleneck(image)
        for i in range(4):
            image = torch.cat((self.contraction_outputs[3 - i], image), dim=1)
            image = self.expansion[i](image)
        self.contraction_outputs = []
        return image