In [2]:
import torch
import torch.nn as nn
from torchinfo import summary
device="cuda" if torch.cuda.is_available else "cpu"

In [8]:

class conv_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.conv1 = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_c)
        self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_c)
        self.relu = nn.ReLU()

    def forward(self, inputs):
        x = self.relu(self.bn1(self.conv1(inputs)))
        x = self.relu(self.bn2(self.conv2(x)))
        return x

class encoder_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.conv = conv_block(in_c, out_c)
        self.pool = nn.MaxPool2d((2, 2))

    def forward(self, inputs):
        x = self.conv(inputs)
        p = self.pool(x)
        return x, p

class decoder_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2)
        self.conv = conv_block(2 * out_c, out_c)

    def forward(self, inputs, skip):
        x = self.up(inputs)
        x = torch.cat([x, skip], dim=1)
        x = self.conv(x)
        return x

class segUnet(nn.Module):
    def __init__(self, num_classes, in_channels=3, depth=5, start_filts=64):
        super().__init__()
        self.num_classes = num_classes
        self.in_channels = in_channels
        self.start_filts = start_filts
        self.depth = depth

        """ Encoders """
        self.encoders = nn.ModuleList([encoder_block(in_channels, start_filts)])
        self.encoders.extend([encoder_block(start_filts * (2 ** i), start_filts * (2 ** (i + 1))) for i in range(depth - 1)])

        """ Bottleneck """
        self.bottleneck = conv_block(start_filts * (2 ** (depth - 1)), start_filts * (2 ** depth))

        """ Decoders """
        self.decoders = nn.ModuleList([decoder_block(start_filts * (2 ** i), start_filts * (2 ** (i - 1))) for i in range(depth, 0, -1)])

        """ Classifier """
        self.outputs = nn.Conv2d(start_filts, num_classes, kernel_size=1)

    def forward(self, inputs):
        skips = []
        x = inputs
        for encoder in self.encoders:
            x, p = encoder(x)
            skips.append(x)
            x = p

        x = self.bottleneck(x)

        for i, decoder in enumerate(self.decoders):
            x = decoder(x, skips[-(i+1)])

        outputs = self.outputs(x)
        return outputs


In [9]:
# Check for CUDA availability
device = "cuda" if torch.cuda.is_available() else "cpu"

# Instantiate the model and move it to the appropriate device
model = segUnet(num_classes=2, in_channels=3, depth=5, start_filts=64).to(device)

# Print model summary using torchinfo
summary(model, input_size=(64, 3, 64, 64), device=device)


Layer (type:depth-idx)                   Output Shape              Param #
segUnet                                  [64, 2, 64, 64]           --
├─ModuleList: 1-1                        --                        --
│    └─encoder_block: 2-1                [64, 64, 64, 64]          --
│    │    └─conv_block: 3-1              [64, 64, 64, 64]          38,976
│    │    └─MaxPool2d: 3-2               [64, 64, 32, 32]          --
│    └─encoder_block: 2-2                [64, 128, 32, 32]         --
│    │    └─conv_block: 3-3              [64, 128, 32, 32]         221,952
│    │    └─MaxPool2d: 3-4               [64, 128, 16, 16]         --
│    └─encoder_block: 2-3                [64, 256, 16, 16]         --
│    │    └─conv_block: 3-5              [64, 256, 16, 16]         886,272
│    │    └─MaxPool2d: 3-6               [64, 256, 8, 8]           --
│    └─encoder_block: 2-4                [64, 512, 8, 8]           --
│    │    └─conv_block: 3-7              [64, 512, 8, 8]           3,54

In [10]:
# Function to recursively print the model's architecture
def print_model_structure(model, indent=0):
    for name, module in model.named_children():
        print('  ' * indent + f'{name}: {module.__class__.__name__}')
        print_model_structure(module, indent + 1)

# Print the model architecture
print_model_structure(model)

encoders: ModuleList
  0: encoder_block
    conv: conv_block
      conv1: Conv2d
      bn1: BatchNorm2d
      conv2: Conv2d
      bn2: BatchNorm2d
      relu: ReLU
    pool: MaxPool2d
  1: encoder_block
    conv: conv_block
      conv1: Conv2d
      bn1: BatchNorm2d
      conv2: Conv2d
      bn2: BatchNorm2d
      relu: ReLU
    pool: MaxPool2d
  2: encoder_block
    conv: conv_block
      conv1: Conv2d
      bn1: BatchNorm2d
      conv2: Conv2d
      bn2: BatchNorm2d
      relu: ReLU
    pool: MaxPool2d
  3: encoder_block
    conv: conv_block
      conv1: Conv2d
      bn1: BatchNorm2d
      conv2: Conv2d
      bn2: BatchNorm2d
      relu: ReLU
    pool: MaxPool2d
  4: encoder_block
    conv: conv_block
      conv1: Conv2d
      bn1: BatchNorm2d
      conv2: Conv2d
      bn2: BatchNorm2d
      relu: ReLU
    pool: MaxPool2d
bottleneck: conv_block
  conv1: Conv2d
  bn1: BatchNorm2d
  conv2: Conv2d
  bn2: BatchNorm2d
  relu: ReLU
decoders: ModuleList
  0: decoder_block
    up: ConvTransp