# Training vit_small_patch16_224 from scratch (no pre-trained weights) on the Food-101 dataset with the multi-key attention block #
## Author - Thomas O'Sullivan ##

### This notebook reduces the model to have 10/12 attention blocks, a drop rate of 0.3, batch size of 32, weight decay of 1e-4, and is trained with a base learning rate of 1e-5 as it progresses through our LR schedule defined in cell 7. ###

### This cell imports libraries for deep learning, data handling, and visualization, and sets random seeds for reproducibility. ###

In [1]:
import time
import copy
import math
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import numpy as np
import random

torch.manual_seed(31)
random.seed(31)
np.random.seed(31)

### This cell defines image transformations for training and validation. Training data is augmented with cropping, flipping, and color jittering, while validation data is resized and center cropped. Both are converted to tensors and normalized using ImageNet statistics. ###

In [2]:
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

### This cell defines functions to generate spatial attention masks for Vision Transformer patches. It calculates patch center positions, builds directional masks (left, right, up, down, identity), and pads them to include the CLS token. The final output is a tuple of 5 expanded torch masks shaped for multi-head attention. ###

In [3]:
def getRowsAndCols(image, patch_size):
    """
    Given an image tensor of shape (C, H, W) and a patch size,
    compute the center row and column of each patch.
    """
    _, H, W = image.shape
    num_patches_y = H // patch_size
    num_patches_x = W // patch_size
    token_rows = []
    token_cols = []
    for i in range(num_patches_y):
        for j in range(num_patches_x):
            center_row = i * patch_size + patch_size / 2
            center_col = j * patch_size + patch_size / 2
            token_rows.append(center_row)
            token_cols.append(center_col)
    return np.array(token_rows), np.array(token_cols)

def makeMasks(token_rows, token_cols):
    """
    Generate 5 spatial masks:
      left_mask, right_mask, up_mask, down_mask, identity_mask
    Each is (num_patches, num_patches).
    """
    num_patches = len(token_rows)
    left_mask = np.zeros((num_patches, num_patches))
    right_mask = np.zeros((num_patches, num_patches))
    up_mask = np.zeros((num_patches, num_patches))
    down_mask = np.zeros((num_patches, num_patches))

    for i in range(num_patches):
        for j in range(num_patches):
            if token_cols[i] < token_cols[j]:
                left_mask[i, j] = 1
            if token_cols[i] > token_cols[j]:
                right_mask[i, j] = 1
            if token_rows[i] < token_rows[j]:
                up_mask[i, j] = 1
            if token_rows[i] > token_rows[j]:
                down_mask[i, j] = 1

    identity_mask = np.eye(num_patches)
    return left_mask, right_mask, up_mask, down_mask, identity_mask

def create_mask_list_for_image(image, patch_size, num_heads, device, dtype):
    """
    Compute a tuple of 5 torch masks, each expanded to (1, num_heads, N, N),
    with an extra row/col for the CLS token.
    """
    token_rows, token_cols = getRowsAndCols(image, patch_size)
    left_np, right_np, up_np, down_np, identity_np = makeMasks(token_rows, token_cols)

    left_np = np.pad(left_np, ((1,0),(1,0)), mode='constant', constant_values=1)
    right_np = np.pad(right_np, ((1,0),(1,0)), mode='constant', constant_values=1)
    up_np = np.pad(up_np, ((1,0),(1,0)), mode='constant', constant_values=1)
    down_np = np.pad(down_np, ((1,0),(1,0)), mode='constant', constant_values=1)
    identity_np = np.pad(identity_np, ((1,0),(1,0)), mode='constant', constant_values=1)

    left_mask = torch.tensor(left_np, device=device, dtype=dtype).unsqueeze(0).unsqueeze(0)
    right_mask = torch.tensor(right_np, device=device, dtype=dtype).unsqueeze(0).unsqueeze(0)
    up_mask = torch.tensor(up_np, device=device, dtype=dtype).unsqueeze(0).unsqueeze(0)
    down_mask = torch.tensor(down_np, device=device, dtype=dtype).unsqueeze(0).unsqueeze(0)
    identity_mask = torch.tensor(identity_np, device=device, dtype=dtype).unsqueeze(0).unsqueeze(0)

    N = left_mask.shape[-1]
    left_mask = left_mask.expand(1, num_heads, N, N)
    right_mask = right_mask.expand(1, num_heads, N, N)
    up_mask = up_mask.expand(1, num_heads, N, N)
    down_mask = down_mask.expand(1, num_heads, N, N)
    identity_mask = identity_mask.expand(1, num_heads, N, N)

    return (left_mask, right_mask, up_mask, down_mask, identity_mask)


### This cell defines a dataset class that returns Food101 images, labels, and their corresponding spatial attention masks. It initializes the training and validation datasets and wraps them in data loaders. Finally, it prints the number of samples in each set. ###

In [None]:
class OnTheFlyMaskedFood101(torch.utils.data.Dataset):
    def __init__(self, root, split, transform, patch_size, num_heads, download=True):
        self.dataset = torchvision.datasets.Food101(
            root=root, split=split, transform=transform, download=download
        )
        self.patch_size = patch_size
        self.num_heads = num_heads
        self.dtype = torch.float32
        self.device = torch.device("cpu")

    def __getitem__(self, idx):
        image, label = self.dataset[idx]
        mask_list = create_mask_list_for_image(
            image, self.patch_size, self.num_heads, device=self.device, dtype=self.dtype
        )
        return image, label, mask_list

    def __len__(self):
        return len(self.dataset)


patch_size = 16 ####### Expiremental #######
batch_size = 32 ####### Expiremental #######
num_heads = 10 ####### Expiremental #######

train_dataset = OnTheFlyMaskedFood101(
    root='./data', split='train', transform=train_transform,
    patch_size=patch_size, num_heads=num_heads, download=True
)

val_dataset = OnTheFlyMaskedFood101(
    root='./data', split='test', transform=val_transform,
    patch_size=patch_size, num_heads=num_heads, download=True
)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

print(f"Train samples: {len(train_dataset)}")
print(f"Test/Val samples: {len(val_dataset)}")


### This cell defines a custom attention module that uses five different key projections, each masked to focus on a specific spatial direction (left, right, up, down, identity). It initializes weights by copying from a standard ViT attention layer and computes five masked attention maps in the forward pass. The outputs from each attention head are summed and passed through a final projection layer. ###

In [5]:
import math
import torch.nn as nn

class CustomAttentionMultipleFiveSpatial(nn.Module):
    def __init__(self, orig_attn: nn.Module, patch_size=16, img_size=224):
        super().__init__()
        self.num_heads = orig_attn.num_heads
        self.embed_dim = orig_attn.qkv.in_features
        self.head_dim = self.embed_dim // self.num_heads
        self.scale = 1.0 / math.sqrt(self.head_dim)
        self.patch_size = patch_size
        self.img_size = img_size


        self.q_linear = nn.Linear(self.embed_dim, self.embed_dim)
        self.v_linear = nn.Linear(self.embed_dim, self.embed_dim)
        self.kA_linear = nn.Linear(self.embed_dim, self.embed_dim)
        self.kB_linear = nn.Linear(self.embed_dim, self.embed_dim)
        self.kC_linear = nn.Linear(self.embed_dim, self.embed_dim)
        self.kD_linear = nn.Linear(self.embed_dim, self.embed_dim)
        self.kE_linear = nn.Linear(self.embed_dim, self.embed_dim)
        self.proj = orig_attn.proj


        qkv_weight = orig_attn.qkv.weight.clone()
        qkv_bias = orig_attn.qkv.bias.clone() if orig_attn.qkv.bias is not None else None


        self.q_linear.weight.data.copy_(qkv_weight[:self.embed_dim, :].clone())
        self.kA_linear.weight.data.copy_(qkv_weight[self.embed_dim:2*self.embed_dim, :].clone())
        self.kB_linear.weight.data.copy_(qkv_weight[self.embed_dim:2*self.embed_dim, :].clone())
        self.kC_linear.weight.data.copy_(qkv_weight[self.embed_dim:2*self.embed_dim, :].clone())
        self.kD_linear.weight.data.copy_(qkv_weight[self.embed_dim:2*self.embed_dim, :].clone())
        self.kE_linear.weight.data.copy_(qkv_weight[self.embed_dim:2*self.embed_dim, :].clone())
        self.v_linear.weight.data.copy_(qkv_weight[2*self.embed_dim:, :].clone())

        if qkv_bias is not None:
            self.q_linear.bias.data.copy_(qkv_bias[:self.embed_dim].clone())
            self.kA_linear.bias.data.copy_(qkv_bias[self.embed_dim:2*self.embed_dim].clone())
            self.kB_linear.bias.data.copy_(qkv_bias[self.embed_dim:2*self.embed_dim].clone())
            self.kC_linear.bias.data.copy_(qkv_bias[self.embed_dim:2*self.embed_dim].clone())
            self.kD_linear.bias.data.copy_(qkv_bias[self.embed_dim:2*self.embed_dim].clone())
            self.kE_linear.bias.data.copy_(qkv_bias[self.embed_dim:2*self.embed_dim].clone())
            self.v_linear.bias.data.copy_(qkv_bias[2*self.embed_dim:].clone())

    def forward(self, query, key=None, value=None, key_padding_mask=None, need_weights=True, attn_mask=None):
        if key is None:
            key = query
        if value is None:
            value = query
        if not hasattr(self, 'current_mask_list'):
            raise ValueError("current_mask_list attribute not set in CustomAttentionMultipleFiveSpatial!")
        mask_list = self.current_mask_list
        left_mask, right_mask, up_mask, down_mask, identity_mask = mask_list

        if left_mask.shape[1] != self.num_heads:
            left_mask = left_mask[:, :self.num_heads, :, :]
            right_mask = right_mask[:, :self.num_heads, :, :]
            up_mask = up_mask[:, :self.num_heads, :, :]
            down_mask = down_mask[:, :self.num_heads, :, :]
            identity_mask = identity_mask[:, :self.num_heads, :, :]

        B, N, _ = query.shape
        q = self.q_linear(query).reshape(B, N, self.num_heads, self.head_dim)
        kA = self.kA_linear(query).reshape(B, N, self.num_heads, self.head_dim)
        kB = self.kB_linear(query).reshape(B, N, self.num_heads, self.head_dim)
        kC = self.kC_linear(query).reshape(B, N, self.num_heads, self.head_dim)
        kD = self.kD_linear(query).reshape(B, N, self.num_heads, self.head_dim)
        kE = self.kE_linear(query).reshape(B, N, self.num_heads, self.head_dim)
        v = self.v_linear(query).reshape(B, N, self.num_heads, self.head_dim)

        q = q.transpose(1, 2) * self.scale
        kA = kA.transpose(1, 2)
        kB = kB.transpose(1, 2)
        kC = kC.transpose(1, 2)
        kD = kD.transpose(1, 2)
        kE = kE.transpose(1, 2)
        v = v.transpose(1, 2)

        attn_ka = (q @ kA.transpose(-2, -1)).softmax(dim=-1) * left_mask
        attn_kb = (q @ kB.transpose(-2, -1)).softmax(dim=-1) * right_mask
        attn_kc = (q @ kC.transpose(-2, -1)).softmax(dim=-1) * up_mask
        attn_kd = (q @ kD.transpose(-2, -1)).softmax(dim=-1) * down_mask
        attn_ke = (q @ kE.transpose(-2, -1)).softmax(dim=-1) * identity_mask

        out_a = attn_ka @ v
        out_b = attn_kb @ v
        out_c = attn_kc @ v
        out_d = attn_kd @ v
        out_e = attn_ke @ v
        out = out_a + out_b + out_c + out_d + out_e

        out = out.transpose(1, 2).reshape(B, N, self.embed_dim)
        out = self.proj(out)
        return out


### This cell defines a custom Vision Transformer that replaces standard attention with the 5-direction masked attention in selected blocks. It keeps a configurable number of transformer blocks and injects CustomAttentionMultipleFiveSpatial into each. The model is instantiated, moved to GPU if available, and printed. ###

In [None]:
from timm.models.vision_transformer import vit_small_patch16_224
import torch.nn as nn

class ViTCustom(nn.Module):
    def __init__(self, num_blocks_to_keep, patch_size=16, img_size=224):
        super().__init__()

        full_model = vit_small_patch16_224(pretrained=False, num_classes=101, drop_rate=0.3, drop_path_rate=0.1) ####### Expiremental #######

        self.patch_embed = full_model.patch_embed
        self.cls_token = full_model.cls_token
        self.pos_embed = full_model.pos_embed
        self.pos_drop = full_model.pos_drop

        self.blocks = nn.Sequential()
        for i, block in enumerate(full_model.blocks[:num_blocks_to_keep]):
            if hasattr(block, 'attn'):
                block.attn = CustomAttentionMultipleFiveSpatial(block.attn, patch_size=patch_size, img_size=img_size)
                print(f"Block {i}: Custom attention injected.")
            else:
                raise AttributeError(f"Block {i} has no attention module.")
            self.blocks.append(block)

        self.norm = full_model.norm
        self.head = full_model.head

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)
        x = self.blocks(x)
        x = self.norm(x)
        return self.head(x[:, 0])

model = ViTCustom(num_blocks_to_keep=10, patch_size=16, img_size=224) ####### Expiremental #######

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(device)
print(model)
print("Model created with custom attention and reduced layers.")


This cell defines a learning rate schedule using flat warmup, linear ramp-up, and cosine decay over 60 epochs. It calculates step-based progress for each phase and returns a scaling factor to apply to the base learning rate. This function will later be used with a PyTorch learning rate scheduler.

In [7]:
base_learning_rate = 1e-5 ####### Expiremental #######
peak_learning_rate = 1e-4 ####### Expiremental #######
final_lr_fraction = 0.10 ####### Expiremental #######

num_epochs = 60 ####### Expiremental #######
warmup_epochs = 15 ####### Expiremental #######
rampup_epochs = 15 ####### Expiremental #######
decay_epochs = num_epochs - (warmup_epochs + rampup_epochs)

total_steps = len(train_loader) * num_epochs
warmup_steps = len(train_loader) * warmup_epochs
rampup_steps = len(train_loader) * rampup_epochs
decay_steps = total_steps - warmup_steps - rampup_steps

weight_decay = 1e-4 ####### Expiremental #######

####### Expiremental #######
def lr_lambda(step):
    if step < warmup_steps:
        return 1.0
    elif step < warmup_steps + rampup_steps:
        progress = (step - warmup_steps) / rampup_steps
        scaled_lr = base_learning_rate + progress * (peak_learning_rate - base_learning_rate)
        return scaled_lr / base_learning_rate
    else:
        progress = (step - warmup_steps - rampup_steps) / max(1, decay_steps)
        cosine = 0.5 * (1 + math.cos(math.pi * progress))
        scaled_lr = final_lr_fraction * peak_learning_rate + (1 - final_lr_fraction) * peak_learning_rate * cosine
        return scaled_lr / base_learning_rate


### This cell sets up the training components: cross-entropy loss with label smoothing, the AdamW optimizer, and a LambdaLR scheduler that applies the custom learning rate schedule. ###

In [8]:
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

optimizer = optim.AdamW(model.parameters(), lr=base_learning_rate, weight_decay=weight_decay)

scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

### This cell runs the training and validation loop, applying spatial masks to each attention block during forward passes. It logs metrics to TensorBoard, updates the learning rate scheduler per step, and tracks the best model weights based on validation accuracy. The best model is restored at the end. ###

In [None]:
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter('runs/food101_multiK_experiment')
best_val_acc = 0.0
best_model_wts = None

global_step = 0

for epoch in range(num_epochs):
    epoch_start = time.time()
    print(f"Epoch {epoch+1}/{num_epochs}")
    print("-" * 40)

    model.train()
    running_loss = 0.0
    running_corrects = 0

    for inputs, labels, batch_mask_list in tqdm(train_loader, desc="Training", leave=False):
        inputs = inputs.to(device)
        labels = labels.to(device)
        batch_mask_list = tuple(m.squeeze(1).to(device) for m in batch_mask_list)
        
        for i in range(len(model.blocks)):
            model.blocks[i].attn.current_mask_list = batch_mask_list

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        _, preds = torch.max(outputs, 1)
        running_corrects += torch.sum(preds == labels.data)

        scheduler.step()
        global_step += 1

    epoch_loss = running_loss / len(train_dataset)
    epoch_acc = running_corrects.double() / len(train_dataset)
    print(f"Train Loss: {epoch_loss:.4f}  Train Acc: {epoch_acc:.4f}")

    model.eval()
    val_running_loss = 0.0
    val_running_corrects = 0

    with torch.no_grad():
        for inputs, labels, batch_mask_list in tqdm(val_loader, desc="Validation", leave=False):
            inputs = inputs.to(device)
            labels = labels.to(device)
            batch_mask_list = tuple(m.squeeze(1).to(device) for m in batch_mask_list)

            for i in range(len(model.blocks)):
                model.blocks[i].attn.current_mask_list = batch_mask_list

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            val_running_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            val_running_corrects += torch.sum(preds == labels.data)

    val_epoch_loss = val_running_loss / len(val_dataset)
    val_epoch_acc = val_running_corrects.double() / len(val_dataset)
    print(f"Val Loss: {val_epoch_loss:.4f}  Val Acc: {val_epoch_acc:.4f}")

    if val_epoch_acc > best_val_acc:
        best_val_acc = val_epoch_acc
        best_model_wts = copy.deepcopy(model.state_dict())

    epoch_duration = time.time() - epoch_start
    print(f"Epoch {epoch+1} completed in {epoch_duration:.2f} seconds\n")

if best_model_wts is not None:
    model.load_state_dict(best_model_wts)
print(f"Best Validation Accuracy: {best_val_acc:.4f}")

writer.close()
