Import libraries

In [1]:
import torch
import torch.nn as nn

Operations block

In [2]:
def conv_block(in_channels, out_channels):
  conv_ops = nn.Sequential(
      nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
      nn.BatchNorm2d(out_channels),
      nn.ReLU(inplace=True),
      nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
      nn.BatchNorm2d(out_channels),
      nn.ReLU(inplace=True),
  )

  return conv_ops

U-Net Class



In [3]:
class UNet(nn.Module):
  def __init__(self, in_channels, num_classes):
    super(UNet, self).__init__()

    # assign number of channels for input image
    self.in_channels = in_channels

    # assign number of classes for output layer
    self.num_classes = num_classes

    # declare max pooling operation
    self.max_pool = nn.MaxPool2d(2, 2)

    # declare all downward blocks of convolution+activation
    self.down_conv1 = conv_block(in_channels=self.in_channels, out_channels=64)
    self.down_conv2 = conv_block(in_channels=64, out_channels=128)
    self.down_conv3 = conv_block(in_channels=128, out_channels=256)
    self.down_conv4 = conv_block(in_channels=256, out_channels=512)

    # declare bridge convolution block
    self.bridge = conv_block(in_channels=512, out_channels=1024)

    # declare up-sampling transformations
    self.conv_trans1 = nn.ConvTranspose2d(
        in_channels=1024,
        out_channels=512,
        kernel_size=2,
        stride=2,
    )
    self.conv_trans2 = nn.ConvTranspose2d(
        in_channels=512,
        out_channels=256,
        kernel_size=2,
        stride=2,
    )
    self.conv_trans3 = nn.ConvTranspose2d(
        in_channels=256,
        out_channels=128,
        kernel_size=2,
        stride=2,
    )
    self.conv_trans4 = nn.ConvTranspose2d(
        in_channels=128,
        out_channels=64,
        kernel_size=2,
        stride=2,
    )

    # declare upward convolution+activation blocks
    self.up_conv1 = conv_block(in_channels=1024, out_channels=512)
    self.up_conv2 = conv_block(in_channels=512, out_channels=256)
    self.up_conv3 = conv_block(in_channels=256, out_channels=128)
    self.up_conv4 = conv_block(in_channels=128, out_channels=64)

    # declare final layer
    self.final = nn.Conv2d(
        in_channels=64,
        out_channels=self.num_classes,
        kernel_size=1
    )

  def forward(self, x):
    skip_connections = []

    # encoding path: convolve, downsample
    x1 = self.down_conv1(x)
    skip_connections.append(x1)
    x1_pool = self.max_pool(x1)

    x2 = self.down_conv2(x1_pool)
    skip_connections.append(x2)
    x2_pool = self.max_pool(x2)

    x3 = self.down_conv3(x2_pool)
    skip_connections.append(x3)
    x3_pool = self.max_pool(x3)

    x4 = self.down_conv4(x3_pool)
    skip_connections.append(x4)
    x4_pool = self.max_pool(x4)

    # bridge
    x_5 = self.bridge(x4_pool)

    # decoding path: upsample, concatenate, convolve
    x6_up = self.conv_trans1(x_5)
    x6_cat = torch.cat((skip_connections.pop(), x6_up), dim=1)
    x6 = self.up_conv1(x6_cat)

    x7_up = self.conv_trans2(x6)
    x7_cat = torch.cat((skip_connections.pop(), x7_up), dim=1)
    x7 = self.up_conv2(x7_cat)

    x8_up = self.conv_trans3(x7)
    x8_cat = torch.cat((skip_connections.pop(), x8_up), dim=1)
    x8 = self.up_conv3(x8_cat)

    x9_up = self.conv_trans4(x8)
    x9_cat = torch.cat((skip_connections.pop(), x9_up), dim=1)
    x9 = self.up_conv4(x9_cat)

    x_final = self.final(x9)

    return x_final

Dimensions check

In [5]:
if __name__ == '__main__':
    input_image = torch.rand((1, 3, 512, 512))
    model = UNet(in_channels=3, num_classes=2)
    # Total parameters and trainable parameters.
    total_params = sum(p.numel() for p in model.parameters())
    print(f"{total_params:,} total parameters.")
    total_trainable_params = sum(
        p.numel() for p in model.parameters() if p.requires_grad)
    print(f"{total_trainable_params:,} training parameters.")
    outputs = model(input_image)
    print(outputs.shape)

31,037,698 total parameters.
31,037,698 training parameters.
torch.Size([1, 2, 512, 512])
