In [71]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as T
import time
from collections import Counter

In [72]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device Set is {device}")

Device Set is cuda


In [73]:
DATA_DIR = './data'
MODEL_PATH = './simple_cifar10.pth'
BATCH_SIZE = 32
EPOCH = 10
LR = 1e-3
NUM_CLASSES = 10
MODEL_PATH = '/mnt/tmp/simple_cfar10.pth'
CLASSES = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
MEAN = (0.4194,0.4852,0.4465)
STD = (0.2470,0.2435,0.2616)

In [74]:
raw = torchvision.datasets.CIFAR10(root=DATA_DIR, train=True, download=True, transform=T.ToTensor())
test_raw = torchvision.datasets.CIFAR10(root=DATA_DIR, train=False, download=True, transform=T.ToTensor())

img, lbl = raw[0]
print(f"Train samples:{len(raw)}")
print(f"Test Samples :{len(test_raw)}")
print(f"Image Shape : {img.shape}")

Train samples:50000
Test Samples :10000
Image Shape : torch.Size([3, 32, 32])


In [75]:
counts = Counter(raw.targets)
for k,v in counts.items():
    print(f"{CLASSES[k]} : {v}")

frog : 5000
truck : 5000
deer : 5000
automobile : 5000
bird : 5000
horse : 5000
ship : 5000
cat : 5000
dog : 5000
airplane : 5000


In [76]:
train_tfm = T.Compose([
      T.RandomCrop(32, padding=4, padding_mode="reflect"),
    T.RandomHorizontalFlip(p=0.5),
    T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05),
    T.RandomRotation(10),
    T.ToTensor(),
    T.Normalize(MEAN, STD),
    T.RandomErasing(p=0.15, scale=(0.02, 0.2)),
])

val_tfm = T.Compose([
        T.ToTensor(),
    T.Normalize(MEAN, STD),
])

train_ds = torchvision.datasets.CIFAR10(root=DATA_DIR, train=True, download=True, transform=train_tfm)
val_ds = torchvision.datasets.CIFAR10(root=DATA_DIR, train=False, download=True, transform=val_tfm)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

imgs, lbls = next(iter(train_loader))

print(f"Train Batches : {len(train_loader)}")
print(f"Val Batches : {len(val_loader)}")
print(f"Train Batch Shape : {tuple(imgs.shape)}")


for t in train_tfm.transforms:
  print(f"  {t}")

Train Batches : 1563
Val Batches : 313
Train Batch Shape : (32, 3, 32, 32)
  RandomCrop(size=(32, 32), padding=4)
  RandomHorizontalFlip(p=0.5)
  ColorJitter(brightness=(0.8, 1.2), contrast=(0.8, 1.2), saturation=(0.8, 1.2), hue=(-0.05, 0.05))
  RandomRotation(degrees=[-10.0, 10.0], interpolation=nearest, expand=False, fill=0)
  ToTensor()
  Normalize(mean=(0.4194, 0.4852, 0.4465), std=(0.247, 0.2435, 0.2616))
  RandomErasing(p=0.15, scale=(0.02, 0.2), ratio=(0.3, 3.3), value=0, inplace=False)


In [77]:
# Simple CNN Architecure

In [78]:
class SimpleCNN(nn.Module):
    """
    Custom CNN for CIFAR-10  (32×32 RGB → 10 classes)
    ──────────────────────────────────────────────────
    Input → Block1 → Block2 → Block3 → Classifier

    Each block:
        Conv2d → BatchNorm → ReLU → Conv2d → BatchNorm → ReLU → MaxPool → Dropout

    Classifier:
        Flatten → FC(512) → BN → ReLU → Dropout → FC(10)
    ──────────────────────────────────────────────────
    """
    def __init__(self, num_classes: int = 10):
        super().__init__()

        # ── Block 1:  3 → 32 channels  |  32×32 → 16×16 ──────────
        self.block1 = nn.Sequential(
            nn.Conv2d(3,  32, kernel_size=3, padding=1, bias=False),  # 32×32
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=False),  # 32×32
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),                                        # 16×16
            nn.Dropout2d(p=0.2),
        )

        # ── Block 2: 32 → 64 channels  |  16×16 → 8×8 ────────────
        self.block2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=False),  # 16×16
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=False),  # 16×16
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),                                        # 8×8
            nn.Dropout2d(p=0.3),
        )

        # ── Block 3: 64 → 128 channels  |  8×8 → 4×4 ─────────────
        self.block3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False), # 8×8
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False),# 8×8
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),                                        # 4×4
            nn.Dropout2d(p=0.4),
        )

        # ── Classifier head ───────────────────────────────────────
        # After block3: (B, 128, 4, 4) → flatten → 128*4*4 = 2048
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128 * 4 * 4, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(512, num_classes),
        )

        # ── Weight init (Kaiming He) ──────────────────────────────
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out',
                                        nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias,   0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.block1(x)       # (B,  3, 32, 32) → (B,  32, 16, 16)
        x = self.block2(x)       # (B, 32, 16, 16) → (B,  64,  8,  8)
        x = self.block3(x)       # (B, 64,  8,  8) → (B, 128,  4,  4)
        x = self.classifier(x)   # → (B, 10)
        return x

In [79]:
model = SimpleCNN(num_classes=NUM_CLASSES).to(device)
print(model)

SimpleCNN(
  (block1): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Dropout2d(p=0.2, inplace=False)
  )
  (block2): Sequential(
    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=Tru

In [80]:
total_p = { sum(p.numel() for p in model.parameters()) }
train_p = { sum(p.numel() for p in model.parameters() if p.requires_grad) }
print(f"Total Params : {total_p}")
print(f"Trainable Params : {train_p}")

Total Params : {1342698}
Trainable Params : {1342698}


In [81]:
# Training utilities
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),lr=LR)

In [82]:
def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    loss_sum , correct , total = 0.0 , 0 , 0
    for img,lables in train_loader:
        img,lables = img.to(device) , lables.to(device)
        optimizer.zero_grad()
        output = model(img)
        loss = criterion(output,lables)
        loss.backward()
        optimizer.step()
        loss_sum += loss.item()
        correct += (output.argmax(1) == lables).sum().item()
        total += lables.size(0)
    return loss_sum/total , correct/total

In [83]:
@torch.no_grad()
def evaluate(model, val_loader, criterion, device):
    model.eval()
    loss_sum , correct , total = 0.0 , 0 , 0
    for img,lables in val_loader:
        img,lables = img.to(device) , lables.to(device)
        output = model(img)
        loss = criterion(output,lables)
        loss_sum += loss.item()
        correct += (output.argmax(1) == lables).sum().item()
        total += lables.size(0)
    return loss_sum / len(val_loader) , correct / total


In [84]:
MODEL_PATH = './simple_cifar10.pth'

In [85]:
history = dict(train_loss=[], train_acc=[], val_loss=[], val_acc=[])
best_acc = 0.0
print(f"\n{'Epoch':>6}  {'Tr Loss':>8}  {'Tr Acc':>8}  "
      f"{'Va Loss':>8}  {'Va Acc':>8}  {'LR':>9}  {'Time':>6}")
for epoch in range(1,EPOCH + 1):
  t0 = time.time()
  tr_loss , tr_acc = train_epoch(model, train_loader, criterion, optimizer, device)
  va_loss , va_acc = evaluate(model, val_loader,criterion,device)
  history['train_loss'] = tr_loss
  history['train_acc'] = tr_acc
  history['val_loss'] = va_loss
  history['val_acc'] = va_acc

  if va_acc > best_acc:
    best_acc = va_acc
    torch.save({"Epoch": epoch, "State": model.state_dict(), "Val_acc": va_acc},MODEL_PATH )

  print(f"{epoch:>6}  {tr_loss:>8.4f}  {tr_acc:>7.2f}%  "
          f"{va_loss:>8.4f}  {va_acc:>7.2f}%  "
          f"{time.time()-t0:>5.1f}s")


 Epoch   Tr Loss    Tr Acc   Va Loss    Va Acc         LR    Time
     1    0.0602     0.31%    1.3490     0.50%   56.6s
     2    0.0479     0.44%    1.1134     0.59%   55.4s
     3    0.0425     0.51%    0.9421     0.67%   53.3s
     4    0.0393     0.55%    0.8897     0.69%   50.5s
     5    0.0370     0.58%    0.8216     0.71%   49.7s
     6    0.0354     0.60%    0.7783     0.72%   51.2s
     7    0.0341     0.61%    0.7354     0.75%   51.7s
     8    0.0330     0.63%    0.7015     0.76%   50.7s
     9    0.0322     0.64%    0.6692     0.77%   51.1s
    10    0.0313     0.65%    0.6402     0.79%   50.9s


In [87]:
"""
Simple CNN on CIFAR-10  —  Built from Scratch (No Pretrained Weights)
======================================================================
Architecture  :  Custom 3-block CNN
Dataset       :  CIFAR-10  (60 000 RGB 32x32 images, 10 classes)
Framework     :  PyTorch
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as T
import time
from collections import Counter

# ──────────────────────────────────────────────────────────────────
# CONFIG
# ──────────────────────────────────────────────────────────────────
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

DATA_DIR    = "./data"
BATCH_SIZE  = 64
EPOCHS      = 10
LR          = 1e-3
NUM_CLASSES = 10
CKPT_PATH   = "./simple_cnn_cifar10.pth"

CLASSES = ["plane","car","bird","cat","deer",
           "dog","frog","horse","ship","truck"]

MEAN = (0.4914, 0.4822, 0.4465)
STD  = (0.2470, 0.2435, 0.2616)

sep = lambda t: print(f"\n{'='*62}\n  {t}\n{'='*62}")


# ══════════════════════════════════════════════════════════════════
# MODEL DEFINITION  (must be importable at module level on Windows)
# ══════════════════════════════════════════════════════════════════
class SimpleCNN(nn.Module):
    """
    Custom CNN for CIFAR-10  (32x32 RGB -> 10 classes)
    Input -> Block1 -> Block2 -> Block3 -> Classifier
    Each block: Conv->BN->ReLU->Conv->BN->ReLU->MaxPool->Dropout
    Classifier: Flatten->FC(512)->BN->ReLU->Dropout->FC(10)
    """
    def __init__(self, num_classes: int = 10):
        super().__init__()

        self.block1 = nn.Sequential(
            nn.Conv2d(3,  32, 3, padding=1, bias=False),
            nn.BatchNorm2d(32), nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, 3, padding=1, bias=False),
            nn.BatchNorm2d(32), nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2), nn.Dropout2d(p=0.2),
        )
        self.block2 = nn.Sequential(
            nn.Conv2d(32, 64, 3, padding=1, bias=False),
            nn.BatchNorm2d(64), nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 3, padding=1, bias=False),
            nn.BatchNorm2d(64), nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2), nn.Dropout2d(p=0.3),
        )
        self.block3 = nn.Sequential(
            nn.Conv2d(64,  128, 3, padding=1, bias=False),
            nn.BatchNorm2d(128), nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, 3, padding=1, bias=False),
            nn.BatchNorm2d(128), nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2), nn.Dropout2d(p=0.4),
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128 * 4 * 4, 512),
            nn.BatchNorm1d(512), nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(512, num_classes),
        )
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
                nn.init.constant_(m.weight, 1); nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight); nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        return self.classifier(x)


# ══════════════════════════════════════════════════════════════════
# TRAINING UTILITIES  (defined at module level — importable)
# ══════════════════════════════════════════════════════════════════
def train_epoch(model, loader, criterion, optimizer, scaler, scheduler=None):
    model.train()
    loss_sum, correct, total = 0.0, 0, 0
    for imgs, labels in loader:
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        with torch.amp.autocast(device_type=device.type,
                                 enabled=(device.type == "cuda")):
            out  = model(imgs)
            loss = criterion(out, labels)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        if scheduler is not None:
            scheduler.step()   # OneCycleLR steps once per BATCH
        loss_sum += loss.item() * imgs.size(0)
        correct  += (out.argmax(1) == labels).sum().item()
        total    += imgs.size(0)
    return loss_sum / total, 100.0 * correct / total


@torch.no_grad()
def evaluate(model, loader, criterion):
    model.eval()
    loss_sum, correct, total = 0.0, 0, 0
    all_preds, all_labels = [], []
    for imgs, labels in loader:
        imgs, labels = imgs.to(device), labels.to(device)
        out   = model(imgs)
        loss  = criterion(out, labels)
        preds = out.argmax(1)
        loss_sum += loss.item() * imgs.size(0)
        correct  += (preds == labels).sum().item()
        total    += imgs.size(0)
        all_preds.append(preds.cpu())
        all_labels.append(labels.cpu())
    return (loss_sum / total,
            100.0 * correct / total,
            torch.cat(all_preds),
            torch.cat(all_labels))


# ══════════════════════════════════════════════════════════════════
# MAIN — required on Windows for multiprocessing in DataLoader
# ══════════════════════════════════════════════════════════════════
if __name__ == "__main__":

    # ── 1. Dataset download & exploration ────────────────────────
    sep("1. DATASET — CIFAR-10")

    raw      = torchvision.datasets.CIFAR10(root=DATA_DIR, train=True,
                                             download=True, transform=T.ToTensor())
    test_raw = torchvision.datasets.CIFAR10(root=DATA_DIR, train=False,
                                             download=True, transform=T.ToTensor())
    img0, lbl0 = raw[0]
    print(f"Train samples   :  {len(raw):,}")
    print(f"Test  samples   :  {len(test_raw):,}")
    print(f"Image shape     :  {tuple(img0.shape)}   (C x H x W)")
    print(f"Classes ({NUM_CLASSES})     :  {CLASSES}")
    print(f"Device          :  {device}")

    print("\nClass distribution:")
    counts = Counter(raw.targets)
    for i, cls in enumerate(CLASSES):
        bar = "=" * (counts[i] // 200)
        print(f"  {i}  {cls:<7}  {counts[i]:,}  {bar}")

    # ── 2. DataLoaders ───────────────────────────────────────────
    sep("2. DATA AUGMENTATION & DATALOADERS")

    train_tfm = T.Compose([
        T.RandomCrop(32, padding=4, padding_mode="reflect"),
        T.RandomHorizontalFlip(p=0.5),
        T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05),
        T.RandomRotation(10),
        T.ToTensor(),
        T.Normalize(MEAN, STD),
        T.RandomErasing(p=0.15, scale=(0.02, 0.2)),
    ])
    val_tfm = T.Compose([T.ToTensor(), T.Normalize(MEAN, STD)])

    train_ds = torchvision.datasets.CIFAR10(root=DATA_DIR, train=True,
                                             download=False, transform=train_tfm)
    val_ds   = torchvision.datasets.CIFAR10(root=DATA_DIR, train=False,
                                             download=False, transform=val_tfm)

    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE,
                              shuffle=True,  num_workers=2, pin_memory=True)
    val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE,
                              shuffle=False, num_workers=2, pin_memory=True)

    imgs, lbls = next(iter(train_loader))
    print(f"Train batches  :  {len(train_loader)}")
    print(f"Val   batches  :  {len(val_loader)}")
    print(f"Batch shape    :  {tuple(imgs.shape)}")

    # ── 3. Model ─────────────────────────────────────────────────
    sep("3. SIMPLE CNN ARCHITECTURE")

    model       = SimpleCNN(num_classes=NUM_CLASSES).to(device)
    total_p     = sum(p.numel() for p in model.parameters())
    trainable_p = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(model)
    print(f"\nTotal parameters     :  {total_p:,}")
    print(f"Trainable parameters :  {trainable_p:,}")
    print(f"Model size (FP32)    :  {total_p * 4 / 1e6:.2f} MB")

    # ── 4. Training setup ────────────────────────────────────────
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.OneCycleLR(
        optimizer, max_lr=LR * 10,
        epochs=EPOCHS, steps_per_epoch=len(train_loader),
        pct_start=0.2, anneal_strategy="cos",
    )
    scaler = torch.amp.GradScaler(device.type, enabled=(device.type == "cuda"))

    # ── 5. Training loop ─────────────────────────────────────────
    sep("4. TRAINING LOOP")

    history  = dict(tr_loss=[], tr_acc=[], va_loss=[], va_acc=[])
    best_acc = 0.0

    print(f"\n{'Epoch':>6}  {'Tr Loss':>8}  {'Tr Acc':>8}  "
          f"{'Va Loss':>8}  {'Va Acc':>8}  {'LR':>9}  {'Time':>6}")
    print("-" * 66)

    for epoch in range(1, EPOCHS + 1):
        t0 = time.time()
        tr_loss, tr_acc       = train_epoch(model, train_loader,
                                            criterion, optimizer, scaler, scheduler)
        va_loss, va_acc, _, _ = evaluate(model, val_loader, criterion)

        history["tr_loss"].append(tr_loss)
        history["tr_acc"].append(tr_acc)
        history["va_loss"].append(va_loss)
        history["va_acc"].append(va_acc)

        flag = " *" if va_acc > best_acc else ""
        if va_acc > best_acc:
            best_acc = va_acc
            torch.save({"epoch": epoch, "state": model.state_dict(),
                        "val_acc": va_acc}, CKPT_PATH)

        lr_now = optimizer.param_groups[0]["lr"]
        print(f"{epoch:>6}  {tr_loss:>8.4f}  {tr_acc:>7.2f}%  "
              f"{va_loss:>8.4f}  {va_acc:>7.2f}%  "
              f"{lr_now:>9.6f}  {time.time()-t0:>5.1f}s{flag}")

    print(f"\nBest validation accuracy :  {best_acc:.2f}%")

    # ── 6. Per-class accuracy ─────────────────────────────────────
    sep("5. PER-CLASS ACCURACY")

    _, _, all_preds, all_labels = evaluate(model, val_loader, criterion)
    cls_correct = torch.zeros(NUM_CLASSES)
    cls_total   = torch.zeros(NUM_CLASSES)
    for p, l in zip(all_preds, all_labels):
        cls_correct[l] += (p == l).item()
        cls_total[l]   += 1

    print(f"\n{'Class':<8}  {'Correct':>8}  {'Total':>6}  {'Acc':>8}   Bar")
    print("-" * 55)
    for i, cls in enumerate(CLASSES):
        acc = 100.0 * cls_correct[i] / cls_total[i]
        bar = "=" * int(acc / 5)
        print(f"{cls:<8}  {int(cls_correct[i]):>8}  "
              f"{int(cls_total[i]):>6}  {acc:>7.2f}%   {bar}")

    # ── 7. Confusion matrix ───────────────────────────────────────
    sep("6. CONFUSION MATRIX  (rows=true, cols=predicted)")

    conf = torch.zeros(NUM_CLASSES, NUM_CLASSES, dtype=torch.long)
    for p, l in zip(all_preds, all_labels):
        conf[l][p] += 1

    print(f"{'':>7}" + "".join(f"{c[:5]:>6}" for c in CLASSES))
    for i, cls in enumerate(CLASSES):
        row = f"{cls[:6]:>7}"
        for j in range(NUM_CLASSES):
            v = conf[i][j].item()
            row += f"[{v:>3}]" if i == j else f" {v:>4} "
        print(row)

    # ── 8. Single-image inference ─────────────────────────────────
    sep("7. SINGLE IMAGE INFERENCE  (first 10 test images)")

    model.eval()
    print(f"\n{'#':>3}  {'True':>8}  {'Predicted':>10}  {'Confidence':>11}  Result")
    print("-" * 48)
    with torch.no_grad():
        for i in range(10):
            img_t, true_lbl = val_ds[i]
            logits = model(img_t.unsqueeze(0).to(device))
            probs  = F.softmax(logits, dim=1).squeeze()
            pred   = probs.argmax().item()
            conf_p = probs[pred].item() * 100
            symbol = "OK" if pred == true_lbl else "X"
            print(f"{i+1:>3}  {CLASSES[true_lbl]:>8}  "
                  f"{CLASSES[pred]:>10}  {conf_p:>10.2f}%   {symbol}")

    # ── 9. Training history ───────────────────────────────────────
    sep("8. TRAINING HISTORY")

    print(f"\n{'Ep':>4}  {'Train Acc':>10}  {'Val Acc':>10}  Progress")
    print("-" * 55)
    for ep, (tr, va) in enumerate(zip(history["tr_acc"], history["va_acc"]), 1):
        bar = "=" * int(va / 5)
        print(f"{ep:>4}  {tr:>9.2f}%  {va:>9.2f}%  {bar}")

    print(f"\nFinal  train acc  :  {history['tr_acc'][-1]:.2f}%")
    print(f"Final  val   acc  :  {history['va_acc'][-1]:.2f}%")
    print(f"Best   val   acc  :  {best_acc:.2f}%")
    print(f"Overfit gap       :  "
          f"{history['tr_acc'][-1] - history['va_acc'][-1]:.2f}%")

    # ── 10. Save & reload ─────────────────────────────────────────
    sep("9. SAVE & RELOAD BEST CHECKPOINT")

    ckpt   = torch.load(CKPT_PATH, map_location=device)
    model2 = SimpleCNN(num_classes=NUM_CLASSES).to(device)
    model2.load_state_dict(ckpt["state"])
    model2.eval()

    _, reload_acc, _, _ = evaluate(model2, val_loader, criterion)
    print(f"Saved at epoch    :  {ckpt['epoch']}")
    print(f"Saved val acc     :  {ckpt['val_acc']:.2f}%")
    print(f"Reloaded val acc  :  {reload_acc:.2f}%  OK")

    # ── 11. Architecture summary ──────────────────────────────────
    sep("ARCHITECTURE SUMMARY")

    summary = [
        ("Input",        "—",                                "3x32x32",   "—"),
        ("Block 1",      "Conv(3->32) BN ReLU x2 MaxPool",  "32x16x16",  "~18K"),
        ("Block 2",      "Conv(32->64) BN ReLU x2 MaxPool", "64x8x8",    "~74K"),
        ("Block 3",      "Conv(64->128) BN ReLU x2 MaxPool","128x4x4",   "~295K"),
        ("Flatten",      "—",                                "2048",      "—"),
        ("FC(2048->512)","Linear BN ReLU Dropout(0.5)",      "512",       "~1.0M"),
        ("FC(512->10)",  "Linear",                           "10",        "~5K"),
    ]
    print(f"\n{'Stage':<16} {'Layers':<38} {'Output':<12} {'Params'}")
    print("-" * 78)
    for stage, layers, out, params in summary:
        print(f"{stage:<16} {layers:<38} {out:<12} {params}")

    print(f"\nTotal trainable parameters :  {trainable_p:,}  "
          f"(~{trainable_p/1e6:.2f} M)")
    print("\n" + "=" * 62)
    print("  Simple CNN + CIFAR-10 complete")
    print("=" * 62)



  1. DATASET — CIFAR-10
Train samples   :  50,000
Test  samples   :  10,000
Image shape     :  (3, 32, 32)   (C x H x W)
Classes (10)     :  ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
Device          :  cuda

Class distribution:

  2. DATA AUGMENTATION & DATALOADERS
Train batches  :  782
Val   batches  :  157
Batch shape    :  (64, 3, 32, 32)

  3. SIMPLE CNN ARCHITECTURE
SimpleCNN(
  (block1): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Dropout2d(p=0.2, inplace=False)
  )
  (block2): Seq



     1       nan    28.64%       nan    10.00%   0.005205   55.5s *
     2       nan    41.80%       nan    10.00%   0.010000   55.2s
     3       nan    48.62%       nan    10.00%   0.009618   53.4s
     4       nan    52.25%       nan    10.00%   0.008534   53.9s
     5       nan    54.15%       nan    10.00%   0.006911   54.0s
     6       nan    56.70%       nan    10.00%   0.004998   52.6s
     7       nan    59.37%       nan    10.00%   0.003084   52.9s
     8       nan    62.94%       nan    10.00%   0.001463   53.0s
     9       nan    66.35%       nan    10.00%   0.000380   53.9s
    10       nan    68.51%       nan    10.00%   0.000000   54.0s

Best validation accuracy :  10.00%

  5. PER-CLASS ACCURACY

Class      Correct   Total       Acc   Bar
-------------------------------------------------------
car              0    1000     0.00%   
bird             0    1000     0.00%   
cat              0    1000     0.00%   
deer             0    1000     0.00%   
dog              