In [1]:
import torch
import torch.nn as nn

In [3]:
def double_conv(in_c, out_c):
    conv = nn.Sequential(
        nn.Conv2d(in_c, out_c, kernel_size = 3),
        nn.ReLU(inplace = True),
        nn.Conv2d(out_c, out_c, kernel_size = 3),
        nn.ReLU(inplace = True),
    )

    return conv

In [31]:
def crop_img(original_tensor, target_tensor):

    target_size = target_tensor.size()[2]
    original_size = original_tensor.size()[2]
    delta = original_size - target_size
    delta = delta//2
    return original_tensor[:, :, delta:original_size-delta, delta:original_size-delta]

In [36]:
def up_transpose(in_channel, out_channel):
    return nn.ConvTranspose2d(
            in_channels = in_channel,
            out_channels = out_channel,
            kernel_size=2,
            stride = 2)

In [78]:
from re import X
class UNet(nn.Module):

    def __init__(self):
        super(UNet, self).__init__()
        self.max_pool_2x2 = nn.MaxPool2d(kernel_size=2, stride = 2)
        self.down_conv1 = double_conv(1, 64)
        self.down_conv2 = double_conv(64, 128)
        self.down_conv3 = double_conv(128, 256)
        self.down_conv4 = double_conv(256, 512)
        self.down_conv5 = double_conv(512, 1024)

        self.transpose_conv_1 = up_transpose(1024, 512)
        self.up_conv_1 = double_conv(1024, 512)
        self.transpose_conv_2 = up_transpose(512, 256)
        self.up_conv_2 = double_conv(512, 256)
        self.transpose_conv_3 = up_transpose(256, 128)
        self.up_conv_3 = double_conv(256, 128)
        self.transpose_conv_4 = up_transpose(128, 64)
        self.up_conv_4 = double_conv(128, 64)

        self.out = nn.Conv2d(
            in_channels = 64,
            out_channels=2,
            kernel_size = 1
        )



    def forward(self, image):

        ## Encoder
        print(f"Shape of image is : {image.shape}")
        print("--------------------------------------------------------")
        x1 = self.down_conv1(image)
        print(f"Shape after first convolution : {x1.shape}")
        x2 = self.max_pool_2x2(x1)
        print(f"Shape after first maxpooling : {x2.shape}")
        print("--------------------------------------------------------")
        x3 = self.down_conv2(x2)
        print(f"Shape after second convolution : {x3.shape}")
        x4 = self.max_pool_2x2(x3)
        print(f"Shape after second maxpooling : {x4.shape}")
        print("--------------------------------------------------------")
        x5 = self.down_conv3(x4)
        print(f"Shape after third convolution : {x5.shape}")
        x6 = self.max_pool_2x2(x5)
        print(f"Shape after third maxpooling : {x6.shape}")
        print("--------------------------------------------------------")
        x7 = self.down_conv4(x6)
        print(f"Shape after fourth convolution : {x7.shape}")
        x8 = self.max_pool_2x2(x7)
        print(f"Shape after fourth maxpooling : {x8.shape}")
        print("--------------------------------------------------------")
        x9 = self.down_conv5(x8)
        print(f"Shape after fifth convolution : {x9.shape}")
        print("We reached the end of Encoder.")
        print("--------------------------------------------------------")

        ## Decoder
        print(f"Shape of input to decoder is : {x9.shape}")
        i1 = self.transpose_conv_1(x9)
        print(f"Shape after first transpose convolution is : {i1.shape}")
        print("--------------------------------------------------------")
        c1 = crop_img(x7, i1)
        print(f"Shape of residual before cropping : {x7.shape}")
        print(f"Shape of residual after cropping : {c1.shape}")
        print(f"Shape before residual connection is : {i1.shape}")
        r1 = torch.cat([i1, c1], 1)
        print(f"Shape after residual connection : {r1.shape}")
        y1 = self.up_conv_1(r1)
        print(f"Shape after first up convolution : {y1.shape}")
        print("--------------------------------------------------------")
        i2 = self.transpose_conv_2(y1)
        c2 = crop_img(x5, i2)
        print(f"Shape of residual before cropping : {x5.shape}")
        print(f"Shape of residual after cropping : {c2.shape}")
        print(f"Shape before residual connection is : {i2.shape}")
        r2 = torch.cat([i2, c2], 1)
        print(f"Shape after residual connection : {r2.shape}")
        y2 = self.up_conv_2(r2)
        print(f"Shape after second up convolution : {y2.shape}")
        print("--------------------------------------------------------")
        i3 = self.transpose_conv_3(y2)
        c3 = crop_img(x3, i3)
        print(f"Shape of residual before cropping : {x3.shape}")
        print(f"Shape of residual after cropping : {c3.shape}")
        print(f"Shape before residual connection is : {i3.shape}")
        r3 = torch.cat([i3, c3], 1)
        print(f"Shape after residual connection : {r3.shape}")
        y3 = self.up_conv_3(r3)
        print(f"Shape after third up convolution : {y3.shape}")
        print("--------------------------------------------------------")
        i4 = self.transpose_conv_4(y3)
        c4 = crop_img(x1, i4)
        print(f"Shape of residual before cropping : {x4.shape}")
        print(f"Shape of residual after cropping : {c4.shape}")
        print(f"Shape before residual connection is : {i4.shape}")
        r4 = torch.cat([i4, c4], 1)
        print(f"Shape after residual connection : {r4.shape}")
        y4 = self.up_conv_4(r4)
        print(f"Shape after third up convolution : {y4.shape}")
        print("--------------------------------------------------------")
        out = self.out(y4)
        print(f"Output shape : {out.shape}")






In [79]:
image = torch.rand((1, 1, 572, 572))
print(f"Original Shape : {image.shape}")
model = UNet()

Original Shape : torch.Size([1, 1, 572, 572])


In [80]:
print(model(image))

Shape of image is : torch.Size([1, 1, 572, 572])
--------------------------------------------------------
Shape after first convolution : torch.Size([1, 64, 568, 568])
Shape after first maxpooling : torch.Size([1, 64, 284, 284])
--------------------------------------------------------
Shape after second convolution : torch.Size([1, 128, 280, 280])
Shape after second maxpooling : torch.Size([1, 128, 140, 140])
--------------------------------------------------------
Shape after third convolution : torch.Size([1, 256, 136, 136])
Shape after third maxpooling : torch.Size([1, 256, 68, 68])
--------------------------------------------------------
Shape after fourth convolution : torch.Size([1, 512, 64, 64])
Shape after fourth maxpooling : torch.Size([1, 512, 32, 32])
--------------------------------------------------------
Shape after fifth convolution : torch.Size([1, 1024, 28, 28])
We reached the end of Encoder.
--------------------------------------------------------
Shape of input to de