#**U-Net Implementation**
Importing necessary libraries

In [1]:
import os
import torch
import torch.nn as nn

U-Net contains downsampling and upsampling path. Downsampling path consists of 2 3x3 convolution layers each followed by ReLU activation function. It is then followed by max-pool layer with stride of 2. 

In the upsampling path, cropping is required due to loss of border due to every convolution. Along with cropping of image, two 3x3 convolution operation is performed along with ReLU which is followed by max-pool with stride 2 is performed. It is then followed by up-convolution. 

Final layer of the network consists of 1x1 convolution that map each 64 component feature vector into desired number of classes.
U-Net architecture consists of total of 23 convolution layers.



![alt text](https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/u-net-architecture.png)

In [7]:
def double_convolution(input_channels, output_channels):
  conv = nn.Sequential(
      nn.Conv2d(input_channels, output_channels, kernel_size=3),
      nn.ReLU(inplace=True),
      nn.Conv2d(output_channels, output_channels, kernel_size=3),
      nn.ReLU(inplace=True)
  )
  return conv

In [18]:
def cropimg(start_tensor, end_tensor):
  ssize = start_tensor.size()[2]
  esize = end_tensor.size()[2]

  delta = ssize - esize
  delta = delta//2

  return start_tensor[:,:,delta:ssize-delta, delta:ssize-delta]

In [25]:
class UNet(nn.Module):
  def __init__(self):
    super().__init__()

    self.maxpool_2x2 = nn.MaxPool2d(kernel_size=2, stride=2)
    self.dconv1 = double_convolution(1,64)
    self.dconv2 = double_convolution(64, 128)
    self.dconv3 = double_convolution(128,256)
    self.dconv4 = double_convolution(256,512)
    self.dconv5 = double_convolution(512, 1024)

    # Upsampling
    self.uptrans1 = nn.ConvTranspose2d(
        in_channels = 1024,
        out_channels = 512,
        kernel_size = 2,
        stride = 2
    )

    self.upconv1 = double_convolution(1024, 512)
    
    self.uptrans2 = nn.ConvTranspose2d(
        in_channels = 512,
        out_channels = 256,
        kernel_size = 2,
        stride = 2
    )

    self.upconv2 = double_convolution(512, 256)

    self.uptrans3 = nn.ConvTranspose2d(
        in_channels = 256,
        out_channels = 128,
        kernel_size = 2,
        stride = 2
    )

    self.upconv3 = double_convolution(256, 128)

    self.uptrans4 = nn.ConvTranspose2d(
        in_channels = 128,
        out_channels = 64,
        kernel_size = 2,
        stride = 2
    )

    self.upconv4 = double_convolution(128, 64)

    self.out = nn.Conv2d(
        in_channels=64,
        out_channels = 1,
        kernel_size = 1
    )

  def forward(self, image):
    # Contracting path
    x1 = self.dconv1(image)
    x2 = self.maxpool_2x2(x1)
    x3 = self.dconv2(x2)
    x4 = self.maxpool_2x2(x3)
    x5 = self.dconv3(x4)
    x6 = self.maxpool_2x2(x5)
    x7 = self.dconv4(x6)
    x8 = self.maxpool_2x2(x7)
    x9 = self.dconv5(x8)

    # Expanding path
    x = self.uptrans1(x9)
    y = cropimg(x7,x)
    x = self.upconv1(torch.cat([x,y],1))

    x = self.uptrans2(x)
    y = cropimg(x5, x)
    x = self.upconv2(torch.cat([x,y],1))

    x = self.uptrans3(x)
    y = cropimg(x3, x)
    x = self.upconv3(torch.cat([x,y],1))

    x = self.uptrans4(x)
    y = cropimg(x1, x)
    x = self.upconv4(torch.cat([x,y],1))

    x = self.out(x)
    print(x.size())
    return x


In [48]:
if __name__ == '__main__':
  image = torch.rand((1,1,572,572))
  model = UNet()
  x = model(image)

torch.Size([1, 1, 388, 388])


In [57]:
print(model)

UNet(
  (maxpool_2x2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (dconv1): Sequential(
    (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (3): ReLU(inplace=True)
  )
  (dconv2): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
    (3): ReLU(inplace=True)
  )
  (dconv3): Sequential(
    (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
    (3): ReLU(inplace=True)
  )
  (dconv4): Sequential(
    (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1))
    (3): ReLU(inplace=True)
  )
  (dconv5): Sequential(
    (0): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1)