In [1]:
import torch
import torch.nn as nn
import torchvision

**Check efficientnet_b0 layer by layer**

In [2]:
efficientnet = torchvision.models.efficientnet_b0()
features = efficientnet.features

input_size = (3, 224, 224)

input_tensor = torch.randn(1, *input_size)

output = input_tensor
for layer_name, layer in features.named_children():
    output = layer(output)
    print(f"Layer: {layer_name}, Output size: {output.size()}")

Layer: 0, Output size: torch.Size([1, 32, 112, 112])
Layer: 1, Output size: torch.Size([1, 16, 112, 112])
Layer: 2, Output size: torch.Size([1, 24, 56, 56])
Layer: 3, Output size: torch.Size([1, 40, 28, 28])
Layer: 4, Output size: torch.Size([1, 80, 14, 14])
Layer: 5, Output size: torch.Size([1, 112, 14, 14])
Layer: 6, Output size: torch.Size([1, 192, 7, 7])
Layer: 7, Output size: torch.Size([1, 320, 7, 7])
Layer: 8, Output size: torch.Size([1, 1280, 7, 7])


**EffcientNetB0 as encoder**

In [3]:
class EfficientNet(nn.Module):
    def __init__(self):
        super(EfficientNet, self).__init__()
        efficientnet = torchvision.models.efficientnet_b0()
        features = efficientnet.features
        self.layer1 = features[:3]
        self.layer2 = features[3]
        self.layer3 = features[4]

    def forward(self, x):
        x1 = self.layer1(x)
        x2 = self.layer2(x1)
        x3 = self.layer3(x2)
        return x1, x2, x3

In [4]:
efficientnet = EfficientNet()

In [5]:
output = efficientnet(input_tensor)

output[0].shape, output[1].shape, output[2].shape

(torch.Size([1, 24, 56, 56]),
 torch.Size([1, 40, 28, 28]),
 torch.Size([1, 80, 14, 14]))

In [6]:
class BiFPN(nn.Module):
    def __init__(self, fpn_sizes):
        super(BiFPN, self).__init__()

        P4_channels, P5_channels, P6_channels = fpn_sizes
        self.W_bifpn = 64

        self.p6_td_conv = nn.Conv2d(
            P6_channels, self.W_bifpn, kernel_size=3, stride=1, bias=True, padding=1
        )
        self.p6_td_conv_2 = nn.Conv2d(
            self.W_bifpn,
            self.W_bifpn,
            kernel_size=3,
            stride=1,
            groups=self.W_bifpn,
            bias=True,
            padding=1,
        )
        self.p6_td_act = nn.ReLU()
        self.p6_td_conv_bn = nn.BatchNorm2d(self.W_bifpn)
        self.p6_td_w1 = torch.tensor(1, dtype=torch.float, requires_grad=True)
        self.p6_td_w2 = torch.tensor(1, dtype=torch.float, requires_grad=True)

        self.p5_td_conv = nn.Conv2d(
            P5_channels, self.W_bifpn, kernel_size=3, stride=1, bias=True, padding=1
        )
        self.p5_td_conv_2 = nn.Conv2d(
            self.W_bifpn,
            self.W_bifpn,
            kernel_size=3,
            stride=1,
            groups=self.W_bifpn,
            bias=True,
            padding=1,
        )
        self.p5_td_act = nn.ReLU()
        self.p5_td_conv_bn = nn.BatchNorm2d(self.W_bifpn)
        self.p5_td_w1 = torch.tensor(1, dtype=torch.float, requires_grad=True)
        self.p5_td_w2 = torch.tensor(1, dtype=torch.float, requires_grad=True)

        self.p4_td_conv = nn.Conv2d(
            P4_channels, self.W_bifpn, kernel_size=3, stride=1, bias=True, padding=1
        )
        self.p4_td_conv_2 = nn.Conv2d(
            self.W_bifpn,
            self.W_bifpn,
            kernel_size=3,
            stride=1,
            groups=self.W_bifpn,
            bias=True,
            padding=1,
        )
        self.p4_td_act = nn.ReLU()
        self.p4_td_conv_bn = nn.BatchNorm2d(self.W_bifpn)
        self.p4_td_w1 = torch.tensor(1, dtype=torch.float, requires_grad=True)
        self.p4_td_w2 = torch.tensor(1, dtype=torch.float, requires_grad=True)
        self.p5_upsample = nn.Upsample(scale_factor=2, mode="nearest")

        self.p3_out_conv_2 = nn.Conv2d(
            self.W_bifpn,
            self.W_bifpn,
            kernel_size=3,
            stride=1,
            groups=self.W_bifpn,
            bias=True,
            padding=1,
        )
        self.p3_out_act = nn.ReLU()
        self.p3_out_conv_bn = nn.BatchNorm2d(self.W_bifpn)
        self.p3_out_w1 = torch.tensor(1, dtype=torch.float, requires_grad=True)
        self.p3_out_w2 = torch.tensor(1, dtype=torch.float, requires_grad=True)
        self.p3_downsample = nn.MaxPool2d(kernel_size=2)

        self.p4_out_conv = nn.Conv2d(
            self.W_bifpn,
            self.W_bifpn,
            kernel_size=3,
            stride=1,
            groups=self.W_bifpn,
            bias=True,
            padding=1,
        )
        self.p4_out_act = nn.ReLU()
        self.p4_out_conv_bn = nn.BatchNorm2d(self.W_bifpn)
        self.p4_out_w1 = torch.tensor(1, dtype=torch.float, requires_grad=True)
        self.p4_out_w2 = torch.tensor(1, dtype=torch.float, requires_grad=True)
        self.p4_out_w3 = torch.tensor(1, dtype=torch.float, requires_grad=True)
        self.p4_upsample = nn.Upsample(scale_factor=2, mode="nearest")

        self.p5_out_conv = nn.Conv2d(
            self.W_bifpn,
            self.W_bifpn,
            kernel_size=3,
            stride=1,
            groups=self.W_bifpn,
            bias=True,
            padding=1,
        )
        self.p5_out_act = nn.ReLU()
        self.p5_out_conv_bn = nn.BatchNorm2d(self.W_bifpn)
        self.p5_out_w1 = torch.tensor(1, dtype=torch.float, requires_grad=True)
        self.p5_out_w2 = torch.tensor(1, dtype=torch.float, requires_grad=True)
        self.p5_out_w3 = torch.tensor(1, dtype=torch.float, requires_grad=True)
        self.p4_downsample = nn.MaxPool2d(kernel_size=2)
        self.p5_downsample = nn.MaxPool2d(kernel_size=2)

        self.p6_out_conv = nn.Conv2d(
            self.W_bifpn,
            self.W_bifpn,
            kernel_size=3,
            stride=1,
            groups=self.W_bifpn,
            bias=True,
            padding=1,
        )
        self.p6_out_act = nn.ReLU()
        self.p6_out_conv_bn = nn.BatchNorm2d(self.W_bifpn)
        self.p6_out_w1 = torch.tensor(1, dtype=torch.float, requires_grad=True)
        self.p6_out_w2 = torch.tensor(1, dtype=torch.float, requires_grad=True)
        self.p6_out_w3 = torch.tensor(1, dtype=torch.float, requires_grad=True)
        self.p6_upsample = nn.Upsample(scale_factor=2, mode="nearest")

        self.p7_out_conv_2 = nn.Conv2d(
            self.W_bifpn,
            self.W_bifpn,
            kernel_size=3,
            stride=1,
            groups=self.W_bifpn,
            bias=True,
            padding=1,
        )
        self.p7_out_act = nn.ReLU()
        self.p7_out_conv_bn = nn.BatchNorm2d(self.W_bifpn)
        self.p7_out_w1 = torch.tensor(1, dtype=torch.float, requires_grad=True)
        self.p7_out_w2 = torch.tensor(1, dtype=torch.float, requires_grad=True)

    def forward(self, inputs):
        epsilon = 0.0001
        P4, P5, P6 = inputs

        P6_td_inp = self.p6_td_conv(P6)
        P6_td = self.p6_td_conv_2((self.p6_td_w1 * P6_td_inp) / (self.p6_td_w1 + epsilon))
        P6_td = self.p6_td_act(P6_td)
        P6_td = self.p6_td_conv_bn(P6_td)

        P5_td_inp = self.p5_td_conv(P5)
        P5_td = self.p5_td_conv_2(
            (self.p5_td_w1 * P5_td_inp + self.p5_td_w2 * self.p6_upsample(P6_td))
            / (self.p5_td_w1 + self.p5_td_w2 + epsilon)
        )
        P5_td = self.p5_td_act(P5_td)
        P5_td = self.p5_td_conv_bn(P5_td)

        P4_td_inp = self.p4_td_conv(P4)
        P4_td = self.p4_td_conv_2(
            (self.p4_td_w1 * P4_td_inp + self.p4_td_w2 * self.p5_upsample(P5_td))
            / (self.p4_td_w1 + self.p4_td_w2 + epsilon)
        )
        P4_td = self.p4_td_act(P4_td)
        P4_td = self.p4_td_conv_bn(P4_td)

        P4_out = self.p4_out_conv(
            (self.p4_out_w1 * P4_td_inp + self.p4_out_w2 * P4_td)
            / (self.p4_out_w1 + self.p4_out_w2 + epsilon)
        )
        P4_out = self.p4_out_act(P4_out)
        P4_out = self.p4_out_conv_bn(P4_out)

        P5_out = self.p5_out_conv(
            (
                self.p5_out_w1 * P5_td_inp
                + self.p5_out_w2 * P5_td
                + self.p5_out_w3 * self.p4_downsample(P4_out)
            )
            / (self.p5_out_w2 + self.p5_out_w3 + epsilon)
        )
        P5_out = self.p5_out_act(P5_out)
        P5_out = self.p5_out_conv_bn(P5_out)

        P6_out = self.p6_out_conv(
            (
                self.p6_out_w1 * P6_td_inp
                + self.p6_out_w2 * P6_td
                + self.p6_out_w3 * self.p5_downsample(P5_out)
            )
            / (self.p6_out_w1 + self.p6_out_w2 + self.p6_out_w3 + epsilon)
        )
        P6_out = self.p6_out_act(P6_out)
        P6_out = self.p6_out_conv_bn(P6_out)

        return [P4_out, P5_out, P6_out]

In [27]:
bifpn = BiFPN([40, 112, 192])

c1 = torch.randn([1, 40, 64, 64])
c2 = torch.randn([1, 112, 32, 32])
c3 = torch.randn([1, 192, 16, 16])

In [28]:
output = bifpn([c1, c2, c3])

In [32]:
output[0].shape, output[1].shape, output[2].shape

(torch.Size([1, 64, 64, 64]),
 torch.Size([1, 64, 32, 32]),
 torch.Size([1, 64, 16, 16]))

**Example model - EfficientNetB0 + BiFPN**

In [22]:
class EfficientDet(nn.Module):
    def __init__(self, in_features=[24, 40, 80]):
        super(EfficientDet, self).__init__()
        self.encoder = EfficientNet()
        self.decoder = BiFPN(in_features)
        self.upsample_4 = nn.Upsample(scale_factor=4, mode="nearest")
        self.upsample_2 = nn.Upsample(scale_factor=2, mode="nearest")
        self.final_conv = nn.Conv2d(in_channels=192, out_channels=1, kernel_size=3, padding="same")

    def upsample_cat(self, x):
        p4, p5, p6 = x
        p6 = self.upsample_4(p6)
        p5 = self.upsample_2(p5)
        return torch.cat([p4, p5, p6], dim=1)

    def forward(self, x):
        x1, x2, x3 = self.encoder(x)
        out = self.decoder([x1, x2, x3])
        cated = self.upsample_cat(out)
        return self.final_conv(cated)

In [23]:
model = EfficientDet()

In [24]:
input_tensor.shape

torch.Size([1, 3, 224, 224])

In [25]:
output = model(input_tensor)

torch.Size([1, 24, 56, 56]) torch.Size([1, 40, 28, 28]) torch.Size([1, 80, 14, 14])
torch.Size([1, 64, 56, 56]) torch.Size([1, 64, 28, 28]) torch.Size([1, 64, 14, 14])
torch.Size([1, 192, 56, 56])


In [13]:
output.shape

torch.Size([1, 1, 56, 56])