In [1]:
import torch
from torch import nn
from torchvision import transforms

In [2]:
class FirstFeature(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=3, 
                padding=1, 
                stride=1
            ), 
            nn.ReLU()
        )
    
    def forward(self, x):
        return self.conv(x)
    

class FeatureOut(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=3, 
                padding=1, 
                stride=1
            ),
            nn.Tanh()
        )

    def forward(self, x):
        return self.conv(x)

In [3]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels, 
                out_channels=out_channels,
                kernel_size=3, 
                padding=1, 
                stride=1
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                in_channels=out_channels, 
                out_channels=out_channels,
                kernel_size=3,
                padding=1, 
                stride=1
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.conv(x)

In [4]:
class Encoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.MaxPool2d(kernel_size=2),
            ConvBlock(in_channels, out_channels)
        )

    def forward(self, x):
        return self.encoder(x)
    
    
class Decoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.trans_conv = nn.ConvTranspose2d(
            in_channels=in_channels, 
            out_channels=out_channels,
            kernel_size=4, 
            padding=1, 
            stride=2
        )
        self.conv_block = ConvBlock(in_channels, out_channels)
    
    def forward(self, x, skip):
        up_sample = self.trans_conv(x)
        concat = torch.concat([up_sample, skip], dim=1)
        return self.conv_block(concat)

In [5]:
LOW_IMG_WIDTH = 64
LOW_IMG_HEIGHT = 64

class UNetArchitecture(nn.Module):
    def __init__(self, in_channels, n_classes):
        super().__init__()
        self.resize = transforms.Resize((LOW_IMG_WIDTH*4, LOW_IMG_HEIGHT*4))
        self.first_feature = FirstFeature(in_channels, 64)
        self.conv = ConvBlock(64, 64)
        
        self.encoder1 = Encoder(64, 128)
        self.encoder2 = Encoder(128, 256)
        self.encoder3 = Encoder(256, 512)
        self.encoder4 = Encoder(512, 1024)

        self.decoder1 = Decoder(1024, 512)
        self.decoder2 = Decoder(512, 256)
        self.decoder3 = Decoder(256, 128)
        self.decoder4 = Decoder(128, 64)

        self.out_conv = FeatureOut(64, n_classes)
    
    def forward(self, x):
        x = self.resize(x)
        x = self.first_feature(x)
        print(f'shape first feature: {x.shape}')
        x1 = self.conv(x)
        print(f'shape x1: {x1.shape}')

        x2 = self.encoder1(x1)
        print(f'shape encoder x2: {x2.shape}')
        x3 = self.encoder2(x2)
        print(f'shape encoder x3: {x3.shape}')
        x4 = self.encoder3(x3)
        print(f'shape encoder x4: {x4.shape}')
        x5 = self.encoder4(x4)
        print(f'shape encoder x5: {x5.shape}')

        x = self.decoder1(x5, x4)
        print(f'shape decoder 1: {x.shape}')
        x = self.decoder2(x, x3)
        print(f'shape decoder 2: {x.shape}')
        x = self.decoder3(x, x2)
        print(f'shape decoder 3: {x.shape}')
        x = self.decoder4(x, x1)
        print(f'shape decoder 4: {x.shape}')
        x = self.out_conv(x)
        print(f'shape out conv: {x.shape}')

        return x

In [6]:
input = torch.randint(
    5, (4, 1, 64, 64), dtype=torch.float32
)

model = UNetArchitecture(in_channels=1, n_classes=2)
model(input)

shape first feature: torch.Size([4, 64, 256, 256])
shape x1: torch.Size([4, 64, 256, 256])
shape encoder x2: torch.Size([4, 128, 128, 128])
shape encoder x3: torch.Size([4, 256, 64, 64])
shape encoder x4: torch.Size([4, 512, 32, 32])
shape encoder x5: torch.Size([4, 1024, 16, 16])
shape decoder 1: torch.Size([4, 512, 32, 32])
shape decoder 2: torch.Size([4, 256, 64, 64])
shape decoder 3: torch.Size([4, 128, 128, 128])
shape decoder 4: torch.Size([4, 64, 256, 256])
shape out conv: torch.Size([4, 2, 256, 256])


tensor([[[[-0.1433, -0.2242, -0.5637,  ...,  0.1078,  0.1862,  0.7400],
          [-0.2610, -0.8039, -0.7532,  ..., -0.8516, -0.8647, -0.3065],
          [-0.0292, -0.3885, -0.0425,  ...,  0.2194, -0.3052,  0.3159],
          ...,
          [-0.1970, -0.4950, -0.6204,  ..., -0.9535, -0.7123, -0.1046],
          [-0.6517, -0.7018, -0.7071,  ..., -0.7859, -0.5010,  0.3958],
          [-0.1124, -0.3615, -0.0480,  ...,  0.4445, -0.3803, -0.7965]],

         [[ 0.0805,  0.3043,  0.0340,  ...,  0.6672,  0.8679,  0.5566],
          [-0.3556,  0.7558,  0.1532,  ...,  0.1599,  0.6567, -0.1257],
          [ 0.2095,  0.9350,  0.7475,  ..., -0.2127,  0.5512,  0.6668],
          ...,
          [ 0.4086,  0.8987,  0.5008,  ...,  0.8472,  0.8808,  0.3018],
          [ 0.3362,  0.6182,  0.0857,  ...,  0.6929,  0.9397,  0.6147],
          [-0.0857,  0.4706,  0.6132,  ...,  0.0412,  0.4414,  0.0296]]],


        [[[ 0.1883, -0.2245, -0.3318,  ..., -0.0749,  0.1616,  0.3922],
          [ 0.2487, -0.6526,

In [7]:
from torchsummary import summary

model = UNetArchitecture(in_channels=3, n_classes=3)
summary(model, (3, 64, 64))

  return torch._C._cuda_getDeviceCount() > 0


shape first feature: torch.Size([2, 64, 256, 256])
shape x1: torch.Size([2, 64, 256, 256])
shape encoder x2: torch.Size([2, 128, 128, 128])
shape encoder x3: torch.Size([2, 256, 64, 64])
shape encoder x4: torch.Size([2, 512, 32, 32])
shape encoder x5: torch.Size([2, 1024, 16, 16])
shape decoder 1: torch.Size([2, 512, 32, 32])
shape decoder 2: torch.Size([2, 256, 64, 64])
shape decoder 3: torch.Size([2, 128, 128, 128])
shape decoder 4: torch.Size([2, 64, 256, 256])
shape out conv: torch.Size([2, 3, 256, 256])
Layer (type:depth-idx)                   Output Shape              Param #
├─Resize: 1-1                            [-1, 3, 256, 256]         --
├─FirstFeature: 1-2                      [-1, 64, 256, 256]        --
|    └─Sequential: 2-1                   [-1, 64, 256, 256]        --
|    |    └─Conv2d: 3-1                  [-1, 64, 256, 256]        1,792
|    |    └─ReLU: 3-2                    [-1, 64, 256, 256]        --
├─ConvBlock: 1-3                         [-1, 64, 256, 256

Layer (type:depth-idx)                   Output Shape              Param #
├─Resize: 1-1                            [-1, 3, 256, 256]         --
├─FirstFeature: 1-2                      [-1, 64, 256, 256]        --
|    └─Sequential: 2-1                   [-1, 64, 256, 256]        --
|    |    └─Conv2d: 3-1                  [-1, 64, 256, 256]        1,792
|    |    └─ReLU: 3-2                    [-1, 64, 256, 256]        --
├─ConvBlock: 1-3                         [-1, 64, 256, 256]        --
|    └─Sequential: 2-2                   [-1, 64, 256, 256]        --
|    |    └─Conv2d: 3-3                  [-1, 64, 256, 256]        36,928
|    |    └─BatchNorm2d: 3-4             [-1, 64, 256, 256]        128
|    |    └─ReLU: 3-5                    [-1, 64, 256, 256]        --
|    |    └─Conv2d: 3-6                  [-1, 64, 256, 256]        36,928
|    |    └─BatchNorm2d: 3-7             [-1, 64, 256, 256]        128
|    |    └─ReLU: 3-8                    [-1, 64, 256, 256]        --
├─

: 