# U-NET Model
In this notebook I implement the U-Net model. This is a machine learning model based on a encoder-decoder architecture with residual connections. It is used in segmentation tasks. We can see the overall architecture in the following image.
![](https://production-media.paperswithcode.com/methods/Screen_Shot_2020-07-07_at_9.08.00_PM_rpNArED.png)

Here we can separate the model into two parts, the down part (or encoder) and the up part (or decoder). In the encoder, in each step we do two convolutions increasing the number of channels and then a max pool so we reduce the image dimensions. This is done until we reach the bottleneck of the network. Afterwards, we have the decoder, which will "undo" the steps done by the encoder. The max pooling is undone via a transposed convolution and then we apply again two convolutions. 

We can see that after our transposed convolution we concatenate a copy of the result in the same step in the encoder, this is done in order to mantain high resolution details. 

Finally, we do a 1x1 convolution in order to match the output channels of the desired output. 

Also note that since this network is only composed by convolutions we can feed images of any size.

Let's start with the code.

In [1]:
import torch
from torch import nn

Here we are going to implement the `DualConv` block which will make the two convolutions of each level

In [2]:
class DualConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(self.__class__, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), # No bias because of batchnorm
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )
        
    def forward(self, x):
        return self.block(x)

Now let's implement the network, we will make our encoder and decoder as a general general module list. Also we need the bottleneck and the last convolution. For the foward pass we will iterate over the encoder and save the results in a list, then we will iterate over the decoder and concatenate the results with the encoder results. Finally we will do the last convolution.

In [3]:
class Unet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, sizes=[64, 128, 256, 512]):
        super(self.__class__, self).__init__()
        self.sizes = sizes
        self.down = nn.ModuleList()
        self.up = nn.ModuleList()

        for size in sizes:
            self.down.append(DualConv(in_channels, size))
            in_channels = size
        self.bottleneck = DualConv(sizes[-1], sizes[-1]*2)
        for size in reversed(sizes):
            self.up.append(
                    nn.ConvTranspose2d(in_channels=2*size, out_channels=size, kernel_size=2, stride=2),
            )
            self.up.append(DualConv(2*size, size))

        self.last = nn.Conv2d(sizes[0], out_channels, kernel_size=1)

    def forward(self, x):
        res = []
        for block in self.down:
            x = block(x)
            res.append(x)
            x = nn.MaxPool2d(2, 2)(x)

        x = self.bottleneck(x)

        for idx, block in enumerate(self.up):
            x = block(x)
            if idx % 2 == 0:
                x = torch.cat([x, res.pop()], dim=1)
            
        return self.last(x)

Let's see the model summary

In [4]:
net = Unet()
print(net)

Unet(
  (down): ModuleList(
    (0): DualConv(
      (block): Sequential(
        (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
    )
    (1): DualConv(
      (block): Sequential(
        (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True

Let's create random data and feed it into the model

In [5]:
X = torch.randn(1, 1, 512, 512)
Y = net(X)
Y.shape

torch.Size([1, 1, 512, 512])

As we can see our model is working, it will need some training but it is out of the scope of this notebook.