# Chest X‑ray Classification with a Vision Transformer (ViT)

**Task:** Binary classification (Normal vs Pneumonia) on the public Kaggle dataset:  
<https://www.kaggle.com/datasets/paultimothymooney/chest-xray-pneumonia>

**Key requirements satisfied:**  
- Leave the **test** set **completely untouched** for training/tuning.  
- You may adjust the train/validation split; this notebook provides two options:
  1) Use the dataset-provided `val/` split as-is, or  
  2) Rebuild validation from `train/` (optionally also include the small original `val/`) **without** touching `test/`.

**Model:** A from-scratch **Vision Transformer (ViT)** implemented in PyTorch, adhering to the paper’s core design:  
- Split the image into non-overlapping patches and linearly project each patch to a latent dimension (Eq. 1).  
- Prepend a **learnable [CLS] token** whose final hidden state is used for classification (Eq. 4).  
- Add **learnable 1‑D positional embeddings** to tokens.  
- Use a **Transformer encoder** stack with pre-LayerNorm, **Multi-Head Self-Attention** (MSA) and an **MLP (GELU)** block with residual connections (Eqs. 2–3).  
- Linear classification head on top of the final [CLS] representation.

> The design follows **“An Image is Worth 16×16 Words: Transformers for Image Recognition at Scale” (ViT)**, ICLR 2021.  
> Dosovitskiy et al., 2021. (See discussion in comments inside code.)

---

### What you’ll get
- Data loader for the Kaggle folder layout (`train/`, `val/`, `test/`) and an option to rebuild validation from `train/` without touching `test/`.
- From-scratch **ViT** (patchify via `Conv2d` with stride=`patch_size`, class token, learned positional embeddings, pre-LN encoder blocks).
- **Training loop** with:
  - **AdamW** optimizer + cosine LR schedule with warmup
  - **AMP** (mixed precision) if CUDA available
  - **Weighted sampler** or **pos_weight** to mitigate class imbalance
  - **Early stopping** on validation AUROC (default) or F1
- **Final evaluation on the untouched test set** (Accuracy, Precision, Recall, F1, AUROC, confusion matrix, ROC curve).
- Optional: if `timm`/`torchvision` pretrained ViT is available locally, you can switch to fine-tuning with a single flag.

> **Note:** Chest X‑rays are grayscale; we replicate to 3 channels for ViT. You can also adapt the patch embedding to 1 channel (a switch is provided).

---

### How to use
1. Run the **Setup & Data** section. If the dataset folder isn’t present, the notebook can try to download via Kaggle CLI (you must provide a valid `kaggle.json`).
2. Configure the **Training Config** cell (image size, patch size, depth, heads, etc.).
3. Run **Train**.  
4. The notebook automatically **saves the best model** (by validation metric) and then runs a **one-time test evaluation**.

---

*Attribution / Reference:* The ViT architecture elements (patch tokenization, [CLS] token, 1‑D positional embeddings, pre-LN MSA+MLP blocks, and fine‑tuning guidance) are based on the ICLR 2021 paper by Dosovitskiy et al., *An Image is Worth 16×16 Words*. See Eqs. (1)–(4), Sec. 3.1–3.2 for the core design and fine‑tuning notes.


In [None]:

# ==== Setup: Imports & Environment ====
import os, sys, math, time, random, shutil, subprocess
from pathlib import Path
from dataclasses import dataclass, asdict
from typing import Tuple, Optional, List, Dict

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split, ConcatDataset, WeightedRandomSampler
from torchvision import transforms, datasets
from torchvision.utils import make_grid

try:
    import timm  # optional: for pretrained ViT fine-tuning if available locally
    _HAVE_TIMM = True
except Exception:
    _HAVE_TIMM = False

try:
    import sklearn
    from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, roc_curve, precision_recall_fscore_support
    _HAVE_SKLEARN = True
except Exception as e:
    _HAVE_SKLEARN = False
    print("scikit-learn not found; installing may improve metric reports.")

import matplotlib.pyplot as plt

SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
torch.backends.cudnn.benchmark = True

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)
print("Torch:", torch.__version__)
print("CUDA:", torch.cuda.is_available())


In [None]:

# ==== Paths & Kaggle Download Helper ====
from pathlib import Path
import os, shutil, subprocess

# Working directory note: If you are in notebooks/, this recreates the path the prompt expects.
DATASET_DIR = Path.cwd().parent / "Data" / "chest_xray"
KAGGLE_ZIP = DATASET_DIR.parent / "chest-xray-pneumonia.zip"

def has_dataset_structure(root: Path) -> bool:
    return (root / "train").exists() and (root / "val").exists() and (root / "test").exists()

def try_kaggle_download():
    kaggle_json_candidates = [Path("./kaggle.json"), Path("/content/kaggle.json")]
    for cand in kaggle_json_candidates:
        if cand.exists():
            os.makedirs(Path.home() / ".kaggle", exist_ok=True)
            shutil.copy(str(cand), str(Path.home() / ".kaggle/kaggle.json"))
            os.chmod(Path.home() / ".kaggle/kaggle.json", 0o600)
            break

    try:
        print("Attempting Kaggle download...")
        DATASET_DIR.parent.mkdir(parents=True, exist_ok=True)
        subprocess.run(
            ["kaggle", "datasets", "download", "-d", "paultimothymooney/chest-xray-pneumonia", "-p", str(DATASET_DIR.parent)],
            check=True
        )
        if KAGGLE_ZIP.exists():
            print("Extracting zip...")
            shutil.unpack_archive(str(KAGGLE_ZIP), str(DATASET_DIR.parent))
        else:
            print("Download finished but zip not found at", KAGGLE_ZIP)
    except Exception as e:
        print("Kaggle download failed:", e)

if not has_dataset_structure(DATASET_DIR):
    print("Dataset folder not found at", DATASET_DIR.resolve())
    if shutil.which("kaggle") is None:
        print("Kaggle CLI not found. Install with `pip install kaggle` and rerun, or upload chest_xray/ manually.")
    else:
        try_kaggle_download()

print("Dataset present?", has_dataset_structure(DATASET_DIR), "->", DATASET_DIR.resolve())
if not has_dataset_structure(DATASET_DIR):
    print("❗ Please upload the `chest_xray/` folder (with train/val/test).")


In [None]:

# ==== Experiment Configuration ====

@dataclass
class Config:
    # Data
    dataset_dir: Path = DATASET_DIR
    use_original_val_split: bool = False  # If False, rebuild val from train (test remains untouched)
    include_small_original_val_into_train: bool = True  # If True and rebuilding val, include the small 'val/' set into the train pool before splitting.
    val_fraction_from_train: float = 0.1  # fraction for validation when rebuilding from train
    num_workers: int = 4
    batch_size: int = 32
    # Image & ViT
    image_size: int = 224  # training resolution; keep fixed to avoid pos-embed interpolation complexity
    in_channels: int = 3   # ViT expects 3; we'll replicate grayscale to 3. If you set to 1, patch embed will adapt.
    patch_size: int = 16   # typical ViT configuration (e.g., /16)
    embed_dim: int = 256   # smaller than ViT-B to keep training light
    depth: int = 6         # number of Transformer blocks
    num_heads: int = 8
    mlp_ratio: float = 4.0
    attn_dropout: float = 0.0
    drop_rate: float = 0.1
    # Optimization
    epochs: int = 15
    warmup_epochs: int = 2
    lr: float = 3e-4
    weight_decay: float = 0.05
    label_smoothing: float = 0.0
    # Class imbalance handling
    use_weighted_sampler: bool = True
    # Early stopping
    early_stop_metric: str = "auroc"  # one of: "auroc", "f1"
    early_stop_patience: int = 5
    # Optional: use local pretrained ViT if available (timm or torchvision). If not present, falls back to scratch ViT.
    try_pretrained_vit: bool = False
    # Repro
    seed: int = 42

cfg = Config()
print(cfg)


In [None]:

# ==== Datasets & DataLoaders ====
from torchvision import transforms, datasets

def build_transforms(image_size: int, replicate_grayscale_to_3: bool = True):
    common = []
    if replicate_grayscale_to_3:
        common.append(transforms.Grayscale(num_output_channels=3))
    else:
        common.append(transforms.Grayscale(num_output_channels=1))

    train_tfms = transforms.Compose([
        *common,
        transforms.Resize((image_size, image_size)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=10),
        transforms.RandomResizedCrop(image_size, scale=(0.8, 1.0), ratio=(1.0, 1.0)),
        transforms.ToTensor(),
        # Normalize using ImageNet stats is fine even from scratch; adjust if desired.
        transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) if replicate_grayscale_to_3
            else transforms.Normalize(mean=(0.5,), std=(0.25,)),
    ])
    eval_tfms = transforms.Compose([
        *common,
        transforms.Resize((image_size, image_size)),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) if replicate_grayscale_to_3
            else transforms.Normalize(mean=(0.5,), std=(0.25,)),
    ])
    return train_tfms, eval_tfms

train_tfms, eval_tfms = build_transforms(cfg.image_size, replicate_grayscale_to_3=(cfg.in_channels==3))

def build_datasets_and_loaders(cfg: Config):
    root = cfg.dataset_dir
    if not (root / "train").exists():
        raise FileNotFoundError(f"Dataset directory not found: {root}")

    # ImageFolder expects a class-per-subfolder structure
    ds_train = datasets.ImageFolder(root / "train", transform=train_tfms)
    ds_val_orig = datasets.ImageFolder(root / "val", transform=eval_tfms)
    ds_test = datasets.ImageFolder(root / "test", transform=eval_tfms)

    class_to_idx = ds_train.class_to_idx
    idx_to_class = {v:k for k,v in class_to_idx.items()}
    print("Classes:", idx_to_class)

    if cfg.use_original_val_split:
        train_set = ds_train
        val_set = ds_val_orig
        full_pool = None
    else:
        # Rebuild validation from training (optionally include original val into the pool)
        pool_datasets = [ds_train]
        if cfg.include_small_original_val_into_train:
            pool_datasets.append(datasets.ImageFolder(root / "val", transform=train_tfms))
        full_pool = ConcatDataset(pool_datasets)

        n_total = len(full_pool)
        n_val = int(cfg.val_fraction_from_train * n_total)
        n_train = n_total - n_val
        train_set, val_set = random_split(full_pool, [n_train, n_val],
                                          generator=torch.Generator().manual_seed(cfg.seed))

        print(f"Rebuilt val from train pool: n_train={n_train}, n_val={n_val} (total={n_total})")

    # Weighted sampler for imbalance (optional)
    sampler = None
    if cfg.use_weighted_sampler:
        # Compute class distribution on the training subset
        if isinstance(train_set, torch.utils.data.Subset):
            # Need to recover targets for the subset
            targets = []
            assert full_pool is not None, "Internal error: full_pool should not be None when using rebuilt split."
            for i in train_set.indices:
                # Map global index to sub-dataset
                running = 0
                for ds in full_pool.datasets:
                    if i < running + len(ds):
                        targets.append(ds.samples[i - running][1])
                        break
                    running += len(ds)
        else:
            targets = [s[1] for s in train_set.samples] if hasattr(train_set, "samples") else [y for _, y in train_set]

        class_sample_count = np.bincount(np.array(targets), minlength=len(class_to_idx))
        class_weights = 1.0 / (class_sample_count + 1e-6)
        sample_weights = np.array([class_weights[t] for t in targets], dtype=np.float32)
        sampler = WeightedRandomSampler(weights=torch.from_numpy(sample_weights),
                                        num_samples=len(sample_weights), replacement=True)
        print("Class counts:", class_sample_count, "-> Using WeightedRandomSampler")

    train_loader = DataLoader(train_set, batch_size=cfg.batch_size, shuffle=(sampler is None),
                              sampler=sampler, num_workers=cfg.num_workers, pin_memory=True)
    val_loader = DataLoader(val_set, batch_size=cfg.batch_size, shuffle=False,
                            num_workers=cfg.num_workers, pin_memory=True)
    test_loader = DataLoader(ds_test, batch_size=cfg.batch_size, shuffle=False,
                             num_workers=cfg.num_workers, pin_memory=True)

    return train_loader, val_loader, test_loader, idx_to_class

train_loader, val_loader, test_loader, idx_to_class = build_datasets_and_loaders(cfg)
num_classes = len(idx_to_class)
assert num_classes == 2, f"Expected 2 classes, got {num_classes} -> {idx_to_class}"


In [None]:

# ==== Vision Transformer (from scratch) ====
# Core ideas: patchify (Eq.1), [CLS] token (Eq.4), 1-D positional embeddings, pre-LN MSA + MLP (Eqs. 2–3).
# This matches the minimalist ViT design (no convs beyond patchify; no 2-D pos enc; pre-LN; GELU MLP).

class PatchEmbed(nn.Module):
    '''
    Splits image into non-overlapping patches, projects each to embed_dim.
    Implemented as a Conv2d with kernel_size=stride=patch_size.
    Input: (B, C, H, W) -> Output: (B, N, D) where N = (H/P)*(W/P).
    '''
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=256):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = (img_size // patch_size, img_size // patch_size)
        self.num_patches = self.grid_size[0] * self.grid_size[1]

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        # x: (B, C, H, W) -> (B, D, H/P, W/P) -> (B, D, N) -> (B, N, D)
        x = self.proj(x)
        x = x.flatten(2).transpose(1, 2)
        return x

class MLP(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.0):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, attn_drop=0.0, proj_drop=0.0):
        super().__init__()
        assert dim % num_heads == 0, "embed_dim must be divisible by num_heads"
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=True)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x)  # (B, N, 3*C)
        qkv = qkv.reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)  # (3, B, heads, N, head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]  # each: (B, heads, N, head_dim)

        attn = (q @ k.transpose(-2, -1)) * self.scale  # (B, heads, N, N)
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = attn @ v  # (B, heads, N, head_dim)
        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4.0, attn_drop=0.0, drop=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim, eps=1e-6)
        self.attn = Attention(dim, num_heads=num_heads, attn_drop=attn_drop, proj_drop=drop)
        self.norm2 = nn.LayerNorm(dim, eps=1e-6)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, drop=drop)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=2,
                 embed_dim=256, depth=6, num_heads=8, mlp_ratio=4.0,
                 attn_drop=0.0, drop_rate=0.1):
        super().__init__()

        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        num_patches = self.patch_embed.num_patches

        # [CLS] token & positional embeddings (1-D learnable)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, 1 + num_patches, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)

        # Transformer encoder blocks
        self.blocks = nn.ModuleList([
            Block(embed_dim, num_heads, mlp_ratio, attn_drop, drop=drop_rate)
            for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim, eps=1e-6)

        # Classification head
        self.head = nn.Linear(embed_dim, num_classes)

        # Parameter init
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)  # (B, N, D)

        # prepend cls token
        cls_tokens = self.cls_token.expand(B, -1, -1)  # (B, 1, D)
        x = torch.cat((cls_tokens, x), dim=1)  # (B, 1+N, D)

        # add pos embeddings
        x = x + self.pos_embed
        x = self.pos_drop(x)

        # transformer
        for blk in self.blocks:
            x = blk(x)

        x = self.norm(x)
        cls = x[:, 0]  # [CLS]
        logits = self.head(cls)
        return logits

# Build model (or optionally try a local pretrained ViT via timm/torchvision)
def create_model(cfg: Config, num_classes: int):
    if cfg.try_pretrained_vit and '_HAVE_TIMM' in globals() and _HAVE_TIMM:
        try:
            print("Trying local pretrained ViT via timm...")
            model = timm.create_model("vit_base_patch16_224", pretrained=True, num_classes=num_classes)
            return model
        except Exception as e:
            print("timm pretrained model not available locally or weights download failed:", e)
            print("Falling back to from-scratch ViT.")
    model = VisionTransformer(
        img_size=cfg.image_size, patch_size=cfg.patch_size, in_chans=cfg.in_channels,
        num_classes=num_classes, embed_dim=cfg.embed_dim, depth=cfg.depth, num_heads=cfg.num_heads,
        mlp_ratio=cfg.mlp_ratio, attn_drop=cfg.attn_dropout, drop_rate=cfg.drop_rate
    )
    return model

model = create_model(cfg, num_classes).to(DEVICE)
n_params = sum(p.numel() for p in model.parameters())
print(model.__class__.__name__, "params:", f"{n_params/1e6:.2f}M")


In [None]:

# ==== Training Utilities ====

class LabelSmoothingCrossEntropy(nn.Module):
    def __init__(self, smoothing=0.0):
        super().__init__()
        self.smoothing = smoothing

    def forward(self, logits, target):
        if self.smoothing <= 0.0:
            return F.cross_entropy(logits, target)
        n_classes = logits.size(-1)
        log_probs = F.log_softmax(logits, dim=-1)
        with torch.no_grad():
            true_dist = torch.zeros_like(log_probs)
            true_dist.fill_(self.smoothing / (n_classes - 1))
            true_dist.scatter_(1, target.unsqueeze(1), 1.0 - self.smoothing)
        return torch.mean(torch.sum(-true_dist * log_probs, dim=-1))

def cosine_scheduler(optimizer, warmup_epochs, total_epochs, base_lr, train_loader_len):
    def lr_lambda(epoch):
        if epoch < warmup_epochs:
            return float(epoch + 1) / float(max(1, warmup_epochs))
        progress = (epoch - warmup_epochs) / float(max(1, total_epochs - warmup_epochs))
        return 0.5 * (1.0 + math.cos(math.pi * progress))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

@torch.no_grad()
def evaluate(model, data_loader, device):
    model.eval()
    all_logits = []
    all_targets = []
    correct = 0
    total = 0
    for images, targets in data_loader:
        images = images.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)
        logits = model(images)
        preds = torch.argmax(logits, dim=1)
        correct += (preds == targets).sum().item()
        total += targets.size(0)
        all_logits.append(logits.detach().cpu())
        all_targets.append(targets.detach().cpu())

    all_logits = torch.cat(all_logits, dim=0)
    all_targets = torch.cat(all_targets, dim=0)
    probs = all_logits.softmax(dim=1)[:, 1].numpy()
    y_true = all_targets.numpy()
    acc = correct / total
    # F1
    if 'sklearn' in sys.modules:
        from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, roc_curve, precision_recall_fscore_support
        prec, rec, f1, _ = precision_recall_fscore_support(y_true, np.argmax(all_logits.numpy(), axis=1), average="binary", zero_division=0)
        try:
            auroc = roc_auc_score(y_true, probs)
        except Exception:
            auroc = float("nan")
    else:
        preds = np.argmax(all_logits.numpy(), axis=1)
        tp = np.sum((preds == 1) & (y_true == 1))
        fp = np.sum((preds == 1) & (y_true == 0))
        fn = np.sum((preds == 0) & (y_true == 1))
        prec = tp / (tp + fp + 1e-9)
        rec = tp / (tp + fn + 1e-9)
        f1 = 2 * prec * rec / (prec + rec + 1e-9)
        auroc = float("nan")
    return {
        "acc": acc,
        "precision": float(prec),
        "recall": float(rec),
        "f1": float(f1),
        "auroc": float(auroc),
        "logits": all_logits.numpy(),
        "targets": y_true,
    }

def train_one_epoch(model, loader, optimizer, criterion, device, scaler=None):
    model.train()
    running_loss = 0.0
    n = 0
    for images, targets in loader:
        images = images.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        if scaler is not None and device.type == "cuda":
            with torch.cuda.amp.autocast():
                logits = model(images)
                loss = criterion(logits, targets)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            logits = model(images)
            loss = criterion(logits, targets)
            loss.backward()
            optimizer.step()

        running_loss += loss.item() * images.size(0)
        n += images.size(0)

    return running_loss / max(1, n)


In [None]:

# ==== Train Loop with Early Stopping ====
best_val_metric = -float("inf")
best_state_dict = None
history = {"train_loss": [], "val_acc": [], "val_f1": [], "val_auroc": []}

model = model.to(DEVICE)
criterion = LabelSmoothingCrossEntropy(smoothing=cfg.label_smoothing).to(DEVICE)

optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
scheduler = cosine_scheduler(optimizer, cfg.warmup_epochs, cfg.epochs, cfg.lr, train_loader_len=len(train_loader))
scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE.type == "cuda"))

epochs_no_improve = 0
for epoch in range(cfg.epochs):
    t0 = time.time()
    train_loss = train_one_epoch(model, train_loader, optimizer, criterion, DEVICE, scaler)
    scheduler.step()

    val_metrics = evaluate(model, val_loader, DEVICE)
    val_metric = val_metrics["auroc"] if cfg.early_stop_metric == "auroc" else val_metrics["f1"]

    history["train_loss"].append(train_loss)
    history["val_acc"].append(val_metrics["acc"])
    history["val_f1"].append(val_metrics["f1"])
    history["val_auroc"].append(val_metrics["auroc"])

    took = time.time() - t0
    print(f"Epoch {epoch+1:02d}/{cfg.epochs} - "
          f"loss: {train_loss:.4f} | "
          f"val_acc: {val_metrics['acc']:.4f} | "
          f"val_f1: {val_metrics['f1']:.4f} | "
          f"val_auroc: {val_metrics['auroc']:.4f} | "
          f"time: {took:.1f}s")

    # Early stopping logic
    improved = val_metric > best_val_metric
    if improved:
        best_val_metric = val_metric
        best_state_dict = {k: v.cpu() for k, v in model.state_dict().items()}
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= cfg.early_stop_patience:
            print(f"Early stopping at epoch {epoch+1}. Best {cfg.early_stop_metric}: {best_val_metric:.4f}")
            break

# Save best checkpoint
OUT_DIR = Path("./outputs")
OUT_DIR.mkdir(parents=True, exist_ok=True)
best_ckpt_path = OUT_DIR / "vit_chestxray_best.pth"
if best_state_dict is not None:
    torch.save({"state_dict": best_state_dict, "config": asdict(cfg), "classes": idx_to_class}, best_ckpt_path)
    print("Saved best checkpoint to", best_ckpt_path.resolve())
else:
    print("No improvement recorded; saving last model state.")
    torch.save({"state_dict": {k: v.cpu() for k, v in model.state_dict().items()},
                "config": asdict(cfg), "classes": idx_to_class}, best_ckpt_path)


In [None]:

# ==== Training Curves ====
fig = plt.figure(figsize=(6,4))
plt.plot(history["train_loss"], label="train_loss")
plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.title("Training Loss")
plt.legend()
plt.show()

fig = plt.figure(figsize=(6,4))
plt.plot(history["val_acc"], label="val_acc")
plt.plot(history["val_f1"], label="val_f1")
plt.plot(history["val_auroc"], label="val_auroc")
plt.xlabel("Epoch"); plt.ylabel("Score"); plt.title("Validation Metrics")
plt.legend()
plt.show()


In [None]:

# ==== Final Evaluation on UNTOUCHED Test Set ====
# Load best checkpoint and evaluate once on test set.
ckpt = torch.load(best_ckpt_path, map_location="cpu")
model.load_state_dict(ckpt["state_dict"])
model = model.to(DEVICE)

test_metrics = evaluate(model, test_loader, DEVICE)
print({k: float(v) if isinstance(v, (np.floating,)) else v for k,v in test_metrics.items() if k in ["acc", "precision", "recall", "f1", "auroc"]})

# Classification report & confusion matrix
if 'sklearn' in sys.modules:
    from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, roc_curve
    y_true = test_metrics["targets"]
    y_pred = np.argmax(test_metrics["logits"], axis=1)
    print("\nClassification report (test):")
    print(classification_report(y_true, y_pred, target_names=[idx_to_class[0], idx_to_class[1]], digits=4))

    cm = confusion_matrix(y_true, y_pred)
    print("Confusion matrix (test):\n", cm)

    # Plot Confusion Matrix
    fig = plt.figure(figsize=(4,4))
    plt.imshow(cm, interpolation='nearest')
    plt.title("Confusion Matrix (Test)")
    plt.colorbar()
    tick_marks = np.arange(len(idx_to_class))
    plt.xticks(tick_marks, [idx_to_class[i] for i in range(len(idx_to_class))], rotation=45)
    plt.yticks(tick_marks, [idx_to_class[i] for i in range(len(idx_to_class))])
    plt.xlabel("Predicted"); plt.ylabel("True")
    plt.tight_layout()
    plt.show()

    # ROC Curve
    from sklearn.metrics import roc_curve
    probs = (test_metrics["logits"])
    if hasattr(probs, "softmax"):
        probs = probs.softmax(axis=1)[:,1]
    else:
        probs = torch.tensor(probs).softmax(dim=1)[:,1].numpy()
    fpr, tpr, _ = roc_curve(y_true, probs)
    fig = plt.figure(figsize=(5,4))
    plt.plot(fpr, tpr, linewidth=2)
    plt.plot([0,1], [0,1], linestyle="--")
    plt.xlabel("False Positive Rate"); plt.ylabel("True Positive Rate")
    plt.title(f"ROC (AUROC={test_metrics['auroc']:.4f})")
    plt.show()


In [None]:

# ==== Qualitative: a few test predictions ====
model.eval()
images_shown = 12
images, targets = next(iter(test_loader))
images = images[:images_shown].to(DEVICE)
targets = targets[:images_shown].to(DEVICE)

with torch.no_grad():
    logits = model(images)
    probs = logits.softmax(dim=1)
    preds = probs.argmax(dim=1)

# Denormalize for display
def denorm(img):
    img = img.clone().detach().cpu()
    if cfg.in_channels == 3:
        mean = torch.tensor([0.485, 0.456, 0.406]).view(3,1,1)
        std  = torch.tensor([0.229, 0.224, 0.225]).view(3,1,1)
    else:
        mean = torch.tensor([0.5]).view(1,1,1)
        std  = torch.tensor([0.25]).view(1,1,1)
    img = img * std + mean
    return img.clamp(0,1)

grid = make_grid(denorm(images), nrow=4)
plt.figure(figsize=(8,8))
plt.imshow(np.transpose(grid.numpy(), (1,2,0)))
plt.axis("off")
title_lines = []
for i in range(len(images)):
    title_lines.append(f"{i+1}) pred={idx_to_class[int(preds[i].item())]} ({float(probs[i, preds[i]].item()):.2f}), true={idx_to_class[int(targets[i].item())]}")
plt.title("\n".join(title_lines), fontsize=9)
plt.show()


In [None]:

# ==== Save brief results artifact ====
results_path = OUT_DIR / "test_metrics.json"
import json, numpy as np
with open(results_path, "w") as f:
    json.dump({k: (float(v) if isinstance(v, (np.floating,)) else v) for k,v in test_metrics.items() if k in ["acc","precision","recall","f1","auroc"]}, f, indent=2)
print("Saved test metrics to", results_path.resolve())

print("NOTE: Test set was never used during training or model selection. Validation guided early stopping; test was evaluated once at the end.")


### Appendix: Using 1‑channel input (optional)

By default, we replicate grayscale X‑rays to 3 channels to keep normalization and pretrained compatibility straightforward.
If you prefer to keep images as **single‑channel (C=1)**, set `cfg.in_channels = 1` **before** building the model and transforms.
The `PatchEmbed` will adapt to `in_chans=1`. Adjust normalization mean/std accordingly in `build_transforms`.
