# Imports

In [1]:
import torch
import torch.nn as nn

from math import ceil
from torchsummary import summary

# Configurations

In [2]:
model_configs = [
    # expand_ratio, channels, kernel_size, stride, repeats
    [1, 16, 3, 1, 1],
    [6, 24, 3, 2, 2],
    [6, 40, 5, 2, 2],
    [6, 80, 3, 2, 3],
    [6, 112, 5, 1, 3],
    [6, 192, 5, 2, 4],
    [6, 320, 3, 1, 1],
]

phi_values = {
    # tuple of: (phi_value, resolution, drop_rate)
    "b0": (0, 224, 0.2),  # alpha, beta, gamma, depth = alpha ** phi
    "b1": (0.5, 240, 0.2),
    "b2": (1, 260, 0.3),
    "b3": (2, 300, 0.3),
    "b4": (3, 380, 0.4),
    "b5": (4, 456, 0.4),
    "b6": (5, 528, 0.5),
    "b7": (6, 600, 0.5),
}

# Model Architecture

In [3]:
class cnn_bn(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, groups=1):
        super(cnn_bn, self).__init__()
        self.cnn = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            groups=groups,
            bias=False,
        )
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.relu(self.bn(self.cnn(x)))

In [4]:
class sne_conv(nn.Module):
    def __init__(self, in_channels, reduced_dim):
        super(sne_conv, self).__init__()
        self.Squeeze_and_Excitation = nn.Sequential(
            nn.AdaptiveAvgPool2d(1), # C x H x W -> C x 1 x 1
            nn.Conv2d(in_channels, reduced_dim, 1),
            nn.ReLU(),
            nn.Conv2d(reduced_dim, in_channels, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = x * self.Squeeze_and_Excitation(x)
        return x

In [5]:
class mb_conv(nn.Module):
    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            expand_ratio,
            reduction=4 # for squeeze excitation
    ):
        super(mb_conv, self).__init__()
        self.flag_residual = in_channels == out_channels and stride == 1
        hidden_dim = in_channels * expand_ratio
        self.flag_expand = in_channels != hidden_dim
        reduced_dim = int(in_channels / reduction) # for squeeze excitation

        if self.flag_expand:
            self.expand_conv = cnn_bn(in_channels, hidden_dim, kernel_size=1, stride=1, padding=0,)

        self.conv = nn.Sequential(
            cnn_bn(hidden_dim, hidden_dim, kernel_size, stride, padding, groups=hidden_dim,),
            sne_conv(hidden_dim, reduced_dim),
            nn.Conv2d(hidden_dim, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(out_channels),
        )


    def forward(self, inputs):
        x = self.expand_conv(inputs) if self.flag_expand else inputs

        if self.flag_residual:
            return self.conv(x) + inputs
        else:
            return self.conv(x)

In [6]:
channels = int(32)
print(channels)
out_channels = 4*ceil(int(channels) / 4)
print(out_channels)

32
32


In [7]:
class EfficientNet(nn.Module):
    def __init__(self, version, num_classes):
        super(EfficientNet, self).__init__()
        last_channels = ceil(1280)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.backbone = self.model_layers(last_channels)
        self.fc = nn.Linear(last_channels, num_classes) # classification layer / last layer

    def model_layers(self, last_channels):
        channels = int(32)
        layers = []
        in_channels = channels
        
        layers.append(cnn_bn(3, channels, 3, stride=2, padding=1,))
        for expand_ratio, channels, kernel_size, stride, repeats in model_configs:
            out_channels = ceil(int(channels))
            layers_repeats = ceil(repeats)

            for layer in range(layers_repeats):
                layers.append(
                    mb_conv(
                        in_channels,
                        out_channels,
                        expand_ratio = expand_ratio,
                        stride = stride if layer == 0 else 1,
                        kernel_size = kernel_size,
                        padding = kernel_size//2, # if k=1:pad=0, k=3:pad=1, k=5:pad=2
                    )
                )
                in_channels = out_channels

        layers.append(cnn_bn(in_channels, last_channels, kernel_size=1, stride=1, padding=0,))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.backbone(x)
        x = self.pool(x)
        x = x.view(x.shape[0], -1)
        x = self.fc(x)
        return x

# Test Model

In [8]:
def test():
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    VERSION = "b0"
    BATCH_SIZE = 64
    NUM_CLASSES = 1
    phi, res, drop_rate = phi_values[VERSION]
    x = torch.randn((BATCH_SIZE, 3, res, res)).to(DEVICE)
    model = EfficientNet(version=VERSION, num_classes=NUM_CLASSES).to(DEVICE)

    print(model(x).shape) # (num_examples, num_classes)
    summary(model, input_size=(3, res, res), device=DEVICE)

test()


torch.Size([64, 1])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 112, 112]             864
       BatchNorm2d-2         [-1, 32, 112, 112]              64
              ReLU-3         [-1, 32, 112, 112]               0
            cnn_bn-4         [-1, 32, 112, 112]               0
            Conv2d-5         [-1, 32, 112, 112]             288
       BatchNorm2d-6         [-1, 32, 112, 112]              64
              ReLU-7         [-1, 32, 112, 112]               0
            cnn_bn-8         [-1, 32, 112, 112]               0
 AdaptiveAvgPool2d-9             [-1, 32, 1, 1]               0
           Conv2d-10              [-1, 8, 1, 1]             264
             ReLU-11              [-1, 8, 1, 1]               0
           Conv2d-12             [-1, 32, 1, 1]             288
          Sigmoid-13             [-1, 32, 1, 1]               0
         sne_conv-1

In [None]:
VERSION = "b0"
NUM_CLASSES = 1
model = EfficientNet(version=VERSION, num_classes=NUM_CLASSES)
print(model)