# DSPSVDD — Deep Structure-Preserving SVDD for One-Class Anomaly Detection
**Author:** Hamza Oukhacha & OUDRAIA Abdelouahad

This notebook trains a DSPSVDD model for anomaly detection:
1) Pretrain an AutoEncoder on normal samples
2) Compute the hypersphere center in latent space
3) Joint training (SVDD distance + reconstruction)
4) Evaluation (ROC-AUC, PR-AUC, FPR@95%TPR, F1, MCC) and figures

**Datasets:** MNIST / Fashion-MNIST  
**Normal Class:** digit k (0..9)  

---

## Table of Contents
1. Environment & Runtime Check  
2. Mount Drive
3. Configuration (CFG) & Utilities  
4. Dataset & Dataloaders  
5. Model: Convolutional AutoEncoder  
6. Training: AE Warm-up  
7. Hypersphere Center  
8. Training: Joint DSPSVDD  
9. Evaluation & Metrics  
10. Plotting: ROC/PR/CM/Curves/Recons  
11. Extra Figures (placeholders)  
12. Save Artifacts  
13. Single-Run Driver (save to Drive)  
14. Multi-Run Driver (MNIST 1→9)  
15. Aggregation to Summary CSV  
16. Re-load a Run & Re-make Figures  
17. Notes & Next Steps


In [None]:
# @title Environment & Runtime Check
import torch, torchvision, sys
print("Python:", sys.version)
print("Torch:", torch.__version__, "| CUDA available:", torch.cuda.is_available())
print("Torchvision:", torchvision.__version__)
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))


In [None]:
# @title Mount Drive (optional but recommended)
from google.colab import drive
drive.mount('/content/drive')


In [None]:
# @title Configuration (CFG) & Utilities
import os, time, json, random, warnings
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F

class Config:
    # Data
    DATASET      = "mnist"       # "mnist" or "fashion_mnist"
    NORMAL_CLASS = 1             # 0..9 you can changer to test or using all mnist (0 to 9)
    IMG_SIZE     = 32
    BATCH_SIZE   = 128
    NUM_WORKERS  = 2

    # Model
    LATENT_DIM   = 128

    # Training
    AE_EPOCHS    = 15
    JOINT_EPOCHS = 20
    LR           = 1e-3
    LR_JOINT     = 5e-4
    WEIGHT_DECAY = 1e-5
    GAMMA        = 0.1
    NU           = 0.05

    # System
    DEVICE       = "cuda" if torch.cuda.is_available() else "cpu"
    SEED         = 42
    OUTPUT_DIR   = "./dspsvdd_results"

CFG = Config()

def cfg_to_dict():
    return {k: getattr(CFG, k) for k in dir(CFG) if k.isupper() and not k.startswith("_")}

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


In [None]:
# @title Dataset & Dataloaders
from torch.utils.data import DataLoader, Subset, Dataset
from torchvision import datasets, transforms

class OneClassDataset(Dataset):
    """Binary labels for test: 0=normal (digit==NORMAL_CLASS), 1=anomaly (otherwise)."""
    def __init__(self, dataset, normal_class: int):
        self.dataset = dataset
        self.nc = int(normal_class)
    def __len__(self): return len(self.dataset)
    def __getitem__(self, idx):
        x, y = self.dataset[idx]
        return x, 0 if int(y) == self.nc else 1

def _get_targets(dataset):
    if hasattr(dataset, "targets"):
        t = dataset.targets
        return t.cpu().numpy() if isinstance(t, torch.Tensor) else np.array(t)
    ys = []
    for _, y in dataset: ys.append(int(y))
    return np.array(ys)

def get_data_loaders():
    transform = transforms.Compose([
        transforms.Resize((CFG.IMG_SIZE, CFG.IMG_SIZE)),
        transforms.ToTensor()
    ])
    if CFG.DATASET.lower() == "mnist":
        train_dataset = datasets.MNIST("./data", train=True, download=True, transform=transform)
        test_dataset  = datasets.MNIST("./data", train=False, download=True, transform=transform)
    elif CFG.DATASET.lower() == "fashion_mnist":
        train_dataset = datasets.FashionMNIST("./data", train=True, download=True, transform=transform)
        test_dataset  = datasets.FashionMNIST("./data", train=False, download=True, transform=transform)
    else:
        raise ValueError("DATASET must be 'mnist' or 'fashion_mnist'.")

    targets = _get_targets(train_dataset)
    idx_normals = np.where(targets == int(CFG.NORMAL_CLASS))[0].tolist()
    train_subset = Subset(train_dataset, idx_normals)

    test_bin = OneClassDataset(test_dataset, CFG.NORMAL_CLASS)
    pin = torch.cuda.is_available()

    train_loader = DataLoader(train_subset, batch_size=CFG.BATCH_SIZE, shuffle=True,
                              num_workers=CFG.NUM_WORKERS, pin_memory=pin)
    test_loader  = DataLoader(test_bin,    batch_size=CFG.BATCH_SIZE, shuffle=False,
                              num_workers=CFG.NUM_WORKERS, pin_memory=pin)
    return train_loader, test_loader


In [None]:
# @title Model: Convolutional AutoEncoder
class ConvolutionalAutoEncoder(nn.Module):
    """Small, stable conv AE for 1x32x32 images."""
    def __init__(self, latent_dim=128):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, 4, 2, 1), nn.ReLU(True),   # 32x16x16
            nn.Conv2d(32, 64, 4, 2, 1), nn.ReLU(True),  # 64x8x8
            nn.Conv2d(64, 128,4, 2, 1), nn.ReLU(True),  # 128x4x4
        )
        self.encoder_fc = nn.Linear(128*4*4, latent_dim)
        self.decoder_fc = nn.Linear(latent_dim, 128*4*4)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.ReLU(True),
            nn.ConvTranspose2d(64, 32, 4, 2, 1),  nn.ReLU(True),
            nn.ConvTranspose2d(32, 1, 4, 2, 1),   nn.Sigmoid()
        )
        self.apply(self._init)

    def _init(self, m):
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None: nn.init.constant_(m.bias, 0)

    def encode(self, x):
        h = self.encoder(x).view(x.size(0), -1)
        return self.encoder_fc(h)

    def decode(self, z):
        h = self.decoder_fc(z).view(z.size(0), 128, 4, 4)
        return self.decoder(h)

    def forward(self, x):
        z = self.encode(x)
        xhat = self.decode(z)
        return z, xhat


In [None]:
# @title Training: AE Warm-up
def train_autoencoder(model, loader, epochs, lr, weight_decay, device):
    """Warm-up AE to learn structure before SVDD joint training."""
    print("Training AutoEncoder...")
    model.train()
    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    mse = nn.MSELoss()
    losses = []
    for ep in range(1, epochs+1):
        s, n = 0.0, 0
        for x, _ in loader:
            x = x.to(device)
            _, xhat = model(x)
            loss = mse(xhat, x)
            opt.zero_grad(); loss.backward(); opt.step()
            s += loss.item(); n += 1
        losses.append(s / max(n, 1))
        print(f"Epoch [{ep}/{epochs}] AE Loss: {losses[-1]:.6f}")
    return losses


In [None]:
# @title Hypersphere Center
@torch.no_grad()
def compute_center(model, loader, device):
    """Mean latent vector over normal samples; small epsilon to avoid collapse."""
    print("Computing hypersphere center...")
    model.eval()
    Z = []
    for x, _ in loader:
        x = x.to(device)
        z, _ = model(x)
        Z.append(z.detach().cpu())
    c = torch.cat(Z, dim=0).mean(dim=0).to(device)
    eps = 1e-6
    c = torch.where(c.abs() < eps, eps * torch.sign(c + eps), c)
    return c


In [None]:
# @title Training: Joint DSPSVDD
def train_dspsvdd(model, loader, center, epochs, gamma, lr, weight_decay, device):
    """Joint objective: ||z - c||^2 + gamma * MSE(x, x_hat)."""
    print("Training DSPSVDD...")
    model.train()
    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    mse = nn.MSELoss()
    losses, dist_losses, recon_losses = [], [], []
    for ep in range(1, epochs+1):
        s_tot = s_d = s_r = 0.0; n = 0
        for x, _ in loader:
            x = x.to(device)
            z, xhat = model(x)
            dloss = ((z - center) ** 2).sum(dim=1).mean()
            rloss = mse(xhat, x)
            loss = dloss + gamma * rloss
            opt.zero_grad(); loss.backward(); opt.step()
            s_tot += loss.item(); s_d += dloss.item(); s_r += rloss.item(); n += 1
        losses.append(s_tot / n); dist_losses.append(s_d / n); recon_losses.append(s_r / n)
        print(f"Epoch [{ep}/{epochs}] Total: {losses[-1]:.6f} | Dist: {dist_losses[-1]:.6f} | Recon: {recon_losses[-1]:.6f}")
    return losses, dist_losses, recon_losses


In [None]:
# @title valuation & Metrics
from sklearn.metrics import (roc_auc_score, average_precision_score, roc_curve,
                             precision_recall_curve, confusion_matrix, f1_score, matthews_corrcoef)

@torch.no_grad()
def evaluate_model(model, loader, center, device):
    """Return per-sample: distance^2, recon error, and binary label."""
    print("Evaluating model...")
    model.eval()
    D2, MSEs, Y = [], [], []
    for x, y in loader:
        x = x.to(device)
        z, xhat = model(x)
        d2 = ((z - center) ** 2).sum(dim=1).cpu().numpy()
        mse = F.mse_loss(xhat, x, reduction="none").mean(dim=[1,2,3]).cpu().numpy()
        D2.append(d2); MSEs.append(mse); Y.append(y.numpy())
    return np.concatenate(D2), np.concatenate(MSEs), np.concatenate(Y)

def compute_metrics(labels, scores, threshold=None):
    """AUCs, FPR@95%TPR; +Accuracy/F1/MCC if threshold is provided."""
    m = {}
    m["roc_auc"] = float(roc_auc_score(labels, scores))
    m["pr_auc"]  = float(average_precision_score(labels, scores))
    fpr, tpr, _ = roc_curve(labels, scores)
    idx = np.where(tpr >= 0.95)[0]
    m["fpr_at_95_tpr"] = float(fpr[idx[0]]) if len(idx) > 0 else 1.0
    if threshold is not None:
        pred = (scores > threshold).astype(int)
        m["accuracy"] = float((pred == labels).mean())
        m["f1_score"] = float(f1_score(labels, pred, zero_division=0))
        m["mcc"]      = float(matthews_corrcoef(labels, pred))
        tn, fp, fn, tp = confusion_matrix(labels, pred).ravel()
        m["tn"], m["fp"], m["fn"], m["tp"] = int(tn), int(fp), int(fn), int(tp)
    return m


In [None]:
# @title Plotting: ROC/PR/CM/Curves/Recons
def _savefig(path):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    plt.savefig(path, dpi=300, bbox_inches="tight"); plt.close()

def plot_roc_pr(distances, recon_errors, labels, outdir):
    # ROC
    plt.figure(figsize=(6,5))
    fpr_d, tpr_d, _ = roc_curve(labels, distances)
    fpr_r, tpr_r, _ = roc_curve(labels, recon_errors)
    auc_d, auc_r = roc_auc_score(labels, distances), roc_auc_score(labels, recon_errors)
    plt.plot(fpr_d, tpr_d, label=f"Distance (AUC={auc_d:.3f})")
    plt.plot(fpr_r, tpr_r, label=f"Recon (AUC={auc_r:.3f})")
    plt.plot([0,1],[0,1],"k--", alpha=0.4)
    plt.xlabel("FPR"); plt.ylabel("TPR"); plt.title("ROC Curves"); plt.legend(); plt.grid(alpha=0.25)
    _savefig(os.path.join(outdir, "roc.png"))
    # PR
    plt.figure(figsize=(6,5))
    pre_d, rec_d, _ = precision_recall_curve(labels, distances)
    pre_r, rec_r, _ = precision_recall_curve(labels, recon_errors)
    ap_d, ap_r = average_precision_score(labels, distances), average_precision_score(labels, recon_errors)
    plt.plot(rec_d, pre_d, label=f"Distance (AP={ap_d:.3f})")
    plt.plot(rec_r, pre_r, label=f"Recon (AP={ap_r:.3f})")
    plt.xlabel("Recall"); plt.ylabel("Precision"); plt.title("PR Curves"); plt.legend(); plt.grid(alpha=0.25)
    _savefig(os.path.join(outdir, "pr.png"))

def plot_training_curves(ae_losses, joint_losses, dist_losses, recon_losses, outdir):
    plt.figure(figsize=(7,5))
    plt.plot(range(1, len(ae_losses)+1), ae_losses, label="AE Loss")
    xs = range(len(ae_losses)+1, len(ae_losses)+len(joint_losses)+1)
    plt.plot(xs, joint_losses, label="DSPSVDD Total")
    plt.plot(xs, dist_losses,  label="Distance Loss")
    plt.plot(xs, recon_losses, label="Recon Loss")
    plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.title("Training Curves")
    plt.legend(); plt.grid(alpha=0.25)
    _savefig(os.path.join(outdir, "training_curves.png"))

def plot_cm(labels, scores, thr, outdir):
    pred = (scores > thr).astype(int)
    cm = confusion_matrix(labels, pred)
    plt.figure(figsize=(4.5,4))
    plt.imshow(cm, cmap="Blues"); plt.title("Confusion Matrix @ threshold"); plt.colorbar()
    for (i,j),v in np.ndenumerate(cm): plt.text(j, i, str(v), ha="center", va="center")
    plt.xticks([0,1], ["Normal","Anomaly"]); plt.yticks([0,1], ["Normal","Anomaly"])
    plt.xlabel("Predicted"); plt.ylabel("True")
    _savefig(os.path.join(outdir, "cm.png"))

@torch.no_grad()
def save_sample_reconstructions(model, loader, device, outdir, num=8):
    model.eval()
    x, y = next(iter(loader))
    x = x[:num].to(device); y = y[:num]
    _, xhat = model(x)
    plt.figure(figsize=(2*num, 4))
    for i in range(num):
        ax = plt.subplot(2, num, i+1)
        ax.imshow(x[i,0].cpu().numpy(), cmap="gray"); ax.axis("off")
        ax.set_title("Normal" if int(y[i])==0 else "Anomaly")
        ax = plt.subplot(2, num, num+i+1)
        ax.imshow(xhat[i,0].cpu().numpy(), cmap="gray"); ax.axis("off")
        ax.set_title("Recon")
    plt.suptitle("Original (top) vs Reconstruction (bottom)")
    _savefig(os.path.join(outdir, "samples_recon.png"))


## Extra Figures
For Visualizations(histograms, PR/F1 vs threshold, t-SNE, error grids)._


In [None]:
# @title Save Artifacts
def save_artifacts(CFG, ae_losses, joint_losses, dist_losses, recon_losses,
                   t_train, test_d2, test_mse, y_test, train_d2, train_mse, thr_d2, thr_mse,
                   m_dist, m_recon, model, center):
    results = {
        "config": cfg_to_dict(),
        "training": {
            "time_seconds": float(t_train),
            "final_ae_loss": float(ae_losses[-1]),
            "final_joint_loss": float(joint_losses[-1]),
            "final_distance_loss": float(dist_losses[-1]),
            "final_recon_loss": float(recon_losses[-1]),
        },
        "evaluation": {
            "distance_metrics": m_dist,
            "reconstruction_metrics": m_recon,
            "thresholds": {"distance": thr_d2, "reconstruction": thr_mse}
        }
    }
    os.makedirs(CFG.OUTPUT_DIR, exist_ok=True)
    with open(os.path.join(CFG.OUTPUT_DIR, "results.json"), "w") as f:
        json.dump(results, f, indent=2)

    hist_df = pd.DataFrame({
        "epoch": list(range(1, CFG.AE_EPOCHS + CFG.JOINT_EPOCHS + 1)),
        "ae_loss":        ae_losses + [np.nan]*CFG.JOINT_EPOCHS,
        "joint_total":    [np.nan]*CFG.AE_EPOCHS + joint_losses,
        "joint_distance": [np.nan]*CFG.AE_EPOCHS + dist_losses,
        "joint_recon":    [np.nan]*CFG.AE_EPOCHS + recon_losses,
    })
    hist_df.to_csv(os.path.join(CFG.OUTPUT_DIR, "training_history.csv"), index=False)

    ckpt = {
        "model_state_dict": model.state_dict(),
        "center": center.detach().cpu(),
        "threshold_distance": thr_d2,
        "threshold_reconstruction": thr_mse,
        "config": cfg_to_dict()
    }
    torch.save(ckpt, os.path.join(CFG.OUTPUT_DIR, "model.pth"))


In [None]:
# @title Single-Run Driver (save to Drive)
def set_drive_output(normal_cls, dataset="mnist", base_dir="/content/drive/MyDrive/Master/DSPSVDD"):
    """Create a clean output folder: .../<dataset>/nc_<k>_<timestamp>/"""
    CFG.DATASET = dataset
    CFG.NORMAL_CLASS = int(normal_cls)
    stamp = time.strftime("%Y%m%d_%H%M")
    CFG.OUTPUT_DIR = os.path.join(base_dir, dataset, f"nc_{CFG.NORMAL_CLASS}_{stamp}")
    os.makedirs(CFG.OUTPUT_DIR, exist_ok=True)
    print("[Output]", CFG.OUTPUT_DIR)

def main():
    print("DSPSVDD Anomaly Detection")
    print(f"Dataset: {CFG.DATASET} | Normal class: {CFG.NORMAL_CLASS} | Device: {CFG.DEVICE}")
    print("-"*55)
    set_seed(CFG.SEED)
    os.makedirs(CFG.OUTPUT_DIR, exist_ok=True)

    train_loader, test_loader = get_data_loaders()
    print(f"Training samples: {len(train_loader.dataset)} | Test samples: {len(test_loader.dataset)}")
    model = ConvolutionalAutoEncoder(CFG.LATENT_DIM).to(CFG.DEVICE)

    t0 = time.time()
    ae_losses = train_autoencoder(model, train_loader, CFG.AE_EPOCHS, CFG.LR, CFG.WEIGHT_DECAY, CFG.DEVICE)
    center = compute_center(model, train_loader, CFG.DEVICE)
    joint_losses, dist_losses, recon_losses = train_dspsvdd(
        model, train_loader, center, CFG.JOINT_EPOCHS, CFG.GAMMA, CFG.LR_JOINT, CFG.WEIGHT_DECAY, CFG.DEVICE
    )
    t_train = time.time() - t0

    test_d2,  test_mse,  y_test  = evaluate_model(model, test_loader, center, CFG.DEVICE)
    train_d2, train_mse, _y_tr   = evaluate_model(model, train_loader, center, CFG.DEVICE)

    thr_d2  = float(np.quantile(train_d2,  1.0 - CFG.NU))
    thr_mse = float(np.quantile(train_mse, 1.0 - CFG.NU))

    m_dist  = compute_metrics(y_test, test_d2,  thr_d2)
    m_recon = compute_metrics(y_test, test_mse, thr_mse)

    # Figures
    plot_roc_pr(test_d2, test_mse, y_test, CFG.OUTPUT_DIR)
    plot_training_curves(ae_losses, joint_losses, dist_losses, recon_losses, CFG.OUTPUT_DIR)
    plot_cm(y_test, test_d2, thr_d2, CFG.OUTPUT_DIR)
    save_sample_reconstructions(model, test_loader, CFG.DEVICE, CFG.OUTPUT_DIR, num=8)

    # Save artifacts
    save_artifacts(CFG, ae_losses, joint_losses, dist_losses, recon_losses,
                   t_train, test_d2, test_mse, y_test, train_d2, train_mse, thr_d2, thr_mse,
                   m_dist, m_recon, model, center)

    print("\nSaved results to:", CFG.OUTPUT_DIR)
    print("Training completed successfully!")

def run_one(normal_cls, dataset="mnist", base_dir="/content/drive/MyDrive/Master/DSPSVDD"):
    set_drive_output(normal_cls, dataset, base_dir)
    main()
    return CFG.OUTPUT_DIR


In [None]:
# @title Multi-Run Driver (MNIST 1→9)
def run_all_mnist_1_to_9(base_dir="/content/drive/MyDrive/Master/DSPSVDD"): # u can change this path independe for ur situation
    folders = []
    for k in range(1, 10):
        print(f"\n=== Running normal class = {k} ===")
        folders.append(run_one(k, dataset="mnist", base_dir=base_dir))
    print("\nFinished. Folders:")
    for d in folders: print(" -", d)
    return folders



In [None]:
# @title Aggregation to Summary CSV
import glob

def aggregate_runs(dataset='mnist', base_dir="/content/drive/MyDrive/Master/DSPSVDD"): # mnist or fushion_mnist
    root = os.path.join(base_dir, dataset)
    pattern_flat   = glob.glob(os.path.join(root, "nc_*_*"))
    pattern_nested = glob.glob(os.path.join(root, "nc_*", "*"))
    run_dirs = sorted([p for p in (pattern_flat + pattern_nested) if os.path.isdir(p)])
    if not run_dirs:
        print("No runs found under:", root)
        return None

    rows = []
    for rd in run_dirs:
        rjson = os.path.join(rd, "results.json")
        if not os.path.exists(rjson):
            continue
        parts = os.path.basename(rd).split('_')
        nc = None
        if len(parts) >= 2 and parts[0] == "nc":
            try: nc = int(parts[1])
            except: pass
        if nc is None:
            parent = os.path.basename(os.path.dirname(rd))
            if parent.startswith("nc_"):
                try: nc = int(parent.split('_')[1])
                except: pass
        with open(rjson, "r") as f:
            R = json.load(f)
        dist = R["evaluation"]["distance_metrics"]
        rec  = R["evaluation"]["reconstruction_metrics"]
        thr  = R["evaluation"]["thresholds"]
        trn  = R["training"]
        cfg  = R["config"]

        rows.append({
            "dataset": dataset, "normal_class": nc,
            "ae_epochs": cfg.get("AE_EPOCHS"), "joint_epochs": cfg.get("JOINT_EPOCHS"),
            "roc_auc_distance": dist.get("roc_auc"), "pr_auc_distance": dist.get("pr_auc"),
            "fpr_at_95_distance": dist.get("fpr_at_95_tpr"),
            "accuracy_distance": dist.get("accuracy"), "f1_distance": dist.get("f1_score"),
            "mcc_distance": dist.get("mcc"),
            "roc_auc_recon": rec.get("roc_auc"), "pr_auc_recon": rec.get("pr_auc"),
            "thr_distance": thr.get("distance"), "thr_reconstruction": thr.get("reconstruction"),
            "train_time_s": trn.get("time_seconds"),
            "final_ae_loss": trn.get("final_ae_loss"),
            "final_joint_loss": trn.get("final_joint_loss"),
            "output_dir": rd
        })

    df = pd.DataFrame(rows).sort_values(["normal_class", "output_dir"])
    stamp = time.strftime("%Y%m%d_%H%M")
    out_csv = os.path.join(root, f"summary_{dataset}_{stamp}.csv")
    df.to_csv(out_csv, index=False)
    print("Summary saved:", out_csv)
    try:
        from caas_jupyter_tools import display_dataframe_to_user
        display_dataframe_to_user(f"DSPSVDD summary ({dataset})", df)
    except Exception:
        print(df.head())
    return out_csv


In [None]:
# @title Re-load a Run & Re-make Figures (This step is optinal)
def load_latest_run_for_class(normal_cls, base_dir="/content/drive/MyDrive/Master/DSPSVDD/mnist"):
    """Reload newest nc_k_* folder, rebuild loaders, reload checkpoint, recompute vectors."""
    import glob
    CFG.NORMAL_CLASS = int(normal_cls)
    matches = sorted(glob.glob(os.path.join(base_dir, f"nc_{CFG.NORMAL_CLASS}_*")), key=os.path.getmtime)
    assert matches, f"No runs found for class {CFG.NORMAL_CLASS}"
    CFG.OUTPUT_DIR = matches[-1]

    # rebuild loaders
    global train_loader, test_loader
    train_loader, test_loader = get_data_loaders()

    # reload ckpt
    ckpt = torch.load(os.path.join(CFG.OUTPUT_DIR, "model.pth"), map_location=CFG.DEVICE)
    global model, center, test_d2, test_mse, y_test, train_d2, train_mse, thr_d2, thr_mse
    model = ConvolutionalAutoEncoder(CFG.LATENT_DIM).to(CFG.DEVICE)
    model.load_state_dict(ckpt["model_state_dict"]); model.eval()
    center = ckpt["center"].to(CFG.DEVICE)

    # recompute
    test_d2,  test_mse,  y_test  = evaluate_model(model, test_loader, center, CFG.DEVICE)
    train_d2, train_mse, _y_tr   = evaluate_model(model, train_loader, center, CFG.DEVICE)
    thr_d2  = float(np.quantile(train_d2,  1.0 - CFG.NU))
    thr_mse = float(np.quantile(train_mse, 1.0 - CFG.NU))
    print(f"[OK] Reloaded class {CFG.NORMAL_CLASS} from:", CFG.OUTPUT_DIR)

## Notes & Next Steps
- Keep `NUM_WORKERS` modest (2–4) on Colab to avoid dataloader hangs.  
- Always rebuild loaders + reload checkpoint when switching classes to avoid stale globals.  
- For faster debugging, lower AE/Joint epochs; restore defaults for final results.

**Next Steps**
- Add fused scoring (distance + scaled reconstruction).  
- Evaluate across all digits (1→9) and Fashion-MNIST; report mean/median AUC.  
- Add per-anomaly-digit ROC analysis for MNIST to discuss difficulty differences.
