In [11]:
import torch
import torch.nn.functional as F
from torch import nn
from torch import Tensor
from typing import Any, Optional, Type, TYPE_CHECKING, Union, Callable, Dict, List, Set, Tuple
from torchinfo import summary

# Define the Blocks

In [4]:
class UnetDoubleConvBlock(nn.Module):
    def __init__(self, in_channels: int, n_filters: int, out_channels: int, kernel_size: int, padding: int, batch_norm_first: bool=True):
        super().__init__()
        if batch_norm_first:
            self.conv_proj = nn.Sequential(
                nn.Conv2d(in_channels=in_channels, out_channels=n_filters, kernel_size=kernel_size, padding=padding, bias=True, padding_mode='zeros'),
                nn.BatchNorm2d(n_filters),
                nn.ReLU(inplace=True),
                nn.Conv2d(in_channels=n_filters, out_channels=out_channels, kernel_size=kernel_size, padding=padding, bias=True, padding_mode='zeros'),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True)
            )
        else:
            self.conv_proj = nn.Sequential(
                nn.Conv2d(in_channels=in_channels, out_channels=n_filters, kernel_size=kernel_size, padding=padding, bias=True, padding_mode='zeros'),
                nn.ReLU(inplace=True),
                nn.BatchNorm2d(n_filters),
                nn.Conv2d(in_channels=n_filters, out_channels=out_channels, kernel_size=kernel_size, padding=padding, bias=True, padding_mode='zeros'),
                nn.ReLU(inplace=True),
                nn.BatchNorm2d(out_channels),
            )

    def forward(self, x: Tensor) -> Tensor:
        return self.conv_proj(x)

class UnetDownBlock(nn.Module):
    def __init__(self, in_channels: int, n_filters: int, out_channels: int, kernel_size: int, padding: int, batch_norm_first: bool=True):
        super().__init__()
        self.conv_proj = UnetDoubleConvBlock(in_channels, n_filters, out_channels, kernel_size, padding, batch_norm_first)
        self.pool = nn.MaxPool2d(2)

    def forward(self, x: Tensor, pool=True):
        x = self.conv_proj(x)
        if pool:
            p = self.pool(x)
            return x, p
        else:
            return x

class UnetUpBlock(nn.Module):
    def __init__(self, in_channels: int, n_filters: int, out_channels: int, bilinear: bool=True, **kwargs):
        super().__init__()
        self.conv_proj = UnetDoubleConvBlock(in_channels, n_filters, out_channels, **kwargs)
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear')
        else:
            self.up = nn.ConvTranspose2d(out_channels, out_channels, kernel_size=kwargs['kernel_size'], stride=2, bias=False, padding=1, output_padding=1)

    def forward(self, x1: Tensor, x2: Tensor):
        x1 = self.conv_proj(x1)
        x1 = self.up(x1)
        x = torch.cat([x1, x2], dim=1)
        return x

class UnetOutputBlock(nn.Module):
    def __init__(self, in_channels: int, n_classes: int):
        super().__init__()
        self.classifier = nn.Conv2d(in_channels, n_classes, (1, 1), padding_mode='zeros')

    def forward(self, x: Tensor) -> Tensor:
        return self.classifier(x)

# Define Encoder

In [7]:
class UnetEncoder(nn.Module):
    def __init__(self, in_channels: int, emb_sizes: List[int], out_channels: List[int], kernel_sizes: List[int], paddings: List[int], batch_norm_first: bool=True):
        super().__init__()
        self.blocks = nn.ModuleList([
            UnetDownBlock(in_channels, emb_sizes[0], out_channels[0], kernel_sizes[0], paddings[0], batch_norm_first),
            *[UnetDownBlock(in_channel, emb_size, out_channel, kernel_size, padding, batch_norm_first)
            for i, (in_channel, emb_size, out_channel, kernel_size, padding) in enumerate(zip(out_channels[:-1], emb_sizes[1:], out_channels[1:], kernel_sizes[1:], paddings[1:]))]
        ])

    def forward(self, img_input: Tensor) -> Tensor:
        x = img_input
        levels = []
        for block in self.blocks:
            prev, x = block(x, True)
            levels.append(prev)
        levels.append(x)
        return img_input, levels

# Define Decoder

In [8]:
class UnetDecoder(nn.Module):
    def __init__(self, in_channels: List[int], emb_sizes: List[int], out_channels: List[int], kernel_sizes: List[int], paddings: List[int], batch_norm_first: bool=True, bilinear: bool=True):
        super().__init__()
        self.blocks = nn.ModuleList([
            UnetUpBlock(in_channel, emb_size, out_channel, bilinear, kernel_size=kernel_size, padding=padding, batch_norm_first=batch_norm_first) 
            for i, (in_channel, emb_size, out_channel, kernel_size, padding) in enumerate(zip(in_channels, emb_sizes, out_channels, kernel_sizes, paddings))
        ])

    def forward(self, levels: List[Tensor]) -> Tensor:
        assert len(levels)==len(self.blocks)+1, "The size of downsampled results doesn't match the number of upward blocks"
        levels = levels[::-1]
        x = levels[0]
        levels = levels[1:]
        for level, block in zip(levels, self.blocks):
            x = block(x, level)
        return x

# Define UNet

In [9]:
class UNet(nn.Module):
    def __init__(self, downward_params, upward_params, output_params):
        super().__init__()
        self.encoder = UnetEncoder(**downward_params)
        self.decoder = UnetDecoder(**upward_params)
        self.classifier = UnetOutputBlock(**output_params)

    def forward(self, img_input: Tensor):
        img_input, levels = self.encoder(img_input)
        x = self.decoder(levels)
        x = self.classifier(x)
        return x

# Experiment

In [10]:
downward_params = {
    'in_channels': 3, 
    'emb_sizes': [32, 64, 128, 256, 512], 
    'out_channels': [32, 64, 128, 256, 512],
    'kernel_sizes': [3, 3, 3 ,3 ,3], 
    'paddings': [1, 1, 1, 1, 1], 
    'batch_norm_first': False,
}
upward_params = {
    'in_channels': [512, 1024, 512, 256, 128],
    'emb_sizes': [1024, 512, 256, 128, 64], 
    'out_channels': [512, 256, 128, 64, 32],
    'kernel_sizes': [3, 3, 3, 3, 3], 
    'paddings': [1, 1, 1, 1, 1], 
    'batch_norm_first': False, 
    'bilinear': True,
}
output_params = {
    'in_channels': 64,
    'n_classes': 2,
}

In [12]:
x = torch.rand(1, 3, 288, 288)
model = UNet(downward_params, upward_params, output_params)
out = model(x)

[W NNPACK.cpp:53] Could not initialize NNPACK! Reason: Unsupported hardware.


In [13]:
out.shape

torch.Size([1, 2, 288, 288])

In [14]:
summary(model, input_size=(1, 3, 288, 288))

Layer (type:depth-idx)                             Output Shape              Param #
UNet                                               [1, 2, 288, 288]          --
├─UnetEncoder: 1-1                                 [1, 3, 288, 288]          --
│    └─ModuleList: 2-1                             --                        --
│    │    └─UnetDownBlock: 3-1                     [1, 32, 288, 288]         10,272
│    │    └─UnetDownBlock: 3-2                     [1, 64, 144, 144]         55,680
│    │    └─UnetDownBlock: 3-3                     [1, 128, 72, 72]          221,952
│    │    └─UnetDownBlock: 3-4                     [1, 256, 36, 36]          886,272
│    │    └─UnetDownBlock: 3-5                     [1, 512, 18, 18]          3,542,016
├─UnetDecoder: 1-2                                 [1, 64, 288, 288]         --
│    └─ModuleList: 2-2                             --                        --
│    │    └─UnetUpBlock: 3-6                       [1, 1024, 18, 18]         9,441,792
│  