In [2]:
import torch
import torch.nn as nn


In [3]:
def double_conv(in_channels, out_channels):
    conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size = 3),
        nn.ReLU(inplace = True),
        nn.Conv2d(out_channels, out_channels, kernel_size = 3),
        nn.ReLU(inplace = True)
    )
    return conv

In [9]:
def crop_image(original, target):
    target_size = target.size()[2]
    original_size = original.size()[2]
    delta = original_size - target_size
    assert(delta >= 0)
    delta = delta // 2
    return original[:, :, delta: original_size-delta, delta:original_size-delta]

In [24]:
class UNet(nn.Module):

    def __init__(self):
        super(UNet, self).__init__()

        self.max_pool_2x2 = nn.MaxPool2d(kernel_size = 2, stride = 2)
        self.down_conv1 = double_conv(1, 64)
        self.down_conv2 = double_conv(64, 128)
        self.down_conv3 = double_conv(128, 256)
        self.down_conv4 = double_conv(256, 512)
        self.down_conv5 = double_conv(512, 1024)

        self.trans_conv1 = nn.ConvTranspose2d(
                in_channels = 1024,
                out_channels = 512,
                kernel_size = 2,
                stride = 2
            )
    
        self.up_conv1 = double_conv(1024, 512)

        self.trans_conv2 = nn.ConvTranspose2d(
                in_channels = 512,
                out_channels = 256,
                kernel_size = 2,
                stride = 2
            )
    
        self.up_conv2 = double_conv(512, 256)

        self.trans_conv3 = nn.ConvTranspose2d(
                in_channels = 256,
                out_channels = 128,
                kernel_size = 2,
                stride = 2
            )
    
        self.up_conv3 = double_conv(256, 128)

        self.trans_conv4 = nn.ConvTranspose2d(
                in_channels = 128,
                out_channels = 64,
                kernel_size = 2,
                stride = 2
            )
    
        self.up_conv4 = double_conv(128, 64)

        self.out = nn.Conv2d(
            in_channels = 64,
            out_channels = 2,    # increase according to number of classes
            kernel_size = 1
        )

    def forward(self, image):
        # image = [batch_size, channels, h, w]
        # Encoder part
        x1 = self.down_conv1(image) # use for concat
        x2 = self.max_pool_2x2(x1)

        x3 = self.down_conv2(x2) # use for concat
        x4 = self.max_pool_2x2(x3)

        x5 = self.down_conv3(x4) # use for concat
        x6 = self.max_pool_2x2(x5)

        x7 = self.down_conv4(x6) # use for concat
        x8 = self.max_pool_2x2(x7)
        
        x9 = self.down_conv5(x8)

        # image size after first part
        print(f'Image after first part : {x9.shape}')

        # Decoder part
        x = self.trans_conv1(x9)
        y = crop_image(x7, x)
        x = self.up_conv1(torch.cat([x, y], 1))

        x = self.trans_conv2(x)
        y = crop_image(x5, x)
        x = self.up_conv2(torch.cat([x, y], 1))

        x = self.trans_conv3(x)
        y = crop_image(x3, x)
        x = self.up_conv3(torch.cat([x, y], 1))

        x = self.trans_conv4(x)
        y = crop_image(x1, x)
        x = self.up_conv4(torch.cat([x, y], 1))

        x = self.out(x)
        print(x.size())
        return x

In [26]:
sample = torch.randn(1, 1, 572, 572)
print(sample.shape)
model = UNet()
print(model(sample)) # expected size according to paper = [1, 2, 388, 388]

torch.Size([1, 1, 572, 572])
Image after first part : torch.Size([1, 1024, 28, 28])
torch.Size([1, 2, 388, 388])
tensor([[[[ 0.0475,  0.0502,  0.0417,  ...,  0.0439,  0.0425,  0.0439],
          [ 0.0486,  0.0437,  0.0477,  ...,  0.0289,  0.0350,  0.0436],
          [ 0.0499,  0.0430,  0.0356,  ...,  0.0519,  0.0380,  0.0338],
          ...,
          [ 0.0382,  0.0460,  0.0478,  ...,  0.0526,  0.0222,  0.0196],
          [ 0.0428,  0.0301,  0.0499,  ...,  0.0401,  0.0342,  0.0371],
          [ 0.0316,  0.0365,  0.0528,  ...,  0.0428,  0.0544,  0.0505]],

         [[-0.0805, -0.0946, -0.0886,  ..., -0.0820, -0.0899, -0.0962],
          [-0.0995, -0.0774, -0.0865,  ..., -0.0999, -0.0961, -0.0926],
          [-0.0864, -0.1052, -0.1003,  ..., -0.0984, -0.0968, -0.0962],
          ...,
          [-0.0918, -0.0828, -0.0936,  ..., -0.0887, -0.0881, -0.0937],
          [-0.0861, -0.0995, -0.0926,  ..., -0.1038, -0.0891, -0.0805],
          [-0.0954, -0.0854, -0.0886,  ..., -0.0825, -0.0784, -