In [1]:
from __future__ import annotations
from pathlib import Path
import time

import json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import math

try:
    import pandas as pd
except ModuleNotFoundError:
    pd = None  

CLASS_NAMES = [
    "T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
    "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"
]

LOSS_YLIM   = (0.0, 1.2)
LOSS_YTICKS = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 1.2]

ACC_YLIM    = (0.7, 1.0)
ACC_YTICKS  = [0.70, 0.75, 0.80, 0.85, 0.90, 0.95, 1.00] 

In [2]:
class ConvBNReLU(nn.Module):
    
    def __init__(self, in_ch, out_ch, k=3, s=1, p=1):
        super().__init__()
        
        self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=k, stride=s, padding=p, bias=False) # no Bias because we 'll use batchNorm
        self.bn = nn.BatchNorm2d(out_ch) #Normalize all feature maps
        self.act = nn.ReLU(inplace=True)

    def forward(self, x):
        return self.act(self.bn(self.conv(x))) # Conv2d → BatchNorm2d → ReLU


class Model2CNN(nn.Module):
    
    """ Fashion-MNIST CNN with BatchNorm, Dropout, GAP """
    
    def __init__(self, num_classes=10, dropout=0.20):
        super().__init__()
        self.features = nn.Sequential( # A container that outputs each layer inside
            ConvBNReLU(1, 32),
            ConvBNReLU(32, 32),
            nn.MaxPool2d(2),   # 28 -> 14
            nn.Dropout2d(0.05),

            ConvBNReLU(32, 64),
            ConvBNReLU(64, 64),
            nn.MaxPool2d(2),   # 14 -> 7
            nn.Dropout2d(0.10),

            ConvBNReLU(64, 128),
            ConvBNReLU(128, 128), 
            nn.Dropout2d(0.15),
        )

        self.gap = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128, 64),
            nn.ReLU(inplace=True),
            nn.Dropout(0.20),
            nn.Linear(64, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.gap(x)
        x = self.classifier(x)
        return x


def build_model():
    
    return Model2CNN(num_classes=10, dropout=0.20)

In [3]:
def get_device():
    if torch.backends.mps.is_available() and torch.backends.mps.is_built():
        return torch.device("mps")
    if torch.cuda.is_available():
        return torch.device("cuda")
    return torch.device("cpu")


def ensure_dirs(base_dir: Path):
    out_dir = base_dir / "outputs"
    ckpt_dir = out_dir / "ckpt"
    fig_dir = out_dir / "figures"
    ckpt_dir.mkdir(parents=True, exist_ok=True)
    fig_dir.mkdir(parents=True, exist_ok=True)
    return out_dir, ckpt_dir, fig_dir


@torch.no_grad()
def evaluate(model, loader, device, criterion):
    model.eval()
    
    total_loss, correct, total = 0.0, 0, 0
    
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss = criterion(logits, y)

        total_loss += loss.item() * x.size(0)
        pred = logits.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += x.size(0)
    return total_loss / total, correct / total


def train_one_epoch(model, loader, device, optimizer, criterion):
    model.train()
    total_loss, correct, total = 0.0, 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad(set_to_none=True)
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * x.size(0)
        pred = logits.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += x.size(0)
    return total_loss / total, correct / total


def save_grid(feature_maps, title, out_path, cols=8, figsize=(8, 4)):
    import matplotlib.pyplot as plt

    n_maps = len(feature_maps)
    rows = (n_maps + cols - 1) // cols

    fig, axes = plt.subplots(rows, cols, figsize=figsize)

    for i, ax in enumerate(axes.flat):
        if i < n_maps:
            ax.imshow(feature_maps[i], cmap="gray")
        ax.axis("off")

    plt.suptitle(title, fontsize=10)
    plt.tight_layout()
    plt.subplots_adjust(top=0.88)
    plt.savefig(out_path, dpi=120)
    plt.close()

def unnormalize(x):
    return x * 0.5 + 0.5


In [4]:
def run_training_model2(base_dir: Path | None = None):

    print("==== Training Model 2 ====")

    if base_dir is None:
        try:
            base_dir = Path(__file__).resolve().parent
        except NameError:
            base_dir = Path().resolve()

    _, ckpt_dir, fig_dir = ensure_dirs(base_dir)
    device = get_device()
    print(f"Device: {device}")

    train_tf = transforms.Compose([
        transforms.RandomCrop(28, padding=4),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
    ])

    val_tf = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
    ])

    data_root = base_dir / "data"

    full_train = datasets.FashionMNIST(
        root=data_root,
        train=True,
        download=True,
        transform=train_tf,
    )

    val_ratio = 0.1
    val_size = int(len(full_train) * val_ratio)
    train_size = len(full_train) - val_size
    train_set, val_set = random_split(full_train, [train_size, val_size])

    batch_size = 128
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True,  num_workers=0)
    val_loader   = DataLoader(val_set,   batch_size=batch_size, shuffle=False, num_workers=0)


    # Model, loss, optimizer, scheduler
    model = build_model().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=2)

    
    # Training loop
    epochs = 12
    best_val_acc = 0.0
    best_path = ckpt_dir / "model2_best.pt"

    history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []}

    for epoch in range(1, epochs + 1):
        t0 = time.time()

        train_loss, train_acc = train_one_epoch(model, train_loader, device, optimizer, criterion)
        val_loss,   val_acc   = evaluate(model, val_loader, device, criterion)

        scheduler.step(val_loss)

        history["train_loss"].append(train_loss)
        history["train_acc"].append(train_acc)
        history["val_loss"].append(val_loss)
        history["val_acc"].append(val_acc)

        dt = time.time() - t0
        print(
            f"Epoch {epoch:02d}/{epochs} | "
            f"train_loss={train_loss:.4f} train_acc={train_acc:.4f} | "
            f"val_loss={val_loss:.4f} val_acc={val_acc:.4f} | "
            f"time={dt:.1f}s"
        )

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({"model_state": model.state_dict(), "val_acc": best_val_acc}, best_path)

    print(f"\nSaved best checkpoint to: {best_path}")
    
    epochs_arr = np.arange(1, len(history["train_loss"]) + 1)

    # Loss curve
    plt.figure(figsize=(8, 5))
    plt.plot(epochs_arr, history["train_loss"], label="train_loss")
    plt.plot(epochs_arr, history["val_loss"],   label="val_loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Model-2 Loss Curve")
    plt.ylim(LOSS_YLIM)
    plt.yticks(LOSS_YTICKS)
    plt.legend()
    plt.tight_layout()
    plt.savefig(fig_dir / "training_loss_model2.png", dpi=120)
    plt.close()

    # Accuracy curve
    plt.figure(figsize=(8, 5))
    plt.plot(epochs_arr, history["train_acc"], label="train_acc")
    plt.plot(epochs_arr, history["val_acc"],   label="val_acc")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.title("Model-2 Accuracy Curve")
    plt.ylim(ACC_YLIM)
    plt.yticks(ACC_YTICKS)
    plt.legend()
    plt.tight_layout()
    plt.savefig(fig_dir / "training_accuracy_model2.png", dpi=120)
    plt.close()

    # Save metadata just like Model 1
    meta = {
        "best_val_acc": best_val_acc,
        "device": str(device),
        "epochs": epochs,
    }
    (ckpt_dir / "train_meta_model2.json").write_text(
        json.dumps(meta, indent=2),
        encoding="utf-8"
    )
    print(f"\nSaved metadata: {ckpt_dir / 'train_meta_model2.json'}")

    return history


In [5]:
@torch.no_grad()
def run_full_evaluation_model2(base_dir: Path | None = None, ckpt_name: str = "model2_best.pt", batch_size: int = 256):

    print("\n==== Full Evaluation: Model 2 ====\n")

    if base_dir is None:
        try:
            base_dir = Path(__file__).resolve().parent
        except NameError:
            base_dir = Path().resolve()

    _, ckpt_dir, fig_dir = ensure_dirs(base_dir)
    device = get_device()
    print(f"Device: {device}")

    
    tf = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    test_set = datasets.FashionMNIST(
        root=base_dir / "data",
        train=False,
        download=True,
        transform=tf
    )
    loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=0)

    
    ckpt_path = ckpt_dir / ckpt_name
    model = build_model().to(device)
    ckpt = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(ckpt["model_state"])
    model.eval()

    
    all_preds, all_labels = [], []
    for x, y in loader:
        x = x.to(device)
        preds = model(x).argmax(dim=1).cpu().numpy()
        all_preds.extend(preds)
        all_labels.extend(y.numpy())

    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)

    acc = (all_labels == all_preds).mean()
    print(f"\nTest Accuracy: {acc:.4f}\n")

    (fig_dir / "test_accuracy_model2.txt").write_text(f"{acc:.4f}")

    
    # Classification report
    report = classification_report(all_labels, all_preds, target_names=CLASS_NAMES, digits=4)
    print(report)
    (fig_dir / "classification_report_model2.txt").write_text(report)

    # Confusion matrix
    cm = confusion_matrix(all_labels, all_preds, normalize='true')

    plt.figure(figsize=(9, 7))
    sns.heatmap(cm, annot=True, fmt=".2f", cmap="Blues",
                xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES,
                vmin=0.0, vmax=1.0)
    plt.ylabel("True Label")
    plt.xlabel("Predicted Label")
    plt.title("Normalized Confusion Matrix - Model 2")
    plt.tight_layout()
    plt.savefig(fig_dir / "confusion_matrix_model2_normalized.png", dpi=160)
    plt.close()

    print(f"Saved normalized confusion matrix to: {fig_dir / 'confusion_matrix_model2_normalized.png'}")


In [6]:
@torch.no_grad()
def run_visualizations_model2(base_dir: Path | None = None,
                              ckpt_name: str = "model2_best.pt",
                              sample_idx: int = 123):

    print("\n==== Visualization Model 2 ====\n")

    if base_dir is None:
        try:
            base_dir = Path(__file__).resolve().parent
        except NameError:
            base_dir = Path().resolve()

    _, _, fig_dir = ensure_dirs(base_dir)
    ckpt_path = base_dir / "outputs" / "ckpt" / ckpt_name

    device = get_device()
    print(f"Device: {device}")

    tf = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    test_ds = datasets.FashionMNIST(
        root=base_dir / "data",
        train=False,
        download=True,
        transform=tf
    )

    model = build_model().to(device)
    ckpt = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(ckpt["model_state"])
    model.eval()

    idxs = list(range(10))
    plt.figure(figsize=(5, 3))
    for i, idx in enumerate(idxs):
        x, y = test_ds[idx]
        plt.subplot(2, 5, i + 1)
        plt.imshow(unnormalize(x).squeeze(), cmap="gray")
        plt.title(CLASS_NAMES[int(y)], fontsize=8)
        plt.axis("off")

    plt.suptitle("Sample Images - Model 2")
    plt.tight_layout()
    out1 = fig_dir / "sample_images_model2.png"
    plt.savefig(out1, dpi=120)
    plt.close()
    print(f"Saved: {out1}")

    print("\nExtracting STAGE-BASED activation maps...\n")

    stage_layers = {
        "stage1": 3,
        "stage2": 7,
        "stage3": 10
    }

    activations = {}

    def get_activation(name):
        def hook_fn(module, inp, out):
            activations[name] = out.detach().cpu()
        return hook_fn

    hooks = []

    for stage_name, layer_idx in stage_layers.items():
        h = model.features[layer_idx].register_forward_hook(get_activation(stage_name))
        hooks.append(h)
        print(f"  ✔ Hook registered for {stage_name} at model.features[{layer_idx}]")

    x, y_true = test_ds[sample_idx]
    x_batch = x.unsqueeze(0).to(device)
    _ = model(x_batch)

    h = model.features[1].register_forward_hook(get_activation("stage1_mid"))
    hooks.append(h)
    print("Hook registered for stage1_mid at model.features[1]")

    for h in hooks:
        h.remove()

    print("\nSaving activation map grids for each stage...\n")

    for stage_name, act in activations.items():
        act = act.squeeze(0).numpy()
        maps = act[:32]

        out_path = fig_dir / f"activation_maps_{stage_name}_idx{sample_idx}.png"
        save_grid(
            maps,
            f"Activation Maps - {stage_name.upper()} (Model 2)",
            out_path,
            cols=8,
        )
        
        print(f"  → saved: {out_path}")

    print("\nStage-based activation map visualizations completed.\n")
    print("Check outputs/figures for results!\n")


In [7]:
run_training_model2()
run_full_evaluation_model2()
run_visualizations_model2()

==== Training Model 2 ====
Device: mps
Epoch 01/12 | train_loss=0.7750 train_acc=0.7158 | val_loss=0.4986 val_acc=0.8230 | time=14.9s
Epoch 02/12 | train_loss=0.4659 train_acc=0.8289 | val_loss=0.3867 val_acc=0.8570 | time=9.9s
Epoch 03/12 | train_loss=0.3977 train_acc=0.8560 | val_loss=0.3287 val_acc=0.8747 | time=9.9s
Epoch 04/12 | train_loss=0.3576 train_acc=0.8695 | val_loss=0.3188 val_acc=0.8782 | time=9.6s
Epoch 05/12 | train_loss=0.3296 train_acc=0.8800 | val_loss=0.3166 val_acc=0.8798 | time=9.5s
Epoch 06/12 | train_loss=0.3096 train_acc=0.8880 | val_loss=0.2849 val_acc=0.8920 | time=9.5s
Epoch 07/12 | train_loss=0.2966 train_acc=0.8929 | val_loss=0.2967 val_acc=0.8845 | time=9.6s
Epoch 08/12 | train_loss=0.2814 train_acc=0.8977 | val_loss=0.2545 val_acc=0.9038 | time=9.6s
Epoch 09/12 | train_loss=0.2733 train_acc=0.9016 | val_loss=0.2574 val_acc=0.9035 | time=10.7s
Epoch 10/12 | train_loss=0.2639 train_acc=0.9033 | val_loss=0.2368 val_acc=0.9093 | time=10.0s
Epoch 11/12 | trai