In [21]:
import torch
import torch.nn as nn


def double_conv(in_c, out_c):
    conv = nn.Sequential(
        nn.Conv2d(in_c, out_c, kernel_size=3),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_c, out_c, kernel_size=3),
        nn.ReLU(inplace=True)
    )
    return conv


def crop_image(target, inputs):
    target_size = target.size()[2]
    inputs_size = inputs.size()[2]
    diff = inputs_size - target_size
    diff = diff//2
    return inputs[:,:,diff:inputs_size-diff, diff:inputs_size-diff ]
    
    
    
class UNET(nn.Module):
    def __init__(self):
        super(UNET, self).__init__()
        self.max_pool_2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.down_conv_1 = double_conv(1,64)
        self.down_conv_2 = double_conv(64,128)
        self.down_conv_3 = double_conv(128,256)
        self.down_conv_4 = double_conv(256,512)
        self.down_conv_5 = double_conv(512,1024)
        
        
        self.uptrans_1 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=2, stride=2)
        
        self.up_conv1_1 = double_conv(1024, 512)

        self.uptrans_2 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=2, stride=2)
        
        self.up_conv1_2 = double_conv(512, 256)
        
        
        self.uptrans_3 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2, stride=2)
        
        self.up_conv1_3 = double_conv(256, 128)
        
        
        self.uptrans_4 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2)
        
        self.up_conv1_4 = double_conv(128, 64)
        
        self.out = nn.Conv2d(in_channels=64, out_channels=2, kernel_size=1)
        
    def forward(self,image):
        
        x1 = self.down_conv_1(image)
        print(x1.shape)
        x2 = self.max_pool_2(x1)
        print(x2.shape)
        x3 = self.down_conv_2(x2)
        x4 = self.max_pool_2(x3)
        x5 = self.down_conv_3(x4)
        x6 = self.max_pool_2(x5)
        x7 = self.down_conv_4(x6)
        x8 = self.max_pool_2(x7)
        x9 = self.down_conv_5(x8)
        x10 = self.uptrans_1(x9)
        x11 = crop_image(x10, x7)
        x = self.up_conv1_1(torch.cat([x10,x11], 1))
        
        x2 = self.uptrans_2(x)
        x = crop_image(x2, x5)
        x = self.up_conv1_2(torch.cat([x2,x], 1))
        
        x10 = self.uptrans_3(x)
        x11 = crop_image(x10, x3)
        x = self.up_conv1_3(torch.cat([x10,x11], 1))
        
        x10 = self.uptrans_4(x)
        x11 = crop_image(x10, x1)
        x = self.up_conv1_4(torch.cat([x10,x11], 1))
        print(x.shape)
        
        return self.out(x)
    
if __name__=="__main__":
    rand_image = torch.rand((1,1,572,572))
    model = UNET()
    image = model(rand_image)
    print(image.shape)

torch.Size([1, 64, 568, 568])
torch.Size([1, 64, 284, 284])
torch.Size([1, 64, 388, 388])
torch.Size([1, 2, 388, 388])
