In [11]:
import torch
import torch.nn as nn
import torchvision

In [2]:
class Block(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3) 
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        return x

In [4]:
# Check the block
encoder_block = Block(1, 64)
x = torch.rand(1, 1, 572, 572)
print(f'Shape of the encoder block: {encoder_block(x).shape}')

Shape of the encoder block: torch.Size([1, 64, 568, 568])


In [9]:
class Encoder(nn.Module):
    def __init__(self, channels=(1, 64, 128, 256, 512, 1024)):
        super().__init__()
        self.encoder_blocks = nn.ModuleList(
            [Block(channels[i], channels[i + 1]) for i in range(len(channels) - 1)]
        )
        self.pool = nn.MaxPool2d(kernel_size=2)

    def forward(self, x):
        block_outputs = []
        for block in self.encoder_blocks:
            x = block(x)
            block_outputs.append(x)
            x = self.pool(x)
        return block_outputs

In [10]:
# Check the encoder
encoder = Encoder()
x = torch.rand(1, 1, 572, 572)
encoder_outputs = encoder(x)

for op in encoder_outputs:
    print(f'Shape of the encoder output: {op.shape}')

Shape of the encoder output: torch.Size([1, 64, 568, 568])
Shape of the encoder output: torch.Size([1, 128, 280, 280])
Shape of the encoder output: torch.Size([1, 256, 136, 136])
Shape of the encoder output: torch.Size([1, 512, 64, 64])
Shape of the encoder output: torch.Size([1, 1024, 28, 28])


In [12]:
class Decoder(nn.Module):
    def __init__(self, channels=(1024, 512, 256, 128, 64)):
        super().__init__()
        self.channels = channels
        self.decoder_blocks = nn.ModuleList(
            [Block(channels[i], channels[i + 1]) for i in range(len(channels) - 1)]
        )
        self.upconvolution = nn.ModuleList(
            [nn.ConvTranspose2d(channels[i], channels[i + 1], kernel_size=2, stride=2) for i in range(len(channels) - 1)]
        )

    def forward(self, x, encoder_outputs):
        for i in range(len(self.channels) - 1):
            x = self.upconvolution[i](x)
            encoder_output = self.crop(encoder_outputs[i], x)
            x = torch.cat([x, encoder_output], dim=1)
            x = self.decoder_blocks[i](x)
        return x

    # Following the paper, we crop the encoder output to match the shape of decoder output    
    def crop(self, encoder_output, tensor):
        _, _, H, W = tensor.shape
        encoder_output = torchvision.transforms.CenterCrop([H, W])(encoder_output)
        return encoder_output

In [13]:
# Check the decoder
decoder = Decoder()
x = torch.rand(1, 1024, 28, 28)
decoder(x, encoder_outputs[::-1][1:]) # Pass the encoder outputs in reverse order
print(f'Shape of the decoder output: {decoder(x, encoder_outputs[::-1][1:]).shape}')

Shape of the decoder output: torch.Size([1, 64, 388, 388])


### UNet model

In [23]:
class UNet(nn.Module):
    def __init__(self, encoder_channels=(1, 64, 128, 256, 512, 1024), decoder_channels=(1024, 512, 256, 128, 64), num_classes=5, retain_dim=False, output_size=(572, 572)):
        super().__init__()
        self.encoder = Encoder(encoder_channels)
        self.decoder = Decoder(decoder_channels)
        self.head = nn.Conv2d(decoder_channels[-1], num_classes, kernel_size=1)
        self.retain_dim = retain_dim
        self.output_size = output_size

    def forward(self, x):
        encoder_outputs = self.encoder(x)
        out = self.decoder(encoder_outputs[-1], encoder_outputs[::-1][1:])
        out = self.head(out)
        if self.retain_dim:
            out = nn.functional.interpolate(out, self.output_size)
        return out

In [26]:
# Check the model
model = UNet(retain_dim=True)
x = torch.rand(1, 1, 572, 572)
out = model(x)
print(f'Shape of the model output: {out.shape}')

Shape of the model output: torch.Size([1, 5, 572, 572])


### Model Summary

In [45]:
from torchinfo import summary

unet_model = UNet(retain_dim=True, num_classes=5, output_size=(572, 572))

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
unet_model = unet_model.to(DEVICE)

summary(model=unet_model,
        input_size=(1, 1, 572, 572),
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"],
        depth=5
        )

Layer (type (var_name))                  Input Shape          Output Shape         Param #              Trainable
UNet (UNet)                              [1, 1, 572, 572]     [1, 5, 572, 572]     --                   True
├─Encoder (encoder)                      [1, 1, 572, 572]     [1, 64, 568, 568]    --                   True
│    └─ModuleList (encoder_blocks)       --                   --                   (recursive)          True
│    │    └─Block (0)                    [1, 1, 572, 572]     [1, 64, 568, 568]    --                   True
│    │    │    └─Conv2d (conv1)          [1, 1, 572, 572]     [1, 64, 570, 570]    640                  True
│    │    │    └─ReLU (relu)             [1, 64, 570, 570]    [1, 64, 570, 570]    --                   --
│    │    │    └─Conv2d (conv2)          [1, 64, 570, 570]    [1, 64, 568, 568]    36,928               True
│    │    │    └─ReLU (relu)             [1, 64, 568, 568]    [1, 64, 568, 568]    --                   --
│    └─MaxPool2d (