In [1]:
import torch
from torch import nn
from torchvision import transforms

## Homework 6 - U-Net
*Lorenzo Basile*

This notebook contains the implementation of a U-Net style CNN, with the following specifications:
- All convolutions use a 3 x 3 kernel and leave the spatial dimensions of the input untouched (this is obtained by using padding);
- Downsampling in the contracting part is performed via maxpooling with a 2 x 2 kernel and stride of 2;
- Upsampling is operated by a deconvolution with a 2 x 2 kernel and stride of 2.
- The final layer of the expanding part has only 1 channel

The final point means that this network could be used to perform binary segmentation.

The network is made of two kinds of submodules: contracting and expanding.  
A contracting module is the series of a max pooling layer (excluded the input layer) and two convolutional layers with ReLU activations.  
The expanding part is slightly more complex: the input is passed through an upsampling layer which doubles its dimensions and then concatenated with a cropped copy of the output of the corresponding contracting module. Then, the result of this concatenation is fed into the series of two convolutional layers with ReLU activations.

If pooling is never applied to an odd sized image (in this specific case this condition translates to having input images with dimensions which are multiples of 16), this network leaves the dimensions of the input untouched.

In [2]:
class contracting_module(torch.nn.Module):
        
    def __init__(self, index):
        super().__init__()
        self.output_channels=2**(6+index)
        self.layers=nn.Sequential(
            nn.MaxPool2d(kernel_size=2) if index!=0 else nn.Identity(),
            nn.Conv2d(in_channels=self.output_channels//2 if index!=0 else 1, out_channels=self.output_channels, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=self.output_channels, out_channels=self.output_channels, kernel_size=3, stride=1, padding=1),
            nn.ReLU()
        )
    
    def forward(self, x):
        return self.layers(x)

class expanding_module(torch.nn.Module):
        
    def __init__(self, index):
        super().__init__()
        self.output_channels=2**(6+index)
        self.upsampling=nn.ConvTranspose2d(in_channels=self.output_channels*2, out_channels=self.output_channels, kernel_size=2, stride=2)
        self.layers=nn.Sequential(
            nn.Conv2d(in_channels=self.output_channels*2, out_channels=self.output_channels, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=self.output_channels, out_channels=self.output_channels, kernel_size=3, stride=1, padding=1),
            nn.ReLU()
        )
    
    def forward(self, x, copy):
        x=self.upsampling(x)
        cropped_copy=transforms.CenterCrop(size=x.shape[-1])(copy)
        x=torch.cat([cropped_copy, x], dim=1)
        return self.layers(x)

class UNet(torch.nn.Module):
    
    def __init__(self):
        super().__init__()
        self.contracting_modules=[contracting_module(i) for i in range(5)]
        self.expanding_modules=[expanding_module(i) for i in range(4)]
        self.final_layer=nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1)
    def forward(self, x):
        x=copy_0=self.contracting_modules[0](x)
        x=copy_1=self.contracting_modules[1](x)
        x=copy_2=self.contracting_modules[2](x)
        x=copy_3=self.contracting_modules[3](x)
        x=self.contracting_modules[4](x)
        x=self.expanding_modules[3](x, copy_3)
        x=self.expanding_modules[2](x, copy_2)
        x=self.expanding_modules[1](x, copy_1)
        x=self.expanding_modules[0](x, copy_0)
        x=self.final_layer(x)
        return x

In [3]:
net=UNet()
input=torch.rand((1,1,512,512))
print(input.shape)
print(net(input).shape)

torch.Size([1, 1, 512, 512])
torch.Size([1, 1, 512, 512])
