In [None]:
import torch
import torch.nn as nn
import torchvision
from torch.nn import functional as F

**EffcientNetB0 as encoder**

In [None]:
class EfficientNet(nn.Module):
    def __init__(self):
        super(EfficientNet, self).__init__()
        efficientnet = torchvision.models.efficientnet_b0()
        features = efficientnet.features
        self.layer1 = features[:3]
        self.layer2 = features[3]
        self.layer3 = features[4]

    def forward(self, x):
        x1 = self.layer1(x)
        x2 = self.layer2(x1)
        x3 = self.layer3(x2)
        return x1, x2, x3

In [None]:
efficientnet = EfficientNet()

In [None]:
input_size = (3, 224, 224)

input_tensor = torch.randn(1, *input_size)

In [None]:
output = efficientnet(input_tensor)

output[0].shape, output[1].shape, output[2].shape

**FFN as decoder**

In [None]:
BN_MOMENTUM = 0.1
gpu_up_kwargs = {"mode": "bilinear", "align_corners": True}
mobile_up_kwargs = {"mode": "nearest"}
relu_inplace = True

In [None]:
class ConvBNReLU(nn.Module):
    def __init__(
        self,
        in_chan,
        out_chan,
        ks=3,
        stride=1,
        padding=1,
        activation=nn.ReLU,
        *args,
        **kwargs,
    ):
        super(ConvBNReLU, self).__init__()
        layers = [
            nn.Conv2d(
                in_chan,
                out_chan,
                kernel_size=ks,
                stride=stride,
                padding=padding,
                bias=False,
            ),
            nn.BatchNorm2d(out_chan, momentum=BN_MOMENTUM),
        ]
        if activation:
            layers.append(activation(inplace=relu_inplace))
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)

In [None]:
class AdapterConv(nn.Module):
    def __init__(self, in_channels=[256, 512, 1024, 2048], out_channels=[64, 128, 256, 512]):
        super(AdapterConv, self).__init__()
        assert len(in_channels) == len(
            out_channels
        ), "Number of input and output branches should match"
        self.adapter_conv = nn.ModuleList()

        for k in range(len(in_channels)):
            self.adapter_conv.append(
                ConvBNReLU(in_channels[k], out_channels[k], ks=1, stride=1, padding=0),
            )

    def forward(self, x):
        out = []
        for k in range(len(self.adapter_conv)):
            out.append(self.adapter_conv[k](x[k]))
        return out

In [None]:
class UpsampleCat(nn.Module):
    def __init__(self, upsample_kwargs=gpu_up_kwargs):
        super(UpsampleCat, self).__init__()
        self._up_kwargs = upsample_kwargs

    def forward(self, x):
        """Upsample and concatenate feature maps."""
        assert isinstance(x, list) or isinstance(x, tuple)
        # print(self._up_kwargs)
        x0 = x[0]
        _, _, H, W = x0.size()
        for i in range(1, len(x)):
            x0 = torch.cat([x0, F.interpolate(x[i], (H, W), **self._up_kwargs)], dim=1)
        return x0

In [None]:
class UpBranch(nn.Module):
    def __init__(
        self,
        in_channels=[64, 128, 256],
        out_channels=[128, 128, 128],
        upsample_kwargs=gpu_up_kwargs,
    ):
        super(UpBranch, self).__init__()

        self._up_kwargs = upsample_kwargs

        self.fam_32_sm = ConvBNReLU(in_channels[2], out_channels[2], ks=3, stride=1, padding=1)
        self.fam_32_up = ConvBNReLU(in_channels[2], in_channels[1], ks=1, stride=1, padding=0)
        self.fam_16_sm = ConvBNReLU(in_channels[1], out_channels[0], ks=3, stride=1, padding=1)
        self.fam_16_up = ConvBNReLU(in_channels[1], in_channels[0], ks=1, stride=1, padding=0)
        self.fam_8_sm = ConvBNReLU(in_channels[0], out_channels[0], ks=3, stride=1, padding=1)
        # self.fam_8_up = ConvBNReLU(
        #     in_channels[1], in_channels[0], ks=1, stride=1, padding=0
        # )
        # self.fam_4 = ConvBNReLU(
        #     in_channels[0], out_channels[0], ks=3, stride=1, padding=1
        # )

        self.high_level_ch = sum(out_channels)
        self.out_channels = out_channels

    def forward(self, x):

        feat8, feat16, feat32 = x

        smfeat_32 = self.fam_32_sm(feat32)
        upfeat_32 = self.fam_32_up(feat32)

        _, _, H, W = feat16.size()
        x = F.interpolate(upfeat_32, (H, W), **self._up_kwargs) + feat16
        smfeat_16 = self.fam_16_sm(x)
        upfeat_16 = self.fam_16_up(x)

        _, _, H, W = feat8.size()
        x = F.interpolate(upfeat_16, (H, W), **self._up_kwargs) + feat8
        smfeat_8 = self.fam_8_sm(x)
        # upfeat_8 = self.fam_8_up(x)

        # _, _, H, W = feat4.size()
        # smfeat_4 = self.fam_4(
        #     F.interpolate(upfeat_8, (H, W), **self._up_kwargs) + feat4
        # )

        return smfeat_8, smfeat_16, smfeat_32

In [None]:
class UpHeadA(nn.Module):
    def __init__(
        self,
        in_chans,
        base_chans=[64, 128, 256],
        upsample_kwargs=gpu_up_kwargs,
    ):
        layers = []
        super().__init__()
        layers.append(AdapterConv(in_chans, base_chans))
        in_chans = base_chans[:]
        layers.append(UpBranch(in_chans))
        layers.append(UpsampleCat(upsample_kwargs))
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)

In [None]:
c1 = torch.randn([1, 24, 56, 56])
c2 = torch.randn([1, 40, 28, 28])
c3 = torch.randn([1, 80, 14, 14])

In [None]:
up_head_a = UpHeadA([24, 40, 80])

out_A = up_head_a([c1, c2, c3])
print("output A: ", out_A.shape)

**EffcientNetB0 + FFN**

In [None]:
class EfficientNetFPN(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = EfficientNet()
        self.decoder = UpHeadA([24, 40, 80])
        self.final_conv = nn.Conv2d(in_channels=384, out_channels=1, kernel_size=3, padding="same")

    def forward(self, x):
        x1, x2, x3 = self.encoder(x)
        x = self.decoder([x1, x2, x3])
        x = self.final_conv(x)
        return x

In [None]:
efficientnet_fpn = EfficientNetFPN()

In [None]:
efficientnet_fpn(input_tensor).shape