In [1]:
import torch
from torch import nn

In [2]:
class FirstFeature(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=3, 
                padding=1, 
                stride=1
            ), 
            nn.ReLU()
        )
    
    def forward(self, x):
        return self.conv(x)
    

class FeatureOut(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=3, 
                padding=1, 
                stride=1
            ),
            nn.Tanh()
        )

    def forward(self, x):
        return self.conv(x)

In [3]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels, 
                out_channels=out_channels,
                kernel_size=3, 
                padding=1, 
                stride=1
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                in_channels=out_channels, 
                out_channels=out_channels,
                kernel_size=3,
                padding=1, 
                stride=1
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.conv(x)

In [4]:
class Encoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.MaxPool2d(kernel_size=2),
            ConvBlock(in_channels, out_channels)
        )

    def forward(self, x):
        return self.encoder(x)
    
    
class Decoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.trans_conv = nn.ConvTranspose2d(
            in_channels=in_channels, 
            out_channels=in_channels,
            kernel_size=4, 
            padding=1, 
            stride=2
        )
        self.conv_block = ConvBlock(in_channels, out_channels)
    
    def forward(self, x):
        up_sample = self.trans_conv(x)
        return self.conv_block(up_sample)

In [5]:
class UNetArchitecture(nn.Module):
    def __init__(self, in_channels, n_classes):
        super().__init__()
        self.first_feature = FirstFeature(in_channels, 64)
        self.conv = ConvBlock(64, 64)
        
        self.encoder1 = Encoder(64, 128)
        self.encoder2 = Encoder(128, 256)
        self.encoder3 = Encoder(256, 512)
        self.encoder4 = Encoder(512, 1024)

        self.decoder1 = Decoder(1024, 512)
        self.decoder2 = Decoder(512, 256)
        self.decoder3 = Decoder(256, 128)
        self.decoder4 = Decoder(128, 64)

        self.out_conv = FeatureOut(64, n_classes)
    
    def forward(self, x):
        x = self.first_feature(x)
        print(f'shape first feature: {x.shape}')
        x1 = self.conv(x)
        print(f'shape x1: {x1.shape}')

        x2 = self.encoder1(x1)
        print(f'shape encoder x2: {x2.shape}')
        x3 = self.encoder2(x2)
        print(f'shape encoder x3: {x3.shape}')
        x4 = self.encoder3(x3)
        print(f'shape encoder x4: {x4.shape}')
        x5 = self.encoder4(x4)
        print(f'shape encoder x5: {x5.shape}')

        x = self.decoder1(x5)
        print(f'shape decoder 1: {x.shape}')
        x = self.decoder2(x)
        print(f'shape decoder 2: {x.shape}')
        x = self.decoder3(x)
        print(f'shape decoder 3: {x.shape}')
        x = self.decoder4(x)
        print(f'shape decoder 4: {x.shape}')
        x = self.out_conv(x)
        print(f'shape out conv: {x.shape}')

        return x

In [6]:
input = torch.randint(
    5, (4, 1, 256, 256), dtype=torch.float32
)

model = UNetArchitecture(in_channels=1, n_classes=2)
model(input)

shape first feature: torch.Size([4, 64, 256, 256])
shape x1: torch.Size([4, 64, 256, 256])
shape encoder x2: torch.Size([4, 128, 128, 128])
shape encoder x3: torch.Size([4, 256, 64, 64])
shape encoder x4: torch.Size([4, 512, 32, 32])
shape encoder x5: torch.Size([4, 1024, 16, 16])
shape decoder 1: torch.Size([4, 512, 32, 32])
shape decoder 2: torch.Size([4, 256, 64, 64])
shape decoder 3: torch.Size([4, 128, 128, 128])
shape decoder 4: torch.Size([4, 64, 256, 256])
shape out conv: torch.Size([4, 2, 256, 256])


tensor([[[[ 0.0253, -0.0214,  0.2855,  ...,  0.0953,  0.1650,  0.1644],
          [-0.0123, -0.0993,  0.1138,  ..., -0.1638, -0.0937, -0.0487],
          [-0.0048, -0.0139,  0.2900,  ...,  0.0469, -0.0540, -0.4672],
          ...,
          [ 0.0048, -0.1110,  0.0106,  ..., -0.0488,  0.1347,  0.0731],
          [-0.0621,  0.0398, -0.0699,  ..., -0.3813, -0.1609,  0.0224],
          [-0.1835, -0.0652, -0.0258,  ..., -0.1151, -0.0927, -0.1107]],

         [[ 0.1043,  0.1514, -0.2359,  ...,  0.0943, -0.1244, -0.0408],
          [-0.0726, -0.0120, -0.0153,  ..., -0.0402, -0.2646, -0.2223],
          [-0.1543,  0.0498, -0.2880,  ..., -0.4365,  0.0688, -0.3326],
          ...,
          [-0.0424,  0.1078, -0.0850,  ..., -0.1062, -0.3943, -0.0931],
          [ 0.0198, -0.1610, -0.1873,  ..., -0.0808, -0.2692, -0.2104],
          [ 0.1952, -0.1796,  0.0207,  ..., -0.1026, -0.2629, -0.2109]]],


        [[[-0.0381, -0.0102,  0.2879,  ..., -0.0223,  0.3266,  0.1423],
          [ 0.0360, -0.1108,

In [7]:
from torchsummary import summary

model = UNetArchitecture(in_channels=3, n_classes=3)
summary(model, (3, 256, 256))

shape first feature: torch.Size([2, 64, 256, 256])
shape x1: torch.Size([2, 64, 256, 256])
shape encoder x2: torch.Size([2, 128, 128, 128])
shape encoder x3: torch.Size([2, 256, 64, 64])
shape encoder x4: torch.Size([2, 512, 32, 32])
shape encoder x5: torch.Size([2, 1024, 16, 16])
shape decoder 1: torch.Size([2, 512, 32, 32])
shape decoder 2: torch.Size([2, 256, 64, 64])
shape decoder 3: torch.Size([2, 128, 128, 128])
shape decoder 4: torch.Size([2, 64, 256, 256])
shape out conv: torch.Size([2, 3, 256, 256])
Layer (type:depth-idx)                   Output Shape              Param #
├─FirstFeature: 1-1                      [-1, 64, 256, 256]        --
|    └─Sequential: 2-1                   [-1, 64, 256, 256]        --
|    |    └─Conv2d: 3-1                  [-1, 64, 256, 256]        1,792
|    |    └─ReLU: 3-2                    [-1, 64, 256, 256]        --
├─ConvBlock: 1-2                         [-1, 64, 256, 256]        --
|    └─Sequential: 2-2                   [-1, 64, 256, 256

Layer (type:depth-idx)                   Output Shape              Param #
├─FirstFeature: 1-1                      [-1, 64, 256, 256]        --
|    └─Sequential: 2-1                   [-1, 64, 256, 256]        --
|    |    └─Conv2d: 3-1                  [-1, 64, 256, 256]        1,792
|    |    └─ReLU: 3-2                    [-1, 64, 256, 256]        --
├─ConvBlock: 1-2                         [-1, 64, 256, 256]        --
|    └─Sequential: 2-2                   [-1, 64, 256, 256]        --
|    |    └─Conv2d: 3-3                  [-1, 64, 256, 256]        36,928
|    |    └─BatchNorm2d: 3-4             [-1, 64, 256, 256]        128
|    |    └─ReLU: 3-5                    [-1, 64, 256, 256]        --
|    |    └─Conv2d: 3-6                  [-1, 64, 256, 256]        36,928
|    |    └─BatchNorm2d: 3-7             [-1, 64, 256, 256]        128
|    |    └─ReLU: 3-8                    [-1, 64, 256, 256]        --
├─Encoder: 1-3                           [-1, 128, 128, 128]       --
| 