# LIS With MobileViTs

## 1. Our MobileViTs

In [1]:
import lightning as pl
import torch
import torch.nn as nn
import torch.functional as F
import wandb

In [None]:
# Optimize this impl:
# https://github.com/chinhsuanwu/mobilevit-pytorch/blob/master/mobilevit.py


class MobileViTBlock(nn.Module):
    pass


class Conv2DBlock(nn.Module):

    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride,
        padding,
        groups,
        bias,
        norm=True,
        activation=True,
    ):
        """__init__ Constructor for Conv2DBlock

        Parameters
        ----------
        in_channels : int
            Number of input channels
        out_channels : int
            Number of output channels
        kernel_size : int
            Size of the kernel
        stride : int
            Stride of the convolutional layer
        padding : int
            Padding of the convolutional layer
        groups : int
            Number of groups
        bias : bool
            Whether to use bias
        """

        super(Conv2DBlock, self).__init__()

        self.block = nn.Sequential()

        self.block.add_module(
            "conv2d",
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                groups=groups,
                bias=bias,
            ),
        )

        if norm:
            self.block.add_module("norm", nn.BatchNorm2d(out_channels))

        if activation:
            self.block.add_module("activation", nn.SiLU())


class MobileBlockV2(nn.Module):

    def __init__(self, in_channels, out_channels, stride, expand_ratio):
        """__init__ Constructor for MobileBlockV2


        Parameters
        ----------
        in_channels : int
            Number of input channels
        out_channels : int
            Number of output channels
        stride : int
            Stride of the convolutional layer
        expand_ratio : int
            Expansion ratio of the block
        """

        super(MobileBlockV2, self).__init__()

        assert stride in [1, 2], "Stride must be either 1 or 2"

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.stride = stride
        self.expand_ratio = expand_ratio
        self.hidden_dim = int(round(in_channels * expand_ratio))

        self.mbv2 = nn.Sequential()
        self.uses_inverse_residual = (
            self.in_channels == self.out_channels and self.stride == 1
        )

        if self.expand_ratio == 1:
            self.mbv2.add_module(
                "depthwise_3x3",
                Conv2DBlock(  # Depthwise Convolution
                    in_channels=self.hidden_dim,
                    out_channels=self.hidden_dim,
                    kernel_size=3,
                    stride=stride,
                    padding=1,
                    groups=self.hidden_dim,
                    bias=False,
                    norm=True,
                    activation=True,
                ),
            )
            self.mbv2.add_module(
                "pointwise-linear_1x1",
                Conv2DBlock(  # Pointwise-Linear Convolution
                    in_channels=self.hidden_dim,
                    out_channels=self.hidden_dim,
                    kernel_size=1,
                    stride=1,
                    padding=0,
                    groups=self.hidden_dim,
                    bias=False,
                    norm=True,
                    activation=False,
                ),
            )
        else:
            self.mbv2.add_module(
                "pointwise_1x1",
                Conv2DBlock(  # Pointwise-Linear Convolution
                    in_channels=in_channels,
                    out_channels=self.hidden_dim,
                    kernel_size=1,
                    stride=1,
                    padding=0,
                    groups=1,
                    bias=False,
                    norm=True,
                    activation=True,
                ),
            )
            self.mbv2.add_module(
                "depthwise_3x3",
                Conv2DBlock(  # Depthwise Convolution
                    in_channels=self.hidden_dim,
                    out_channels=self.hidden_dim,
                    kernel_size=3,
                    stride=stride,
                    padding=1,
                    groups=self.hidden_dim,
                    bias=False,
                    norm=True,
                    activation=True,
                ),
            )
            self.mbv2.add_module(
                "pointwise-linear_1x1",
                Conv2DBlock(  # Pointwise-Linear Convolution
                    in_channels=self.hidden_dim,
                    out_channels=self.out_channels,
                    kernel_size=1,
                    stride=1,
                    padding=0,
                    groups=1,
                    bias=False,
                    norm=True,
                    activation=False,
                ),
            )

    def forward(self, x):
        if self.uses_inverse_residual:
            return x + self.mbv2(x)
        else:
            return self.mbv2(x)


class MobileViT(pl.LightningModule):

    def __init__(self):
        pass

    def forward(self):
        pass

    def training_step(self, batch, batch_idx):
        pass

    def validation_step(self, batch, batch_idx):
        pass

    def test_step(self, batch, batch_idx):
        pass

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=1e-4, amsgrad=True)

In [None]:
nn.TransformerEncoder(
    encoder_layer=nn.TransformerEncoderLayer(d_model=512, nhead=8),
    num_layers=6,
    norm=nn.LayerNorm,
)