In [1]:
from __future__ import annotations

import os
import json
import random
from dataclasses import dataclass, asdict
from pathlib import Path
from collections import Counter

import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, confusion_matrix

CLASS_NAMES = [
    "T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
    "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"
]

LOSS_YLIM   = (0.0, 1.2)
LOSS_YTICKS = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 1.2]

ACC_YLIM    = (0.7, 1.0)
ACC_YTICKS  = [0.70, 0.75, 0.80, 0.85, 0.90, 0.95, 1.00] 


In [2]:
class FashionCNN(nn.Module):
    """
    Baseline CNN for Fashion-MNIST 
    """
    def __init__(self, num_classes: int = 10) -> None:
        
        super().__init__()

        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1)   # 28x28
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)   # 28x28
        self.relu = nn.ReLU(inplace=True)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)  # 14x14
        self.dropout = nn.Dropout(p=0.25)

        self.fc1 = nn.Linear(64 * 14 * 14, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.pool(x)
        x = self.dropout(x)

        x = torch.flatten(x, 1)
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x


In [3]:
def get_device() -> torch.device:
    # Apple Silicon GPU
    if torch.backends.mps.is_available() and torch.backends.mps.is_built():
        return torch.device("mps")
    # NVIDIA CUDA GPU
    if torch.cuda.is_available():
        return torch.device("cuda")
    # CPU 
    return torch.device("cpu")


def set_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def ensure_dirs(base_dir: Path) -> tuple[Path, Path]:
    ckpt_dir = base_dir / "outputs" / "ckpt"
    fig_dir = base_dir / "outputs" / "figures"
    ckpt_dir.mkdir(parents=True, exist_ok=True)
    fig_dir.mkdir(parents=True, exist_ok=True)
    return ckpt_dir, fig_dir


def unnormalize(img_tensor: torch.Tensor) -> np.ndarray:
    mean, std = 0.2860, 0.3530
    x = img_tensor.clone()
    x = x * std + mean
    x = x.clamp(0, 1)
    return x.squeeze(0).cpu().numpy()


def save_grid(images: np.ndarray, title: str, out_path: Path, cols: int = 4) -> None:
    """
    images: (N, H, W)
    """
    n = images.shape[0]
    rows = int(np.ceil(n / cols))

    plt.figure(figsize=(cols * 1.5, rows * 1.5))
    
    for i in range(n):
        plt.subplot(rows, cols, i + 1)
        img = images[i]
        # normalize per-map for visibility
        img = (img - img.min()) / (img.max() - img.min() + 1e-8)
        plt.imshow(img, cmap="gray")
        plt.axis("off")

    plt.suptitle(title)
    plt.tight_layout()
    plt.savefig(out_path, dpi=120, bbox_inches="tight")
    plt.close()


In [4]:
@dataclass
class TrainConfig:
    seed: int = 42
    batch_size: int = 128
    epochs: int = 10
    lr: float = 1e-3
    weight_decay: float = 0.0
    val_ratio: float = 0.1
    num_workers: int = 0  


@torch.no_grad()
def evaluate(model: nn.Module,loader: DataLoader, device: torch.device) -> tuple[float, float]:
    """
    Validation/test evaluation loop.
    """
    model.eval()
    total, correct = 0, 0
    loss_sum = 0.0
    criterion = nn.CrossEntropyLoss()

    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss = criterion(logits, y)
        loss_sum += loss.item() * x.size(0)

        preds = torch.argmax(logits, dim=1)
        correct += (preds == y).sum().item()
        total += x.size(0)

    avg_loss = loss_sum / total
    acc = correct / total
    return avg_loss, acc


def plot_curves(history: dict, out_path: Path) -> None:
    epochs = np.arange(1, len(history["train_loss"]) + 1)

    plt.figure(figsize=(8, 5))
    plt.plot(epochs, history["train_loss"], label="train_loss")
    plt.plot(epochs, history["val_loss"], label="val_loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Model-1 Loss Curve")
    plt.ylim(LOSS_YLIM)
    plt.yticks(LOSS_YTICKS)
    plt.legend()
    plt.tight_layout()
    plt.savefig(out_path.with_name("training_loss.png"), dpi=120)
    plt.close()

    plt.figure(figsize=(8, 5))
    plt.plot(epochs, history["train_acc"], label="train_acc")
    plt.plot(epochs, history["val_acc"], label="val_acc")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.title("Model-1 Accuracy Curve")
    plt.ylim(ACC_YLIM)
    plt.yticks(ACC_YTICKS)
    plt.legend()
    plt.tight_layout()
    plt.savefig(out_path.with_name("training_accuracy.png"), dpi=120)
    plt.close()


In [5]:
def run_training(cfg: TrainConfig, base_dir: Path | None = None):
    
    if base_dir is None:
        try:
            base_dir = Path(__file__).resolve().parent
        except NameError:
            base_dir = Path().resolve()

    ckpt_dir, fig_dir = ensure_dirs(base_dir)

    set_seed(cfg.seed)
    device = get_device()
    print(f"Device: {device}")

    
    transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.2860,), (0.3530,))])

    # Dataset + split
    full_train = datasets.FashionMNIST(
        root=str(base_dir / "data"),
        train=True,
        download=True,
        transform=transform
    )

    val_size = int(len(full_train) * cfg.val_ratio)
    train_size = len(full_train) - val_size
    train_ds, val_ds = random_split(full_train,[train_size, val_size],
        generator=torch.Generator().manual_seed(cfg.seed)
    )

    train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers)
    val_loader = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers
    )

    # Model, loss, optimizer
    model = FashionCNN(num_classes=10).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(),lr=cfg.lr,weight_decay=cfg.weight_decay)

    history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []}
    best_val_acc = -1.0
    best_path = ckpt_dir / "model1_best.pt"

    #Training and Validation Loop
    for epoch in range(1, cfg.epochs + 1):
        model.train() #Switch model to training mode
        total, correct = 0, 0
        loss_sum = 0.0

        for x, y in train_loader: # Inner batch loop
            x, y = x.to(device), y.to(device)

            optimizer.zero_grad(set_to_none=True)
            logits = model(x)
            loss = criterion(logits, y)
            loss.backward()
            optimizer.step()

            loss_sum += loss.item() * x.size(0)
            preds = torch.argmax(logits, dim=1)
            correct += (preds == y).sum().item()
            total += x.size(0)

        train_loss = loss_sum / total
        train_acc = correct / total

        val_loss, val_acc = evaluate(model, val_loader, device)

        history["train_loss"].append(train_loss)
        history["train_acc"].append(train_acc)
        history["val_loss"].append(val_loss)
        history["val_acc"].append(val_acc)

        print(
            f"Epoch {epoch:02d}/{cfg.epochs} | "
            f"train_loss={train_loss:.4f} train_acc={train_acc:.4f} | "
            f"val_loss={val_loss:.4f} val_acc={val_acc:.4f}"
        )

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(
                {
                    "model_state": model.state_dict(),
                    "config": asdict(cfg),
                    "val_acc": best_val_acc,
                },
                best_path,
            )

    
    plot_curves(history, fig_dir / "training_curves.png")

    
    meta = {
        "best_val_acc": best_val_acc,
        "device": str(device),
        "config": asdict(cfg),
    }
    (ckpt_dir / "train_meta.json").write_text(
        json.dumps(meta, indent=2),
        encoding="utf-8"
    )

    print(f"\nSaved best checkpoint: {best_path}")
    print(f"Saved training curves to: {fig_dir}")

    return model, best_path, history, fig_dir


In [6]:
@torch.no_grad()
def run_test_eval(base_dir: Path | None = None, ckpt_name: str = "model1_best.pt"):
    
    if base_dir is None:
        try:
            base_dir = Path(__file__).resolve().parent
        except NameError:
            base_dir = Path().resolve()

    _, fig_dir = ensure_dirs(base_dir)
    ckpt_path = base_dir / "outputs" / "ckpt" / ckpt_name

    device = get_device()
    print(f"Device: {device}")

    transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.2860,), (0.3530,))])

    test_ds = datasets.FashionMNIST(
        root=str(base_dir / "data"),
        train=False,
        download=True,
        transform=transform
    )

    test_loader = DataLoader(
        test_ds, batch_size=256,
        shuffle=False, num_workers=0
    )

    model = FashionCNN(num_classes=10).to(device)
    ckpt = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(ckpt["model_state"])
    model.eval()

    y_true_all, y_pred_all = [], []
    
    for x, y in test_loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        preds = torch.argmax(logits, dim=1)
        y_true_all.append(y.cpu().numpy())
        y_pred_all.append(preds.cpu().numpy())

    y_true = np.concatenate(y_true_all)
    y_pred = np.concatenate(y_pred_all)

    acc = (y_true == y_pred).mean()
    print(f"\nTest Accuracy: {acc:.4f}\n")

    # Classification report
    report = classification_report(y_true, y_pred, target_names=CLASS_NAMES, digits=4)
    print(report)
    (fig_dir / "classification_report_model1.txt").write_text(report, encoding="utf-8")

    # Confusion matrix (normalized)
    cm = confusion_matrix(y_true, y_pred, labels=list(range(10)))
    cm_norm = cm / (cm.sum(axis=1, keepdims=True) + 1e-12)

    plt.figure(figsize=(5, 5))
    plt.imshow(cm_norm, cmap="Blues")
    plt.title("Normalized Confusion Matrix (Model 1 - PyTorch)")
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.xticks(range(10), CLASS_NAMES, rotation=45, ha="right")
    plt.yticks(range(10), CLASS_NAMES)

    for i in range(10):
        for j in range(10):
            plt.text(j, i, f"{cm_norm[i, j]:.2f}",
                     ha="center", va="center", fontsize=7)

    plt.tight_layout()
    out_path = fig_dir / "confusion_matrix_normalized_model1.png"
    plt.savefig(out_path, dpi=120)
    plt.close()

    print(f"Saved normalized confusion matrix to: {out_path}")


In [7]:
@torch.no_grad()
def run_visualizations(base_dir: Path | None = None, ckpt_name: str = "model1_best.pt", sample_idx: int = 123):
    if base_dir is None:
        try:
            base_dir = Path(__file__).resolve().parent
        except NameError:
            base_dir = Path().resolve()

    _, fig_dir = ensure_dirs(base_dir)
    ckpt_path = base_dir / "outputs" / "ckpt" / ckpt_name

    device = get_device()
    print(f"Device: {device}")

    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.2860,), (0.3530,))])

    test_ds = datasets.FashionMNIST(
        root=str(base_dir / "data"),
        train=False,
        download=True,
        transform=transform
    )

    model = FashionCNN(num_classes=10).to(device)
    ckpt = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(ckpt["model_state"])
    model.eval()

    #Sample images
    idxs = list(range(10))
    plt.figure(figsize=(5, 3))
    for i, idx in enumerate(idxs):
        x, y = test_ds[idx]
        plt.subplot(2, 5, i + 1)
        plt.imshow(unnormalize(x), cmap="gray")
        plt.title(CLASS_NAMES[int(y)], fontsize=9)
        plt.axis("off")

    plt.suptitle("Sample Images from Fashion-MNIST Dataset")
    plt.tight_layout()
    out_path = fig_dir / "sample_images_fashion_mnist_pytorch.png"
    plt.savefig(out_path, dpi=120)
    plt.close()
    print(f"Saved Figure 1 to: {out_path}")

   
    #Filters visualization (conv1)
    w = model.conv1.weight.detach().cpu().numpy()  # (outC, inC, k, k)
    w = w[:, 0, :, :]  # grayscale -> (outC, k, k)
    n_filters = min(16, w.shape[0])
    filters = w[:n_filters]

    save_grid(filters, "Learned Filters (conv1) - Model 1", fig_dir / "filters_conv1_model1.png", cols=4)
    print(f"Saved filters to: {fig_dir / 'filters_conv1_model1.png'}")

    #Activation maps with hooks
    activations: dict[str, torch.Tensor] = {}

    def hook_fn(name):
        def _hook(module, inp, out):
            activations[name] = out.detach()
        return _hook

    h1 = model.conv1.register_forward_hook(hook_fn("conv1"))
    h2 = model.conv2.register_forward_hook(hook_fn("conv2"))
    h_pool = model.pool.register_forward_hook(hook_fn("pool"))

    x, y_true = test_ds[sample_idx]
    x_batch = x.unsqueeze(0).to(device)

    logits = model(x_batch)
    y_pred = int(torch.argmax(logits, dim=1).item())

    
    h1.remove()
    h2.remove()
    h_pool.remove()

    # Sample preview
    plt.figure(figsize=(3,3))
    plt.imshow(unnormalize(x), cmap="gray")
    plt.axis("off")
    plt.title(
        f"idx={sample_idx} | True: {CLASS_NAMES[int(y_true)]} "
        f"| Pred: {CLASS_NAMES[y_pred]}"
    )
    sample_path = (
        fig_dir / f"sample_{sample_idx}_true_{int(y_true)}_pred_{y_pred}_model1.png"
    )
    plt.tight_layout()
    plt.savefig(sample_path, dpi=120)
    plt.close()
    print(f"Saved sample preview to: {sample_path}")

    # Activation maps (1, C, H, W) -> first N channels
    for layer_name in ["conv1", "conv2"]:
        act = activations[layer_name][0].cpu().numpy()  # (C, H, W)
        n_maps = min(16, act.shape[0])
        maps = act[:n_maps]

        out_path = fig_dir / f"activation_{layer_name}_idx{sample_idx}_model1.png"
        save_grid(maps, f"Activation Maps ({layer_name}) - idx {sample_idx} - Model 1", out_path, cols=4)
        print(f"Saved activation maps to: {out_path}")

    act = activations["pool"][0].cpu().numpy()  # (C, H, W)
    n_maps = min(16, act.shape[0])
    maps = act[:n_maps]

    out_path = fig_dir / f"activation_pool_idx{sample_idx}_model1.png"
    save_grid(
    maps,
    f"Activation Maps (pool) - idx {sample_idx} - Model 1",
    out_path,
    cols=4
)

    print("Done. Check outputs/figures/ folder.")


In [8]:
cfg = TrainConfig(epochs=12)
model, best_path, history, fig_dir = run_training(cfg)

run_test_eval()
run_visualizations(sample_idx=123)

Device: mps
Epoch 01/12 | train_loss=0.4344 train_acc=0.8449 | val_loss=0.2903 val_acc=0.8978
Epoch 02/12 | train_loss=0.2771 train_acc=0.8986 | val_loss=0.2435 val_acc=0.9092
Epoch 03/12 | train_loss=0.2281 train_acc=0.9160 | val_loss=0.2229 val_acc=0.9193
Epoch 04/12 | train_loss=0.1939 train_acc=0.9289 | val_loss=0.1957 val_acc=0.9278
Epoch 05/12 | train_loss=0.1676 train_acc=0.9381 | val_loss=0.2105 val_acc=0.9282
Epoch 06/12 | train_loss=0.1447 train_acc=0.9462 | val_loss=0.2063 val_acc=0.9282
Epoch 07/12 | train_loss=0.1267 train_acc=0.9525 | val_loss=0.2081 val_acc=0.9278
Epoch 08/12 | train_loss=0.1104 train_acc=0.9585 | val_loss=0.2026 val_acc=0.9328
Epoch 09/12 | train_loss=0.0976 train_acc=0.9645 | val_loss=0.2152 val_acc=0.9305
Epoch 10/12 | train_loss=0.0854 train_acc=0.9678 | val_loss=0.2279 val_acc=0.9302
Epoch 11/12 | train_loss=0.0783 train_acc=0.9704 | val_loss=0.2282 val_acc=0.9340
Epoch 12/12 | train_loss=0.0687 train_acc=0.9740 | val_loss=0.2382 val_acc=0.9322

Sav