In [1]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
!pip install -q torch-summary
from torchsummary import summary
from torchvision.transforms import CenterCrop

In [93]:
class Unet(nn.Module):
    def __init__(self):
        super().__init__()
        self.down_block1 = self.conv_down_block(1, 64)
        self.down_block2 = self.conv_down_block(64, 128)
        self.down_block3 = self.conv_down_block(128, 256)
        self.down_block4 = self.conv_down_block(256, 512)
        self.down_block5 = self.conv_down_block(512, 1024)
        
        
        self.conv_tr1 = nn.ConvTranspose2d(1024, 512, 2, 2)
        self.up_block1 = self.conv_up_block(1024, 512)
        
        self.conv_tr2 = nn.ConvTranspose2d(512, 256, 2, stride = 2)
        self.up_block2 = self.conv_up_block(512, 256)
        
        self.conv_tr3 = nn.ConvTranspose2d(256, 128, 2, stride = 2)
        self.up_block3 = self.conv_up_block(256, 128)
        
        self.conv_tr4 = nn.ConvTranspose2d(128, 64, 2, stride = 2)
        self.up_block4 = self.conv_up_block(128, 64)
        
        self.final_conv = nn.Conv2d(64, 2, 1)
        
    
    def forward(self, input):
        #Encoder 
        d1 = self.down_block1(input)
        d2 = self.down_block2(F.max_pool2d(d1, 2))
        d3 = self.down_block3(F.max_pool2d(d2, 2))
        d4 = self.down_block4(F.max_pool2d(d3, 2))
        d5 = self.down_block5(F.max_pool2d(d4, 2))
        
        #Decoder
        u1 = self.conv_tr1(d5)
        sh1 = (u1.shape[-2], u1.shape[-1])
        d4_crop = CenterCrop(sh1)(d4)
        u1 = torch.cat((u1, d4_crop))
        u1 = self.up_block1(u1)
        
        u2 = self.conv_tr2(u1)
        sh2 = (u2.shape[-2], u2.shape[-1])
        d3_crop = CenterCrop(sh2)(d3)
        u2 = torch.cat((u2, d3_crop))
        u2 = self.up_block2(u2)
        
        u3 = self.conv_tr3(u2)
        sh3 = (u3.shape[-2], u3.shape[-1])
        d2_crop = CenterCrop(sh3)(d2)
        u3 = torch.cat((u3, d2_crop))
        u3 = self.up_block3(u3)
        
        u4 = self.conv_tr4(u3)
        sh4 = (u4.shape[-2], u4.shape[-1])
        d4_crop = CenterCrop(sh4)(d1)
        u4 = torch.cat((u4, d4_crop))
        u4 = self.up_block4(u4)
        
        ret = self.final_conv(u4)
        
        return ret
 
        
    
    def conv_down_block(self, fan_in, fan_out):
        self.block = nn.Sequential(
            nn.Conv2d(fan_in, fan_out, 3),
            nn.ReLU(),
            nn.Conv2d(fan_out, fan_out, 3),
            nn.ReLU()
        )
        return self.block
    
    def conv_up_block(self, fan_in, fan_out):
        self.block = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(fan_in, fan_out, 3),
            nn.ReLU(),
            nn.Conv2d(fan_out, fan_out, 3),
            nn.ReLU()
        )
        return self.block
    
    

In [94]:
model = Unet()
a = torch.rand((1, 572, 572))
model(a).shape


torch.Size([2, 388, 388])

The results match the ones mentioned in the original paper, hence we can be assured that the implementation was accurately done.