In [1]:
# Pruning ViT model, Starts Pruning after 4th epoch.# Basic imports and device setup
import os, time, math
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Use GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
print("GPUs:", torch.cuda.device_count())

GPUs: 2


In [2]:
# -----------------------------
# Hyperparameters / configuration
# -----------------------------

# Vision Transformer (ViT) architecture
IMG = 48          # input image resolution (CIFAR10 32x32 -> resized to 48x48)
PATCH = 4         # patch size (each patch is 4x4 pixels)
DIM = 512         # embedding dimension
DEPTH = 8         # number of transformer blocks
HEADS = 8         # number of attention heads
NUM_CLASSES = 10  # CIFAR10

# Training schedule
EPOCHS = 10
WARMUP_EPOCHS = 4  # number of epochs without pruning (full tokens)

# Token pruning configuration
R_MAX = 0.6        # maximum fraction of tokens to drop at the deepest layer
ALPHA = 2.0        # controls how drop rate increases with depth
MIN_TOKENS = 8     # minimum number of non-CLS tokens to keep per layer

# Optimization
LR = 3e-4
BS = 128

In [3]:
class Attention(nn.Module):
    """Multi-head self-attention module.

    Returns both the projected output and the attention weights.
    """
    def __init__(self, dim, heads=8):
        super().__init__()
        self.heads = heads
        # Scale factor for dot-product attention (1/sqrt(d_k))
        self.scale = (dim // heads) ** -0.5
        # Single linear layer to generate query, key, value (Q, K, V)
        self.qkv = nn.Linear(dim, dim * 3)
        # Final projection after concatenating heads
        self.proj = nn.Linear(dim, dim)

    def forward(self, x):
        """x: [B, N, C] where
        B = batch size, N = number of tokens, C = embedding dim
        """
        B, N, C = x.shape
        H = self.heads

        # Compute Q, K, V and reshape to [B, N, 3, H, C//H]
        qkv = self.qkv(x).reshape(B, N, 3, H, C // H)
        # Split the 3-tuple dimension into separate tensors: [B, N, H, C//H]
        q, k, v = qkv.unbind(2)

        # Rearrange to [B, H, N, C//H] for batched attention computation
        q = q.permute(0, 3, 1, 2)
        k = k.permute(0, 3, 1, 2)
        v = v.permute(0, 3, 1, 2)

        # Scaled dot-product attention: [B, H, N, N]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(-1)

        # Apply attention to values and merge heads back to [B, N, C]
        out = (attn @ v).transpose(1, 2).reshape(B, N, C)
        return self.proj(out), attn


In [4]:
class MLP(nn.Module):
    """Feed-forward network used inside each transformer block."""
    def __init__(self, dim):
        super().__init__()
        self.fc1 = nn.Linear(dim, dim * 4)
        self.fc2 = nn.Linear(dim * 4, dim)

    def forward(self, x):
        # GELU activation is standard for transformers
        return self.fc2(F.gelu(self.fc1(x)))


class Block(nn.Module):
    """Single transformer encoder block: LN -> MHA -> residual -> LN -> MLP -> residual."""
    def __init__(self, dim, heads):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = Attention(dim, heads)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp  = MLP(dim)

    def forward(self, x, return_attn=False):
        # Multi-head attention with pre-norm
        a, attn = self.attn(self.norm1(x))
        x = x + a
        # Feed-forward network with pre-norm
        x = x + self.mlp(self.norm2(x))
        if return_attn:
            # Optionally return attention map for analysis / pruning
            return x, attn
        return x


In [5]:
class ViT(nn.Module):
    """Baseline Vision Transformer for CIFAR-10.

    This version does not perform token pruning – it operates on all tokens.
    """
    def __init__(self):
        super().__init__()
        # Patch embedding: 3xHxW -> DIM x (H/PATCH) x (W/PATCH)
        self.patch_embed = nn.Conv2d(3, DIM, PATCH, PATCH)
        N = (IMG // PATCH) ** 2  # number of patches (tokens) per image

        # CLS token and positional embeddings (CLS + N patches)
        self.cls = nn.Parameter(torch.zeros(1, 1, DIM))
        self.pos = nn.Parameter(torch.zeros(1, 1 + N, DIM))

        # Transformer encoder blocks
        self.blocks = nn.ModuleList([
            Block(DIM, HEADS) for _ in range(DEPTH)
        ])

        # Final layer norm and classification head
        self.norm = nn.LayerNorm(DIM)
        self.head = nn.Linear(DIM, NUM_CLASSES)

    def forward(self, x):
        B = x.size(0)

        # Patch embedding: [B, 3, H, W] -> [B, N, DIM]
        x = self.patch_embed(x).flatten(2).transpose(1, 2)

        # Prepend CLS token
        x = torch.cat([self.cls.expand(B, -1, -1), x], dim=1)

        # Add positional embeddings (truncate if tokens were fewer)
        x = x + self.pos[:, :x.size(1), :]

        # Pass through all transformer blocks
        for blk in self.blocks:
            x = blk(x)

        # CLS-based classification
        x = self.norm(x)
        return self.head(x[:, 0])


In [6]:
class SimplePrunedViT(nn.Module):
    """Wrapper around a ViT that performs simple, training-time token pruning.

    At each layer after warmup, we:
      1. Compute attention on the full set of tokens.
      2. Score tokens by how much other tokens attends to them based on the attention weights.
      3. Keep only the top-K tokens (plus CLS) and drop the rest.
      4. Run attention + MLP on the reduced token set.
    """
    def __init__(self, vit_model):
        super().__init__()
        self.m = vit_model
        self.L = len(vit_model.blocks)  # number of layers

    def forward(self, x, epoch=None):
        B = x.size(0)

        # --- Patch embedding and positional encodings ---
        # [B, 3, H, W] -> [B, N, DIM]
        x = self.m.patch_embed(x).flatten(2).transpose(1, 2)
        cls = self.m.cls.expand(B, -1, -1)

        # Concatenate CLS token and add positional encodings
        x = torch.cat([cls, x], dim=1)
        x = x + self.m.pos[:, :x.size(1), :]

        # --- Layer-wise pruning ---
        for l, blk in enumerate(self.m.blocks):
            # No pruning during warmup or when epoch is not provided
            if epoch is None or epoch <= WARMUP_EPOCHS:
                x = blk(x)
                continue

            # 1. Compute attention on current tokens (without changing x yet)
            x_norm = blk.norm1(x)
            _, attn = blk.attn(x_norm)  # attn: [B, H, N, N]

            # Number of non-CLS tokens currently present
            N = x.size(1) - 1
            if N <= 1:
                # Not enough tokens to prune; fall back to standard block
                x = blk(x)
                continue

            # 2. Compute token importance score from CLS attention
            #    - mean over heads, pick CLS row, drop CLS column, then mean over batch
            #    Result shape: [N]
            score = attn.mean(1)[:, 0, 1:1 + N].mean(0)

            # 3. Decide how many tokens to keep
            #    Drop ratio increases with layer depth up to R_MAX.
            drop = R_MAX * ((l + 1) / self.L) ** ALPHA
            keep = int(N * (1 - drop))

            # Ensure we keep at least 1 token and at most N
            keep = max(1, min(N, keep))
            # Additionally enforce global minimum token count
            keep = max(MIN_TOKENS, keep)

            # 4. Select top-K tokens according to scores (excluding CLS)
            _, idx = torch.topk(score, keep)  # indices in [0, N-1] for non-CLS tokens
            idx = idx.sort().values + 1       # shift by +1 to account for CLS at index 0

            # Build final list of token indices to keep: [0] (CLS) + selected tokens
            keep_idx = torch.cat([
                torch.tensor([0], device=x.device, dtype=torch.long),
                idx
            ])

            # Subsample tokens: [B, N+1, C] -> [B, 1+keep, C]
            x = x[:, keep_idx]

            # 5. Apply attention + MLP on the reduced sequence
            x = x + blk.attn(blk.norm1(x))[0]
            x = x + blk.mlp(blk.norm2(x))

        # Final normalization and classification head on CLS token
        x = self.m.norm(x)
        return self.m.head(x[:, 0])


In [7]:
# -----------------------------
# CIFAR-10 data loading
# -----------------------------

# Use local path if available; otherwise fall back to Kaggle working dir
data_root = "data/cifar-10" if os.path.exists("data/cifar-10") else "/kaggle/working"

# Basic data augmentation for training, simple resize for test
train_tf = transforms.Compose([
    transforms.Resize(48),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
])
test_tf = transforms.Compose([
    transforms.Resize(48),
    transforms.ToTensor()
])

# Download + create datasets
train_set = datasets.CIFAR10(data_root, train=True,  download=True, transform=train_tf)
test_set  = datasets.CIFAR10(data_root, train=False, download=True, transform=test_tf)

# Data loaders
batch_size = 128
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True,  num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_set,  batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

print("Train samples:", len(train_set), "Test samples:", len(test_set))


100%|██████████| 170M/170M [00:02<00:00, 77.5MB/s]


Train samples: 50000 Test samples: 10000


In [8]:
# -----------------------------
# Model initialization and optimizer
# -----------------------------

# Base ViT and pruned wrapper
base = ViT().to(device)
model = SimplePrunedViT(base)

# Optionally use DataParallel if multiple GPUs are available
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)

model = model.to(device)

# AdamW optimizer is standard for transformer models
opt = torch.optim.AdamW(model.parameters(), lr=LR)


In [9]:
import time

def format_time(t):
    """Pretty-print a duration in seconds as either seconds or minutes."""
    return f"{t/60:.2f} min" if t > 60 else f"{t:.1f} sec"


@torch.no_grad()
def evaluate(epoch=None):
    """Evaluate model accuracy on the test set.

    If `epoch` is provided, the model will use the corresponding pruning
    behavior (warmup vs pruned) during evaluation as well.
    """
    model.eval()
    correct, total = 0, 0
    for x, y in tqdm(test_loader, desc="Eval", leave=False):
        x, y = x.to(device), y.to(device)
        pred = model(x, epoch=epoch).argmax(1)
        correct += (pred == y).sum().item()
        total += y.size(0)
    return correct / total


# -----------------------------
# Training loop
# -----------------------------
total_train_time = 0

for epoch in range(1, EPOCHS + 1):
    model.train()
    epoch_start = time.time()
    running_loss = 0.0
    seen = 0

    # TQDM progress bar for training
    pbar = tqdm(train_loader, desc=f"[PRUNED] Epoch {epoch}/{EPOCHS}")

    for x, y in pbar:
        x, y = x.to(device), y.to(device)

        opt.zero_grad()
        # Pass epoch to control pruning schedule inside the model
        logits = model(x, epoch=epoch)
        loss = F.cross_entropy(logits, y)
        loss.backward()
        opt.step()

        # Track average loss
        running_loss += loss.item() * x.size(0)
        seen += x.size(0)
        pbar.set_postfix(loss=running_loss / seen)

    epoch_time = time.time() - epoch_start
    total_train_time += epoch_time

    # --- Validation at the end of each epoch ---
    val_acc = evaluate(epoch)

    print(
        f"\nEpoch {epoch} summary:"
        f" time={format_time(epoch_time)}"
        f"  train_loss={running_loss/seen:.4f}"
        f"  val_acc={val_acc*100:.2f}%\n"
    )

print(f"Total training time: {format_time(total_train_time)}")


[PRUNED] Epoch 1/10: 100%|██████████| 391/391 [05:05<00:00,  1.28it/s, loss=1.89]
                                                     


Epoch 1 summary: time=5.09 min  train_loss=1.8912  val_acc=44.23%



[PRUNED] Epoch 2/10: 100%|██████████| 391/391 [05:18<00:00,  1.23it/s, loss=1.29]
                                                     


Epoch 2 summary: time=5.31 min  train_loss=1.2913  val_acc=57.85%



[PRUNED] Epoch 3/10: 100%|██████████| 391/391 [05:19<00:00,  1.22it/s, loss=1.09]
                                                     


Epoch 3 summary: time=5.33 min  train_loss=1.0899  val_acc=61.35%



[PRUNED] Epoch 4/10: 100%|██████████| 391/391 [05:20<00:00,  1.22it/s, loss=0.97]
                                                     


Epoch 4 summary: time=5.34 min  train_loss=0.9702  val_acc=62.83%



[PRUNED] Epoch 5/10: 100%|██████████| 391/391 [03:43<00:00,  1.75it/s, loss=0.892]
                                                     


Epoch 5 summary: time=3.72 min  train_loss=0.8920  val_acc=66.41%



[PRUNED] Epoch 6/10: 100%|██████████| 391/391 [03:43<00:00,  1.75it/s, loss=0.83]
                                                     


Epoch 6 summary: time=3.72 min  train_loss=0.8296  val_acc=68.43%



[PRUNED] Epoch 7/10: 100%|██████████| 391/391 [03:43<00:00,  1.75it/s, loss=0.771]
                                                     


Epoch 7 summary: time=3.73 min  train_loss=0.7708  val_acc=68.97%



[PRUNED] Epoch 8/10: 100%|██████████| 391/391 [03:43<00:00,  1.75it/s, loss=0.727]
                                                     


Epoch 8 summary: time=3.73 min  train_loss=0.7269  val_acc=70.12%



[PRUNED] Epoch 9/10: 100%|██████████| 391/391 [03:43<00:00,  1.75it/s, loss=0.681]
                                                     


Epoch 9 summary: time=3.73 min  train_loss=0.6814  val_acc=70.26%



[PRUNED] Epoch 10/10: 100%|██████████| 391/391 [03:43<00:00,  1.75it/s, loss=0.641]
                                                     


Epoch 10 summary: time=3.73 min  train_loss=0.6412  val_acc=69.97%

Total training time: 43.43 min


