In [38]:
import torch
import torch.nn as nn

def double_conv(in_c, out_c):
  conv = nn.Sequential(
      nn.Conv2d(in_c, out_c, kernel_size = 3),
      nn.ReLU(inplace = True),
      nn.Conv2d(out_c, out_c, kernel_size = 3),
      nn.ReLU(inplace = True)
  )

  return conv

def crop_img(tensor, target_tensor):
  target_size = target_tensor.size()[2]
  tensor_size = tensor.size()[2]

  delta = tensor_size - target_size
  delta = delta // 2

  return tensor[:, :, delta:tensor_size-delta, delta:tensor_size-delta]

class UNet(nn.Module):
  def __init__(self):
    super(UNet, self).__init__()

    self.max_pool_2x2 = nn.MaxPool2d(kernel_size = 2, stride = 2)
    self.down_conv_1 = double_conv(1, 64)
    self.down_conv_2 = double_conv(64, 128)
    self.down_conv_3 = double_conv(128, 256)
    self.down_conv_4 = double_conv(256, 512) 
    self.down_conv_5 = double_conv(512, 1024)

    self.up_trans_1 = nn.ConvTranspose2d(in_channels = 1024, 
                                         out_channels = 512,
                                         kernel_size = 2, 
                                         stride = 2)

    self.up_conv_1 = double_conv(1024, 512)

    self.up_trans_2 = nn.ConvTranspose2d(in_channels = 512, 
                                         out_channels = 256,
                                         kernel_size = 2, 
                                         stride = 2)

    self.up_conv_2 = double_conv(512, 256)

    self.up_trans_3 = nn.ConvTranspose2d(in_channels = 256, 
                                         out_channels = 128,
                                         kernel_size = 2, 
                                         stride = 2)

    self.up_conv_3 = double_conv(256, 128)

    self.up_trans_4 = nn.ConvTranspose2d(in_channels = 128, 
                                         out_channels = 64,
                                         kernel_size = 2, 
                                         stride = 2)

    self.up_conv_4 = double_conv(128, 64)

    self.out = nn.Conv2d(in_channels = 64,
                         out_channels = 2,
                         kernel_size = 1)

  def forward(self, image):
    # Encoder
    x1 = self.down_conv_1(image)  
    x2 = self.max_pool_2x2(x1)
    x3 = self.down_conv_2(x2) 
    x4 = self.max_pool_2x2(x3)
    x5 = self.down_conv_3(x4) 
    x6 = self.max_pool_2x2(x5)
    x7 = self.down_conv_4(x6) 
    x8 = self.max_pool_2x2(x7)
    x9 = self.down_conv_5(x8)

    # Decoder
    x = self.up_trans_1(x9)
    y = crop_img(x7, x)
    x = self.up_conv_1(torch.cat([x, y], 1))

    x = self.up_trans_2(x)
    y = crop_img(x5, x)
    x = self.up_conv_2(torch.cat([x, y], 1))

    x = self.up_trans_3(x)
    y = crop_img(x3, x)
    x = self.up_conv_3(torch.cat([x, y], 1))

    x = self.up_trans_4(x)
    y = crop_img(x1, x)
    x = self.up_conv_4(torch.cat([x, y], 1))

    x = self.out(x)
    print(x.size())
    
    return x

In [39]:
image = torch.rand((1, 1, 572, 572)) # batch_size x chanels x height x width
model = UNet()
print(model(image))

torch.Size([1, 2, 388, 388])
tensor([[[[-0.0117, -0.0122, -0.0130,  ..., -0.0149, -0.0132, -0.0152],
          [-0.0123, -0.0137, -0.0150,  ..., -0.0154, -0.0135, -0.0152],
          [-0.0111, -0.0157, -0.0116,  ..., -0.0160, -0.0154, -0.0157],
          ...,
          [-0.0097, -0.0130, -0.0162,  ..., -0.0125, -0.0156, -0.0133],
          [-0.0160, -0.0163, -0.0147,  ..., -0.0103, -0.0166, -0.0137],
          [-0.0129, -0.0144, -0.0131,  ..., -0.0161, -0.0119, -0.0155]],

         [[ 0.0716,  0.0711,  0.0723,  ...,  0.0737,  0.0740,  0.0721],
          [ 0.0749,  0.0732,  0.0697,  ...,  0.0725,  0.0724,  0.0723],
          [ 0.0698,  0.0728,  0.0753,  ...,  0.0722,  0.0742,  0.0712],
          ...,
          [ 0.0749,  0.0709,  0.0726,  ...,  0.0720,  0.0754,  0.0732],
          [ 0.0742,  0.0731,  0.0764,  ...,  0.0735,  0.0722,  0.0696],
          [ 0.0694,  0.0740,  0.0741,  ...,  0.0738,  0.0693,  0.0705]]]],
       grad_fn=<ThnnConv2DBackward>)
