In [None]:
import os, math, time, csv
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from PIL import Image
plt.switch_backend("agg")

TRAIN_ROOT = "cleaned/Training"
TEST_ROOT  = "cleaned/Testing"
OUTDIR     = "runs/vit64"
SAVE_NAME  = "vit64_best.pt"

EPOCHS       = 30
BATCH_SIZE   = 64
LR           = 3e-4
WEIGHT_DECAY = 1e-4
WORKERS      = 4
PATIENCE     = 8
SEED         = 0

In [None]:
class PatchEmbed(nn.Module):
    def __init__(self, in_ch=3, dim=128, patch=8):
        super().__init__()
        self.proj = nn.Conv2d(in_ch, dim, kernel_size=patch, stride=patch, bias=False)
    def forward(self, x):
        x = self.proj(x)
        x = x.flatten(2).transpose(1, 2)
        return x

class ViT64(nn.Module):
    def __init__(self, num_classes=4, img_size=64, patch=8, dim=128, depth=6, heads=4, mlp_ratio=4, drop=0.1, in_ch=3):
        super().__init__()
        assert img_size % patch == 0
        self.num_patches = (img_size // patch) ** 2
        self.dim = dim

        self.patch_embed = PatchEmbed(in_ch, dim, patch)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, dim))
        self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches + 1, dim) * 0.02)
        self.pos_drop = nn.Dropout(drop)

        enc_layer = nn.TransformerEncoderLayer(
            d_model=dim, nhead=heads, dim_feedforward=dim * mlp_ratio,
            dropout=drop, batch_first=True, activation="gelu", norm_first=True
        )
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=depth)
        self.norm = nn.LayerNorm(dim, eps=1e-6)
        self.head = nn.Linear(dim, num_classes)
        self.register_buffer("eps", torch.tensor(1e-6), persistent=False)

    def zscore(self, x):
        mean = x.mean(dim=[2,3], keepdim=True)
        std  = x.std(dim=[2,3], keepdim=True) + self.eps
        return (x - mean) / std

    def forward(self, x):
        x = self.zscore(x)
        x = self.patch_embed(x)
        B, N, _ = x.shape
        cls = self.cls_token.expand(B, 1, self.dim)
        x = torch.cat([cls, x], dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)
        x = self.encoder(x)
        x = self.norm(x)
        return self.head(x[:, 0])

def set_seed(seed=42):
    import random
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def make_transforms():
    train_tf = transforms.Compose([
        transforms.RandomHorizontalFlip(0.5),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
    ])
    val_tf = transforms.Compose([
        transforms.ToTensor(),
    ])
    return train_tf, val_tf

def stratified_split_indices(targets, train_frac=0.8, seed=42):
    targets = np.array(targets)
    classes = np.unique(targets)
    rng = np.random.default_rng(seed)
    train_idx, val_idx = [], []
    for c in classes:
        idx = np.where(targets == c)[0]
        rng.shuffle(idx)
        n_train = int(round(len(idx) * train_frac))
        train_idx.extend(idx[:n_train])
        val_idx.extend(idx[n_train:])
    rng.shuffle(train_idx); rng.shuffle(val_idx)
    return train_idx, val_idx

def class_weights_from_indices(targets, idx):
    t = torch.tensor(np.array(targets)[idx])
    K = int(t.max().item() + 1)
    counts = torch.bincount(t, minlength=K).float()
    w = (counts.sum() / (counts + 1e-6))
    return (w / w.mean())

@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    ce = nn.CrossEntropyLoss(reduction="sum")
    loss_sum, correct, total = 0.0, 0, 0
    all_y, all_p = [], []
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss_sum += ce(logits, y).item()
        pred = logits.argmax(1)
        correct += (pred == y).sum().item()
        total += y.numel()
        all_y.append(y.cpu()); all_p.append(pred.cpu())
    y_true = torch.cat(all_y).numpy()
    y_pred = torch.cat(all_p).numpy()
    return loss_sum / total, correct / total, y_true, y_pred

def plot_curves(history, outdir):
    outdir = Path(outdir); outdir.mkdir(parents=True, exist_ok=True)
    csv_path = outdir / "history.csv"
    with open(csv_path, "w", newline="") as f:
        wr = csv.writer(f); wr.writerow(["epoch","train_loss","train_acc","val_loss","val_acc"])
        for i in range(len(history["epoch"])):
            wr.writerow([history[k][i] for k in ["epoch","train_loss","train_acc","val_loss","val_acc"]])

    e = history["epoch"]
    plt.figure(figsize=(8,4))
    plt.plot(e, history["train_loss"], label="train")
    plt.plot(e, history["val_loss"],   label="val")
    plt.xlabel("epoch"); plt.ylabel("loss"); plt.legend(); plt.tight_layout()
    plt.savefig(outdir/"loss.png"); plt.close()

    plt.figure(figsize=(8,4))
    plt.plot(e, history["train_acc"], label="train")
    plt.plot(e, history["val_acc"],   label="val")
    plt.xlabel("epoch"); plt.ylabel("accuracy"); plt.legend(); plt.tight_layout()
    plt.savefig(outdir/"acc.png"); plt.close()

def plot_confusion(y_true, y_pred, class_names, outpath):
    from sklearn.metrics import confusion_matrix
    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(class_names))))
    cm = cm.astype(np.int32)
    plt.figure(figsize=(5,5))
    plt.imshow(cm, interpolation="nearest")
    plt.title("Confusion Matrix")
    plt.xticks(range(len(class_names)), class_names, rotation=45, ha="right")
    plt.yticks(range(len(class_names)), class_names)
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(j, i, str(cm[i, j]), ha="center", va="center")
    plt.xlabel("Predicted"); plt.ylabel("True"); plt.tight_layout()
    plt.savefig(outpath); plt.close()

In [None]:
set_seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

train_tf, val_tf = make_transforms()

full_train = datasets.ImageFolder(TRAIN_ROOT, transform=train_tf)
test_ds    = datasets.ImageFolder(TEST_ROOT,  transform=val_tf)
classes = full_train.classes
print("Classes:", classes)

train_idx, val_idx = stratified_split_indices(full_train.targets, train_frac=0.8, seed=SEED)
train_ds = Subset(full_train, train_idx)
val_base = datasets.ImageFolder(TRAIN_ROOT, transform=val_tf)
val_ds = Subset(val_base, val_idx)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=WORKERS, pin_memory=True)
val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=WORKERS, pin_memory=True)
test_loader  = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=WORKERS, pin_memory=True)

# Model, loss, opt
model = ViT64(num_classes=len(classes), in_ch=3).to(device)
criterion = nn.CrossEntropyLoss()
opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
sched = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, factor=0.5, patience=3)

outdir = Path(OUTDIR); outdir.mkdir(parents=True, exist_ok=True)
ckpt_path = outdir / SAVE_NAME

best_val = math.inf
patience_ctr = 0
history = {"epoch":[], "train_loss":[], "train_acc":[], "val_loss":[], "val_acc":[]}

scaler = torch.amp.GradScaler("cuda", enabled=(device.type == "cuda"))

for epoch in range(1, EPOCHS + 1):
    model.train()
    t0 = time.time()
    run_loss, run_correct, run_total = 0.0, 0, 0

    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        opt.zero_grad(set_to_none=True)
        with torch.amp.autocast("cuda", enabled=(device.type == "cuda")):
            logits = model(x)
            loss = criterion(logits, y)
        scaler.scale(loss).backward()
        scaler.step(opt)
        scaler.update()

        run_loss += loss.item() * y.size(0)
        run_correct += (logits.argmax(1) == y).sum().item()
        run_total += y.size(0)

    train_loss = run_loss / run_total
    train_acc  = run_correct / run_total

    val_loss, val_acc, _, _ = evaluate(model, val_loader, device)
    sched.step(val_loss)

    history["epoch"].append(epoch)
    history["train_loss"].append(train_loss)
    history["train_acc"].append(train_acc)
    history["val_loss"].append(val_loss)
    history["val_acc"].append(val_acc)

    print(f"Epoch {epoch:03d} | {time.time()-t0:5.1f}s | "
          f"train {train_loss:.4f}/{train_acc:.4f} | "
          f"val {val_loss:.4f}/{val_acc:.4f}")

    if val_loss + 1e-6 < best_val:
        best_val = val_loss
        patience_ctr = 0
        torch.save({"model": model.state_dict(), "classes": classes}, ckpt_path)
        print(f"  → saved best to {ckpt_path}")
    else:
        patience_ctr += 1
        if patience_ctr >= PATIENCE:
            print(f"Early stopping (best val loss {best_val:.4f})")
            break

plot_curves(history, outdir)

ckpt = torch.load(ckpt_path, map_location=device)
model.load_state_dict(ckpt["model"])
test_loss, test_acc, y_true, y_pred = evaluate(model, test_loader, device)
print(f"TEST   loss {test_loss:.4f}  acc {test_acc:.4f}")

plot_confusion(y_true, y_pred, classes, outdir/"confusion_test.png")
print(f"Saved confusion matrix to {outdir/'confusion_test.png'}")
print(f"History CSV / curves saved to {outdir}")


Device: cuda
Classes: ['glioma', 'meningioma', 'notumor', 'pituitary']
Epoch 001 |  75.9s | train 1.1898/0.4348 | val 1.0123/0.5464
  → saved best to runs\vit64\vit64_best.pt
Epoch 002 |  60.4s | train 0.9262/0.5875 | val 0.9178/0.6532
  → saved best to runs\vit64\vit64_best.pt
Epoch 003 |  64.1s | train 0.7803/0.6757 | val 0.8366/0.6454
  → saved best to runs\vit64\vit64_best.pt
Epoch 004 |  62.0s | train 0.7328/0.6958 | val 0.6680/0.7469
  → saved best to runs\vit64\vit64_best.pt
Epoch 005 |  58.3s | train 0.6322/0.7455 | val 0.5591/0.7863
  → saved best to runs\vit64\vit64_best.pt
Epoch 006 |  58.1s | train 0.6383/0.7530 | val 0.6363/0.7452
Epoch 007 |  58.4s | train 0.5925/0.7659 | val 0.6041/0.7338
Epoch 008 |  65.0s | train 0.5537/0.7875 | val 0.4403/0.8389
  → saved best to runs\vit64\vit64_best.pt
Epoch 009 |  58.1s | train 0.5421/0.7790 | val 0.4144/0.8415
  → saved best to runs\vit64\vit64_best.pt
Epoch 010 |  58.0s | train 0.5429/0.7862 | val 0.4904/0.8074
Epoch 011 |  58.1s

In [11]:
from collections import Counter
import numpy as np

def count_split(dataset, class_names):
    targets = np.array(dataset.dataset.targets)[dataset.indices] if isinstance(dataset, Subset) else np.array(dataset.targets)
    counts = Counter(targets)
    return {class_names[i]: counts.get(i, 0) for i in range(len(class_names))}

print("Image counts per class:")
print("  Train:", count_split(train_ds, classes))
print("  Val:  ", count_split(val_ds, classes))
print("  Test: ", count_split(test_ds, classes))


Image counts per class:
  Train: {'glioma': 1057, 'meningioma': 1071, 'notumor': 1276, 'pituitary': 1166}
  Val:   {'glioma': 264, 'meningioma': 268, 'notumor': 319, 'pituitary': 291}
  Test:  {'glioma': 300, 'meningioma': 306, 'notumor': 405, 'pituitary': 300}


In [None]:
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")


Total parameters: 1,223,428
Trainable parameters: 1,223,428
