In [1]:
import os, time, math
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

device = "cuda" if torch.cuda.is_available() else "cpu"
print("GPUs:", torch.cuda.device_count())

GPUs: 2


In [2]:
IMG = 48
PATCH = 4
DIM = 512
DEPTH = 8
HEADS = 8
NUM_CLASSES = 10

EPOCHS = 10
WARMUP_EPOCHS = 4

R_MAX = 0.6
ALPHA = 2.0
MIN_TOKENS = 8

LR = 3e-4
BS = 128

In [3]:
class Attention(nn.Module):
    def __init__(self, dim, heads=8):
        super().__init__()
        self.heads = heads
        self.scale = (dim // heads) ** -0.5
        self.qkv = nn.Linear(dim, dim*3)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x):
        B, N, C = x.shape
        H = self.heads
        qkv = self.qkv(x).reshape(B, N, 3, H, C//H)
        q, k, v = qkv.unbind(2)
        q = q.permute(0,3,1,2)
        k = k.permute(0,3,1,2)
        v = v.permute(0,3,1,2)

        attn = (q @ k.transpose(-2,-1)) * self.scale
        attn = attn.softmax(-1)

        out = (attn @ v).transpose(1,2).reshape(B, N, C)
        return self.proj(out), attn


In [4]:
class MLP(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.fc1 = nn.Linear(dim, dim*4)
        self.fc2 = nn.Linear(dim*4, dim)

    def forward(self, x):
        return self.fc2(F.gelu(self.fc1(x)))

class Block(nn.Module):
    def __init__(self, dim, heads):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = Attention(dim, heads)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp  = MLP(dim)

    def forward(self, x, return_attn=False):
        a, attn = self.attn(self.norm1(x))
        x = x + a
        x = x + self.mlp(self.norm2(x))
        if return_attn:
            return x, attn
        return x


In [5]:
class ViT(nn.Module):
    def __init__(self):
        super().__init__()
        self.patch_embed = nn.Conv2d(3, DIM, PATCH, PATCH)
        N = (IMG // PATCH)**2
        self.cls = nn.Parameter(torch.zeros(1,1,DIM))
        self.pos = nn.Parameter(torch.zeros(1,1+N,DIM))
        self.blocks = nn.ModuleList([Block(DIM, HEADS) for _ in range(DEPTH)])
        self.norm = nn.LayerNorm(DIM)
        self.head = nn.Linear(DIM, NUM_CLASSES)

    def forward(self, x):
        B = x.size(0)
        x = self.patch_embed(x).flatten(2).transpose(1,2)
        x = torch.cat([self.cls.expand(B,-1,-1), x], dim=1)
        x = x + self.pos[:, :x.size(1), :]
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        return self.head(x[:,0])


In [6]:
class SimplePrunedViT(nn.Module):
    def __init__(self, vit_model):
        super().__init__()
        self.m = vit_model
        self.L = len(vit_model.blocks)

    def forward(self, x, epoch=None):
        B = x.size(0)

        # patch embed
        x = self.m.patch_embed(x).flatten(2).transpose(1,2)
        cls = self.m.cls.expand(B,-1,-1)
        x = torch.cat([cls, x], dim=1)
        x = x + self.m.pos[:, :x.size(1), :]

        for l, blk in enumerate(self.m.blocks):
            # no pruning in warmup
            if epoch is None or epoch <= WARMUP_EPOCHS:
                x = blk(x)
                continue

            # compute attention
            x_norm = blk.norm1(x)
            _, attn = blk.attn(x_norm)

            N = x.size(1) - 1
            if N <= 1:
                x = blk(x)
                continue

            score = attn.mean(1)[:, 0, 1:1+N].mean(0)   # [N]

            drop = R_MAX * ((l+1) / self.L)**ALPHA
            keep = int(N * (1 - drop))
            keep = max(1, min(N, keep))
            keep = max(MIN_TOKENS, keep)

            _, idx = torch.topk(score, keep)
            idx = idx.sort().values + 1

            keep_idx = torch.cat([
                torch.tensor([0], device=x.device, dtype=torch.long),
                idx
            ])

            x = x[:, keep_idx]

            x = x + blk.attn(blk.norm1(x))[0]
            x = x + blk.mlp(blk.norm2(x))

        x = self.m.norm(x)
        return self.m.head(x[:,0])


In [7]:
# CELL 4 - data loaders
data_root = "data/cifar-10" if os.path.exists("data/cifar-10") else "/kaggle/working"
train_tf = transforms.Compose([
    transforms.Resize(48),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
])
test_tf = transforms.Compose([transforms.Resize(48), transforms.ToTensor()])

train_set = datasets.CIFAR10(data_root, train=True, download=True , transform=train_tf)
test_set  = datasets.CIFAR10(data_root, train=False, download=True , transform=test_tf)

batch_size = 128
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_set,  batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

print("Train samples:", len(train_set), "Test samples:", len(test_set))


100%|██████████| 170M/170M [00:02<00:00, 77.5MB/s]


Train samples: 50000 Test samples: 10000


In [8]:
base = ViT().to(device)
model = SimplePrunedViT(base)

if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)

model = model.to(device)

opt = torch.optim.AdamW(model.parameters(), lr=LR)


In [9]:
import time

def format_time(t):
    return f"{t/60:.2f} min" if t > 60 else f"{t:.1f} sec"


@torch.no_grad()
def evaluate(epoch=None):
    model.eval()
    c, t = 0, 0
    for x, y in tqdm(test_loader, desc="Eval", leave=False):
        x, y = x.to(device), y.to(device)
        pred = model(x, epoch=epoch).argmax(1)
        c += (pred == y).sum().item()
        t += y.size(0)
    return c / t


total_train_time = 0

for epoch in range(1, EPOCHS + 1):
    model.train()
    epoch_start = time.time()
    running_loss = 0
    seen = 0

    pbar = tqdm(train_loader, desc=f"[PRUNED] Epoch {epoch}/{EPOCHS}")

    for x, y in pbar:
        x, y = x.to(device), y.to(device)

        opt.zero_grad()
        logits = model(x, epoch=epoch)
        loss = F.cross_entropy(logits, y)
        loss.backward()
        opt.step()

        running_loss += loss.item() * x.size(0)
        seen += x.size(0)

        pbar.set_postfix(loss=running_loss / seen)

    epoch_time = time.time() - epoch_start
    total_train_time += epoch_time

    # --- Validation ---
    val_acc = evaluate(epoch)

    print(
        f"\nEpoch {epoch} summary:"
        f" time={format_time(epoch_time)}"
        f"  train_loss={running_loss/seen:.4f}"
        f"  val_acc={val_acc*100:.2f}%\n"
    )

print(f"Total training time: {format_time(total_train_time)}")


[PRUNED] Epoch 1/10: 100%|██████████| 391/391 [05:05<00:00,  1.28it/s, loss=1.89]
                                                     


Epoch 1 summary: time=5.09 min  train_loss=1.8912  val_acc=44.23%



[PRUNED] Epoch 2/10: 100%|██████████| 391/391 [05:18<00:00,  1.23it/s, loss=1.29]
                                                     


Epoch 2 summary: time=5.31 min  train_loss=1.2913  val_acc=57.85%



[PRUNED] Epoch 3/10: 100%|██████████| 391/391 [05:19<00:00,  1.22it/s, loss=1.09]
                                                     


Epoch 3 summary: time=5.33 min  train_loss=1.0899  val_acc=61.35%



[PRUNED] Epoch 4/10: 100%|██████████| 391/391 [05:20<00:00,  1.22it/s, loss=0.97]
                                                     


Epoch 4 summary: time=5.34 min  train_loss=0.9702  val_acc=62.83%



[PRUNED] Epoch 5/10: 100%|██████████| 391/391 [03:43<00:00,  1.75it/s, loss=0.892]
                                                     


Epoch 5 summary: time=3.72 min  train_loss=0.8920  val_acc=66.41%



[PRUNED] Epoch 6/10: 100%|██████████| 391/391 [03:43<00:00,  1.75it/s, loss=0.83]
                                                     


Epoch 6 summary: time=3.72 min  train_loss=0.8296  val_acc=68.43%



[PRUNED] Epoch 7/10: 100%|██████████| 391/391 [03:43<00:00,  1.75it/s, loss=0.771]
                                                     


Epoch 7 summary: time=3.73 min  train_loss=0.7708  val_acc=68.97%



[PRUNED] Epoch 8/10: 100%|██████████| 391/391 [03:43<00:00,  1.75it/s, loss=0.727]
                                                     


Epoch 8 summary: time=3.73 min  train_loss=0.7269  val_acc=70.12%



[PRUNED] Epoch 9/10: 100%|██████████| 391/391 [03:43<00:00,  1.75it/s, loss=0.681]
                                                     


Epoch 9 summary: time=3.73 min  train_loss=0.6814  val_acc=70.26%



[PRUNED] Epoch 10/10: 100%|██████████| 391/391 [03:43<00:00,  1.75it/s, loss=0.641]
                                                     


Epoch 10 summary: time=3.73 min  train_loss=0.6412  val_acc=69.97%

Total training time: 43.43 min


