In [13]:
from torch import nn
import torch

In [14]:
class FireModule(nn.Module):
    def __init__(self, in_c: int, squeeze_c: int, out_c: int, stride: int = 1):
        super().__init__()
        self.relu = nn.ReLU()
        # for all layers below, stride is default = 1

        # first point-wise convolution: 1x1
        self.squeeze = nn.Conv2d(in_channels=in_c, out_channels=squeeze_c, kernel_size=1, stride=stride)
        # first independent conv: 3x3
        self.expand1x1 = nn.Conv2d(in_channels=squeeze_c, out_channels=out_c // 2, kernel_size=3, padding=1)
        # second independent (pointwise) conv: 1x1
        self.expand3x3 = nn.Conv2d(in_channels=squeeze_c, out_channels=out_c // 2, kernel_size=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.relu(self.squeeze(x))
        x = torch.cat([
            self.relu(self.expand1x1(x)),
            self.relu(self.expand3x3(x)),
        ], dim=1)

        return x

In [15]:
class DownSample(nn.Module):

    def __init__(self, in_c: int, squeeze_c: int, out_c: int):
        self.__init__()
        self.fire_mods = nn.Sequential(
            FireModule(in_c=in_c, squeeze_c=squeeze_c, out_c=out_c // 2, stride=2),
            FireModule(in_c=out_c // 2, squeeze_c=squeeze_c * 2, out_c=out_c, stride=2),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.fire_mods(x)

        return x

In [16]:
class TransposedFireModule(nn.Module):
    def __init__(self, in_c: int, squeeze_c: int, out_c: int, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.relu = nn.ReLU()
        self.t_squeeze_1x1 = nn.ConvTranspose2d(in_channels=in_c, out_channels=squeeze_c, kernel_size=1)
        self.t_expand_1x1 = nn.ConvTranspose2d(in_channels=squeeze_c, out_channels=out_c // 2, kernel_size=1)
        self.t_expand_2x2 = nn.ConvTranspose2d(in_channels=out_c * 2, out_channels=out_c // 2, kernel_size=2, padding=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.relu(self.t_squeeze_1x1(x))  # squeeze
        # inception stage
        x = torch.cat([
            self.relu(self.t_expand_1x1(x)),
            self.relu(self.t_expand_2x2(x)),
        ], dim=1)

        return x


In [17]:
class UpSample(nn.Module):
    def __init__(self, in_c: int, t_fire_mod_out_c: int, out_c: int):
        super().__init__()
        self.t_fire_mod = TransposedFireModule(in_c=in_c, squeeze_c=in_c // 2, out_c=t_fire_mod_out_c)
        self.fire_mods = nn.Sequential(
            FireModule(in_c=t_fire_mod_out_c * 2,  # twice since transpose fire mod output is concatenated
                       squeeze_c=t_fire_mod_out_c, # squeeze with half the input size
                       out_c=out_c * 2,
                       stride=1),
            FireModule(in_c=out_c * 2, squeeze_c=out_c, out_c=out_c, stride=1),
        )

    def forward(self, x: torch.Tensor, x1: torch.Tensor) -> torch.Tensor:
        x = self.t_fire_mod(x)
        x = torch.cat([x, x1,], dim=1)
        x = self.fire_mods(x)

        return x

In [18]:
class SqueezeUnet(nn.Module):
    def __init__(self, num_classes: int):
        super().__init__()
        self.relu = nn.ReLU()

        # first two 3x3 conv layers (3, 3, 64)
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)

        # contracting step
        self.ds1 = DownSample(in_c=64, squeeze_c=32, out_c=256)
        self.ds2 = DownSample(in_c=256, squeeze_c=48, out_c=1_024)
        self.ds3 = DownSample(in_c=1_024, squeeze_c=64, out_c=4_096)
        self.ds4 = DownSample(in_c=4_096, squeeze_c=80, out_c=16_384)

        # expanding step
        self.us1 = UpSample(in_c=16_384, t_fire_mod_out_c=4_096, out_c=2_048)
        self.us2 = UpSample(in_c=2_048, t_fire_mod_out_c=1_024, out_c=512)
        self.us3 = UpSample(in_c=512, t_fire_mod_out_c=256, out_c=128)

        self.t_conv = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2)
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)

        self.outc = nn.Conv2d(in_channels=64, out_channels=num_classes, kernel_size=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # pre steps
        x = self.relu(self.conv1(x))
        x_conv1 = self.relu(self.conv2(x)) # skip connection to last concat after upsample

        # 4 downsampling steps
        x_ds1 = self.ds1(x=x_conv1)  # skip connection to UpSample(US) #3
        x_ds2 = self.ds2(x=x_ds1)  # skip connection to US #2
        x_ds3 = self.ds3(x=x_ds2)  # skip connection to US #1
        x_ds4 = self.ds4(x=x_ds3)

        # 3 upsampling steps
        x = self.us1(x_ds4, x_ds3)
        x = self.us2(x, x_ds2)
        x = self.us3(x, x_ds1)

        # post steps
        x = self.relu(self.t_conv(x))
        x = self.relu(self.conv3(x))
        x = self.relu(self.conv4(x))

        logits = nn.ReLU(self.outc(x))

        return logits


In [19]:
t = torch.rand(1, 3, 512, 512)
sq_unet = SqueezeUnet(num_classes=2)
sq_unet(t).shape

# first = nn.Sequential(
#     nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1),
#     nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)
# )
# first_out = first(t)
# print(first_out.shape)
# fire = FireModule(in_c=64, squeeze_c=32, expand_c=64)
# fire2 = FireModule(in_c=128, squeeze_c=64, expand_c=128)
# fire_out = fire(first_out)
# print(fire_out.shape)
# fire2_out = fire2(fire_out)
# fire2_out.shape
# x_ds1 = self.down_sample(x=x_conv1, in_c=64, squeeze_c=32, expand_c=64)  # skip connection to UpSample(US) #3
# x_ds2 = self.down_sample(x=x_ds1, in_c=128, squeeze_c=48, expand_c=128)  # skip connection to US #2
# x_ds3 = self.down_sample(x=x_ds2, in_c=256, squeeze_c=64, expand_c=256)  # skip connection to US #1
# x_ds4 = self.down_sample(x=x_ds3, in_c=512, squeeze_c=80, expand_c=512)

TypeError: DownSample.__init__() missing 3 required positional arguments: 'in_c', 'squeeze_c', and 'out_c'

In [None]:
"""
# 3 up sampling modules

# US 1
t_fire_mod1 = TransposedFireModule(in_c=16_384, squeeze_c=8192, out_c=4096)
# concatenation step
# torch.cat((t_fire_mod1(x_ds4), x_ds3), dim=1)
# fire modules
mods = nn.Sequential(
            FireModule(in_c=8192, squeeze_c=4096, out_c=4096, stride=1),
            FireModule(in_c=4096, squeeze_c=4096 // 2, out_c=2048, stride=1),
)

# US 2
t_fire_mod2 = TransposedFireModule(in_c=2048, squeeze_c=1024, out_c=1024)
# concatenation step
# torch.cat((t_fire_mod2(x_ds4), x_ds3), dim=1)
# fire modules
mods = nn.Sequential(
            FireModule(in_c=2048, squeeze_c=1024, out_c=1024, stride=1),
            FireModule(in_c=1024, squeeze_c=1024 // 2, out_c=512, stride=1),
)

# US 3
t_fire_mod3 = TransposedFireModule(in_c=512, squeeze_c=256, out_c=256)
# concatenation step
# torch.cat((t_fire_mod1(x_ds4), x_ds3), dim=1)
# fire modules
mods = nn.Sequential(
            FireModule(in_c=512, squeeze_c=256, out_c=256, stride=1),
            FireModule(in_c=256, squeeze_c=256 // 2, out_c=128, stride=1),
)

# transposed conv layer
t_conv = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=1)
"""

'\n# 3 up sampling modules\n\n# US 1\nt_fire_mod1 = TransposedFireModule(in_c=16_384, squeeze_c=8192, out_c=4096)\n# concatenation step\n# torch.cat((t_fire_mod1(x_ds4), x_ds3), dim=1)\n# fire modules\nmods = nn.Sequential(\n            FireModule(in_c=8192, squeeze_c=4096, out_c=4096, stride=1),\n            FireModule(in_c=4096, squeeze_c=4096 // 2, out_c=2048, stride=1),\n)\n\n# US 2\nt_fire_mod2 = TransposedFireModule(in_c=2048, squeeze_c=1024, out_c=1024)\n# concatenation step\n# torch.cat((t_fire_mod2(x_ds4), x_ds3), dim=1)\n# fire modules\nmods = nn.Sequential(\n            FireModule(in_c=2048, squeeze_c=1024, out_c=1024, stride=1),\n            FireModule(in_c=1024, squeeze_c=1024 // 2, out_c=512, stride=1),\n)\n\n# US 3\nt_fire_mod3 = TransposedFireModule(in_c=512, squeeze_c=256, out_c=256)\n# concatenation step\n# torch.cat((t_fire_mod1(x_ds4), x_ds3), dim=1)\n# fire modules\nmods = nn.Sequential(\n            FireModule(in_c=512, squeeze_c=256, out_c=256, stride=1),\n      