**UNet**


---



In [None]:
import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchsummary import summary

UNet has two main block types: 

contract_block reduces the size but increases the number of channels.

expand_block increases size but decreases the number of channels

In [None]:
def contract_block(in_channels,out_channels,kernel_size,padding):
  contract = nn.Sequential(
    torch.nn.Conv2d(in_channels,out_channels,kernel_size,stride=1,padding=padding),
    torch.nn.BatchNorm2d(out_channels),
    torch.nn.ReLU(),
    torch.nn.Conv2d(out_channels,out_channels,kernel_size,stride=1,padding=padding),
    torch.nn.BatchNorm2d(out_channels),
    torch.nn.ReLU(),
    # the next line reduces the size by half
    torch.nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
  )
  return contract

def expand_block(in_channels,out_channels,kernel_size,padding):
  expand = nn.Sequential(
    torch.nn.Conv2d(in_channels,out_channels,kernel_size,stride=1,padding=padding),
    torch.nn.BatchNorm2d(out_channels),
    torch.nn.ReLU(),
    torch.nn.Conv2d(out_channels,out_channels,kernel_size,stride=1,padding=padding),
    torch.nn.BatchNorm2d(out_channels),
    torch.nn.ReLU(),
    # the next line increases the size by a factor of two
    torch.nn.ConvTranspose2d(out_channels,out_channels,kernel_size=3,stride=2,padding=1,output_padding=1)      
  )
  return expand


In [None]:
# Test the blocks - kernel_size and padding should be properly chosen
block1 = contract_block(3,32,3,1)
xb = torch.randn((1,3,128,128))
print('xb:',xb.shape)
c1 = block1(xb)
print('c1:',c1.shape)

block2 = expand_block(32,3,3,1)
c2 = block2(c1)
print('c2:',c2.shape)

xb: torch.Size([1, 3, 128, 128])
c1: torch.Size([1, 32, 64, 64])
c2: torch.Size([1, 3, 128, 128])


UNet contracts and expands the input. When expanding, it concatenates outputs from previous layers. 

In [None]:
class UNET(nn.Module):
  def __init__(self,in_channels,out_channels):
    super().__init__()

    self.conv1 = contract_block(in_channels,32,3,1)
    self.conv2 = contract_block(32,64,3,1)
    self.conv3 = contract_block(64,128,3,1)

    self.upconv3 = expand_block(128,64,3,1)
    self.upconv2 = expand_block(64*2,32,3,1) # Due to concat, in_channels=64*2
    self.upconv1 = expand_block(32*2,out_channels,3,1)

  
  def forward(self,x):
    # downsample
    conv1 = self.conv1(x)
    conv2 = self.conv2(conv1)
    conv3 = self.conv3(conv2)

    upconv3 = self.upconv3(conv3)
    upconv2 = self.upconv2(torch.cat([conv2,upconv3],1))
    upconv1 = self.upconv1(torch.cat([conv1,upconv2],1))

    return upconv1

In [None]:
# test the model
unet = UNET(4,128)

print(unet)


xb = torch.randn((1,4,384,384))
out = unet(xb)
print(out.shape)
