In [3]:
import torch.functional as F
from torch import nn
import torch

In [None]:
class FireModule(nn.Module):
    def __init__(self, in_c: int, out_c: int):
        super().__init__(self)
        # first point-wise convolution: 1x1
        self.conv1 = nn.Conv2d(in_channels=in_c, out_channels=in_c // 2, kernel_size=(1, 1), stride=2)
        # first independent conv: 3x3
        self.conv2 = nn.Conv2d(in_channels=in_c // 2, out_channels=out_c // 2, kernel_size=(3, 3), stride=2)
        # second independent (pointwise) conv: 1x1
        self.conv3 = nn.Conv2d(in_channels=in_c // 2, out_channels=out_c // 2, kernel_size=(1, 1), stride=2)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = nn.ReLU(self.conv1(x))
        x1 = self.conv2(x)
        x2 = self.conv3(x)
        x = nn.ReLU(x1 + x2)

        return x

In [None]:
class TransposedFireModule(nn.Module):
    def __init__(self, out_c: int, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.conv1 = nn.ConvTranspose2d(in_channels= , out_channels=out_c * 2, kernel_size=(1, 1), stride=2)
        self.conv2 = nn.ConvTranspose2d(in_channels=out_c * 2, out_channels=out_c, kernel_size=(1, 1), stride=2)
        self.conv3 = nn.ConvTranspose2d(in_channels=out_c * 2, out_channels=out_c, kernel_size=(1, 1), stride=2)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv1(x)
        x1 = self.conv2(x)
        x2 = self.conv3(x)

        return x1 + x2


In [None]:
class UpSample(nn.Module):
    def __init__(self, x_concat: int, out_c: int, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.x_concat = x_concat

        self.transposed_fire_mod = TransposedFireModule(out_c = out_c * 4)
        self.fire_mods = nn.Sequential(
            FireModule(in_c= out_c * 4, out_c= out_c * 2),
            FireModule(in_c= out_c * 2, out_c=out_c),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_transposed_fire = self.transposed_fire_mod(x)
        x = x_transposed_fire + self.x_concat
        x = self.fire_mods(x)


In [None]:
class SqueezeUnet(nn.Module):
    def __init__(self, num_classes: int):
        super().__init__(self)
        # first two 3x3 conv layers (3, 3, 64)
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(3, 3), stride=2)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), stride=2)

        # downsample block: 2 fire modules
        self.downsample = nn.Sequential(
            FireModule(in_c=64, out_c=128),
            FireModule(in_c=128, out_c=256),
        )

        self.up_sample = lambda x, out_c: UpSample(x, out_c)

        self.t_conv = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=(2, 2), stride=2)
        self.conv3 = nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=(3, 3), stride=2)
        self.conv4 = nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=(3, 3), stride=2)
        self.conv5 = nn.ConvTranspose2d(in_channels=64, out_channels=num_classes, kernel_size=(1, 1), stride=2)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = nn.ReLU(self.conv(x))
        x_conv1 = nn.ReLU(self.conv2(x)) # skip connection to last concat after upsample

        # 3 downsampling steps
        x_ds1 = self.downsample(x_conv1) # skip connection to UpSample(US) #3
        x_ds2 = self.downsample(x_ds1) # skip connection to US #2
        x_ds3 = self.downsample(x_ds2) # skip connection to US #1
        x_ds4 = self.downsample(x_ds3)

        # 3 upsampling steps
        us1 = self.up_sample(x_ds3, out_c = 4096)
        x_us1 = us1(x_ds4)

        us2 = self.up_sample(x_ds2, out_c = 1024)
        x_us2 = us2(x_us1)

        us3 = self.up_sample(x = x_ds1, out_c = 64)
        x_us3 = us3(x_us2)

        # post (after downsampling + upsampling)
        x = nn.ReLU(self.t_conv(x_us3))
        x = x + x_conv1
        x = nn.ReLU(self.conv3(x))
        x = nn.ReLU(self.conv4(x))
        x = nn.ReLU(self.conv5(x))

        return x

