# U-NET

This is an implementation of U2-Net as published by paper in [here](https://arxiv.org/abs/2005.09007)

In [1]:
import torch
from torch.utils.data import DataLoader
from torch import nn
import torchvision 

  from .autonotebook import tqdm as notebook_tqdm


U2Net Implementation structure


<img src="/home/oem/Documents/coding/personal/computer_vision_toolkit/assets/U2net.png"  width="600" height="500">

In [2]:
# repeated application of two 3x3 convolutions (unpadded convolutions)
def double_conv(in_ch, out_ch):
    # Blue arrows
    conv = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, kernel_size=3),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_ch, out_ch, kernel_size=3),
                nn.ReLU(inplace=True)
            )
    return conv

# 
def crop_img(tensor, target_tensor):
    # White blocks
    target_size = target_tensor.shape[2]    # get height of target
    tensor_size = tensor.shape[2]           # get present tensor height
    diff = tensor_size - target_size
    diff = diff//2

    return tensor[:, :, diff:tensor_size-diff, diff:tensor_size-diff]

In [7]:
class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        # left side 
        self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.down_conv_1 = double_conv(1,64)
        self.down_conv_2 = double_conv(64,128)
        self.down_conv_3 = double_conv(128,256)
        self.down_conv_4 = double_conv(256,512)
        self.down_conv_5 = double_conv(512,1024)
        # right side
        self.up_trans_1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.up_conv_1 = double_conv(1024,512)
        self.up_trans_2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.up_conv_2 = double_conv(512,256)
        self.up_trans_3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.up_conv_3 = double_conv(256,128)
        self.up_trans_4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.up_conv_4 = double_conv(128,64)
        self.out = nn.Conv2d(64,2,kernel_size=1)

    def forward(self, x):
        # left side
        x1 = self.down_conv_1(x) # skip connection
        x2 = self.max_pool(x1)
        x3 = self.down_conv_2(x2) # skip connection
        x4 = self.max_pool(x3)
        x5 = self.down_conv_3(x4) # skip connection
        x6 = self.max_pool(x5)
        x7 = self.down_conv_4(x6) # skip connection
        x8 = self.max_pool(x7)
        x9 = self.down_conv_5(x8)
        # right side
        x = self.up_trans_1(x9) # upsampling
        y = crop_img(x7, x)
        x = self.up_conv_1(torch.cat([y,x],1))
        x = self.up_trans_2(x)
        y = crop_img(x5, x)
        x = self.up_conv_2(torch.cat([y,x],1))
        x = self.up_trans_3(x)
        y = crop_img(x3, x)
        x = self.up_conv_3(torch.cat([y,x],1))
        x = self.up_trans_4(x)
        y = crop_img(x1, x)
        x = self.up_conv_4(torch.cat([y,x],1))
        x = self.out(x)

        return x



In [8]:
a = torch.rand((1,1,572,572))
model = UNet()

model(a)

torch.Size([1, 2, 388, 388])


## Difference between UpSampling2d and ConvTranspose2d

- Upsampling (opposite of maxpooling) - Expands the existing dimensions of the image
- ConvTranspose2d performs upsampling and convolution