# LIS With MobileViTs

## 1. Our MobileViTs

In [8]:
import lightning as pl
import torch
import torch.nn as nn
import torch.functional as F
import wandb
from einops import rearrange
import torchvision

In [None]:
# Optimize this impl:
# https://github.com/chinhsuanwu/mobilevit-pytorch/blob/master/mobilevit.py


class Conv2DBlock(nn.Module):

    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride,
        padding,
        groups=1,
        bias=True,
        norm=True,
        activation=True,
    ):
        """__init__ Constructor for Conv2DBlock

        Parameters
        ----------
        in_channels : int
            Number of input channels
        out_channels : int
            Number of output channels
        kernel_size : int
            Size of the kernel
        stride : int
            Stride of the convolutional layer
        padding : int
            Padding of the convolutional layer
        groups : int
            Number of groups
        bias : bool
            Whether to use bias
        """

        super(Conv2DBlock, self).__init__()

        self.block = nn.Sequential()

        self.block.add_module(
            "conv2d",
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                groups=groups,
                bias=bias,
            ),
        )

        if norm:
            self.block.add_module("norm", nn.BatchNorm2d(out_channels))

        if activation:
            self.block.add_module("activation", nn.SiLU())

        self.block = torch.compile(self.block)

    def forward(self, x):
        return self.block(x)


class MobileBlockV2(nn.Module):

    def __init__(self, in_channels, out_channels, stride, expand_ratio):
        """__init__ Constructor for MobileBlockV2


        Parameters
        ----------
        in_channels : int
            Number of input channels
        out_channels : int
            Number of output channels
        stride : int
            Stride of the convolutional layer
        expand_ratio : int
            Expansion ratio of the block
        """

        super(MobileBlockV2, self).__init__()

        assert stride in [1, 2], "Stride must be either 1 or 2"

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.stride = stride
        self.expand_ratio = expand_ratio
        self.hidden_dim = int(round(in_channels * expand_ratio))

        self.mbv2 = nn.Sequential()
        self.uses_inverse_residual = (
            self.in_channels == self.out_channels and self.stride == 1
        )

        if self.expand_ratio == 1:
            self.mbv2.add_module(
                "depthwise_3x3",
                Conv2DBlock(  # Depthwise Convolution
                    in_channels=self.hidden_dim,
                    out_channels=self.hidden_dim,
                    kernel_size=3,
                    stride=stride,
                    padding=1,
                    groups=self.hidden_dim,
                    bias=False,
                    norm=True,
                    activation=True,
                ),
            )
            self.mbv2.add_module(
                "pointwise-linear_1x1",
                Conv2DBlock(  # Pointwise-Linear Convolution
                    in_channels=self.hidden_dim,
                    out_channels=self.hidden_dim,
                    kernel_size=1,
                    stride=1,
                    padding=0,
                    groups=self.hidden_dim,
                    bias=False,
                    norm=True,
                    activation=False,
                ),
            )
        else:
            self.mbv2.add_module(
                "pointwise_1x1",
                Conv2DBlock(  # Pointwise-Linear Convolution
                    in_channels=in_channels,
                    out_channels=self.hidden_dim,
                    kernel_size=1,
                    stride=1,
                    padding=0,
                    groups=1,
                    bias=False,
                    norm=True,
                    activation=True,
                ),
            )
            self.mbv2.add_module(
                "depthwise_3x3",
                Conv2DBlock(  # Depthwise Convolution
                    in_channels=self.hidden_dim,
                    out_channels=self.hidden_dim,
                    kernel_size=3,
                    stride=stride,
                    padding=1,
                    groups=self.hidden_dim,
                    bias=False,
                    norm=True,
                    activation=True,
                ),
            )
            self.mbv2.add_module(
                "pointwise-linear_1x1",
                Conv2DBlock(  # Pointwise-Linear Convolution
                    in_channels=self.hidden_dim,
                    out_channels=self.out_channels,
                    kernel_size=1,
                    stride=1,
                    padding=0,
                    groups=1,
                    bias=False,
                    norm=True,
                    activation=False,
                ),
            )

        self.mbv2 = torch.compile(self.mbv2)

    def forward(self, x):
        if self.uses_inverse_residual:
            return x + self.mbv2(x)
        else:
            return self.mbv2(x)


class MobileViTBlock(nn.Module):

    def __init__(
        self, hidden_dim, depth, channels, kernel_size, patch_size, mlp_dim, dropout=0.0
    ):

        super(MobileViTBlock, self).__init__()

        self.hidden_dim = hidden_dim
        self.depth = depth
        self.channels = channels
        self.kernel_size = kernel_size
        self.patch_size = patch_size
        self.mlp_dim = mlp_dim
        self.dropout = dropout

        self.local_conv = torch.compile(
            nn.Sequential(
                Conv2DBlock(
                    in_channels=channels,
                    out_channels=channels,
                    kernel_size=kernel_size,
                    stride=1,
                    padding=1,
                    norm=True,
                    activation=True,
                    bias=False,
                ),
                Conv2DBlock(
                    in_channels=channels,
                    out_channels=hidden_dim,
                    kernel_size=1,
                    stride=1,
                    padding=0,
                    norm=True,
                    activation=True,
                    bias=False,
                ),
            )
        )

        # def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
        self.global_transformer = torch.compile(
            nn.TransformerEncoder(
                nn.TransformerEncoderLayer(
                    d_model=hidden_dim,
                    nhead=8,
                    dim_feedforward=mlp_dim,
                    dropout=dropout,
                    batch_first=True,
                    activation="gelu",
                ),
                num_layers=depth,
                norm=nn.LayerNorm(hidden_dim),
            )
        )

        self.fusion_conv_preres = torch.compile(
            Conv2DBlock(
                in_channels=hidden_dim,
                out_channels=channels,
                kernel_size=1,
                stride=1,
                padding=0,
                bias=False,
                norm=True,
                activation=True,
            )
        )

        self.fusion_conv_postres = torch.compile(
            Conv2DBlock(
                in_channels=2 * channels,
                out_channels=channels,
                kernel_size=kernel_size,
                stride=1,
                padding=1,
                bias=False,
                norm=True,
                activation=True,
            )
        )

    def forward(self, x):

        x_res = x.clone()

        # local_repr
        x = self.local_conv(x)

        ph, pw = self.patch_size

        # global_repr
        _, _, h, w = x.shape
        x = rearrange(  # reshape the image into patches for ViT input
            x,
            "b d (h ph) (w pw) -> (b h w) (ph pw) d",
            ph=ph,
            pw=pw,
        )
        x = self.global_transformer(x)
        x = rearrange(
            x,
            "(b h w) (ph pw) d -> b d (h ph) (w pw)",
            h=h // ph,
            w=w // pw,
            ph=ph,
            pw=pw,
        )

        # fusion
        x = self.fusion_conv_preres(x)
        x = torch.cat([x, x_res], dim=1)
        x = self.fusion_conv_postres(x)

        return x


class MobileViT(pl.LightningModule):

    def __init__(
        self,
        dims,
        conv_channels,
        num_classes,
        expand_ratio=4,
        patch_size=(2, 2),
    ):
        super(MobileViT, self).__init__()

        self.dims = dims
        self.conv_channels = conv_channels
        self.num_classes = num_classes
        self.expand_ratio = expand_ratio
        self.patch_size = patch_size
        self.kernel_size = 3

        L = [2, 4, 3]

        self.in_conv = Conv2DBlock(
            in_channels=3,
            out_channels=conv_channels[0],
            kernel_size=self.kernel_size,
            stride=2,
            padding=1,
            norm=True,
            activation=True,
            bias=False,
        )

        self.mv2_blocks = nn.ModuleList([])

        self.mv2_blocks.append(
            MobileBlockV2(
                conv_channels[0], conv_channels[1], 1, expand_ratio=expand_ratio
            )
        )
        self.mv2_blocks.append(
            MobileBlockV2(
                conv_channels[1], conv_channels[2], 2, expand_ratio=expand_ratio
            )
        )
        self.mv2_blocks.append(
            MobileBlockV2(
                conv_channels[2], conv_channels[3], 1, expand_ratio=expand_ratio
            )
        )
        self.mv2_blocks.append(
            MobileBlockV2(
                conv_channels[2], conv_channels[3], 1, expand_ratio=expand_ratio
            )
        )
        self.mv2_blocks.append(
            MobileBlockV2(
                conv_channels[3], conv_channels[4], 2, expand_ratio=expand_ratio
            )
        )
        self.mv2_blocks.append(
            MobileBlockV2(
                conv_channels[5], conv_channels[6], 2, expand_ratio=expand_ratio
            )
        )
        self.mv2_blocks.append(
            MobileBlockV2(
                conv_channels[7], conv_channels[8], 2, expand_ratio=expand_ratio
            )
        )

        self.mvit_blocks = nn.ModuleList([])

        self.mvit_blocks.append(
            MobileViTBlock(
                dims[0],
                L[0],
                conv_channels[5],
                self.kernel_size,
                patch_size,
                int(dims[0] * 2),
            )
        )
        self.mvit_blocks.append(
            MobileViTBlock(
                dims[1],
                L[1],
                conv_channels[7],
                self.kernel_size,
                patch_size,
                int(dims[1] * 4),
            )
        )
        self.mvit_blocks.append(
            MobileViTBlock(
                dims[2],
                L[2],
                conv_channels[9],
                self.kernel_size,
                patch_size,
                int(dims[2] * 4),
            )
        )

        self.final_pw = Conv2DBlock(
            in_channels=conv_channels[-2],
            out_channels=conv_channels[-1],
            kernel_size=1,
            stride=1,
            padding=0,
            groups=1,
            bias=False,
            norm=True,
            activation=False,
        )

        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Linear(conv_channels[-1], num_classes, bias=False)

        ## Training-related members
        self.criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

        self.apply(self.init_weights)  # Initialize weights

    def init_weights(self, m):

        if type(m) == nn.Conv2d:
            nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif type(m) == nn.BatchNorm2d:
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)
        elif type(m) == nn.Linear:
            nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def forward(self, x):

        x = self.in_conv(x)

        for i in range(5):
            x = self.mv2_blocks[i](x)

        x = self.mvit_blocks[0](x)
        x = self.mv2_blocks[5](x)
        x = self.mvit_blocks[1](x)
        x = self.mv2_blocks[6](x)
        x = self.mvit_blocks[2](x)
        x = self.final_pw(x)
        x = self.pool(x)
        x = x.flatten(1)
        x = self.classifier(x)
        return x

    def compute_metrics(self, y_hat, y):
        _, preds = torch.max(y_hat, dim=1)
        acc = (preds == y).float().mean()
        tp = ((preds == y) & (y == 1)).sum().float()
        fp = ((preds != y) & (preds == 1)).sum().float()
        fn = ((preds != y) & (y == 1)).sum().float()

        precision = tp / (tp + fp + 1e-8)
        recall = tp / (tp + fn + 1e-8)
        f1 = 2 * precision * recall / (precision + recall + 1e-8)

        return acc, precision, recall, f1

    def log_metrics(self, phase, loss, acc, prec, rec, f1):
        self.log(f"{phase}/loss", loss, prog_bar=True)
        self.log(f"{phase}/accuracy", acc, prog_bar=True)
        self.log(f"{phase}/precision", prec, prog_bar=True)
        self.log(f"{phase}/recall", rec, prog_bar=True)
        self.log(f"{phase}/f1", f1, prog_bar=True)

    def step(self, batch, phase):
        x, y = batch
        y_hat = self(x)

        loss = self.criterion(y_hat, y)
        acc, prec, rec, f1 = self.compute_metrics(y_hat, y)
        self.log_metrics(phase, loss, acc, prec, rec, f1)

        return loss

    def training_step(self, batch, batch_idx):
        return self.step(batch, "train")

    def validation_step(self, batch, batch_idx):
        return self.step(batch, "val")

    def test_step(self, batch, batch_idx):
        return self.step(batch, "test")

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(), lr=2e-4, amsgrad=True, weight_decay=0.01
        )
        scheduler = {
            "scheduler": torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
                optimizer, T_0=10, T_mult=2, eta_min=1e-6
            ),
            "interval": "step",
        }

        return [optimizer], [scheduler]

In [17]:
from lightning.pytorch.utilities.model_summary import summarize
from lightning.fabric.utilities import measure_flops

In [18]:
def xxs_mvit(classes: int = 10):
    return MobileViT(
        dims=[64, 80, 96],
        conv_channels=[16, 16, 24, 24, 48, 48, 64, 64, 80, 80, 320],
        num_classes=classes,
        expand_ratio=2,
        patch_size=(2, 2),
    )


def xs_mvit(classes: int = 10):
    return MobileViT(
        dims=[96, 120, 144],
        conv_channels=[16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384],
        num_classes=classes,
        expand_ratio=4,
        patch_size=(4, 4),
    )


def s_mvit(classes: int = 10):
    return MobileViT(
        dims=[144, 192, 240],
        conv_channels=[16, 32, 64, 64, 96, 96, 128, 128, 160, 160, 640],
        num_classes=classes,
        expand_ratio=4,
        patch_size=(8, 8),
    )

In [19]:
# CIFAR-10 dataset
transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
    ]
)

LFWPeople_trainval = torchvision.datasets.LFWPeople(
    root="./data", split="train", download=True, transform=transform
)
LFWPeople_test = torchvision.datasets.LFWPeople(
    root="./data", split="test", download=True, transform=transform
)

from torch.utils.data import DataLoader

train_size = int(0.8 * len(LFWPeople_trainval))
val_size = len(LFWPeople_trainval) - train_size


LFWPeople_train, LFWPeople_val = torch.utils.data.random_split(
    LFWPeople_trainval,
    [train_size, val_size],
    generator=torch.Generator().manual_seed(42),
)

Files already downloaded and verified
Files already downloaded and verified


In [None]:
import wandb

from kaggle_secrets import UserSecretsClient

user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("wandb_api")

wandb.login(key=secret_value_0)

In [None]:
wandb.init(project="mobilevit")

In [None]:
import lightning.pytorch as lp

# setup training and wandb

N_LABELS = len(LFWPeople_trainval.class_to_idx)
BATCH_SIZE = 4

wandb_logger = lp.loggers.WandbLogger()

model = xs_mvit(N_LABELS)


train_loader = DataLoader(
    LFWPeople_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=3
)
val_loader = DataLoader(
    LFWPeople_val, batch_size=BATCH_SIZE, shuffle=False, num_workers=3
)
test_loader = DataLoader(LFWPeople_test, batch_size=BATCH_SIZE, shuffle=False)

trainer = pl.Trainer(
    max_epochs=10,
    accelerator="cuda",
    logger=wandb_logger,
    callbacks=[
        lp.callbacks.ModelCheckpoint(
            monitor="val/loss",
            filename="best_model",
            save_top_k=1,
            mode="min",
        ),
        lp.callbacks.EarlyStopping(
            monitor="val/ loss",
            patience=3,
            mode="min",
        ),
    ],
)

wandb_logger.watch(model, log_graph=False)

trainer.fit(model, train_loader, val_loader)

wandb.finish()

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/dario/repos/lis-vit/.venv/lib/python3.11/site-packages/lightning/pytorch/loggers/wandb.py:396: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type              | Params | Mode 
----------------------------------------------------------
0 | in_conv     | Conv2DBlock       | 464    | train
1 | mv2_blocks  | ModuleList        | 183 K  | train
2 | mvit_blocks | ModuleList        | 2.2 M  | train
3 | final_pw    | Conv2DBlock       | 37.6 K | train
4 | pool        | AdaptiveAvgPool2d | 0      | train
5 | classifier  | Linear            | 2.2 M  | train
6 | criterion   | CrossEntropyLoss  | 0      | train
----------------------------------------------------------
4.6 M    

Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

W1203 13:05:05.929000 24231 torch/_dynamo/convert_frame.py:844] [21/8] torch._dynamo hit config.cache_size_limit (8)
W1203 13:05:05.929000 24231 torch/_dynamo/convert_frame.py:844] [21/8]    function: 'inner' (/home/dario/repos/lis-vit/.venv/lib/python3.11/site-packages/torch/_dynamo/external_utils.py:38)
W1203 13:05:05.929000 24231 torch/_dynamo/convert_frame.py:844] [21/8]    last reason: 21/0: len(L['fn']) == 3                                           
W1203 13:05:05.929000 24231 torch/_dynamo/convert_frame.py:844] [21/8] To log all recompilation reasons, use TORCH_LOGS="recompiles".
W1203 13:05:05.929000 24231 torch/_dynamo/convert_frame.py:844] [21/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html.


Epoch 0:   0%|          | 0/1905 [00:00<?, ?it/s]                          

W1203 13:05:27.161000 24231 torch/_dynamo/convert_frame.py:844] [22/8] torch._dynamo hit config.cache_size_limit (8)
W1203 13:05:27.161000 24231 torch/_dynamo/convert_frame.py:844] [22/8]    function: 'forward' (/tmp/ipykernel_24231/740721331.py:64)
W1203 13:05:27.161000 24231 torch/_dynamo/convert_frame.py:844] [22/8]    last reason: 22/0: GLOBAL_STATE changed: grad_mode 
W1203 13:05:27.161000 24231 torch/_dynamo/convert_frame.py:844] [22/8] To log all recompilation reasons, use TORCH_LOGS="recompiles".
W1203 13:05:27.161000 24231 torch/_dynamo/convert_frame.py:844] [22/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html.


Epoch 0: 100%|██████████| 1905/1905 [05:29<00:00,  5.79it/s, v_num=52k4, train/loss=8.370, train/accuracy=0.000, train/precision=0.000, train/recall=0.000, train/f1=0.000, val/loss=8.380, val/accuracy=0.0373, val/precision=0.000, val/recall=0.000, val/f1=0.000]

RuntimeError: Early stopping conditioned on metric `val_loss` which is not available. Pass in or modify your `EarlyStopping` callback to use any of the following: `train/loss`, `train/accuracy`, `train/precision`, `train/recall`, `train/f1`, `val/loss`, `val/accuracy`, `val/precision`, `val/recall`, `val/f1`

In [None]:
# flush cuda cache
torch.cuda.empty_cache()


wandb.finish()