As in CVR_C, using vit_small_patch

Shape example (for 224x224, patch_size=16):
Image x: (B,3,224,224)
After patch_embed: (B, 196, D) - 14x14 = 196 patches

Starting with seeding for reproducibility

In [1]:
#standard imports
import time
import math
import torch
import random
import numpy as np

#PyTorch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm

#PyTorch Lightning
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint

#torchvision
import torchvision
from torchvision import transforms

#logging (Lightning handles TensorBoard automatically)

#set random seeds for reproducibility
torch.manual_seed(31)
np.random.seed(31)
random.seed(31)

Data transforms.  Define augmentation for training and resizing/cropping for validation.  Lightning will use these transforms in the DataModule.

In [2]:
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])

])

Generates masks for custom attention module (5 spatial relations + identity).

In [3]:
def getRowsAndCols(image, patch_size):
    """Compute center row/column of each patch."""
    _, H, W = image.shape
    num_patches_y = H // patch_size
    num_patches_x = W // patch_size
    token_rows, token_cols = [], []
    for i in range(num_patches_y):
        for j in range(num_patches_x):
            token_rows.append(i * patch_size + patch_size / 2)
            token_cols.append(j * patch_size + patch_size / 2)
    return np.array(token_rows), np.array(token_cols)

def makeMasks(token_rows, token_cols):
    """Generate left, right, up, down, identity masks."""
    num_patches = len(token_rows)
    left_mask = np.zeros((num_patches, num_patches))
    right_mask = np.zeros((num_patches, num_patches))
    up_mask = np.zeros((num_patches, num_patches))
    down_mask = np.zeros((num_patches, num_patches))

    for i in range(num_patches):
        for j in range(num_patches):
            if token_cols[i] < token_cols[j]: left_mask[i, j] = 1
            if token_cols[i] > token_cols[j]: right_mask[i, j] = 1
            if token_rows[i] < token_rows[j]: up_mask[i, j] = 1
            if token_rows[i] > token_rows[j]: down_mask[i, j] = 1

    identity_mask = np.eye(num_patches)
    return left_mask, right_mask, up_mask, down_mask, identity_mask

def create_mask_list_for_image(image, patch_size, num_heads, device, dtype):
    """Compute 5 spatial masks and expand for attention heads and CLS token."""
    token_rows, token_cols = getRowsAndCols(image, patch_size)
    left_np, right_np, up_np, down_np, identity_np = makeMasks(token_rows, token_cols)

    # Pad for CLS token
    left_np = np.pad(left_np, ((1,0),(1,0)), constant_values=1)
    right_np = np.pad(right_np, ((1,0),(1,0)), constant_values=1)
    up_np = np.pad(up_np, ((1,0),(1,0)), constant_values=1)
    down_np = np.pad(down_np, ((1,0),(1,0)), constant_values=1)
    identity_np = np.pad(identity_np, ((1,0),(1,0)), constant_values=1)

    # Convert to torch tensors
    def to_tensor(np_mask):
        return torch.tensor(np_mask, device=device, dtype=dtype).unsqueeze(0).unsqueeze(0)
    
    left_mask = to_tensor(left_np).expand(1, num_heads, left_np.shape[0], left_np.shape[1])
    right_mask = to_tensor(right_np).expand(1, num_heads, right_np.shape[0], right_np.shape[1])
    up_mask = to_tensor(up_np).expand(1, num_heads, up_np.shape[0], up_np.shape[1])
    down_mask = to_tensor(down_np).expand(1, num_heads, down_np.shape[0], down_np.shape[1])
    identity_mask = to_tensor(identity_np).expand(1, num_heads, identity_np.shape[0], identity_np.shape[1])

    return (left_mask, right_mask, up_mask, down_mask, identity_mask)


This dataset generates attention masks on the fly for each image.  This mirrors custom attention setup.

In [4]:
class OnTheFlyMaskedFood101(torch.utils.data.Dataset):
    def __init__(self, root, split, transform, patch_size, num_heads, download=True):
        self.dataset = torchvision.datasets.Food101(
            root=root,
            split=split,
            transform=transform,
            download=download
        )
        self.patch_size = patch_size
        self.num_heads = num_heads
        self.dtype = torch.float32
        self.device = torch.device("cpu")
        ''' creates datasets in cpu and hands them off
            it avoids keeping large datasets in GPU memory
            (very wasteful).  Move only the batch needed to GPU during training. '''
    
    def __getitem__(self, idx):
        image, label = self.dataset[idx]
        mask_list = create_mask_list_for_image(
            image,
            self.patch_size,
            self.num_heads,
            device=self.device,
            dtype=self.dtype)
        return image, label, mask_list
    
    def __len__(self):
        return len(self.dataset)

Implements five-spatial-masks attention for ViT

In [5]:
class CustomAttentionMultipleFiveSpatial(nn.Module):
    def __init__(self, orig_attn: nn.Module, patch_size=16, img_size=224):
        super().__init__()
        self.num_heads = orig_attn.num_heads
        self.embed_dim = orig_attn.qkv.in_features
        self.head_dim = self.embed_dim // self.num_heads
        self.scale = 1.0 / math.sqrt(self.head_dim)
        self.patch_size = patch_size
        self.img_size = img_size

        # Separate Q, V, and 5 K linear layers
        self.q_linear = nn.Linear(self.embed_dim, self.embed_dim)
        self.v_linear = nn.Linear(self.embed_dim, self.embed_dim)
        self.kA_linear = nn.Linear(self.embed_dim, self.embed_dim)
        self.kB_linear = nn.Linear(self.embed_dim, self.embed_dim)
        self.kC_linear = nn.Linear(self.embed_dim, self.embed_dim)
        self.kD_linear = nn.Linear(self.embed_dim, self.embed_dim)
        self.kE_linear = nn.Linear(self.embed_dim, self.embed_dim)
        self.proj = orig_attn.proj

        # Initialize weights from original attention module
        qkv_weight = orig_attn.qkv.weight.clone()
        qkv_bias = orig_attn.qkv.bias.clone() if orig_attn.qkv.bias is not None else None
        self.q_linear.weight.data.copy_(qkv_weight[:self.embed_dim, :])
        self.kA_linear.weight.data.copy_(qkv_weight[self.embed_dim:2*self.embed_dim, :])
        self.kB_linear.weight.data.copy_(qkv_weight[self.embed_dim:2*self.embed_dim, :])
        self.kC_linear.weight.data.copy_(qkv_weight[self.embed_dim:2*self.embed_dim, :])
        self.kD_linear.weight.data.copy_(qkv_weight[self.embed_dim:2*self.embed_dim, :])
        self.kE_linear.weight.data.copy_(qkv_weight[self.embed_dim:2*self.embed_dim, :])
        self.v_linear.weight.data.copy_(qkv_weight[2*self.embed_dim:, :])

        if qkv_bias is not None:
            self.q_linear.bias.data.copy_(qkv_bias[:self.embed_dim])
            self.kA_linear.bias.data.copy_(qkv_bias[self.embed_dim:2*self.embed_dim])
            self.kB_linear.bias.data.copy_(qkv_bias[self.embed_dim:2*self.embed_dim])
            self.kC_linear.bias.data.copy_(qkv_bias[self.embed_dim:2*self.embed_dim])
            self.kD_linear.bias.data.copy_(qkv_bias[self.embed_dim:2*self.embed_dim])
            self.kE_linear.bias.data.copy_(qkv_bias[self.embed_dim:2*self.embed_dim])
            self.v_linear.bias.data.copy_(qkv_bias[2*self.embed_dim:])

    def forward(self, query, key=None, value=None, key_padding_mask=None, need_weights=True, attn_mask=None):
        if key is None: key = query
        if value is None: value = query
        if not hasattr(self, 'current_mask_list'):
            raise ValueError("current_mask_list not set!")

        left_mask, right_mask, up_mask, down_mask, identity_mask = self.current_mask_list
        B, N, _ = query.shape

        # Linear projections and reshape
        q = self.q_linear(query).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) * self.scale
        v = self.v_linear(query).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
        k_list = [self.kA_linear, self.kB_linear, self.kC_linear, self.kD_linear, self.kE_linear]
        attn_masks = [left_mask, right_mask, up_mask, down_mask, identity_mask]

        # Compute masked attention for each K
        out = 0
        for k_linear, mask in zip(k_list, attn_masks):
            k = k_linear(query).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
            attn = (q @ k.transpose(-2, -1)).softmax(dim=-1) * mask
            out += attn @ v

        out = out.transpose(1, 2).reshape(B, N, self.embed_dim)
        return self.proj(out)


Creates ViT with custom attention and fewer blocks to reduce computation.

In [6]:
from timm.models.vision_transformer import vit_small_patch16_224

class ViTCustom(nn.Module):
    def __init__(self, num_blocks_to_keep, patch_size=16, img_size=224):
        super().__init__()
        full_model = vit_small_patch16_224(pretrained=False, num_classes=101, drop_rate=0.3, drop_path_rate=0.1)

        self.patch_embed = full_model.patch_embed
        self.cls_token = full_model.cls_token
        self.pos_embed = full_model.pos_embed
        self.pos_drop = full_model.pos_drop

        self.blocks = nn.Sequential()
        for i, block in enumerate(full_model.blocks[:num_blocks_to_keep]):
            if hasattr(block, 'attn'):
                block.attn = CustomAttentionMultipleFiveSpatial(block.attn, patch_size, img_size)
            self.blocks.append(block)

        self.norm = full_model.norm
        self.head = full_model.head

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)
        x = self.blocks(x)
        x = self.norm(x)
        return self.head(x[:, 0])

# # Instantiate model
# patch_size = 16
# num_blocks_to_keep = 10
# model = ViTCustom(num_blocks_to_keep=num_blocks_to_keep, patch_size=patch_size)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model.to(device)


DataModule

    Lightning DataModules separate data preparation and data loaders.

    *setup initializes datasets
    train_dataloader and val_dataloader returns DataLoaders with persistent workers for speed.

    Also, with jupyter notebook with num_workers > 0 code could hang without persistent workers.

In [7]:
class Food101DataModule(pl.LightningDataModule):
    def __init__(self, 
                 data_dir='./data', 
                 batch_size=32,
                 patch_size=16,
                 num_heads=10, 
                 num_workers=2):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.patch_size = patch_size
        self.num_heads = num_heads
        self.num_workers = num_workers

    def setup(self, stage=None):
        self.train_dataset = OnTheFlyMaskedFood101(
            root=self.data_dir,
            split='train',
            transform=train_transform,
            patch_size=self.patch_size,
            num_heads=self.num_heads,
            download=True
        )
        self.val_dataset = OnTheFlyMaskedFood101(
            root=self.data_dir,
            split='test',
            transform=val_transform,
            patch_size=self.patch_size,
            num_heads=self.num_heads,
            download=True
        )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            persistent_workers=True
        )
    
    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            persistent_workers=True
        )

Next, training_step and validation_step replace manual loops.
self.log() automatically logs to TensorBoard.
Optimizer + learning rate schedule are defined in configure_optimizers.
Custom attention masks are applied on the fly just like PyTorch Loop from initial code.

In [8]:
class LitViT(pl.LightningModule):
    def __init__(self, num_blocks_to_keep=10, patch_size=16, img_size=224,
                 num_classes=101, base_lr=1e-5, peak_lr=1e-4, final_lr_frac=0.1,
                 warmup_epochs=15, rampup_epochs=15, num_epochs=60, weight_decay=0.5):
        super().__init__()
        self.save_hyperparameters()
        self.model = ViTCustom(num_blocks_to_keep=num_blocks_to_keep,
                               patch_size=patch_size, img_size=img_size)
        self.criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
        self.epoch_start_time = None

    def on_train_epoch_start(self):
        self.epoch_start_time = time.time()
        print(f"\nStarting epoch {self.current_epoch + 1}/{self.hparams.num_epochs}")

    def on_train_epoch_end(self):
        duration = time.time() - self.epoch_start_time
        print(f"Epoch {self.current_epoch + 1} completed in {duration:.2f} seconds")

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y, mask_list = batch
        batch_mask_list = tuple(m.squeeze(1) for m in mask_list)
        for blk in self.model.blocks:
            blk.attn.current_mask_list = batch_mask_list
        out = self(x)
        loss = self.criterion(out, y)
        acc = (out.argmax(dim=1) == y).float().mean()
        self.log('train_loss', loss, on_epoch=True, prog_bar=True)
        self.log('train_acc', acc, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y, mask_list = batch
        batch_mask_list = tuple(m.squeeze(1) for m in mask_list)
        for blk in self.model.blocks:
            blk.attn.current_mask_list = batch_mask_list
        out = self(x)
        loss = self.criterion(out, y)
        acc = (out.argmax(dim=1) == y).float().mean()
        self.log('val_loss', loss, on_epoch=True, prog_bar=True)
        self.log('val_acc', acc, on_epoch=True, prog_bar=True)

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.hparams.base_lr, weight_decay=self.hparams.weight_decay)

        def lr_lambda(step):
            steps_per_epoch = self.trainer.estimated_stepping_batches / self.hparams.num_epochs
            warmup_steps = steps_per_epoch * self.hparams.warmup_epochs
            rampup_steps = steps_per_epoch * self.hparams.rampup_epochs
            decay_steps = self.trainer.estimated_stepping_batches - warmup_steps - rampup_steps

            if step < warmup_steps:
                return 1.0
            elif step < warmup_steps + rampup_steps:
                progress = (step - warmup_steps) / rampup_steps
                scaled_lr = self.hparams.base_lr + progress * (self.hparams.peak_lr - self.hparams.base_lr)
                return scaled_lr / self.hparams.base_lr
            else:
                progress = (step - warmup_steps - rampup_steps) / max(1, decay_steps)
                cosine = 0.5 * (1 + math.cos(math.pi * progress))
                scaled_lr = self.hparams.final_lr_frac * self.hparams.peak_lr + (1 - self.hparams.final_lr_frac) * self.hparams.peak_lr * cosine
                return scaled_lr / self.hparams.base_lr

        scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
        return [optimizer], [scheduler]


Lightning automatically handles epoch/step loops, device placement, logging, checkpointing.
TensorBoard logs will appear under tb_logs/food101_vit
Model checkpoints are saved automatically based on validation accuracy.

In [None]:
from pytorch_lightning.callbacks import ModelCheckpoint

# Initialize DataModule
data_module = Food101DataModule(batch_size=32, patch_size=16, num_heads=10)

# Initialize LightningModule
lit_model = LitViT(num_blocks_to_keep=10, patch_size=16, img_size=224,
                   num_epochs=60, base_lr=1e-5, peak_lr=1e-4, final_lr_frac=0.1)

# Set up checkpoint callback to save the best model based on validation accuracy
checkpoint_callback = ModelCheckpoint(
    monitor="val_acc",        # the metric to monitor
    mode="max",               # save the model with max val_acc
    save_top_k=1,             # save only the best model
    filename="best-vit-{epoch:02d}-{val_acc:.4f}"  # optional formatting
)

# Initialize Trainer
trainer = pl.Trainer(
    max_epochs=60,
    accelerator='gpu',  # or 'cpu' if no GPU
    devices=1,
    callbacks=[checkpoint_callback],
    log_every_n_steps=10
)

# Start training
trainer.fit(lit_model, datamodule=data_module)

# Path to best checkpoint
print(f"Best model saved at: {checkpoint_callback.best_model_path}")