This is a notebook made for fun and experimentation, it is not a tutorial. Consider checking out the other notebook for a structured approach.

In [33]:
# Optimize it further and reproduce results

In [1]:
import torch
import torch.nn as nn

In [7]:
class ConvLayer(nn.Module):
    # This will look better as nn.Sequential
    def __init__(self, in_channels, out_channels, kernel_size):
        super().__init__()
        self.conv_1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size)
        self.conv_2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size)
        self.relu_1 = nn.ReLU()
        self.relu_2 = nn.ReLU()

    def forward(self,x):
        """
        Args:
            x: (B,C,H,W)
        Return:

        """
        x_out = self.conv_1(x)
        x_out = self.relu_1(x_out)
        x_out = self.conv_2(x_out)
        x_out = self.relu_2(x_out)
        return x_out


In [8]:
class DownBlock(nn.Module):
    def __init__(self, kernel_size, stride):
        super().__init__()
        self.max_pool = nn.MaxPool2d(kernel_size=kernel_size, stride=stride)

    def forward(self, x):
        x_out = self.max_pool(x)
        return x_out

In [None]:
# F.interpolate and nn.Upsample can be used too
class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size):
        super().__init__()
        self.conv_transpose_1 = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=2)


    def forward(self,x, x_skip_connection):
        x_out = self.conv_transpose_1(x)
        b_1,c_1,h_1,w_1 = x_out.shape
        b_2,c_2,h_2,w_2 = x_skip_connection.shape
        x_skip_connection = x_skip_connection[:,:,(h_2-h_1)//2:(h_2-h_1)//2+h_1,(w_2-w_1)//2:(w_2-w_1)//2+w_1]
        return torch.cat([x_out,x_skip_connection], dim=1)

In [26]:
# Conv_layer is redundant and used in both blocks, should have added it in the class itself
class UNet(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size_conv, kernel_size_down_block, kernel_size_up_block, stride):
        super().__init__()
        self.conv_layer_1 = ConvLayer(in_channels=in_channels, out_channels=out_channels,kernel_size=kernel_size_conv)
        self.down_block_1 = DownBlock(kernel_size=kernel_size_down_block, stride=stride)
        self.conv_layer_2 = ConvLayer(in_channels=out_channels, out_channels=out_channels*2,kernel_size=kernel_size_conv)
        self.down_block_2 = DownBlock(kernel_size=kernel_size_down_block, stride=stride)
        self.conv_layer_3 = ConvLayer(in_channels=out_channels*2, out_channels=out_channels*4, kernel_size=kernel_size_conv)
        self.down_block_3 = DownBlock(kernel_size=kernel_size_down_block, stride=stride)
        self.conv_layer_4 = ConvLayer(in_channels=out_channels*4, out_channels=out_channels*8, kernel_size=kernel_size_conv)
        self.down_block_4 = DownBlock(kernel_size=kernel_size_down_block, stride=stride)
        self.conv_layer_5 = ConvLayer(in_channels=out_channels*8, out_channels=out_channels*16, kernel_size=kernel_size_conv)
        
        self.up_layer_1 = UpBlock(in_channels=out_channels*16, out_channels=out_channels*8, kernel_size=kernel_size_up_block)
        self.conv_layer_6 = ConvLayer(in_channels=out_channels*16, out_channels=out_channels*8, kernel_size=kernel_size_conv)
        self.up_layer_2 = UpBlock(in_channels=out_channels*8, out_channels=out_channels*4, kernel_size=kernel_size_up_block)
        self.conv_layer_7 = ConvLayer(in_channels=out_channels*8, out_channels=out_channels*4, kernel_size=kernel_size_conv)
        self.up_layer_3 = UpBlock(in_channels=out_channels*4, out_channels=out_channels*2, kernel_size=kernel_size_up_block)
        self.conv_layer_8 = ConvLayer(in_channels=out_channels*4, out_channels=out_channels*2, kernel_size=kernel_size_conv)
        self.up_layer_4 = UpBlock(in_channels=out_channels*2, out_channels=out_channels, kernel_size=kernel_size_up_block)
        self.conv_layer_9 = ConvLayer(in_channels=out_channels*2, out_channels=out_channels, kernel_size=kernel_size_conv)

        self.final_conv_layer = nn.Conv2d(in_channels=out_channels, out_channels=2, kernel_size=1)


    def forward(self, x):
        x1 = self.conv_layer_1(x)
        x2 = self.down_block_1(x1)
        x2 = self.conv_layer_2(x2)
        x3 = self.down_block_2(x2)
        x3 = self.conv_layer_3(x3)
        x4 = self.down_block_3(x3)
        x4 = self.conv_layer_4(x4)
        x5 = self.down_block_4(x4)
        x5 = self.conv_layer_5(x5)

        x_out = self.up_layer_1(x5,x4)
        x_out = self.conv_layer_6(x_out)
        x_out = self.up_layer_2(x_out, x3)
        x_out = self.conv_layer_7(x_out)
        x_out = self.up_layer_3(x_out, x2)
        x_out = self.conv_layer_8(x_out)
        x_out = self.up_layer_4(x_out, x1) 
        x_out = self.conv_layer_9(x_out)
        
        x_out = self.final_conv_layer(x_out)
        return x_out

In [31]:
model = UNet(in_channels=1, out_channels=64, kernel_size_conv=3, kernel_size_down_block=2, kernel_size_up_block=2, stride=2)
dummy_input = torch.randn(1, 1, 572, 572)
output = model(dummy_input)

In [32]:
output.shape

torch.Size([1, 2, 388, 388])