In [None]:
# ViT CIFAR-10 (PyTorch) - Colab-ready
# Requirements: torch, torchvision (Colab preinstalls these). Optionally tqdm.
# Run in Colab with GPU runtime selected.

import math
import os
import random
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm

# -------------------------
# Config / Hyperparameters
# -------------------------
class Config:
    seed = 42
    device = "cuda" if torch.cuda.is_available() else "cpu"
    data_dir = "./data"
    num_workers = 4
    batch_size = 256                # adjust if GPU memory small
    epochs = 120
    image_size = 32                 # CIFAR-10 images are 32x32
    patch_size = 4                  # for CIFAR-10: 4 -> 8x8 patches (64 tokens). 16 is too big for 32x32.
    in_channels = 3
    embed_dim = 256                 # model width
    depth = 12                      # number of transformer encoder blocks
    num_heads = 8
    mlp_ratio = 4.0
    dropout = 0.1
    attn_dropout = 0.0
    weight_decay = 0.05
    lr = 3e-4
    betas = (0.9, 0.999)
    warmup_epochs = 10
    use_amp = True                  # mixed precision
    save_dir = "./checkpoints"
    print_freq = 100

cfg = Config()

# -------------------------
# Reproducibility
# -------------------------
def set_seed(seed: int):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(cfg.seed)
os.makedirs(cfg.save_dir, exist_ok=True)

# -------------------------
# Patch Embedding
# -------------------------
class PatchEmbed(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_chans=3, embed_dim=256):
        super().__init__()
        assert img_size % patch_size == 0, "Image size must be divisible by patch size."
        self.patch_size = patch_size
        self.grid_size = img_size // patch_size
        self.num_patches = self.grid_size * self.grid_size
        # Implement patchify with Conv2d: kernel & stride = patch_size
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        # conv output shape: (B, embed_dim, grid, grid) -> flatten to (B, num_patches, embed_dim)

    def forward(self, x):
        # x: (B, C, H, W)
        x = self.proj(x)  # (B, embed_dim, G, G)
        B, E, G, G2 = x.shape
        assert G == G2
        x = x.flatten(2).transpose(1, 2)  # (B, num_patches, embed_dim)
        return x

# -------------------------
# Transformer Blocks
# -------------------------
class MLP(nn.Module):
    def __init__(self, in_features, hidden_features=None, dropout=0.0):
        super().__init__()
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, in_features)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

class TransformerEncoderBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.0, attn_dropout=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim, eps=1e-6)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=attn_dropout, batch_first=True)
        self.dropout1 = nn.Dropout(dropout)
        self.norm2 = nn.LayerNorm(embed_dim, eps=1e-6)
        hidden_dim = int(embed_dim * mlp_ratio)
        self.mlp = MLP(embed_dim, hidden_features=hidden_dim, dropout=dropout)

    def forward(self, x):
        # x: (B, N, E)
        y = self.norm1(x)
        y, _ = self.attn(y, y, y, need_weights=False)
        x = x + self.dropout1(y)
        x = x + self.mlp(self.norm2(x))
        return x

# -------------------------
# Vision Transformer
# -------------------------
class ViT(nn.Module):
    def __init__(self,
                 img_size=32,
                 patch_size=4,
                 in_chans=3,
                 num_classes=10,
                 embed_dim=256,
                 depth=12,
                 num_heads=8,
                 mlp_ratio=4.0,
                 dropout=0.1,
                 attn_dropout=0.0):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size,
                                      in_chans=in_chans, embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches

        # CLS token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

        # Learnable pos embedding for (num_patches + 1)
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))

        self.pos_drop = nn.Dropout(p=dropout)

        # Transformer blocks
        self.blocks = nn.ModuleList([
            TransformerEncoderBlock(embed_dim, num_heads, mlp_ratio, dropout, attn_dropout)
            for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim, eps=1e-6)

        # Classifier head
        self.head = nn.Linear(embed_dim, num_classes)

        # Initialization
        self._init_weights()

    def _init_weights(self):
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        # initialize head
        nn.init.trunc_normal_(self.head.weight, std=0.02)
        if self.head.bias is not None:
            nn.init.zeros_(self.head.bias)

    def forward(self, x):
        # x: (B, C, H, W)
        B = x.shape[0]
        x = self.patch_embed(x)  # (B, N, E)
        cls_tokens = self.cls_token.expand(B, -1, -1)  # (B, 1, E)
        x = torch.cat((cls_tokens, x), dim=1)  # (B, N+1, E)
        x = x + self.pos_embed
        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)

        x = self.norm(x)
        cls_out = x[:, 0]  # (B, E)
        logits = self.head(cls_out)
        return logits

# -------------------------
# Data: CIFAR-10 + Transforms
# -------------------------
def get_dataloaders(batch_size, num_workers=4):
    # Data augmentation tuned for CIFAR-10 training a ViT from scratch
    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.4914, 0.4822, 0.4465),
                             std=(0.2470, 0.2435, 0.2616)),
    ])
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.4914, 0.4822, 0.4465),
                             std=(0.2470, 0.2435, 0.2616)),
    ])
    train_dataset = datasets.CIFAR10(cfg.data_dir, train=True, download=True, transform=train_transform)
    test_dataset = datasets.CIFAR10(cfg.data_dir, train=False, download=True, transform=test_transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    return train_loader, test_loader

# -------------------------
# Utilities
# -------------------------
def accuracy(output, target, topk=(1,)):
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)
        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))
        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append((correct_k.mul_(100.0 / batch_size)).item())
        return res

# -------------------------
# LR schedule: Cosine + linear warmup
# -------------------------
def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0.0):
    warmup_schedule = []
    if warmup_epochs > 0:
        warmup_iters = warmup_epochs * niter_per_ep
        warmup_schedule = list(torch.linspace(start_warmup_value, base_value, warmup_iters))
    iters = epochs * niter_per_ep - len(warmup_schedule)
    if iters > 0:
        schedule = final_value + 0.5 * (base_value - final_value) * (1 + torch.cos(torch.linspace(0, math.pi, iters)))
        schedule = schedule.tolist()
    else:
        schedule = []
    return warmup_schedule + schedule

# -------------------------
# Train / Eval loops
# -------------------------
def train_one_epoch(model, optimizer, data_loader, device, epoch, scaler=None, scheduler=None):
    model.train()
    running_loss = 0.0
    running_acc = 0.0
    pbar = tqdm(enumerate(data_loader), total=len(data_loader), desc=f"Train Epoch {epoch}")
    for i, (images, targets) in pbar:
        images = images.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)
        if scaler is not None:
            with torch.cuda.amp.autocast():
                outputs = model(images)
                loss = F.cross_entropy(outputs, targets)
        else:
            outputs = model(images)
            loss = F.cross_entropy(outputs, targets)

        optimizer.zero_grad()
        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()

        if scheduler is not None:
            scheduler()

        acc1 = accuracy(outputs, targets, topk=(1,))[0]
        running_loss = (running_loss * i + loss.item()) / (i + 1)
        running_acc = (running_acc * i + acc1) / (i + 1)
        if (i + 1) % cfg.print_freq == 0 or i == len(data_loader) - 1:
            pbar.set_postfix(loss=running_loss, acc=running_acc)
    return running_loss, running_acc

def evaluate(model, data_loader, device):
    model.eval()
    total_loss = 0.0
    total_acc = 0.0
    total = 0
    with torch.no_grad():
        for i, (images, targets) in enumerate(tqdm(data_loader, desc="Eval")):
            images = images.to(device, non_blocking=True)
            targets = targets.to(device, non_blocking=True)
            outputs = model(images)
            loss = F.cross_entropy(outputs, targets)
            acc1 = accuracy(outputs, targets, topk=(1,))[0]
            batch_size = images.size(0)
            total_loss = (total_loss * i + loss.item()) / (i + 1)
            total_acc = (total_acc * i + acc1) / (i + 1)
    return total_loss, total_acc

# -------------------------
# Main training entrypoint
# -------------------------
def main():
    device = torch.device(cfg.device)
    print("Device:", device)

    train_loader, test_loader = get_dataloaders(cfg.batch_size, cfg.num_workers)

    model = ViT(img_size=cfg.image_size,
                patch_size=cfg.patch_size,
                in_chans=cfg.in_channels,
                num_classes=10,
                embed_dim=cfg.embed_dim,
                depth=cfg.depth,
                num_heads=cfg.num_heads,
                mlp_ratio=cfg.mlp_ratio,
                dropout=cfg.dropout,
                attn_dropout=cfg.attn_dropout).to(device)

    # optimizer + scheduler (AdamW recommended for ViT)
    optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, betas=cfg.betas, weight_decay=cfg.weight_decay)

    niter_per_epoch = len(train_loader)
    total_iters = cfg.epochs * niter_per_epoch
    warmup_iters = cfg.warmup_epochs * niter_per_epoch
    # Create per-step LR schedule
    lr_schedule = cosine_scheduler(cfg.lr, 1e-5, cfg.epochs, niter_per_epoch, warmup_epochs=cfg.warmup_epochs, start_warmup_value=1e-6)
    # PyTorch requires scheduler.step() every optimizer step; we'll implement a lambda-based scheduler:
    lr_iter = iter(lr_schedule)

    def _scheduler_step():
        try:
            lr = next(lr_iter)
            for pg in optimizer.param_groups:
                pg['lr'] = lr
        except StopIteration:
            pass

    scheduler = _scheduler_step

    scaler = torch.cuda.amp.GradScaler(enabled=cfg.use_amp and device.type == 'cuda')

    best_acc = 0.0
    for epoch in range(1, cfg.epochs + 1):
        train_loss, train_acc = train_one_epoch(model, optimizer, train_loader, device, epoch, scaler=scaler, scheduler=scheduler)
        val_loss, val_acc = evaluate(model, test_loader, device)
        print(f"Epoch {epoch}: Train loss {train_loss:.4f}, Train acc {train_acc:.2f}%, Val loss {val_loss:.4f}, Val acc {val_acc:.2f}%")

        # Save best
        if val_acc > best_acc:
            best_acc = val_acc
            ckpt_path = os.path.join(cfg.save_dir, "best_vit_cifar10.pth")
            torch.save({"epoch": epoch, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "val_acc": val_acc}, ckpt_path)
            print(f"Saved best model (val acc {best_acc:.2f}%) to {ckpt_path}")

    print(f"Training finished. Best Val Acc: {best_acc:.2f}%")

if __name__ == "__main__":
    main()


Device: cpu


  scaler = torch.cuda.amp.GradScaler(enabled=cfg.use_amp and device.type == 'cuda')
  with torch.cuda.amp.autocast():
Train Epoch 1:   5%|▍         | 9/196 [04:37<1:32:41, 29.74s/it]

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm

# ------------------------
# 1. Device setup
# ------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ------------------------
# 2. Data preprocessing
# ------------------------
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = torchvision.datasets.CIFAR10(root="./data", train=True,
                                             transform=transform_train, download=True)
test_dataset = torchvision.datasets.CIFAR10(root="./data", train=False,
                                            transform=transform_test, download=True)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True,
                          num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False,
                         num_workers=2, pin_memory=True)

# ------------------------
# 3. Simple ViT (lite version for Colab)
# ------------------------
class PatchEmbed(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_chans=3, embed_dim=256):
        super().__init__()
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)                # [B, embed_dim, H/patch, W/patch]
        x = x.flatten(2).transpose(1, 2) # [B, num_patches, embed_dim]
        return x

class Attention(nn.Module):
    def __init__(self, dim, num_heads=4, qkv_bias=True):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
        q, k, v = qkv.permute(2, 0, 3, 1, 4)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        return self.proj(x)

class MLP(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.fc1 = nn.Linear(dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, dim)
    def forward(self, x):
        return self.fc2(F.gelu(self.fc1(x)))

class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = Attention(dim, num_heads)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim, int(dim * mlp_ratio))
    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class VisionTransformer(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_chans=3,
                 num_classes=10, embed_dim=256, depth=6, num_heads=4):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))

        self.blocks = nn.ModuleList([
            Block(embed_dim, num_heads) for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed

        for blk in self.blocks:
            x = blk(x)

        x = self.norm(x)
        return self.head(x[:, 0])

# ------------------------
# 4. Training setup
# ------------------------
model = VisionTransformer().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.05)
criterion = nn.CrossEntropyLoss()

scaler = torch.amp.GradScaler("cuda", enabled=(device.type=="cuda"))

# ------------------------
# 5. Training loop
# ------------------------
def train_one_epoch(model, loader, optimizer, criterion, scaler):
    model.train()
    total_loss, total_correct = 0, 0
    for images, labels in tqdm(loader, desc="Training"):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        with torch.amp.autocast("cuda", enabled=(device.type=="cuda")):
            outputs = model(images)
            loss = criterion(outputs, labels)
        if device.type == "cuda":
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()
        total_loss += loss.item() * images.size(0)
        total_correct += outputs.argmax(1).eq(labels).sum().item()
    return total_loss / len(loader.dataset), total_correct / len(loader.dataset)

def evaluate(model, loader, criterion):
    model.eval()
    total_loss, total_correct = 0, 0
    with torch.no_grad():
        for images, labels in tqdm(loader, desc="Evaluating"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            total_loss += loss.item() * images.size(0)
            total_correct += outputs.argmax(1).eq(labels).sum().item()
    return total_loss / len(loader.dataset), total_correct / len(loader.dataset)

# ------------------------
# 6. Run training
# ------------------------
epochs = 10
for epoch in range(1, epochs + 1):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, scaler)
    val_loss, val_acc = evaluate(model, test_loader, criterion)
    print(f"Epoch {epoch}: Train Acc={train_acc:.4f}, Val Acc={val_acc:.4f}")


Using device: cuda


100%|██████████| 170M/170M [01:15<00:00, 2.26MB/s]
Training: 100%|██████████| 391/391 [00:26<00:00, 14.61it/s]
Evaluating: 100%|██████████| 79/79 [00:02<00:00, 27.16it/s]


Epoch 1: Train Acc=0.3345, Val Acc=0.4334


Training: 100%|██████████| 391/391 [00:24<00:00, 15.73it/s]
Evaluating: 100%|██████████| 79/79 [00:02<00:00, 27.41it/s]


Epoch 2: Train Acc=0.4650, Val Acc=0.5221


Training: 100%|██████████| 391/391 [00:24<00:00, 15.66it/s]
Evaluating: 100%|██████████| 79/79 [00:03<00:00, 21.38it/s]


Epoch 3: Train Acc=0.5199, Val Acc=0.5407


Training: 100%|██████████| 391/391 [00:24<00:00, 15.65it/s]
Evaluating: 100%|██████████| 79/79 [00:03<00:00, 25.03it/s]


Epoch 4: Train Acc=0.5463, Val Acc=0.5750


Training: 100%|██████████| 391/391 [00:24<00:00, 15.79it/s]
Evaluating: 100%|██████████| 79/79 [00:02<00:00, 27.35it/s]


Epoch 5: Train Acc=0.5687, Val Acc=0.5809


Training: 100%|██████████| 391/391 [00:25<00:00, 15.47it/s]
Evaluating: 100%|██████████| 79/79 [00:05<00:00, 15.56it/s]


Epoch 6: Train Acc=0.5863, Val Acc=0.6066


Training: 100%|██████████| 391/391 [00:38<00:00, 10.22it/s]
Evaluating: 100%|██████████| 79/79 [00:05<00:00, 15.80it/s]


Epoch 7: Train Acc=0.6015, Val Acc=0.5995


Training: 100%|██████████| 391/391 [00:33<00:00, 11.74it/s]
Evaluating: 100%|██████████| 79/79 [00:03<00:00, 25.22it/s]


Epoch 8: Train Acc=0.6134, Val Acc=0.6083


Training: 100%|██████████| 391/391 [00:24<00:00, 15.89it/s]
Evaluating: 100%|██████████| 79/79 [00:03<00:00, 20.67it/s]


Epoch 9: Train Acc=0.6275, Val Acc=0.6381


Training: 100%|██████████| 391/391 [00:28<00:00, 13.85it/s]
Evaluating: 100%|██████████| 79/79 [00:07<00:00, 11.25it/s]

Epoch 10: Train Acc=0.6400, Val Acc=0.6463



