In [1]:
!pip install -q einops

In [8]:
!pip install timm



In [2]:
import os
import math
import time
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms, datasets

from einops import rearrange

seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device:", device)


Device: cuda


In [3]:
class PatchEmbed(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_chans=3, embed_dim=256):
        super().__init__()
        assert img_size % patch_size == 0, "Image size must be divisible by patch size"
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)
        x = x.flatten(2).transpose(1, 2)
        return x

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, dim, num_heads=8, attn_dropout=0., proj_dropout=0.):
        super().__init__()
        assert dim % num_heads == 0
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3)
        self.attn_drop = nn.Dropout(attn_dropout)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_dropout)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x)
        qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.fc1 = nn.Linear(dim, hidden_dim)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_dim, dim)
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class TransformerEncoderBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4.0, dropout=0., attn_dropout=0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = MultiHeadSelfAttention(dim, num_heads, attn_dropout, proj_dropout=dropout)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = FeedForward(dim, int(dim * mlp_ratio), dropout)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x


In [4]:
class ViT(nn.Module):
    def __init__(self, *, img_size=32, patch_size=4, in_chans=3, num_classes=10,
                 embed_dim=256, depth=6, num_heads=8, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        num_patches = self.patch_embed.num_patches
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=dropout)
        self.blocks = nn.ModuleList([
            TransformerEncoderBlock(embed_dim, num_heads, mlp_ratio, dropout) for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.zeros_(m.bias)
            nn.init.ones_(m.weight)

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        cls = x[:, 0]
        out = self.head(cls)
        return out


In [9]:
from torchvision.transforms import RandAugment

mean = (0.4914, 0.4822, 0.4465)
std = (0.2470, 0.2435, 0.2616)

def get_dataloaders(batch_size=128, num_workers=2):
    train_transforms = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        RandAugment(num_ops=2, magnitude=9),   # added strong augmentation
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    test_transforms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    train_ds = datasets.CIFAR10(
        root='./data', train=True, download=True, transform=train_transforms
    )
    test_ds = datasets.CIFAR10(
        root='./data', train=False, download=True, transform=test_transforms
    )

    train_loader = DataLoader(
        train_ds, batch_size=batch_size, shuffle=True,
        num_workers=num_workers, pin_memory=True
    )
    test_loader = DataLoader(
        test_ds, batch_size=batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=True
    )
    return train_loader, test_loader

train_loader, test_loader = get_dataloaders(batch_size=128)


In [10]:
from torch.cuda.amp import autocast, GradScaler

def evaluate(model, dataloader, device):
    model.eval()
    total = 0
    correct = 0
    loss_meter = 0.0
    criterion = nn.CrossEntropyLoss()
    with torch.no_grad():
        for x, y in dataloader:
            x = x.to(device)
            y = y.to(device)
            out = model(x)
            loss = criterion(out, y)
            loss_meter += float(loss) * x.size(0)
            preds = out.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += x.size(0)
    return loss_meter / total, 100.0 * correct / total

def train_one_epoch(model, optimizer, loader, device, scaler, epoch, print_every=100):
    model.train()
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    running_loss = 0.0
    pbar = tqdm(enumerate(loader), total=len(loader))
    for i, (x, y) in pbar:
        x = x.to(device)
        y = y.to(device)
        optimizer.zero_grad()
        with autocast():
            out = model(x)
            loss = criterion(out, y)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += float(loss.item()) * x.size(0)
        if (i + 1) % print_every == 0:
            pbar.set_description(f"Epoch {epoch} iter {i+1} loss {running_loss/((i+1)*x.size(0)):.4f}")
    return running_loss / len(loader.dataset)


In [12]:
import math
from timm.scheduler.cosine_lr import CosineLRScheduler
img_size = 32
patch_size = 4
embed_dim = 256
depth = 8
num_heads = 8
mlp_ratio = 4.0
dropout = 0.1

epochs = 65
batch_size = 128
lr = 3e-4
weight_decay = 0.05
model = ViT(
    img_size=img_size,
    patch_size=patch_size,
    embed_dim=embed_dim,
    depth=depth,
    num_heads=num_heads,
    mlp_ratio=mlp_ratio,
    dropout=dropout
).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
steps_per_epoch = len(train_loader)
total_steps = steps_per_epoch * epochs

scheduler = CosineLRScheduler(
    optimizer,
    t_initial=total_steps,
    lr_min=1e-6,
    warmup_t=steps_per_epoch * 5,
    warmup_lr_init=1e-6,
    t_in_epochs=False,
)
scaler = GradScaler()
best_acc = 0.0
global_step = 0
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

for epoch in range(1, epochs + 1):
    model.train()
    running_loss = 0.0

    for i, (x, y) in enumerate(train_loader):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()

        with autocast():
            out = model(x)
            loss = criterion(out, y)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item() * x.size(0)


        scheduler.step_update(global_step)
        global_step += 1

    train_loss = running_loss / len(train_loader.dataset)
    val_loss, val_acc = evaluate(model, test_loader, device)
    print(f"Epoch {epoch} TrainLoss {train_loss:.4f} ValLoss {val_loss:.4f} ValAcc {val_acc:.2f}%")

    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), "best_vit_cifar10.pth")

print("Best Val Acc: {:.2f}%".format(best_acc))


  scaler = GradScaler()
  with autocast():


Epoch 1 TrainLoss 2.1444 ValLoss 1.8939 ValAcc 28.72%
Epoch 2 TrainLoss 1.9777 ValLoss 1.8098 ValAcc 34.27%
Epoch 3 TrainLoss 1.8091 ValLoss 1.4925 ValAcc 45.97%
Epoch 4 TrainLoss 1.6832 ValLoss 1.3190 ValAcc 52.51%
Epoch 5 TrainLoss 1.6087 ValLoss 1.2730 ValAcc 53.92%
Epoch 6 TrainLoss 1.5529 ValLoss 1.1802 ValAcc 58.77%
Epoch 7 TrainLoss 1.5075 ValLoss 1.1403 ValAcc 59.83%
Epoch 8 TrainLoss 1.4704 ValLoss 1.1355 ValAcc 59.57%
Epoch 9 TrainLoss 1.4408 ValLoss 1.0352 ValAcc 63.48%
Epoch 10 TrainLoss 1.4118 ValLoss 1.0197 ValAcc 64.29%
Epoch 11 TrainLoss 1.3826 ValLoss 0.9911 ValAcc 65.86%
Epoch 12 TrainLoss 1.3577 ValLoss 1.0028 ValAcc 65.40%
Epoch 13 TrainLoss 1.3370 ValLoss 0.9262 ValAcc 67.91%
Epoch 14 TrainLoss 1.3107 ValLoss 0.9154 ValAcc 68.48%
Epoch 15 TrainLoss 1.2856 ValLoss 0.8557 ValAcc 70.95%
Epoch 16 TrainLoss 1.2658 ValLoss 0.8575 ValAcc 70.45%
Epoch 17 TrainLoss 1.2491 ValLoss 0.8359 ValAcc 71.66%
Epoch 18 TrainLoss 1.2274 ValLoss 0.8117 ValAcc 72.64%
Epoch 19 TrainLoss 