In [46]:
from torch import nn
import torch

In [47]:
class FireModule(nn.Module):
    def __init__(self, in_c: int, squeeze_c: int, expand_c: int):
        super().__init__()
        self.relu = nn.ReLU()
        # for all layers below, stride is default = 1

        # first point-wise convolution: 1x1
        self.squeeze = nn.Conv2d(in_channels=in_c, out_channels=squeeze_c, kernel_size=1, stride=2)
        # first independent conv: 3x3
        self.expand1x1 = nn.Conv2d(in_channels=squeeze_c, out_channels=expand_c, kernel_size=3, padding=1)
        # second independent (pointwise) conv: 1x1
        self.expand3x3 = nn.Conv2d(in_channels=squeeze_c, out_channels=expand_c, kernel_size=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.relu(self.squeeze(x))
        x = torch.cat([
            self.relu(self.expand1x1(x)),
            self.relu(self.expand3x3(x)),
        ], dim=1)

        return x

In [48]:
class TransposedFireModule(nn.Module):
    def __init__(self, out_c: int, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.conv1 = nn.ConvTranspose2d(in_channels=128, out_channels=out_c * 2, kernel_size=(1, 1), stride=2)
        self.conv2 = nn.ConvTranspose2d(in_channels=out_c * 2, out_channels=out_c, kernel_size=(1, 1), stride=2)
        self.conv3 = nn.ConvTranspose2d(in_channels=out_c * 2, out_channels=out_c, kernel_size=(1, 1), stride=2)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = nn.ReLU(self.conv1(x))
        x = torch.cat([
            nn.ReLU(self.conv2(x)),
            nn.ReLU(self.conv3(x)),
        ], dim=1)

        return x


In [49]:
class UpSample(nn.Module):
    def __init__(self, x_concat: int, out_c: int, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.x_concat = x_concat

        self.transposed_fire_mod = TransposedFireModule(out_c = out_c * 4)
        self.fire_mods = nn.Sequential(
            FireModule(in_c= out_c * 4, squeeze_c= ,out_c= out_c * 2),
            FireModule(in_c= out_c * 2, squeeze_c= ,out_c=out_c),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_transposed_fire = self.transposed_fire_mod(x)
        x = x_transposed_fire + self.x_concat
        x = self.fire_mods(x)


In [68]:
class SqueezeUnet(nn.Module):
    def __init__(self, num_classes: int):
        super().__init__()
        self.relu = nn.ReLU()
        # first two 3x3 conv layers (3, 3, 64)
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)

        self.t_conv = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2)
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)

        self.outc = nn.Conv2d(in_channels=64, out_channels=num_classes, kernel_size=1)

    def down_sample(self, x: torch.Tensor, in_c: int, squeeze_c: int, out_c: int) -> torch.Tensor:
        mods = nn.Sequential(
            FireModule(in_c=in_c, squeeze_c=squeeze_c, expand_c=out_c // 4),
            FireModule(in_c=out_c // 2, squeeze_c=squeeze_c * 2, expand_c=out_c // 2),
        )
        x = mods(x)

        return x

    def up_sample(x: torch.Tensor, out_c: int) -> nn.Module:
        return UpSample(x, out_c=out_c)

    # def up_sample(self, x: torch.Tensor, out_c: int) -> torch.Tensor:
    #     us_mod = UpSample(x, out_c)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.relu(self.conv1(x))
        x_conv1 = self.relu(self.conv2(x)) # skip connection to last concat after upsample

        # 4 downsampling steps
        x_ds1 = self.down_sample(x=x_conv1, in_c=64, squeeze_c=32, out_c=256)  # skip connection to UpSample(US) #3
        x_ds2 = self.down_sample(x=x_ds1, in_c=256, squeeze_c=48, out_c=1024)  # skip connection to US #2
        x_ds3 = self.down_sample(x=x_ds2, in_c=1024, squeeze_c=64, out_c=4096)  # skip connection to US #1
        x_ds4 = self.down_sample(x=x_ds3, in_c=4096, squeeze_c=80, out_c=16384)

        # 3 upsampling steps
        us1 = self.up_sample(x_ds3, out_c = 4096)
        x_us1 = us1(x_ds4)

        us2 = self.up_sample(x_ds2, out_c = 1024)
        x_us2 = us2(x_us1)

        us3 = self.up_sample(x = x_ds1, out_c = 64)
        x_us3 = us3(x_us2)

        # # post (after downsampling + upsampling)
        # x = nn.ReLU(self.t_conv(x_us3))
        # x = torch.cat([x + x_conv1], dim=1)
        # x = nn.ReLU(self.conv3(x))
        # x = nn.ReLU(self.conv4(x))

        # logits = nn.ReLU(self.outc(x))

        # return logits


In [67]:
t = torch.rand(1, 3, 512, 512)
sq_unet = SqueezeUnet(num_classes=2)
sq_unet(t).shape

# first = nn.Sequential(
#     nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1),
#     nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)
# )
# first_out = first(t)
# print(first_out.shape)
# fire = FireModule(in_c=64, squeeze_c=32, expand_c=64)
# fire2 = FireModule(in_c=128, squeeze_c=64, expand_c=128)
# fire_out = fire(first_out)
# print(fire_out.shape)
# fire2_out = fire2(fire_out)
# fire2_out.shape
# x_ds1 = self.down_sample(x=x_conv1, in_c=64, squeeze_c=32, expand_c=64)  # skip connection to UpSample(US) #3
# x_ds2 = self.down_sample(x=x_ds1, in_c=128, squeeze_c=48, expand_c=128)  # skip connection to US #2
# x_ds3 = self.down_sample(x=x_ds2, in_c=256, squeeze_c=64, expand_c=256)  # skip connection to US #1
# x_ds4 = self.down_sample(x=x_ds3, in_c=512, squeeze_c=80, expand_c=512)

TypeError: FireModule.__init__() got an unexpected keyword argument 'out_c'