In [23]:
import torch.nn.functional as F
from torch import nn
import torch

In [24]:
class FireModule(nn.Module):
    def __init__(self, in_c: int, squeeze_c: int, out_c: int, stride: int = 1):
        super().__init__()
        # for all layers below, stride is default = 1

        # first point-wise convolution: 1x1
        self.squeeze = nn.Conv2d(in_channels=in_c, out_channels=squeeze_c, kernel_size=1, stride=stride)
        # first independent conv: 3x3
        self.expand1x1 = nn.Conv2d(in_channels=squeeze_c, out_channels=out_c // 2, kernel_size=3, padding=1)
        # second independent (pointwise) conv: 1x1
        self.expand3x3 = nn.Conv2d(in_channels=squeeze_c, out_channels=out_c // 2, kernel_size=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.relu(self.squeeze(x))
        x = torch.cat([
            F.relu(self.expand1x1(x)),
            F.relu(self.expand3x3(x)),
        ], dim=1)

        return x

In [25]:
class DownSample(nn.Module):

    def __init__(self, in_c: int, squeeze_c: int, out_c: int):
        super().__init__()
        self.fire_mods = nn.Sequential(
            FireModule(in_c=in_c, squeeze_c=squeeze_c, out_c=out_c // 2, stride=2),
            FireModule(in_c=out_c // 2, squeeze_c=squeeze_c * 2, out_c=out_c, stride=2),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.fire_mods(x)

        return x

In [26]:
class TransposedFireModule(nn.Module):
    def __init__(self, in_c: int, squeeze_c: int, out_c: int, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.t_squeeze_1x1 = nn.ConvTranspose2d(in_channels=in_c, out_channels=squeeze_c, kernel_size=1, stride=2, padding=0, output_padding=1)  # not sure of this
        # self.t_squeeze_1x1 = nn.ConvTranspose2d(in_channels=in_c, out_channels=squeeze_c, kernel_size=1, stride=4)  # not sure of this
        self.t_expand_1x1 = nn.ConvTranspose2d(in_channels=squeeze_c, out_channels=out_c // 2, kernel_size=2, stride=2, padding=0, output_padding=0) # not sure of this
        self.t_expand_2x2 = nn.ConvTranspose2d(in_channels=squeeze_c, out_channels=out_c // 2, kernel_size=2, stride=2, padding=0)  # not sure of this

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # squeeze
        x = F.relu(self.t_squeeze_1x1(x))
        # inception stage
        x = torch.cat([
            F.relu(self.t_expand_1x1(x)),
            F.relu(self.t_expand_2x2(x)),
        ], dim=1)

        return x


In [27]:
class UpSample(nn.Module):
    def __init__(self, in_c: int, t_fire_mod_out_c: int, out_c: int):
        super().__init__()
        self.t_fire_mod = TransposedFireModule(in_c=in_c, squeeze_c=in_c // 2, out_c=t_fire_mod_out_c)
        self.fire_mods = nn.Sequential(
            FireModule(in_c=t_fire_mod_out_c * 2,  # twice since transpose fire mod output is concatenated
                       squeeze_c=t_fire_mod_out_c, # squeeze with half the input size
                       out_c=out_c * 2,
                       stride=1),
            FireModule(in_c=out_c * 2, squeeze_c=out_c, out_c=out_c, stride=1),
        )

    def forward(self, x: torch.Tensor, x1: torch.Tensor) -> torch.Tensor:
        x = self.t_fire_mod(x)
        x = torch.cat([x, x1,], dim=1)
        x = self.fire_mods(x)

        return x

In [28]:
class SqueezeUnet(nn.Module):
    def __init__(self, num_classes: int):
        super().__init__()

        # first two 3x3 conv layers (3, 3, 64)
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)

        # contracting step
        self.ds1 = DownSample(in_c=64, squeeze_c=32, out_c=256)
        self.ds2 = DownSample(in_c=256, squeeze_c=48, out_c=1_024)
        self.ds3 = DownSample(in_c=1_024, squeeze_c=64, out_c=4_096)
        self.ds4 = DownSample(in_c=4_096, squeeze_c=80, out_c=16_384)

        # expanding step
        self.us1 = UpSample(in_c=16_384, t_fire_mod_out_c=4_096, out_c=2_048)
        self.us2 = UpSample(in_c=2_048, t_fire_mod_out_c=1_024, out_c=512)
        self.us3 = UpSample(in_c=512, t_fire_mod_out_c=256, out_c=128)

        self.t_conv = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=4, output_padding=2)  # not sure about this
        self.conv3 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)

        self.outc = nn.Conv2d(in_channels=64, out_channels=num_classes, kernel_size=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # pre steps
        x = F.relu(self.conv1(x))
        x_conv1 = F.relu(self.conv2(x)) # skip connection to last concat after upsample

        # 4 downsampling steps
        x_ds1 = self.ds1(x=x_conv1)  # skip connection to UpSample(US) #3
        x_ds2 = self.ds2(x=x_ds1)  # skip connection to US #2
        x_ds3 = self.ds3(x=x_ds2)  # skip connection to US #1
        x_ds4 = self.ds4(x=x_ds3)

        # 3 upsampling steps
        x = self.us1(x_ds4, x_ds3)
        x = self.us2(x, x_ds2)
        x = self.us3(x, x_ds1)

        # # post steps
        x = F.relu(self.t_conv(x))
        x = torch.cat([x, x_conv1], dim=1)
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))

        logits = F.relu(self.outc(x))

        return logits


In [29]:
t = torch.rand(1, 3, 512, 512)
sq_unet = SqueezeUnet(num_classes=2)
sq_unet(t).shape

torch.Size([1, 2, 512, 512])