# Install additional dependencies

In [None]:
!pip install torchtyping

# Imports

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchtyping import TensorType

from torchvision.datasets import ImageFolder
from torchvision.transforms import v2

import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group

In [None]:
import os
from datetime import date

In [None]:
def ddp_setup(rank: int, world_size: int):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    torch.cuda.set_device(rank)
    init_process_group(backend="nccl", rank=rank, world_size=world_size)

# Vision Transformer Model

In [None]:
class PatchAndPositionEmbedding(nn.Module):
    """
    Combines patch and position embeddings for Vision Transformer (ViT) models.

    Args:
        image_size (int, optional): Input image size (square). Default is 384.
        patch_size (int, optional): Patch size. Default is 16.
        embed_dim (int, optional): Embedding dimension. Default is 192.
        n_channels (int, optional): Number of input channels (e.g., 3 for RGB). Default is 3.

    Attributes:
        n_patches (int): Number of patches in the image.
        linear_projection (nn.Linear): Projects patches to embedding dimension.
        cls_token (nn.Parameter): Learnable class token.
        pos_embedding (nn.Parameter): Learnable position embeddings.
        unfold (nn.Unfold): Extracts image patches.
    """
    
    def __init__(self, image_size: int = 224, patch_size: int = 16, embed_dim: int = 192, n_channels: int = 3) -> None:
        """
        Initializes the PatchAndPositionEmbedding module.

        Args:
            image_size (int, optional): Size of the input image (assumed square). Default is 384.
            patch_size (int, optional): Size of each patch to be extracted from the image. Default is 16.
            embed_dim (int, optional): Dimensionality of the output embeddings. Default is 192.
            n_channels (int, optional): Number of input channels (e.g., 3 for RGB). Default is 3.
        """
        
        super(PatchAndPositionEmbedding, self).__init__()
        
        assert image_size % patch_size == 0, f"image_size of {image_size} is not divisible by patch_size of {patch_size}"
        
        self.n_patches = image_size * image_size // patch_size ** 2
        self.image_size = image_size
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        self.n_channels = n_channels
        
        self.linear_projection = nn.Linear(self.patch_size ** 2 * self.n_channels, self.embed_dim)
        self.cls_token = nn.Parameter(torch.zeros(1, self.embed_dim))
        self.pos_embedding = nn.Parameter(torch.zeros(self.n_patches + 1, self.embed_dim))
        self.unfold = nn.Unfold(kernel_size=self.patch_size, stride=self.patch_size)
        
        
    def forward(self, x: TensorType[torch.float32]) -> TensorType[torch.float32]:
        """
        Forward pass that computes patch embeddings, prepends the class token, and adds positional embeddings.

        Args:
            x (TensorType[torch.float32]): Input tensor of shape (C, H, W), where C is the number of channels, 
                              and H, W are the height and width of the image.

        Returns:
            TensorType[torch.float32]: Tensor of shape (batch_size, n_patches + 1, embed_dim), where `n_patches + 1` includes 
                          the class token, and `embed_dim` is the embedding dimension.
        """
        
        # patch embedding
        x = self.unfold(x).transpose(1, 2)
        x = self.linear_projection(x)
        
        # prepending class token
        batch_size = x.shape[0]
        
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        
        # positional embedding
        pos_embeddings = self.pos_embedding.expand(batch_size, -1, -1)
        x = x + pos_embeddings

        return x

In [None]:
class VisionTransformer(nn.Module):
    """
    Vision Transformer (ViT) model for image classification.

    Args:
        image_size (int, optional): Input image size (square). Default is 224.
        patch_size (int, optional): Size of image patches. Default is 16.
        embed_dim (int, optional): Dimensionality of patch embeddings. Default is 192.
        n_layers (int, optional): Number of transformer encoder layers. Default is 12.
        n_heads (int, optional): Number of attention heads. Default is 4.
        mlp_size (int, optional): Size of the MLP layer. Default is 768.
        n_classes (int, optional): Number of output classes. Default is 10.
        n_channels (int, optional): Number of input channels (e.g., 3 for RGB). Default is 3.
        dropout (float, optional): Dropout rate. Default is 0.1.
        batch_first (bool, optional): If True, the batch dimension comes first. Default is False.

    Attributes:
        embedding (PatchAndPositionEmbedding): Patch and position embedding layer.
        transformer_encoder (nn.TransformerEncoder): Transformer encoder with multiple layers.
        MLP (nn.Sequential): Final classification layer with normalization and linear layers.
    """
    
    def __init__(
        self,
        image_size: int = 224,
        patch_size: int = 16,
        embed_dim: int = 192,
        n_layers: int = 12,
        n_heads: int = 4,
        mlp_size: int = 768,
        n_classes: int = 10,
        n_channels: int = 3,
        dropout: float = 0.1,
        batch_first: bool = False
    ) -> None:
        """
        Initializes the VisionTransformer model.

        Args:
            image_size (int, optional): Size of the input image. Default is 224.
            patch_size (int, optional): Size of image patches. Default is 16.
            embed_dim (int, optional): Dimensionality of patch embeddings. Default is 192.
            n_layers (int, optional): Number of transformer encoder layers. Default is 12.
            n_heads (int, optional): Number of attention heads. Default is 4.
            mlp_size (int, optional): Size of the MLP layer. Default is 768.
            n_classes (int, optional): Number of output classes. Default is 10.
            n_channels (int, optional): Number of input channels. Default is 3.
            dropout (float, optional): Dropout rate. Default is 0.1.
            batch_first (bool, optional): If True, batch dimension comes first. Default is False.
        """
        
        super(VisionTransformer, self).__init__()
        
        self.batch_first = batch_first
        
        self.embedding = PatchAndPositionEmbedding(image_size, patch_size, embed_dim, n_channels)
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=n_heads,
            dim_feedforward=mlp_size,
            dropout=dropout,
            activation='gelu',
            batch_first=self.batch_first
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        
        self.MLP = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, n_classes)
        )
        
        self._init_weights()
                
    def _init_weights(self):
        """
        Initializes the weights of the model, including the class token, positional embedding, 
        and the linear projection layer.
        """
        
        nn.init.trunc_normal_(self.embedding.cls_token, std=0.02)
        nn.init.trunc_normal_(self.embedding.pos_embedding, std=0.02)
        nn.init.xavier_normal_(self.embedding.linear_projection.weight)
        nn.init.zeros_(self.embedding.linear_projection.bias)
                    
    def forward(self, x: TensorType[torch.float32]) -> TensorType[torch.float32]:
        """
        Forward pass of the VisionTransformer model.

        Args:
            x (TensorType[torch.float32]): Input image tensor of shape (C, H, W), where C is the number of channels, 
                              and H, W are the height and width of the image.

        Returns:
            TensorType[torch.float32]: Output logits of shape (batch_size, n_classes), representing the class scores.
        """
        
        x = self.embedding(x)
        
        if not self.batch_first: 
            x = x.transpose(0, 1)
        
        x = self.transformer_encoder(x)
        
        cls_token = x[:, 0, :] if self.batch_first else x[0]
        
        logits = self.MLP(cls_token)
        
        return logits

In [None]:
class Trainer():
    """
    Trainer class for training and testing a model using distributed data parallel (DDP).

    Args:
        model (nn.Module): The model to be trained.
        dataloaders (tuple[DataLoader, DataLoader]): Tuple containing the training and testing dataloaders in this order: (train_dataloader, test_dataloader).
        optimizer (optim.Optimizer): Optimizer for updating the model parameters.
        rank (int): Rank of the current process in DDP (GPU ID).

    Attributes:
        rank (int): GPU rank for DDP.
        _model (DDP): Distributed Data Parallel wrapped model.
        _train_dataloader (DataLoader): Dataloader for training data.
        _test_dataloader (DataLoader): Dataloader for testing data.
        _optimizer (optim.Optimizer): Optimizer for the model.
        _loss_fn (nn.CrossEntropyLoss): Loss function for classification tasks.
        _softmax (nn.Softmax): Softmax function to compute probabilities.
    """
    
    def __init__(
        self,
        model: nn.Module,
        dataloaders: tuple[DataLoader, DataLoader],
        optimizer: optim.Optimizer,
        rank: int
    ) -> None:
        """
        Initializes the Trainer with a model, dataloaders, optimizer, and rank.

        Args:
            model (nn.Module): The neural network model to be trained.
            dataloaders (tuple[DataLoader, DataLoader]): Tuple containing the training and testing dataloaders in this order: (train_dataloader, test_dataloader).
            optimizer (optim.Optimizer): Optimizer used for training the model.
            rank (int): Rank of the GPU device for DDP.
        """
        
        self.rank = rank
        self._model = nn.DataParallel(model.to(self.rank))
        self._train_dataloader, self._test_dataloader = dataloaders
        self._optimizer = optimizer
        self._loss_fn = nn.CrossEntropyLoss()
        self._softmax = nn.Softmax(dim=1)
        
    def _train_epoch(self) -> None:
        """
        Performs a single training epoch. Computes loss, backpropagates, and updates model weights.
        Prints average loss every 10 batches.
        """
        
        avg_loss = 0.0

        self._model.train()
        for batch_index, (X, y) in enumerate(self._train_dataloader, 1):
            X, y = X.to(self.rank), y.to(self.rank)

            prediction = self._model(X)
            loss = self._loss_fn(prediction, y)

            avg_loss += loss.item()

            self._optimizer.zero_grad()
            loss.backward()
            self._optimizer.step()

            if batch_index % 10 == 0:
                loss = loss.item()
                current_sample = batch_index * len(X)

                print(
                    f'[GPU-{self.rank}]Current loss: {loss:.5f}, ' \
                    f'Average loss across 10 batches: {avg_loss / 10:.5f} ' \
                    f'[{current_sample} / {len(self._train_dataloader.dataset)}]'
                )
                avg_loss = 0.0
                
    def _test_epoch(self) -> None:
        """
        Evaluates the model on the test dataset, computing the average loss and accuracy.
        """
        
        self._model.eval()
        
        avg_test_loss = 0.0
        n_correct = 0

        for X, y in self._test_dataloader:
            X, y = X.to(self.rank), y.to(self.rank)

            with torch.no_grad():
                prediction = self._model(X)

                loss = self._loss_fn(prediction, y)
                avg_test_loss += loss.item()

                prediction = self._softmax(prediction)
                n_correct += (torch.argmax(prediction, dim=1) == y).sum().item()

        print(
            f'Average test loss: {avg_test_loss / len(self._test_dataloader.dataset):.5f}, ' \
            f'Accuracy: {n_correct / len(self._test_dataloader.dataset) * 100:.2f}%'
        )
        
    def train(self, n_epochs: int, path: str | None) -> None:
        """
        Trains the model for a specified number of epochs and optionally saves the model.

        Args:
            n_epochs (int): Number of epochs to train the model.
            path (str | None): If provided, saves the model to the given path after training 
                               (only on rank 0). If `None`, the model is not saved.
        """
        
        for epoch in range(n_epochs):
            print(f"|--------------------------{epoch + 1}/{N_EPOCHS}--------------------------|")
            self._train_epoch()
            self._test_epoch()
            
        if self.rank == 0 and path:
            self.save_model(path)
            
            
    def save_model(self, path: str) -> None:
        """
        Saves the model's state dictionary to a specified path.

        Args:
            path (str): Path where the model will be saved.
        """
        
        state = self._model.module.state_dict()
        torch.save(state, path)
        print(f'Model saved in: {path}')

# Preparing the data

In [None]:
train_transforms = v2.Compose([
    v2.Resize((224, 224)),
    v2.RandomVerticalFlip(p=0.5),
    v2.PILToTensor(),
    v2.ToDtype(torch.float32, scale=True)
])

test_transforms = v2.Compose([
    v2.Resize((224, 224)),
    v2.PILToTensor(),
    v2.ToDtype(torch.float32, scale=True)
])

In [None]:
train_dataset = ImageFolder(
    root='/kaggle/input/packed-fruits-and-vegetables-recognition-benchmark/train/train',
    transform=train_transforms
)
test_dataset = ImageFolder(
    root='/kaggle/input/packed-fruits-and-vegetables-recognition-benchmark/test/test',
    transform=test_transforms
)

In [None]:
def get_dataloaders() -> tuple[DataLoader, DataLoader]:
    return (
        DataLoader(
            train_dataset,
            batch_size=256,
            shuffle=True,
            pin_memory=True,
            num_workers=4,
#             sampler=DistributedSampler(train_dataset)
        ),
        DataLoader(
            test_dataset,
            batch_size=256,
            shuffle=False,
            pin_memory=True,
            num_workers=4,
#             sampler=DistributedSampler(test_dataset)
        )
    )

In [None]:
X, y = test_dataset[0]

In [None]:
idx_to_class = train_dataset.find_classes(train_dataset.root)[0]
idx_to_class

In [None]:
plt.title(idx_to_class[y])
plt.imshow(X.permute(1, 2, 0))
plt.show()

In [None]:
X.shape

# Training the ViT

In [None]:
ViT = VisionTransformer(batch_first=True, n_classes=len(idx_to_class))

In [None]:
N_EPOCHS = 10
LEARNING_RATE = 8e-5

optimizer = optim.AdamW(ViT.parameters(), lr=LEARNING_RATE)

In [None]:
rank = 'cuda'
trainer = Trainer(ViT, get_dataloaders(), optimizer, rank)
trainer.train(n_epochs=N_EPOCHS, path=f'ViT-1-{date.today().isoformat()}.pth')

# TODO:
- **DONE** revert this change as multiprocessing does not work in jupyter **DONE** replace `nn.DataParallel` with `nn.parallel.DistribiutedDataParallel`
- **DONE** use `X, y = test_dataset[0]` instead of `X, y = next(iter(test_dataloader))` to free some RAM
- **DONE** maybe change the learning rate to `8e-5` to see if model converges faster
- add more markdown cells explaining each step better

# Models' parameters:
 - ViT-1: 
     - image_size: 224,
     - patch_size: 16,
     - embed_dim: 192,
     - n_layers: 12,
     - n_heads: 4,
     - mlp_size: 768,
     - n_classes: 65,
     - n_channels: 3,
     - dropout: 0.1,
     - batch_first: True