In [1]:
import time
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as T
import torchvision.datasets as datasets

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)


Device: cuda


In [2]:
IMG = 48
PATCH = 4
DIM = 512
DEPTH = 8
HEADS = 8
NUM_CLASSES = 10

EPOCHS = 10

LR = 3e-4
BS = 128

In [3]:
class Attention(nn.Module):
    def __init__(self, dim, heads=8):
        super().__init__()
        self.heads = heads
        self.scale = (dim // heads) ** -0.5
        self.qkv = nn.Linear(dim, dim*3)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x):
        B, N, C = x.shape
        H = self.heads
        qkv = self.qkv(x).reshape(B, N, 3, H, C//H)
        q, k, v = qkv.unbind(2)
        q = q.permute(0,3,1,2)
        k = k.permute(0,3,1,2)
        v = v.permute(0,3,1,2)

        attn = (q @ k.transpose(-2,-1)) * self.scale
        attn = attn.softmax(-1)

        out = (attn @ v).transpose(1,2).reshape(B, N, C)
        return self.proj(out), attn
class MLP(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.fc1 = nn.Linear(dim, dim*4)
        self.fc2 = nn.Linear(dim*4, dim)

    def forward(self, x):
        return self.fc2(F.gelu(self.fc1(x)))

class Block(nn.Module):
    def __init__(self, dim, heads):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = Attention(dim, heads)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp  = MLP(dim)

    def forward(self, x, return_attn=False):
        a, attn = self.attn(self.norm1(x))
        x = x + a
        x = x + self.mlp(self.norm2(x))
        if return_attn:
            return x, attn
        return x
class ViT(nn.Module):
    def __init__(self):
        super().__init__()
        self.patch_embed = nn.Conv2d(3, DIM, PATCH, PATCH)
        N = (IMG // PATCH)**2
        self.cls = nn.Parameter(torch.zeros(1,1,DIM))
        self.pos = nn.Parameter(torch.zeros(1,1+N,DIM))
        self.blocks = nn.ModuleList([Block(DIM, HEADS) for _ in range(DEPTH)])
        self.norm = nn.LayerNorm(DIM)
        self.head = nn.Linear(DIM, NUM_CLASSES)

    def forward(self, x):
        B = x.size(0)
        x = self.patch_embed(x).flatten(2).transpose(1,2)
        x = torch.cat([self.cls.expand(B,-1,-1), x], dim=1)
        x = x + self.pos[:, :x.size(1), :]
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        return self.head(x[:,0])

In [4]:
train_tf = T.Compose([
    T.RandomResizedCrop(48, scale=(0.8,1.0)),
    T.RandomHorizontalFlip(),
    T.ToTensor()
])
test_tf = T.Compose([T.Resize(48), T.ToTensor()])

train_set = datasets.CIFAR10(
    root="./data", 
    train=True, 
    download=True, 
    transform=train_tf
)

test_set = datasets.CIFAR10(
    root="./data", 
    train=False, 
    download=True, 
    transform=test_tf
)

train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2)
test_loader  = DataLoader(test_set,  batch_size=128, shuffle=False, num_workers=2)

print(len(train_set), len(test_set))


100%|██████████| 170M/170M [00:01<00:00, 86.4MB/s]


50000 10000


In [5]:
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

# ---------------------------------------
#  BASELINE MODEL + OPTIMIZER
# ---------------------------------------
model = ViT().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.05)

if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)


# ---------------------------------------
#  TRAIN ONE EPOCH  (with tqdm)
# ---------------------------------------
def train_one_epoch_baseline(model, loader):
    model.train()
    total_loss = 0
    total_samples = 0

    pbar = tqdm(loader, desc="[BASELINE] Training", leave=False)
    for x, y in pbar:
        x, y = x.to(device), y.to(device)

        optimizer.zero_grad()
        logits = model(x)
        loss = F.cross_entropy(logits, y)

        loss.backward()
        optimizer.step()

        total_loss += loss.item() * x.size(0)
        total_samples += x.size(0)

        pbar.set_postfix(loss=loss.item())

    return total_loss / total_samples


# ---------------------------------------
#  TEST ACCURACY
# ---------------------------------------
@torch.no_grad()
def eval_acc_baseline(model, loader):
    model.eval()
    correct = 0
    total = 0

    for x, y in tqdm(loader, desc="[BASELINE] Eval", leave=False):
        x, y = x.to(device), y.to(device)
        preds = model(x).argmax(1)
        correct += (preds == y).sum().item()
        total += y.size(0)

    return correct / total


# ---------------------------------------
#  FULL TRAINING LOOP (MATCHES PRUNED FORMAT)
# ---------------------------------------
print("\n================ BASELINE TRAINING ================\n")

overall_start = time.time()

for epoch in range(1, 11):
    epoch_start = time.time()

    train_loss = train_one_epoch_baseline(model, train_loader)
    val_acc = eval_acc_baseline(model, test_loader)

    epoch_time = time.time() - epoch_start
    print(f"Epoch {epoch}/10 summary: time={epoch_time/60:.2f} min  train_loss={train_loss:.4f}  val_acc={val_acc*100:.2f}%\n")

overall_time = time.time() - overall_start
print(f"Total training time: {overall_time/60:.2f} min")

torch.save(model.state_dict(), "vit_baseline.pth")
print("Saved baseline model → vit_baseline.pth")






                                                                

Epoch 1/10 summary: time=5.64 min  train_loss=1.8552  val_acc=46.89%



                                                                

Epoch 2/10 summary: time=5.95 min  train_loss=1.3116  val_acc=57.61%



                                                                

Epoch 3/10 summary: time=5.96 min  train_loss=1.1278  val_acc=62.50%



                                                                

Epoch 4/10 summary: time=5.97 min  train_loss=1.0178  val_acc=63.55%



                                                                

Epoch 5/10 summary: time=5.98 min  train_loss=0.9595  val_acc=66.44%



                                                                

Epoch 6/10 summary: time=5.98 min  train_loss=0.8908  val_acc=68.75%



                                                                

Epoch 7/10 summary: time=5.99 min  train_loss=0.8536  val_acc=67.67%



                                                                

Epoch 8/10 summary: time=5.99 min  train_loss=0.8144  val_acc=68.49%



                                                                

Epoch 9/10 summary: time=5.99 min  train_loss=0.7761  val_acc=71.00%



                                                                

Epoch 10/10 summary: time=5.99 min  train_loss=0.7477  val_acc=71.68%

Total training time: 59.45 min
Saved baseline model → vit_baseline.pth
