# EfficientNet - Implementation from scratch in PyTorch

In [None]:
import time
import math
import numpy as np
import matplotlib.pyplot as plt
import inspect
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau

In [None]:
device= 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# count how many trainable weights the model has
def count_parameters(model) -> None:
    total_params= sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f'Number of parameters: {total_params}')

# Define configs for different EfficientNet versions

# Architecture Implementation

In [None]:
class ConvLayer(nn.Module):
    """
    Implements one customizable CNN layer.
    EfficientNet-style: Input -> Conv2d -> BatchNorm2d -> SiLU -> Output
    """

    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, groups=1,
                 bias=False, activation=None) -> None:
        super(ConvLayer, self).__init__()
        self.conv= nn.Conv2d(
            in_channels, out_channels, kernel_size, stride, padding, groups=groups, bias=bias
        )
        # Batch Normalization to stabilize training
        self.norm= nn.BatchNorm2d(out_channels)
        # Activation function -- SiLU is the default in EfficientNet
        self.activation= nn.SiLU() if activation is None else activation


    def forward(self, x):
        x= self.conv(x)
        x= self.norm(x)
        x= self.activation(x)

        return x


In [None]:
class SqueezeExcitation(nn.Module):
    """
    Implements a Squeeze-and-Excitation module.
    It squeezes global spatial information into a channel descriptor and re-scales the channels.
    """

    def __init__(self, in_channels, squeezed_dim, activation=None) -> None:
        super(SqueezeExcitation, self).__init__()
        # Global average pooling: C x H x W -> C x 1 x 1
        self.average_pool= nn.AdaptiveAvgPool2d(output_size=(1, 1))
        # 1x1 convolution reduces the channel dimension
        self.conv1= nn.Conv2d(in_channels, squeezed_dim, kernel_size=1)
        # Activation function -- SiLU is the default in EfficientNet
        self.activation= nn.SiLU() if activation is None else activation
        # 1x1 convolution restores the channel dimension
        self.conv2= nn.Conv2d(squeezed_dim, in_channels, kernel_size=1)
        # Sigmoid activation to obtain channel-wise weights between 0 and 1
        self.sigmoid= nn.Sigmoid()


    def forward(self, x):
        se= self.average_pool(x)
        se= self.activation(self.conv1(se))
        se= self.sigmoid(self.conv2(se))

        return x * se


In [None]:
class MBConv(nn.Module):
    """
    Implements a Mobile Inverted Residual Block.
    Residual (skip) connection is used if stride==1 and input/output channels match.
    """

    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, expand_ratio,
                 reduction=4, survival_prob=0.8, bias=False, activation=None) -> None:
        super(MBConv, self).__init__()
        # Activation function -- SiLU is the default in EfficientNet
        activation= nn.SiLU() if activation is None else activation
        hidden_dim= in_channels * expand_ratio
        # For squeeze and excitation module
        reduced_dim= in_channels // reduction

        self.use_residual= in_channels == out_channels and stride == 1
        # Determine if expansion is needed
        self.expand= in_channels != hidden_dim
        # For stochastic depth
        self.survival_prob= survival_prob

        if self.expand:
            # Optional expansion phase (1x1 conv + BN + Activation)
            self.expand_conv= ConvLayer(
                in_channels, hidden_dim, kernel_size=1, stride=1, padding=0, bias=bias,
                activation=activation
            )

        self.conv= nn.Sequential(
            # Depthwise convolution
            ConvLayer(
                hidden_dim, hidden_dim, kernel_size, stride, padding, groups=hidden_dim,
                bias=bias, activation=activation
            ),
            # Squeeze-and-Excitation
            SqueezeExcitation(hidden_dim, reduced_dim, activation=activation),
            # Projection phase (1x1 conv + BN) -- reduce channels to out_channels
            nn.Conv2d(hidden_dim, out_channels, kernel_size=1, bias=bias),
            nn.BatchNorm2d(out_channels),
        )


    def stochastic_depth(self, x):
        """
        Implements stochastic depth regularization.
        During training, randomly drops the output of the block with probability
        (1 - survival_prob).
        """
        if not self.training:
            return x
        # Binary tensor with the same batch size and shape (broadcasted over spatial dimensions)
        binary_tensor= torch.rand(
            x.shape[0], 1, 1, 1, device=x.device, dtype=x.dtype
        ) < self.survival_prob
        # Scale the output to maintain expected value and apply the mask
        return torch.div(x, self.survival_prob) * binary_tensor


    def forward(self, x):
        if self.expand:
            x= self.expand_conv(x)
        if self.use_residual:
            return x + self.stochastic_depth(self.conv(x))

        return self.conv(x)


# Building the EfficientNet

In [None]:
class EfficientNet(nn.Module):
    """
    WIP
    """

    def __init__(self, model_version, num_classes, bias=False, activation=None) -> None:
        super(EfficientNet, self).__init__()
        # Activation function -- SiLU is the default in EfficientNet
        activation= nn.SiLU() if activation is None else activation



    def forward(self, x):

        return x


# Training a EfficientNet model from scratch

# Trainer Function

# Training setup using TF32 and Fused AdamW

In [None]:
# https://medium.com/@aniketthomas27/efficientnet-implementation-from-scratch-in-pytorch-a-step-by-step-guide-a7bb96f2bdaa
# https://medium.com/technological-singularity/efficientnet-revolutionizing-deep-learning-through-model-efficiency-0ed5485f9a6f