# LIS With MobileViTs

## 1. Our MobileViTs

In [None]:
!pip install lightning torchmetrics einops av

In [None]:
import lightning as pl
import torch
import torch.nn as nn
import torch.functional as F
import wandb
from einops import rearrange
import torchvision
from torch.utils.data import DataLoader
import os
import gc
import numpy as np
from torchmetrics.text import WordErrorRate
from tqdm import tqdm
import gzip
import pickle
from lightning.pytorch.utilities.model_summary import summarize
from lightning.fabric.utilities import measure_flops

In [None]:
# Optimize this impl:
# https://github.com/chinhsuanwu/mobilevit-pytorch/blob/master/mobilevit.py


class Conv2DBlock(nn.Module):

    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride,
        padding,
        groups=1,
        bias=True,
        norm=True,
        activation=True,
        dropout=0.1,
    ):
        """__init__ Constructor for Conv2DBlock

        Parameters
        ----------
        in_channels : int
            Number of input channels
        out_channels : int
            Number of output channels
        kernel_size : int
            Size of the kernel
        stride : int
            Stride of the convolutional layer
        padding : int
            Padding of the convolutional layer
        groups : int
            Number of groups
        bias : bool
            Whether to use bias
        dropout : float
            Dropout rate
        """

        super(Conv2DBlock, self).__init__()

        self.block = nn.Sequential()

        self.block.add_module(
            "conv2d",
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                groups=groups,
                bias=bias,
            ),
        )

        if norm:
            self.block.add_module("norm", nn.BatchNorm2d(out_channels))

        if activation:
            self.block.add_module("activation", nn.SiLU())  # sigmoid(x) * x // SWISH

        if dropout:
            self.block.add_module("dropout", nn.Dropout(dropout))

        self.block = self.block

    def forward(self, x):
        return self.block(x)


class MobileBlockV2(nn.Module):

    def __init__(self, in_channels, out_channels, stride, expand_ratio):
        """__init__ Constructor for MobileBlockV2


        Parameters
        ----------
        in_channels : int
            Number of input channels
        out_channels : int
            Number of output channels
        stride : int
            Stride of the convolutional layer
        expand_ratio : int
            Expansion ratio of the block
        """

        super(MobileBlockV2, self).__init__()

        assert stride in [1, 2], "Stride must be either 1 or 2"

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.stride = stride
        self.expand_ratio = expand_ratio
        self.hidden_dim = int(round(in_channels * expand_ratio))

        self.mbv2 = nn.Sequential()
        self.uses_inverse_residual = (
            self.in_channels == self.out_channels and self.stride == 1
        )

        if self.expand_ratio == 1:
            self.mbv2.add_module(
                "depthwise_3x3",
                Conv2DBlock(  # Depthwise Convolution
                    in_channels=self.hidden_dim,
                    out_channels=self.hidden_dim,
                    kernel_size=3,
                    stride=stride,
                    padding=1,
                    groups=self.hidden_dim,
                    bias=False,
                    norm=True,
                    activation=True,
                ),
            )
            self.mbv2.add_module(
                "pointwise-linear_1x1",
                Conv2DBlock(  # Pointwise-Linear Convolution
                    in_channels=self.hidden_dim,
                    out_channels=self.hidden_dim,
                    kernel_size=1,
                    stride=1,
                    padding=0,
                    groups=self.hidden_dim,
                    bias=False,
                    norm=True,
                    activation=False,
                ),
            )
        else:
            self.mbv2.add_module(
                "pointwise_1x1",
                Conv2DBlock(  # Pointwise-Linear Convolution
                    in_channels=in_channels,
                    out_channels=self.hidden_dim,
                    kernel_size=1,
                    stride=1,
                    padding=0,
                    groups=1,
                    bias=False,
                    norm=True,
                    activation=True,
                ),
            )
            self.mbv2.add_module(
                "depthwise_3x3",
                Conv2DBlock(  # Depthwise Convolution
                    in_channels=self.hidden_dim,
                    out_channels=self.hidden_dim,
                    kernel_size=3,
                    stride=stride,
                    padding=1,
                    groups=self.hidden_dim,
                    bias=False,
                    norm=True,
                    activation=True,
                ),
            )
            self.mbv2.add_module(
                "pointwise-linear_1x1",
                Conv2DBlock(  # Pointwise-Linear Convolution
                    in_channels=self.hidden_dim,
                    out_channels=self.out_channels,
                    kernel_size=1,
                    stride=1,
                    padding=0,
                    groups=1,
                    bias=False,
                    norm=True,
                    activation=False,
                ),
            )

        self.mbv2 = self.mbv2

    def forward(self, x):
        if self.uses_inverse_residual:
            return x + self.mbv2(x)
        else:
            return self.mbv2(x)


class MobileViTBlock(nn.Module):

    def __init__(
        self, hidden_dim, depth, channels, kernel_size, patch_size, mlp_dim, dropout=0.1
    ):

        super(MobileViTBlock, self).__init__()

        self.hidden_dim = hidden_dim
        self.depth = depth
        self.channels = channels
        self.kernel_size = kernel_size
        self.patch_size = patch_size
        self.mlp_dim = mlp_dim
        self.dropout = dropout

        self.local_conv = nn.Sequential(
            Conv2DBlock(
                in_channels=channels,
                out_channels=channels,
                kernel_size=kernel_size,
                stride=1,
                padding=1,
                norm=True,
                activation=True,
                bias=False,
            ),
            Conv2DBlock(
                in_channels=channels,
                out_channels=hidden_dim,
                kernel_size=1,
                stride=1,
                padding=0,
                norm=True,
                activation=True,
                bias=False,
            ),
        )

        self.global_transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=hidden_dim,
                nhead=8,
                dim_feedforward=mlp_dim,
                dropout=dropout,
                batch_first=True,
                activation="gelu",
            ),
            num_layers=depth,
            norm=nn.LayerNorm(hidden_dim),
        )

        self.fusion_conv_preres = Conv2DBlock(
            in_channels=hidden_dim,
            out_channels=channels,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=False,
            norm=True,
            activation=True,
        )

        self.fusion_conv_postres = Conv2DBlock(
            in_channels=2 * channels,
            out_channels=channels,
            kernel_size=kernel_size,
            stride=1,
            padding=1,
            bias=False,
            norm=True,
            activation=True,
        )

    def forward(self, x):

        B, T, C, H, W = x.shape
        x = rearrange(x, "b t c h w -> (b t) c h w")
        x_res = x.clone()

        # local_repr
        x = self.local_conv(x)

        ph, pw = self.patch_size

        # global_repr
        _, _, h, w = x.shape
        x = rearrange(  # reshape the image into patches for ViT input
            x,
            "(b t) d (h ph) (w pw) -> (b h w) (t ph pw) d",
            ph=ph,
            pw=pw,
            b=B,
            t=T,
        )
        x = self.global_transformer(x)
        x = rearrange(
            x,
            "(b h w) (t ph pw) d -> (b t) d (h ph) (w pw)",
            h=h // ph,
            w=w // pw,
            ph=ph,
            pw=pw,
            b=B,
            t=T,
        )

        # fusion
        x = self.fusion_conv_preres(x)
        x = torch.cat([x, x_res], dim=1)
        x = self.fusion_conv_postres(x)

        x = rearrange(x, "(b t) c h w -> b t c h w", b=B, t=T)

        return x


class MobileViT(pl.LightningModule):

    def __init__(
        self,
        dims,
        conv_channels,
        num_classes,
        vocabulary,
        expand_ratio=4,
        patch_size=(2, 2),
    ):
        super(MobileViT, self).__init__()

        self.dims = dims
        self.conv_channels = conv_channels
        self.num_classes = num_classes
        self.expand_ratio = expand_ratio
        self.patch_size = patch_size
        self.kernel_size = 3

        L = [2, 4, 3]

        self.in_conv = Conv2DBlock(
            in_channels=3,
            out_channels=conv_channels[0],
            kernel_size=self.kernel_size,
            stride=2,
            padding=1,
            norm=True,
            activation=True,
            bias=False,
        )

        self.mv2_blocks = nn.ModuleList([])

        self.mv2_blocks.append(
            MobileBlockV2(
                conv_channels[0], conv_channels[1], 1, expand_ratio=expand_ratio
            )
        )
        self.mv2_blocks.append(
            MobileBlockV2(
                conv_channels[1], conv_channels[2], 2, expand_ratio=expand_ratio
            )
        )
        self.mv2_blocks.append(
            MobileBlockV2(
                conv_channels[2], conv_channels[3], 1, expand_ratio=expand_ratio
            )
        )
        self.mv2_blocks.append(
            MobileBlockV2(
                conv_channels[2], conv_channels[3], 1, expand_ratio=expand_ratio
            )
        )
        self.mv2_blocks.append(
            MobileBlockV2(
                conv_channels[3], conv_channels[4], 2, expand_ratio=expand_ratio
            )
        )
        self.mv2_blocks.append(
            MobileBlockV2(
                conv_channels[5], conv_channels[6], 2, expand_ratio=expand_ratio
            )
        )
        self.mv2_blocks.append(
            MobileBlockV2(
                conv_channels[7], conv_channels[8], 2, expand_ratio=expand_ratio
            )
        )

        self.mvit_blocks = nn.ModuleList([])

        self.mvit_blocks.append(
            MobileViTBlock(
                dims[0],
                L[0],
                conv_channels[5],
                self.kernel_size,
                patch_size,
                int(dims[0] * 2),
            )
        )
        self.mvit_blocks.append(
            MobileViTBlock(
                dims[1],
                L[1],
                conv_channels[7],
                self.kernel_size,
                patch_size,
                int(dims[1] * 4),
            )
        )
        self.mvit_blocks.append(
            MobileViTBlock(
                dims[2],
                L[2],
                conv_channels[9],
                self.kernel_size,
                patch_size,
                int(dims[2] * 4),
            )
        )

        self.final_pw = Conv2DBlock(
            in_channels=conv_channels[-2],
            out_channels=conv_channels[-1],
            kernel_size=1,
            stride=1,
            padding=0,
            groups=1,
            bias=False,
            norm=True,
            activation=False,
        )

        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Linear(conv_channels[-1], num_classes, bias=False)

        ## Training-related members
        self.criterion = nn.CTCLoss(blank=0, zero_infinity=True)
        self.wer = WordErrorRate()

        self.vocabulary = vocabulary

        self.apply(self.init_weights)  # Initialize weights

    def init_weights(self, m):

        if type(m) == nn.Conv2d:
            nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif type(m) == nn.BatchNorm2d:
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)
        elif type(m) == nn.Linear:
            nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def forward(self, x):

        B, T, C, H, W = x.shape
        x = rearrange(x, "b t c h w -> (b t) c h w")

        x = self.in_conv(x)

        for i in range(5):
            x = self.mv2_blocks[i](x)

        x = rearrange(x, "(b t) c h w -> b t c h w", b=B, t=T)
        x = self.mvit_blocks[0](x)

        x = rearrange(x, "b t c h w -> (b t) c h w")
        x = self.mv2_blocks[5](x)

        x = rearrange(x, "(b t) c h w -> b t c h w", b=B, t=T)
        x = self.mvit_blocks[1](x)

        x = rearrange(x, "b t c h w -> (b t) c h w")
        x = self.mv2_blocks[6](x)

        x = rearrange(x, "(b t) c h w -> b t c h w", b=B, t=T)
        x = self.mvit_blocks[2](x)

        x = rearrange(x, "b t c h w -> (b t) c h w", b=B, t=T)
        x = self.final_pw(x)
        x = self.pool(x)
        x = x.flatten(1)
        x = self.classifier(x)
        x = rearrange(x, "(b t) p -> t b p", b=B, t=T)  # time major due to CTC loss
        return x

    def step(self, batch, phase):
        x, y = batch

        B, T, C, H, W = x.shape
        B, N = y.shape

        # assume padding is done in the dataloader
        input_lengths = torch.full((B,), T, dtype=torch.long)
        target_lengths = torch.full((B,), N, dtype=torch.long)

        logits = self(x)

        log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
        # sum one to y as 0 is the blank token in our representation
        loss = self.criterion(log_probs, y + 1, input_lengths, target_lengths)

        self.log(f"{phase}/loss", loss, prog_bar=True)

        if phase != "train":
            word_error_rate = self.calculate_wer(logits, y)
            self.log(f"{phase}/wer", word_error_rate, prog_bar=True)

        return loss

    def calculate_wer(self, logits, y):
        pred_strings = self.string_from_logits(logits)
        target_strings = self.string_from_ground_truth(y)

        pred_strings = [" ".join(pred) for pred in pred_strings]
        target_strings = [" ".join(target) for target in target_strings]

        return self.wer(pred_strings, target_strings)

    def string_from_logits(self, logits):
        decoded = self.ctc_decode(logits)

        return self.string_from_ground_truth(decoded)

    def ctc_decode(self, logits, blank=0):
        logits = torch.argmax(logits, dim=-1)  # Take the most probable class
        decoded = []
        for seq in logits.T:  # Iterate over batch
            result = []
            prev_token = blank
            for token in seq:
                if token != prev_token and token != blank:
                    result.append(token.item())
                prev_token = token
            decoded.append(result)
        return decoded

    def string_from_ground_truth(self, y):
        target_strings = []
        for target in y:
            if isinstance(target, torch.Tensor):
                target = target.tolist()
            target_strings.append(self.vocabulary.decode_from_ids(target))
        return target_strings

    def training_step(self, batch, batch_idx):
        return self.step(batch, "train")

    def validation_step(self, batch, batch_idx):
        return self.step(batch, "val")

    def test_step(self, batch, batch_idx):
        return self.step(batch, "test")

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(), lr=2e-4, amsgrad=True, weight_decay=1e-4
        )
        scheduler = {
            "scheduler": torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
                optimizer, T_0=10, T_mult=2, eta_min=1e-6
            ),
            "interval": "step",
            "frequency": 32,
        }

        return [optimizer], [scheduler]

# 2. Our Datamodules

In [None]:
from dataclasses import dataclass
import random


@dataclass
class PhoenixFiles:
    train: str
    dev: str
    test: str


@dataclass
class PhoenixDataHyperparameters:
    num_workers: int
    batch_size: int
    sample_random: bool
    sample_uniform: bool
    sample_percentage: float
    transforms: torchvision.transforms.Compose


class PhoenixVocabulary:

    def __init__(self, filenames: PhoenixFiles):
        self.filenames = filenames
        self.vocab = self.build_vocab()
        self.vocab_inversed = {v: k for k, v in self.vocab.items()}

    def build_vocab(self):
        vocab = {}

        files = [self.filenames.train, self.filenames.dev, self.filenames.test]

        for file in files:
            with gzip.open(file, "rb") as f:
                annotations = pickle.load(f)

            for ann in tqdm(
                random.sample(annotations, len(annotations)),
                desc=f"Extracting tokens from {os.path.basename(file)}",
            ):

                # random.sample shuffles the strings, this improves token distribution (hopefully improves training)
                for word in random.sample(
                    ann["gloss"].split(), len(ann["gloss"].split())
                ):
                    if word not in vocab:
                        vocab[word] = len(vocab) + 1

        vocab["<blank>"] = 0

        return vocab

    def __len__(self):
        return len(self.vocab)

    def __getitem__(self, idx):
        return self.vocab[idx]

    def encode_as_ids(self, sentence):
        return [self.vocab[word] for word in sentence]

    def decode_from_ids(self, ids):
        return [self.vocab_inversed[tkn] for tkn in ids]


class Phoenix14TDataset(torch.utils.data.Dataset):

    def __init__(self, path: str, vocabulary: PhoenixVocabulary, hyperparameters):
        self.metadata_path = path
        self.base_dir = os.path.dirname(path)
        self.data = None
        self.vocab: PhoenixVocabulary = vocabulary
        self.data_hparams = hyperparameters
        self.transform = hyperparameters.transforms

        assert (
            self.data_hparams.sample_random != self.data_hparams.sample_uniform
        ), "Both random and uniform sampling cannot be enabled at the same time"

    def setup(self, stage):
        with gzip.open(self.metadata_path, "rb") as f:
            self.data = pickle.load(f)

        self.signers = [d["signer"] for d in self.data]
        self.video_names = [d["name"] for d in self.data]
        self.annotations = [d["gloss"] for d in self.data]
        self.text = [d["text"] for d in self.data]
        self.targets = [
            self.vocab.encode_as_ids([token for token in ann.split()])
            for ann in tqdm(
                self.annotations,
                desc=(
                    f"Encoding annotations for {stage}"
                    if stage
                    else "Encoding annotations"
                ),
            )
        ]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):

        video_path = os.path.join(
            self.base_dir, "videos_phoenix", "videos", self.video_names[idx] + ".mp4"
        )

        video = torchvision.io.read_video(
            video_path,
            pts_unit="sec",
            output_format="TCHW",
        )[0]
        

        if self.data_hparams.sample_random:
            indices = sorted(random.sample(
                range(len(video)), len(video) * self.data_hparams.sample_percentage
            ))
        elif self.data_hparams.sample_uniform:
            indices = np.linspace(
                0,
                len(video) - 1,
                int(len(video) * self.data_hparams.sample_percentage),
            ).tolist()
        else:
            indices = list(range(len(video)))
            

        T = len(indices)
        video_tensor = video.unsqueeze(0)
        video_tensor = rearrange(
                self.transform(
                    rearrange(video_tensor[:, indices], "b t c h w -> (b t) c h w")
                ), "(b t) c h w -> b t c h w", c=3, t=T
        ).squeeze(0)

        return video_tensor, torch.tensor(self.targets[idx], dtype=torch.long)

    def get_metadata(self, idx):
        return {
            "signer": self.signers[idx],
            "video_name": self.video_names[idx],
            "annotation": self.annotations[idx],
            "text": self.text[idx],
        }


class Phoenix14TDatamodule(pl.LightningDataModule):

    def __init__(
        self,
        files: PhoenixFiles,
        vocabulary: PhoenixVocabulary,
        num_workers: int,
        data_hyperparameters: PhoenixDataHyperparameters,
    ):
        super(Phoenix14TDatamodule, self).__init__()
        self.metadata_paths = files
        self.vocabulary = vocabulary
        self.workers = num_workers
        self.train = None
        self.dev = None
        self.test = None
        self.data_hparams = data_hyperparameters

        assert (
            self.data_hparams.sample_random != self.data_hparams.sample_uniform
        ), "Both random and uniform sampling cannot be enabled at the same time"

    def setup(self, stage):
        self.train = Phoenix14TDataset(
            self.metadata_paths.train, self.vocabulary, self.data_hparams
        )
        self.train.setup("train")

        self.dev = Phoenix14TDataset(
            self.metadata_paths.dev, self.vocabulary, self.data_hparams
        )
        self.dev.setup("dev")

        self.test = Phoenix14TDataset(
            self.metadata_paths.test, self.vocabulary, self.data_hparams
        )
        self.test.setup("test")

    def train_dataloader(self):
        return DataLoader(
            self.train,
            shuffle=True,
            batch_size=1,
            pin_memory=True,
            num_workers=self.workers,
        )

    def val_dataloader(self):
        return DataLoader(
            self.dev,
            shuffle=False,
            batch_size=1,
            pin_memory=True,
            num_workers=self.workers,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test,
            shuffle=False,
            batch_size=1,
            pin_memory=True,
            num_workers=self.workers,
        )
        


def animate_torch_tensor(video):
    
    # Adapted from https://stackoverflow.com/a/57275596
    
    %matplotlib inline
    from matplotlib import pyplot as plt
    from matplotlib import animation
    from IPython.display import HTML

    
    video_np = video.numpy().transpose(0, 2, 3, 1)

    fig = plt.figure()
    im = plt.imshow(video_np[0,:,:,:])

    plt.close() # this is required to not display the generated image

    def init():
        im.set_data(video_np[0,:,:,:])

    def animate(i):
        im.set_data(video_np[i,:,:,:])
        return im

    anim = animation.FuncAnimation(fig, animate, init_func=init, frames=video_np.shape[0],
                                interval=1000//6)
    return HTML(anim.to_html5_video())

In [None]:
def xxs_mvit(vocab):
    return MobileViT(
        dims=[64, 80, 96],
        conv_channels=[16, 16, 24, 24, 48, 48, 64, 64, 80, 80, 320],
        num_classes=len(vocab),
        vocabulary=vocab,
        expand_ratio=2,
        patch_size=(2, 2),
    )


def xs_mvit(vocab):
    return MobileViT(
        dims=[96, 120, 144],
        conv_channels=[16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384],
        num_classes=len(vocab),
        vocabulary=vocab,
        expand_ratio=4,
        patch_size=(4, 4),
    )


def s_mvit(vocab):
    return MobileViT(
        dims=[144, 192, 240],
        conv_channels=[16, 32, 64, 64, 96, 96, 128, 128, 160, 160, 640],
        num_classes=len(vocab),
        vocabulary=vocab,
        expand_ratio=4,
        patch_size=(8, 8),
    )

In [None]:
import wandb

# UNCOMMENT THIS TO LOG TO WANDB
from kaggle_secrets import UserSecretsClient

user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("wandb_api")


wandb.login(key=secret_value_0)

In [None]:
files = PhoenixFiles(
    train="/kaggle/input/phoenixweather2014t-3rd-attempt/phoenix14t.pami0.train.annotations_only.gzip",
    dev="/kaggle/input/phoenixweather2014t-3rd-attempt/phoenix14t.pami0.dev.annotations_only.gzip",
    test="/kaggle/input/phoenixweather2014t-3rd-attempt/phoenix14t.pami0.test.annotations_only.gzip",
)

transforms = torchvision.transforms.Compose(
    [
        torchvision.transforms.Resize(
            (192, 256), interpolation=torchvision.transforms.InterpolationMode.BICUBIC
        ),  # from uint8 to float32
        lambda x: x / 255.0,
    ]
)

data_hparams = PhoenixDataHyperparameters(
    num_workers=3,
    batch_size=1,
    sample_random=False,
    sample_uniform=True,
    sample_percentage=1 / 3,
    transforms=transforms,
)

vocab = PhoenixVocabulary(files)

dm = Phoenix14TDatamodule(files, vocab, 3, data_hparams)

In [None]:
dm.setup("fit")

## Check out a video

In [None]:
animate_torch_tensor(dm.train[0][0])

In [None]:
# flush cuda cache
torch.cuda.empty_cache()

import lightning.pytorch as lp

wandb.init(project="mobilevit")

# setup training and wandb

BATCH_SIZE = 1  # we don't do padding, so batch size must be 1

wandb_logger = lp.loggers.WandbLogger()

model = xxs_mvit(vocab)


trainer = pl.Trainer(
    max_epochs=100,
    accelerator="cuda",
    logger=wandb_logger,
    callbacks=[
        lp.callbacks.ModelCheckpoint(
            monitor="val/loss",
            filename="best_model",
            save_top_k=1,
            mode="min",
        ),
        lp.callbacks.EarlyStopping(
            monitor="val/loss",
            patience=3,
            mode="min",
        ),
        lp.callbacks.LearningRateMonitor(logging_interval="step"),
    ],
    limit_train_batches=0.25,
    accumulate_grad_batches=32,
)

wandb_logger.watch(model, log_graph=False)

trainer.fit(model, datamodule=dm)

wandb.finish()