In [1]:
import torch
import torch.nn as nn
from torch.optim import SGD

In [2]:
torch.cuda.is_available()

False

# U-Net Architecture

![UNET Architecture](UNET_architecture.png)

In [56]:
def double_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, kernel_size=3),
        nn.ReLU(inplace=True))

def up_trans(in_channels, out_channels):
    return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)

def crop(original_tensor, target_tensor):
    target_size = target_tensor.size()[2]
    original_size = original_tensor.size()[2]
    delta = abs(original_size - target_size)
    start = delta // 2
    end = original_size - start
    return original_tensor[:, :, start:end, start:end]

In [97]:
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        
        self.max_pool_2x2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.down_conv_1 = double_conv(1, 64)
        self.down_conv_2 = double_conv(64, 128)
        self.down_conv_3 = double_conv(128, 256)
        self.down_conv_4 = double_conv(256, 512)
        self.down_conv_5 = double_conv(512, 1024)
        
        self.up_trans_6 = up_trans(1024, 512)
        self.up_trans_7 = up_trans(512, 256)
        self.up_trans_8 = up_trans(256, 128)
        self.up_trans_9 = up_trans(128, 64)
        
        self.up_conv_6 = double_conv(1024, 512)
        self.up_conv_7 = double_conv(512, 256)
        self.up_conv_8 = double_conv(256, 128)
        self.up_conv_9 = double_conv(128, 64)
        
        self.out = nn.Conv2d(in_channels=64, out_channels=2, kernel_size=1)
        
    def forward(self, img, verbose=0):
        # Contracting path
        # Block 1
        contracting_1 = self.down_conv_1(img)
        if verbose:
            print(f'Conntracting Block 1: {contracting_1.shape}')
        
        # Block 2
        contracting_2 = self.max_pool_2x2(contracting_1)  
        contracting_2 = self.down_conv_2(contracting_2)
        if verbose:
            print(f'Conntracting Block 2: {contracting_2.shape}')

        # Block 3
        contracting_3 = self.max_pool_2x2(contracting_2)
        contracting_3 = self.down_conv_3(contracting_3)
        if verbose:
            print(f'Conntracting Block 3: {contracting_3.shape}')

        # Block 4
        contracting_4 = self.max_pool_2x2(contracting_3)
        contracting_4 = self.down_conv_4(contracting_4)
        if verbose:
            print(f'Conntracting Block 4: {contracting_4.shape}')

        # Block 5
        contracting_5 = self.max_pool_2x2(contracting_4)
        contracting_5 = self.down_conv_5(contracting_5)
        if verbose:
            print(f'Conntracting Block 5: {contracting_5.shape}')
        
        # Expansive path
        # Block 6
        expansive_6 = self.up_trans_6(contracting_5)
        contracting_4_cropped = crop(contracting_4, expansive_6)
        concat = torch.cat([contracting_4_cropped, expansive_6], dim=1)
        expansive_6 = self.up_conv_6(concat)
        if verbose:
            print(f'Expansive Block 6: {expansive_6.shape}')
        
        # Block 7
        expansive_7 = self.up_trans_7(expansive_6)
        contracting_3_cropped = crop(contracting_3, expansive_7)
        concat = torch.cat([contracting_3_cropped, expansive_7], dim=1)
        expansive_7 = self.up_conv_7(concat)
        if verbose:
            print(f'Expansive Block 7: {expansive_7.shape}')
        
        # Block 8
        expansive_8 = self.up_trans_8(expansive_7)
        contracting_2_cropped = crop(contracting_2, expansive_8)
        concat = torch.cat([contracting_2_cropped, expansive_8], dim=1)
        expansive_8 = self.up_conv_8(concat)
        if verbose:
            print(f'Expansive Block 8: {expansive_8.shape}')
        
        # Block 9
        expansive_9 = self.up_trans_9(expansive_8)
        contracting_1_cropped = crop(contracting_1, expansive_9)
        concat = torch.cat([contracting_1_cropped, expansive_9], dim=1)
        expansive_9 = self.up_conv_9(concat)
        output = self.out(expansive_9)
        if verbose:
            print(f'Expansive Block 9: {expansive_9.shape}')
        
        return output
        

In [98]:
unet = UNet()

In [101]:
img = torch.rand(1, 1, 572, 572) # batch_size, channel, height, width
output = unet.forward(img, verbose=1)

Conntracting Block 1: torch.Size([1, 64, 568, 568])
Conntracting Block 2: torch.Size([1, 128, 280, 280])
Conntracting Block 3: torch.Size([1, 256, 136, 136])
Conntracting Block 4: torch.Size([1, 512, 64, 64])
Conntracting Block 5: torch.Size([1, 1024, 28, 28])
Expansive Block 6: torch.Size([1, 512, 52, 52])
Expansive Block 7: torch.Size([1, 256, 100, 100])
Expansive Block 8: torch.Size([1, 128, 196, 196])
Expansive Block 9: torch.Size([1, 64, 388, 388])


In [100]:
output.size()

torch.Size([1, 2, 388, 388])