# Training vit_small_patch16_224 from scratch (no pre-trained weights) on the Food-101 dataset with the default attention block #
## Author - Thomas O'Sullivan ##

### This notebook reduces the model to have 10/12 attention blocks, a drop rate of 0.3, batch size of 32, weight decay of 1e-4, and is trained with a base learning rate of 1e-5 as it progresses through our LR schedule defined in cell 4. ###

### This cell imports libraries for deep learning, data handling, and visualization, and sets random seeds for reproducibility. ###

In [1]:
import time
import copy
import math
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import numpy as np
import random


torch.manual_seed(31)
random.seed(31)
np.random.seed(31)

### This cell defines image transformations for training and validation. Training data is augmented with cropping, flipping, and color jittering, while validation data is resized and center cropped. Both are converted to tensors and normalized using ImageNet statistics. ###

In [2]:
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])


### This cell loads the Food101 dataset with the defined transforms and creates data loaders for training and validation. It sets a batch size of 32 and enables shuffling for the training loader. It also prints the number of samples in each set. ###

In [3]:
train_dataset = torchvision.datasets.Food101(
    root='./data',
    split='train',
    transform=train_transform,
    download=True
)

val_dataset = torchvision.datasets.Food101(
    root='./data',
    split='test',
    transform=val_transform,
    download=True
)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)  ####### Expiremental #######
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2)  ####### Expiremental #######

print(f"Train samples: {len(train_dataset)}")
print(f"Val samples:   {len(val_dataset)}")


### This cell defines a custom Vision Transformer model using only the first 10 transformer blocks of vit_small_patch16_224. It disables pretraining, sets dropout rates, and customizes the forward pass. The model is then moved to GPU if available and printed. ###

In [4]:
from timm import create_model
import torch.nn as nn

class ViTLayerReduction(nn.Module):
    def __init__(self):
        super().__init__()
        full_model = create_model(
            "vit_small_patch16_224",
            pretrained=False,
            num_classes=101,  ####### Expiremental #######
            drop_rate=0.3,  ####### Expiremental #######
            drop_path_rate=0.1 ####### Expiremental #######
        )

        self.patch_embed = full_model.patch_embed
        self.cls_token = full_model.cls_token
        self.pos_embed = full_model.pos_embed
        self.pos_drop = full_model.pos_drop

        self.blocks = nn.Sequential(*list(full_model.blocks[:10]))  ####### Expiremental #######

        self.norm = full_model.norm
        self.head = full_model.head

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)
        x = self.blocks(x)
        x = self.norm(x)
        return self.head(x[:, 0])


model = ViTLayerReduction()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(device)
print(model)



### This cell defines the learning rate schedule using warmup, linear ramp-up, and cosine decay phases. It also sets hyperparameters like learning rates, epoch counts, and weight decay. The compute_scheduled_lr function calculates the learning rate for a given epoch. ###

In [5]:
base_learning_rate = 1e-5 ####### Expiremental #######
peak_learning_rate = 1e-4 ####### Expiremental #######
final_lr_fraction = 0.10 ####### Expiremental #######

num_epochs = 60 ####### Expiremental #######
warmup_epochs = 15 ####### Expiremental #######
rampup_epochs = 15 ####### Expiremental #######
decay_epochs = num_epochs - (warmup_epochs + rampup_epochs)

weight_decay = 1e-4 ####### Expiremental #######

lr_history = []

####### Expiremental #######
def compute_scheduled_lr(epoch_step):
    """Flat Warmup -> Linear Rampup -> Cosine Decay."""
    if epoch_step < warmup_epochs:
        return base_learning_rate
    elif epoch_step < warmup_epochs + rampup_epochs:
        progress = (epoch_step - warmup_epochs) / rampup_epochs
        return base_learning_rate + progress * (peak_learning_rate - base_learning_rate)
    else:
        decay_progress = (epoch_step - warmup_epochs - rampup_epochs) / max(1, decay_epochs)
        cosine_decay = 0.5 * (1 + math.cos(math.pi * decay_progress))
        return final_lr_fraction * peak_learning_rate + (1 - final_lr_fraction) * peak_learning_rate * cosine_decay



### This cell trains the model over multiple epochs using the AdamW optimizer and label smoothed cross entropy loss. It logs training/validation loss, accuracy, and learning rate to TensorBoard, while applying the custom learning rate schedule. The model weights with the best validation accuracy are saved and restored at the end. ###

In [6]:
optimizer = optim.AdamW(model.parameters(), lr=base_learning_rate, weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter('runs/food101_vit_experiment')

best_val_acc = 0.0
best_model_wts = None

for epoch in range(num_epochs):
    epoch_start = time.time()
    print(f"Epoch {epoch+1}/{num_epochs}")
    print("-" * 40)

    model.train()
    running_loss = 0.0
    running_corrects = 0

    for inputs, labels in tqdm(train_loader, desc="Training", leave=False):
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        _, preds = torch.max(outputs, 1)
        running_corrects += torch.sum(preds == labels.data)

    epoch_loss = running_loss / len(train_dataset)
    epoch_acc = running_corrects.double() / len(train_dataset)
    print(f"Train Loss: {epoch_loss:.4f}  Train Acc: {epoch_acc:.4f}")

    model.eval()
    val_running_loss = 0.0
    val_running_corrects = 0

    with torch.no_grad():
        for inputs, labels in tqdm(val_loader, desc="Validation", leave=False):
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_running_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            val_running_corrects += torch.sum(preds == labels.data)

    val_epoch_loss = val_running_loss / len(val_dataset)
    val_epoch_acc = val_running_corrects.double() / len(val_dataset)
    print(f"Val Loss: {val_epoch_loss:.4f}  Val Acc: {val_epoch_acc:.4f}")


    writer.add_scalar('Loss/Train', epoch_loss, epoch)
    writer.add_scalar('Accuracy/Train', epoch_acc.item(), epoch)
    writer.add_scalar('Loss/Validation', val_epoch_loss, epoch)
    writer.add_scalar('Accuracy/Validation', val_epoch_acc.item(), epoch)

    current_lr_epoch = epoch
    new_lr = compute_scheduled_lr(current_lr_epoch)
    for param_group in optimizer.param_groups:
        param_group['lr'] = new_lr
    lr_history.append(new_lr)
    writer.add_scalar('Learning Rate', new_lr, epoch)

    if val_epoch_acc > best_val_acc:
        best_val_acc = val_epoch_acc
        best_model_wts = copy.deepcopy(model.state_dict())

    epoch_duration = time.time() - epoch_start
    print(f"Epoch {epoch+1} completed in {epoch_duration:.2f} seconds\n")

if best_model_wts is not None:
    model.load_state_dict(best_model_wts)
print(f"Best Validation Accuracy: {best_val_acc:.4f}")

writer.close()

