In [21]:
import torch.nn as nn
import torch


class DownBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DownBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)

        self.residual_block = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.GroupNorm(2, out_channels),
            nn.SiLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.GroupNorm(2, out_channels),
            nn.SiLU()
        )

    def forward(self, x):
        out = self.conv1(x)
        residual = self.residual_block(out)
        out = out + residual
        return out


class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UpBlock, self).__init__()
        # Scale the data size back up
        self.upsample = nn.ConvTranspose2d(
            in_channels*2,
            in_channels*2,
            kernel_size=2, stride=2)

        # Adjusting channels after concatenation
        self.adjust_channels = nn.Conv2d(in_channels*2, out_channels, kernel_size=1)

        # Residual block similar to DownBlock but with adjusted channels due to skip connection
        self.residual_block = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.GroupNorm(2, out_channels),
            nn.SiLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.GroupNorm(2, out_channels),
            nn.SiLU()
        )

    def forward(self, x, skip_connection):
        x = torch.cat([x, skip_connection], dim=1)
        print(x.shape)
        x = self.upsample(x)
        print(x.shape)
        x = self.adjust_channels(x)
        print(x.shape)
        residual = self.residual_block(x)
        x = x + residual
        return x

In [22]:
x = torch.randn(1, 128, 180, 360)
model = DownBlock(128, 64)
x = model(x)
# print(x.shape)
up = UpBlock(64, 70)
up(x, x).shape

torch.Size([1, 128, 90, 180])
torch.Size([1, 128, 180, 360])
torch.Size([1, 70, 180, 360])


torch.Size([1, 70, 180, 360])