### `U-Net: Convolutional Networks for Biomedical Image Segmentation`

# ![UNet Architecture](../u-net-architecture.png)

### Import torch

In [1]:
# Import PyTorch
import torch
from torch import nn

torch.__version__

'1.13.1'

### Helper Functions

In [2]:
def conv2x(in_channels: int, out_channels: int) -> nn.Sequential:
    conv = nn.Sequential(
        nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3),
        nn.ReLU(inplace=True),
        nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3),
        nn.ReLU(inplace=True),
    )
    return conv

def conv_trans2x(in_channels: int, out_channels: int) -> nn.Sequential:
    convTrans = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=2, stride=2)
    return convTrans

def crop_img(tensor: torch.Tensor, target_tensor: torch.Tensor) -> torch.Tensor:
    target_size = target_tensor.size()[2] # Second index, Height
    tensor_size = tensor.size()[2]
    delta = tensor_size - target_size
    delta = delta // 2
    return tensor[:, :, delta:tensor_size - delta, delta:tensor_size - delta]

### UNet Implementation with PyTorch

In [3]:
class UNet(nn.Module):
    def __init__(self) -> None:
        super(UNet, self).__init__()

        self.conv_c = [
            [1, 64],
            [64, 128],
            [128, 256],
            [256, 512],
            [512, 1024]
        ]

        self.reverse_conv_c = [conv_c[::-1] for conv_c in self.conv_c[:0:-1]] # [[1024, 512], [512, 256], [256, 128], [128, 64]]

        self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.down_convs = [conv2x(channel[0], channel[1]) for channel in self.conv_c]
        self.up_convs = [[conv_trans2x(channel[0], channel[1]), conv2x(channel[0], channel[1])] for channel in self.reverse_conv_c]

        self.out = nn.Conv2d(
            in_channels=64,
            out_channels=2,
            kernel_size=1
        )

    def forward(self, img):
        # Encoder (Downward)
        # Batch Size, Channel, Height, Width
        # Img size: [1, 1, 572, 572]

        x1 = self.down_convs[0](img) # [1, 64, 568, 568]

        x2 = self.max_pool(x1) # [1, 64, 284, 284]
        x3 = self.down_convs[1](x2) # [1, 128, 280, 280]

        x4 = self.max_pool(x3) # [1, 128, 140, 140]
        x5 = self.down_convs[2](x4) # [1, 256, 136, 136]

        x6 = self.max_pool(x5) # [1, 256, 68, 68]
        x7 = self.down_convs[3](x6) # [1, 512, 64, 64]

        x8 = self.max_pool(x7) # [1, 512, 32, 32]
        x9 = self.down_convs[4](x8) # [1, 1024, 28, 28]


        # Decoder (Upward)
        x = self.up_convs[0][0](x9) # [1, 512, 56, 56]
        y = crop_img(x7, x) # [1, 512, 56, 56]
        x = self.up_convs[0][1](torch.cat([x, y], 1)) # [1, 512, 52, 52]

        x = self.up_convs[1][0](x) # [1, 256, 104, 104]
        y = crop_img(x5, x) # [1, 256, 104, 104]
        x = self.up_convs[1][1](torch.cat([x, y], 1)) # [1, 256, 100, 100]

        x = self.up_convs[2][0](x) # [1, 128, 200, 200]
        y = crop_img(x3, x) # [1, 128, 200, 200]
        x = self.up_convs[2][1](torch.cat([x, y], 1)) # [1, 128, 196, 196]

        x = self.up_convs[3][0](x) # [1, 64, 392, 392]
        y = crop_img(x1, x) # [1, 64, 392, 392]
        x = self.up_convs[3][1](torch.cat([x, y], 1)) # [1, 64, 388, 388]

        x = self.out(x) # [1, 2, 388, 388]

        return x

In [4]:
torch.manual_seed(42)
image = torch.rand((1, 1, 572, 572))
model = UNet()
print(model(image))

tensor([[[[0.0520, 0.0471, 0.0536,  ..., 0.0474, 0.0412, 0.0463],
          [0.0484, 0.0485, 0.0556,  ..., 0.0482, 0.0504, 0.0503],
          [0.0488, 0.0547, 0.0505,  ..., 0.0475, 0.0454, 0.0451],
          ...,
          [0.0485, 0.0485, 0.0497,  ..., 0.0478, 0.0484, 0.0427],
          [0.0501, 0.0495, 0.0444,  ..., 0.0513, 0.0541, 0.0505],
          [0.0515, 0.0463, 0.0527,  ..., 0.0504, 0.0522, 0.0508]],

         [[0.1161, 0.1126, 0.1167,  ..., 0.1156, 0.1171, 0.1164],
          [0.1163, 0.1153, 0.1147,  ..., 0.1180, 0.1162, 0.1136],
          [0.1154, 0.1164, 0.1148,  ..., 0.1163, 0.1178, 0.1163],
          ...,
          [0.1176, 0.1167, 0.1154,  ..., 0.1147, 0.1136, 0.1148],
          [0.1201, 0.1134, 0.1154,  ..., 0.1144, 0.1158, 0.1130],
          [0.1154, 0.1170, 0.1159,  ..., 0.1166, 0.1147, 0.1158]]]],
       grad_fn=<ConvolutionBackward0>)
