In [1]:
import time
import math
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from timm.models.vision_transformer import vit_small_patch16_224
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
import pytorch_lightning as pl

torch.set_float32_matmul_precision('high')

# Set random seeds for reproducibility
torch.manual_seed(31)
random.seed(31)
np.random.seed(31)

In [2]:
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])


In [3]:
def getRowsAndCols(image, patch_size):
    _, H, W = image.shape
    num_patches_y = H // patch_size
    num_patches_x = W // patch_size
    token_rows, token_cols = [], []
    for i in range(num_patches_y):
        for j in range(num_patches_x):
            center_row = i * patch_size + patch_size / 2
            center_col = j * patch_size + patch_size / 2
            token_rows.append(center_row)
            token_cols.append(center_col)
    return np.array(token_rows), np.array(token_cols)

def makeMasks(token_rows, token_cols):
    num_patches = len(token_rows)
    left_mask = np.zeros((num_patches, num_patches))
    right_mask = np.zeros((num_patches, num_patches))
    up_mask = np.zeros((num_patches, num_patches))
    down_mask = np.zeros((num_patches, num_patches))
    for i in range(num_patches):
        for j in range(num_patches):
            if token_cols[i] < token_cols[j]: left_mask[i, j] = 1
            if token_cols[i] > token_cols[j]: right_mask[i, j] = 1
            if token_rows[i] < token_rows[j]: up_mask[i, j] = 1
            if token_rows[i] > token_rows[j]: down_mask[i, j] = 1
    identity_mask = np.eye(num_patches)
    return left_mask, right_mask, up_mask, down_mask, identity_mask

def create_mask_list_for_image(image, patch_size, num_heads, device, dtype):
    token_rows, token_cols = getRowsAndCols(image, patch_size)
    left_np, right_np, up_np, down_np, identity_np = makeMasks(token_rows, token_cols)
    def pad_mask(m): return np.pad(m, ((1,0),(1,0)), mode='constant', constant_values=1)
    left_np, right_np, up_np, down_np, identity_np = map(pad_mask, [left_np, right_np, up_np, down_np, identity_np])
    def to_mask(m): return torch.tensor(m, device=device, dtype=dtype).unsqueeze(0).unsqueeze(0)
    left_mask, right_mask, up_mask, down_mask, identity_mask = map(to_mask, [left_np, right_np, up_np, down_np, identity_np])
    N = left_mask.shape[-1]
    masks = [m.expand(1, num_heads, N, N) for m in [left_mask, right_mask, up_mask, down_mask, identity_mask]]
    return tuple(masks)


In [4]:
class OnTheFlyMaskedFood101(torch.utils.data.Dataset):
    def __init__(self, root, split, transform, patch_size, download=True):
        self.dataset = torchvision.datasets.Food101(root=root, split=split, transform=transform, download=download)
        self.patch_size = patch_size
        #self.num_heads = None
        self.dtype = torch.float32
        self.device = torch.device("cpu")  # mask creation is light-weight

    def __getitem__(self, idx):
        image, label = self.dataset[idx]
        mask_list = create_mask_list_for_image(image, self.patch_size, self.num_heads, device=self.device, dtype=self.dtype)
        return image, label, mask_list

    def __len__(self):
        return len(self.dataset)


In [5]:
class CustomAttentionMultipleFiveSpatial(nn.Module):
    def __init__(self, orig_attn: nn.Module, patch_size=16, img_size=224):
        super().__init__()
        self.num_heads = orig_attn.num_heads
        self.embed_dim = orig_attn.qkv.in_features
        self.head_dim = self.embed_dim // self.num_heads
        self.scale = 1.0 / math.sqrt(self.head_dim)

        self.q_linear = nn.Linear(self.embed_dim, self.embed_dim)
        self.v_linear = nn.Linear(self.embed_dim, self.embed_dim)
        self.kA_linear = nn.Linear(self.embed_dim, self.embed_dim)
        self.kB_linear = nn.Linear(self.embed_dim, self.embed_dim)
        self.kC_linear = nn.Linear(self.embed_dim, self.embed_dim)
        self.kD_linear = nn.Linear(self.embed_dim, self.embed_dim)
        self.kE_linear = nn.Linear(self.embed_dim, self.embed_dim)
        self.proj = orig_attn.proj

        qkv_weight = orig_attn.qkv.weight.clone()
        qkv_bias = orig_attn.qkv.bias.clone()
        self.q_linear.weight.data.copy_(qkv_weight[:self.embed_dim, :])
        for k_lin in [self.kA_linear, self.kB_linear, self.kC_linear, self.kD_linear, self.kE_linear]:
            k_lin.weight.data.copy_(qkv_weight[self.embed_dim:2*self.embed_dim, :])
        self.v_linear.weight.data.copy_(qkv_weight[2*self.embed_dim:, :])
        if qkv_bias is not None:
            self.q_linear.bias.data.copy_(qkv_bias[:self.embed_dim])
            for k_lin in [self.kA_linear, self.kB_linear, self.kC_linear, self.kD_linear, self.kE_linear]:
                k_lin.bias.data.copy_(qkv_bias[self.embed_dim:2*self.embed_dim])
            self.v_linear.bias.data.copy_(qkv_bias[2*self.embed_dim:])

    def forward(self, query, attn_mask=None):
        if not hasattr(self, "current_mask_list"):
            raise ValueError("current_mask_list attribute not set in CustomAttentionMultipleFiveSpatial!")
        left_mask, right_mask, up_mask, down_mask, identity_mask = self.current_mask_list
        B, N, _ = query.shape
        q = self.q_linear(query).reshape(B, N, self.num_heads, -1).transpose(1, 2) * self.scale
        v = self.v_linear(query).reshape(B, N, self.num_heads, -1).transpose(1, 2)
        ks = [self.kA_linear, self.kB_linear, self.kC_linear, self.kD_linear, self.kE_linear]
        masks = [left_mask, right_mask, up_mask, down_mask, identity_mask]
        out = sum((q @ k(query).reshape(B, N, self.num_heads, -1).transpose(1, 2).transpose(-2, -1)).softmax(-1) * m @ v for k, m in zip(ks, masks))
        if left_mask.shape[1] != self.num_heads:
            left_mask = left_mask[:, :self.num_heads, :, :]
            right_mask = right_mask[:, :self.num_heads, :, :]
            up_mask = up_mask[:, :self.num_heads, :, :]
            down_mask = down_mask[:, :self.num_heads, :, :]
            identity_mask = identity_mask[:, :self.num_heads, :, :]

        return self.proj(out.transpose(1, 2).reshape(B, N, -1))


In [6]:
class ViTCustom(nn.Module):
    def __init__(self, num_blocks_to_keep, patch_size=16, img_size=224):
        super().__init__()
        full_model = vit_small_patch16_224(pretrained=False, num_classes=101, drop_rate=0.3, drop_path_rate=0.1)
        self.patch_embed = full_model.patch_embed
        self.cls_token = full_model.cls_token
        self.pos_embed = full_model.pos_embed
        self.pos_drop = full_model.pos_drop
        self.blocks = nn.Sequential()
        for i, block in enumerate(full_model.blocks[:num_blocks_to_keep]):
            block.attn = CustomAttentionMultipleFiveSpatial(block.attn, patch_size=patch_size, img_size=img_size)
            print(f"Injected custom attention in block {i}")
            self.blocks.append(block)
        self.norm = full_model.norm
        self.head = full_model.head

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = self.pos_drop(x + self.pos_embed)
        x = self.blocks(x)
        x = self.norm(x)
        return self.head(x[:, 0])


In [7]:
class Food101DataModule(pl.LightningDataModule):
    def __init__(self, batch_size=32, patch_size=16):
        super().__init__()
        self.batch_size = batch_size
        self.patch_size = patch_size
        #self.num_heads = num_heads # removed handled in LightningModule

    def setup(self, stage=None):
        self.train_dataset = OnTheFlyMaskedFood101("./data", "train", train_transform, self.patch_size)
        self.val_dataset = OnTheFlyMaskedFood101("./data", "test", val_transform, self.patch_size)

    def train_dataloader(self):
        #return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4, persistent_workers=True) 
        #same issue with pyTorch - looks like a bug in Jupyter
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=0, persistent_workers=False)

    def val_dataloader(self):
        #return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=4, persistent_workers=True)
        #same issue with pyTorch - looks like a bug in Jupyter
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=True, num_workers=0, persistent_workers=False)


In [8]:
class LitViT(pl.LightningModule):
    def __init__(self, num_blocks_to_keep=10, patch_size=16, img_size=224, num_classes=101,
                 base_lr=1e-5, peak_lr=1e-4, final_lr_frac=0.1, warmup_epochs=15,
                 rampup_epochs=15, num_epochs=60, weight_decay=0.5):
        super().__init__()
        self.save_hyperparameters()
        self.model = ViTCustom(num_blocks_to_keep, patch_size, img_size)
        self.criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

    def forward(self, x):
        return self.model(x)

    def on_train_epoch_start(self):
        self.epoch_start_time = time.time()
        print(f"\nStarting epoch {self.current_epoch + 1}/{self.hparams.num_epochs}")

    def on_train_epoch_end(self):
        print(f"Epoch {self.current_epoch + 1} completed in {time.time() - self.epoch_start_time:.2f} sec")

    def training_step(self, batch, batch_idx):
        x, y, masks = batch
        batch_masks = tuple(m.squeeze(1) for m in masks)
        for blk in self.model.blocks:
            blk.attn.current_mask_list = batch_masks
        out = self(x)
        loss = self.criterion(out, y)
        acc = (torch.argmax(out, dim=1) == y).float().mean()
        self.log("train_loss", loss, on_epoch=True, prog_bar=True)
        self.log("train_acc", acc, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y, masks = batch
        batch_masks = tuple(m.squeeze(1) for m in masks)
        for blk in self.model.blocks:
            blk.attn.current_mask_list = batch_masks
        out = self(x)
        loss = self.criterion(out, y)
        acc = (torch.argmax(out, dim=1) == y).float().mean()
        self.log("val_loss", loss, on_epoch=True, prog_bar=True)
        self.log("val_acc", acc, on_epoch=True, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.hparams.base_lr, weight_decay=self.hparams.weight_decay)
        def lr_lambda(step):
            total_steps = self.trainer.estimated_stepping_batches
            warmup_steps = total_steps * self.hparams.warmup_epochs / self.hparams.num_epochs
            rampup_steps = total_steps * self.hparams.rampup_epochs / self.hparams.num_epochs
            decay_steps = total_steps - warmup_steps - rampup_steps
            if step < warmup_steps:
                return 1.0
            elif step < warmup_steps + rampup_steps:
                progress = (step - warmup_steps) / rampup_steps
                scaled_lr = self.hparams.base_lr + progress * (self.hparams.peak_lr - self.hparams.base_lr)
                return scaled_lr / self.hparams.base_lr
            else:
                progress = (step - warmup_steps - rampup_steps) / max(1, decay_steps)
                cosine = 0.5 * (1 + math.cos(math.pi * progress))
                scaled_lr = self.hparams.final_lr_frac * self.hparams.peak_lr + (1 - self.hparams.final_lr_frac) * self.hparams.peak_lr * cosine
                return scaled_lr / self.hparams.base_lr
        scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
        return [optimizer], [scheduler]


In [9]:
# DataModule
data_module = Food101DataModule(batch_size=32, patch_size=16)

# Model
lit_model = LitViT(num_blocks_to_keep=10, patch_size=16, img_size=224)

# TensorBoard Logger
logger = TensorBoardLogger("lightning_logs", name="food101_experiment")

# ModelCheckpoint
checkpoint_callback = ModelCheckpoint(
    monitor="val_acc",
    mode="max",
    save_top_k=1,
    filename="best-vit-{epoch:02d}-{val_acc:.4f}"
)

# Trainer
trainer = pl.Trainer(
    max_epochs=60,
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    devices=1,
    logger=logger,
    callbacks=[checkpoint_callback],
    log_every_n_steps=10
)

# Train
trainer.fit(lit_model, datamodule=data_module)

print(f"✅ Best model saved at: {checkpoint_callback.best_model_path}")


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores


Injected custom attention in block 0
Injected custom attention in block 1
Injected custom attention in block 2
Injected custom attention in block 3
Injected custom attention in block 4
Injected custom attention in block 5
Injected custom attention in block 6
Injected custom attention in block 7
Injected custom attention in block 8
Injected custom attention in block 9


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.
c:\Users\MILLAC24\Miniconda3\envs\pytorch_env\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=27` in the `DataLoader` to improve performance.

  | Name      | Type             | Params | Mode 
-------------------------------------------------------
0 | model     | ViTCustom        | 24.1 M | train
1 | criterion | CrossEntropyLoss | 0      | train
-------------------------------------------------------
24.1 M    Trainable params
0         Non-trainable params
24.1 M    Total params
96.277    Total estimated model params size (MB)
239       Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

c:\Users\MILLAC24\Miniconda3\envs\pytorch_env\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:484: Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.
c:\Users\MILLAC24\Miniconda3\envs\pytorch_env\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:433: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=27` in the `DataLoader` to improve performance.


AttributeError: 'OnTheFlyMaskedFood101' object has no attribute 'num_heads'