<a href="https://colab.research.google.com/github/mathusalini/amc-multimodal-ablation/blob/main/ablation_study_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Ablation Study: Multi-Modal Fusion for Automatic Modulation Classification

## Experimental Design

| Config        | IQ | Constellation | Spectrogram |
|---------------|----|---------------|-------------|
| IQ only       | ✓  |               |             |
| Const only    |    | ✓             |             |
| Spec only     |    |               | ✓           |
| IQ + Const    | ✓  | ✓             |             |
| IQ + Spec     | ✓  |               | ✓           |
| Const + Spec  |    | ✓             | ✓           |
| Full Fusion   | ✓  | ✓             | ✓           |

Each configuration is trained **3 times with different random seeds**.  
Results are reported as **mean ± std** overall accuracy.  
Learning curves, saved model checkpoints, and per-SNR analysis are produced for every configuration.

## 1. Install & Import Dependencies

In [2]:
!pip install torch torchvision numpy matplotlib scikit-learn tqdm scipy -q

import os, time, random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from scipy.signal import spectrogram as scipy_spectrogram
from scipy.ndimage import zoom
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import itertools

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# ── Directory for saved models ────────────────────────────────────────────────
SAVE_DIR = "saved_models"
os.makedirs(SAVE_DIR, exist_ok=True)
print(f"Models will be saved to: {os.path.abspath(SAVE_DIR)}/")

Using device: cpu
Models will be saved to: /content/saved_models/


## 2. Signal Generation & Dataset

In [3]:
def add_awgn(signal, snr_db):
    snr_lin = 10 ** (snr_db / 10.0)
    pwr     = np.mean(np.abs(signal) ** 2)
    noise   = np.sqrt(pwr / snr_lin / 2) * (
        np.random.randn(*signal.shape) + 1j * np.random.randn(*signal.shape))
    return signal + noise

def gen_bpsk(n):   return np.random.choice([-1, 1], n).astype(complex)
def gen_qpsk(n):   r=np.random.choice([-1,1],n); i=np.random.choice([-1,1],n); return (r+1j*i)/np.sqrt(2)
def gen_8psk(n):   return np.exp(1j * np.random.randint(0,8,n) * np.pi/4)
def gen_16qam(n):  v=np.array([-3,-1,1,3]); return (np.random.choice(v,n)+1j*np.random.choice(v,n))/np.sqrt(10)
def gen_64qam(n):  v=np.array([-7,-5,-3,-1,1,3,5,7]); return (np.random.choice(v,n)+1j*np.random.choice(v,n))/np.sqrt(42)

MODULATIONS = {"BPSK": gen_bpsk, "QPSK": gen_qpsk, "8PSK": gen_8psk,
               "16QAM": gen_16qam, "64QAM": gen_64qam}
MOD_NAMES = list(MODULATIONS.keys())

def generate_dataset(samples_per_mod_snr=500, num_symbols=128,
                     snrs=list(range(-20, 22, 2))):
    X, y, snr_arr = [], [], []
    for label, (name, func) in enumerate(MODULATIONS.items()):
        for snr in snrs:
            for _ in range(samples_per_mod_snr):
                noisy = add_awgn(func(num_symbols), snr)
                X.append(np.stack([np.real(noisy), np.imag(noisy)], axis=0))
                y.append(label); snr_arr.append(snr)
    return (np.array(X, dtype=np.float32),
            np.array(y, dtype=np.int64),
            np.array(snr_arr))

SAMPLES_PER_MOD_SNR = 500
NUM_SYMBOLS         = 128
SNRs                = list(range(-20, 22, 2))

print("Generating dataset …")
X, y, snr_vals = generate_dataset(SAMPLES_PER_MOD_SNR, NUM_SYMBOLS, SNRs)
print(f"  Shape  : {X.shape}")
print(f"  Total  : {len(X):,} samples")
print(f"  SNR    : {SNRs[0]} → {SNRs[-1]} dB")

Generating dataset …
  Shape  : (52500, 2, 128)
  Total  : 52,500 samples
  SNR    : -20 → 20 dB


## 3. Image Generation & Dataset Class

In [4]:
IMG_SIZE = 64

def iq_to_constellation(iq, img_size=IMG_SIZE):
    i, q = iq[0].numpy(), iq[1].numpy()
    hist, _, _ = np.histogram2d(i, q, bins=img_size, range=[[-2,2],[-2,2]])
    hist = hist.astype(np.float32)
    if hist.max() > 0: hist /= hist.max()
    return torch.from_numpy(np.stack([hist, hist, hist]))

def iq_to_spectrogram(iq, img_size=IMG_SIZE, nperseg=64, noverlap=32):
    cx = iq[0].numpy() + 1j * iq[1].numpy()
    _, _, Sxx = scipy_spectrogram(cx, fs=1.0, nperseg=nperseg,
                                  noverlap=noverlap, mode="magnitude")
    Sxx_db   = 10 * np.log10(np.abs(Sxx) + 1e-10)
    Sxx_norm = (Sxx_db - Sxx_db.min()) / (Sxx_db.max() - Sxx_db.min() + 1e-8)
    zf       = (img_size / Sxx_norm.shape[0], img_size / Sxx_norm.shape[1])
    resized  = zoom(Sxx_norm.astype(np.float32), zf, order=1)
    return torch.from_numpy(np.stack([resized, resized, resized]))


class MultiModalDataset(Dataset):
    def __init__(self, iq_tensors, labels):
        self.iq     = iq_tensors
        self.labels = torch.as_tensor(labels, dtype=torch.long)
        print("  Pre-computing images …")
        consts, specs = [], []
        for i in tqdm(range(len(iq_tensors))):
            iq = iq_tensors[i]
            consts.append(iq_to_constellation(iq))
            specs.append(iq_to_spectrogram(iq))
        self.const_imgs = torch.stack(consts)
        self.spec_imgs  = torch.stack(specs)

    def __len__(self): return len(self.iq)

    def __getitem__(self, idx):
        return (self.iq[idx], self.const_imgs[idx],
                self.spec_imgs[idx], self.labels[idx])


In [5]:

X_train, X_test, y_train, y_test, snr_train, snr_test = train_test_split(
    X, y, snr_vals, test_size=0.3, random_state=42, stratify=y)

X_train_t = torch.tensor(X_train, dtype=torch.float32)
X_test_t  = torch.tensor(X_test,  dtype=torch.float32)

print("Building train dataset …")
train_ds = MultiModalDataset(X_train_t, y_train)
print("\nBuilding test dataset …")
test_ds  = MultiModalDataset(X_test_t,  y_test)
print(f"\nTrain : {len(train_ds):,}  |  Test : {len(test_ds):,}")


Building train dataset …
  Pre-computing images …


  0%|          | 0/36750 [00:00<?, ?it/s]

  _, _, Sxx = scipy_spectrogram(cx, fs=1.0, nperseg=nperseg,



Building test dataset …
  Pre-computing images …


  0%|          | 0/15750 [00:00<?, ?it/s]


Train : 36,750  |  Test : 15,750


## 4. Model Architecture

In [6]:
class IQEncoder(nn.Module):
    def __init__(self, out_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(2,   64,  3, padding=1), nn.BatchNorm1d(64),  nn.ReLU(), nn.MaxPool1d(2),
            nn.Conv1d(64,  128, 3, padding=1), nn.BatchNorm1d(128), nn.ReLU(), nn.MaxPool1d(2),
            nn.Conv1d(128, 256, 3, padding=1), nn.BatchNorm1d(256), nn.ReLU(), nn.MaxPool1d(2),
            nn.AdaptiveAvgPool1d(1))
        self.fc = nn.Linear(256, out_dim)
    def forward(self, x): return self.fc(self.net(x).squeeze(-1))


class ImageEncoder(nn.Module):
    def __init__(self, out_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3,  32,  3, padding=1), nn.BatchNorm2d(32),  nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64,  3, padding=1), nn.BatchNorm2d(64),  nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d(2),
            nn.AdaptiveAvgPool2d(1))
        self.fc = nn.Linear(128, out_dim)
    def forward(self, x): return self.fc(self.net(x).squeeze(-1).squeeze(-1))


class FusionModel(nn.Module):
    """
    Flexible fusion model.
    Active branches are controlled by use_iq / use_const / use_spec.
    The fusion head input dimension adjusts automatically.
    """
    def __init__(self, num_classes=5, feat_dim=128,
                 use_iq=True, use_const=True, use_spec=True):
        super().__init__()
        assert any([use_iq, use_const, use_spec]), "At least one branch required."
        self.use_iq = use_iq; self.use_const = use_const; self.use_spec = use_spec

        if use_iq:    self.iq_enc    = IQEncoder(out_dim=feat_dim)
        if use_const: self.const_enc = ImageEncoder(out_dim=feat_dim)
        if use_spec:  self.spec_enc  = ImageEncoder(out_dim=feat_dim)

        n = sum([use_iq, use_const, use_spec])
        self.fusion = nn.Sequential(
            nn.Linear(feat_dim * n, 256), nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(256, num_classes))

    def forward(self, iq, const, spec):
        feats = []
        if self.use_iq:    feats.append(self.iq_enc(iq))
        if self.use_const: feats.append(self.const_enc(const))
        if self.use_spec:  feats.append(self.spec_enc(spec))
        return self.fusion(torch.cat(feats, dim=1))

##5. Training & Evaluation Helpers

In [7]:
NUM_EPOCHS  = 100
BATCH_SIZE  = 256
NUM_RUNS    = 3
SEEDS       = list(range(NUM_RUNS))
LR          = 1e-3

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=0)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

# Master results store — populated by each config cell below
results = {}


def set_seed(seed):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)


def train_one_run(model, train_loader, test_loader, num_epochs, lr, snr_array):
    """
    Train model for num_epochs.

    Returns
    -------
    train_losses : list[float]   – per-epoch avg training loss
    val_accs     : list[float]   – per-epoch test accuracy
    best_state   : dict          – state_dict at best validation accuracy
    snr_acc      : dict          – {snr: acc} at the best epoch
    """
    opt  = optim.Adam(model.parameters(), lr=lr)
    sched = optim.lr_scheduler.ReduceLROnPlateau(
        opt, mode="max", patience=7, factor=0.5, verbose=False)
    crit = nn.CrossEntropyLoss()

    train_losses, val_accs = [], []
    best_acc, best_state, best_snr_acc = 0.0, None, {}

    for epoch in range(num_epochs):
        # ── Train ──────────────────────────────────────────────────────────
        model.train()
        total_loss = 0
        for iq, const, spec, by in train_loader:
            iq, const, spec, by = (iq.to(device), const.to(device),
                                   spec.to(device), by.to(device))
            opt.zero_grad()
            loss = crit(model(iq, const, spec), by)
            loss.backward(); opt.step()
            total_loss += loss.item()
        avg_loss = total_loss / len(train_loader)
        train_losses.append(avg_loss)

        # ── Validate ───────────────────────────────────────────────────────
        model.eval()
        preds, labels = [], []
        with torch.no_grad():
            for iq, const, spec, by in test_loader:
                iq, const, spec = iq.to(device), const.to(device), spec.to(device)
                preds.extend(model(iq, const, spec).argmax(1).cpu().numpy())
                labels.extend(by.numpy())
        preds, labels = np.array(preds), np.array(labels)
        acc = accuracy_score(labels, preds)
        val_accs.append(acc)
        sched.step(acc)

        # ── Track best ─────────────────────────────────────────────────────
        if acc > best_acc:
            best_acc   = acc
            best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
            best_snr_acc = {}
            for snr in np.unique(snr_array):
                idx = np.where(snr_array == snr)[0]
                best_snr_acc[snr] = accuracy_score(labels[idx], preds[idx])

    return train_losses, val_accs, best_state, best_snr_acc

def plot_learning_curves(cfg_name, all_train_losses, all_val_accs):
    """Plot loss + accuracy learning curves for all runs of one configuration."""
    fig, axes = plt.subplots(1, 2, figsize=(13, 4))
    palette = ["#e15759", "#4e79a7", "#59a14f"]

    for run_i, (losses, accs) in enumerate(zip(all_train_losses, all_val_accs)):
        epochs = range(1, len(losses) + 1)
        axes[0].plot(epochs, losses, color=palette[run_i],
                     label=f"Seed {run_i}", linewidth=1.5)
        axes[1].plot(epochs, [a * 100 for a in accs], color=palette[run_i],
                     label=f"Seed {run_i}", linewidth=1.5)

    for ax in axes:
        ax.set_xlabel("Epoch"); ax.legend(fontsize=8.5); ax.grid(True, alpha=0.4)
    axes[0].set_ylabel("Cross-Entropy Loss")
    axes[0].set_title(f"{cfg_name} — Training Loss")
    axes[1].set_ylabel("Test Accuracy (%)")
    axes[1].set_title(f"{cfg_name} — Validation Accuracy")

    plt.suptitle(f"Learning Curves: {cfg_name}", fontsize=12, y=1.01)
    plt.tight_layout(); plt.show()


def run_config(cfg_name, use_iq, use_const, use_spec):
    """
    Train NUM_RUNS seeds for a single configuration.
    Saves the best checkpoint of each run.
    Populates results[cfg_name].
    """
    print(f"\n{'='*62}")
    print(f"  Config : {cfg_name}")
    print(f"  Active : IQ={use_iq}  Const={use_const}  Spec={use_spec}")
    print(f"{'='*62}")

    all_train_losses, all_val_accs, all_best_accs, all_snr_accs = [], [], [], []

    for seed in SEEDS:
        print(f"\n  ── Run {seed+1}/{NUM_RUNS}  (seed={seed}) ──")
        set_seed(seed)
        model = FusionModel(num_classes=len(MOD_NAMES),
                            use_iq=use_iq, use_const=use_const, use_spec=use_spec
                            ).to(device)

        t0 = time.time()
        train_losses, val_accs, best_state, snr_acc = train_one_run(
            model, train_loader, test_loader, NUM_EPOCHS, LR, snr_test)
        elapsed = time.time() - t0

        best_acc = max(val_accs)
        all_train_losses.append(train_losses)
        all_val_accs.append(val_accs)
        all_best_accs.append(best_acc)
        all_snr_accs.append(snr_acc)

        # ── Save checkpoint ────────────────────────────────────────────────
        safe_name = cfg_name.lower().replace(" ", "_").replace("+", "plus")
        ckpt_path = os.path.join(SAVE_DIR, f"{safe_name}_seed{seed}.pt")
        torch.save({
            "config":      cfg_name,
            "use_iq":      use_iq,
            "use_const":   use_const,
            "use_spec":    use_spec,
            "seed":        seed,
            "best_acc":    best_acc,
            "state_dict":  best_state,
            "train_losses": train_losses,
            "val_accs":    val_accs,
            "snr_acc":     snr_acc,
        }, ckpt_path)

        print(f"     Best acc : {best_acc*100:.2f}%  |  Time: {elapsed:.0f}s")
        print(f"     Saved  → {ckpt_path}")

    # ── Learning curves for this config ────────────────────────────────────
    plot_learning_curves(cfg_name, all_train_losses, all_val_accs)

    # ── Store in global results dict ───────────────────────────────────────
    results[cfg_name] = {
        "accs":      all_best_accs,
        "snr_accs":  all_snr_accs,
    }

    mean_acc, std_acc = np.mean(all_best_accs) * 100, np.std(all_best_accs) * 100
    print(f"\n  Summary  →  {mean_acc:.2f}% ± {std_acc:.2f}%")
    return results[cfg_name]



##6. Run Each Configuration
Each cell below is fully self-contained — you can re-run any single configuration independently without re-running the others.

#6.1 — IQ only
Active branches: IQ

In [8]:
run_config("IQ only", use_iq=True, use_const=False, use_spec=False)


  Config : IQ only
  Active : IQ=True  Const=False  Spec=False

  ── Run 1/3  (seed=0) ──


TypeError: ReduceLROnPlateau.__init__() got an unexpected keyword argument 'verbose'

##6.2 — Const only
Active branches: Constellation

In [None]:
run_config("Const only", use_iq=False, use_const=True, use_spec=False)

##6.3 — Spec only
Active branches: Spectrogram

In [None]:
run_config("Spec only", use_iq=False, use_const=False, use_spec=True)

##6.4 — IQ + Const
Active branches: IQ + Constellation

In [None]:
run_config("IQ + Const", use_iq=True, use_const=True, use_spec=False)

##6.5 — IQ + Spec
Active branches: IQ + Spectrogram

In [None]:
run_config("IQ + Spec", use_iq=True, use_const=False, use_spec=True)

##6.6 — Const + Spec
Active branches: Constellation + Spectrogram

In [None]:
run_config("Const + Spec", use_iq=False, use_const=True, use_spec=True)

##6.7 — Full Fusion
Active branches: IQ + Constellation + Spectrogram

In [None]:
run_config("Full Fusion", use_iq=True, use_const=True, use_spec=True)

##7. Aggregate Results
Run this section only after all 7 configuration cells above have completed.

In [None]:
# ── Summary table ─────────────────────────────────────────────────────────────
cfg_names = [c[0] for c in [
    ("IQ only",      True,  False, False),
    ("Const only",   False, True,  False),
    ("Spec only",    False, False, True),
    ("IQ + Const",   True,  True,  False),
    ("IQ + Spec",    True,  False, True),
    ("Const + Spec", False, True,  True),
    ("Full Fusion",  True,  True,  True),
]]

print(f"{'Configuration':<18}  {'Mean Acc':>9}  {'Std':>7}  {'Min':>7}  {'Max':>7}")
print("-" * 55)
summary = {}
for name in cfg_names:
    if name not in results:
        print(f"{name:<18}  (not run yet)"); continue
    accs = results[name]["accs"]
    mean, std = np.mean(accs) * 100, np.std(accs) * 100
    summary[name] = (mean, std)
    print(f"{name:<18}  {mean:>8.2f}%  {std:>6.2f}%  "
          f"{min(accs)*100:>6.2f}%  {max(accs)*100:>6.2f}%")

In [None]:
# ── Bar chart ─────────────────────────────────────────────────────────────────
names  = [n for n in cfg_names if n in summary]
means  = [summary[n][0] for n in names]
stds   = [summary[n][1] for n in names]
colours = ["#7f7f7f"]*3 + ["#4e8ecb"]*3 + ["#e8a838"]

fig, ax = plt.subplots(figsize=(11, 5))
bars = ax.bar(names, means, yerr=stds, capsize=5,
              color=colours, edgecolor="black", linewidth=0.6,
              width=0.55, error_kw=dict(elinewidth=1.5, ecolor="black"))

for bar, m, s in zip(bars, means, stds):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + s + 0.3,
            f"{m:.1f}%", ha="center", va="bottom", fontsize=8.5, fontweight="bold")

ax.set_ylabel("Test Accuracy (%)", fontsize=11)
ax.set_title(f"Ablation Study — Mean ± Std Accuracy ({NUM_RUNS} runs each)", fontsize=12)
ax.set_ylim(0, min(105, max(means) + max(stds) + 8))
ax.grid(axis="y", linestyle="--", alpha=0.5)
ax.set_xticklabels(names, rotation=15, ha="right")

from matplotlib.patches import Patch
ax.legend(handles=[
    Patch(facecolor="#7f7f7f", edgecolor="black", label="Single branch"),
    Patch(facecolor="#4e8ecb", edgecolor="black", label="Two branches"),
    Patch(facecolor="#e8a838", edgecolor="black", label="Full fusion"),
], loc="lower right")

plt.tight_layout(); plt.show()

In [None]:
# ── Per-SNR accuracy curves ────────────────────────────────────────────────────
snrs_unique = sorted(np.unique(snr_test))
palette     = plt.cm.tab10(np.linspace(0, 1, len(cfg_names)))

fig, ax = plt.subplots(figsize=(11, 6))
for name, colour in zip(cfg_names, palette):
    if name not in results: continue
    runs      = results[name]["snr_accs"]
    mean_c    = np.array([np.mean([r[s] for r in runs]) for s in snrs_unique])
    std_c     = np.array([np.std( [r[s] for r in runs]) for s in snrs_unique])
    ls        = "-" if "IQ" in name else "--"
    ax.plot(snrs_unique, mean_c*100, label=name, color=colour,
            linewidth=2, linestyle=ls, marker="o", markersize=4)
    ax.fill_between(snrs_unique, (mean_c-std_c)*100, (mean_c+std_c)*100,
                    alpha=0.12, color=colour)

ax.axhline(20, color="black", linestyle=":", linewidth=1, alpha=0.5, label="Random (20%)")
ax.set_xlabel("SNR (dB)", fontsize=11); ax.set_ylabel("Accuracy (%)", fontsize=11)
ax.set_title(f"Per-SNR Accuracy — All Configurations\n(shaded = ±1 std, {NUM_RUNS} runs)",
             fontsize=12)
ax.legend(loc="upper left", fontsize=8.5, ncol=2)
ax.grid(True, linestyle="--", alpha=0.4)
ax.set_xticks(snrs_unique)
plt.tight_layout(); plt.show()

In [None]:
# ── Heat-map ──────────────────────────────────────────────────────────────────
snrs_unique  = sorted(np.unique(snr_test))
active_names = [n for n in cfg_names if n in results]

heatmap = np.zeros((len(active_names), len(snrs_unique)))
for ci, name in enumerate(active_names):
    for si, snr in enumerate(snrs_unique):
        heatmap[ci, si] = np.mean([r[snr] for r in results[name]["snr_accs"]])

fig, ax = plt.subplots(figsize=(14, 4))
im = ax.imshow(heatmap * 100, aspect="auto", cmap="RdYlGn", vmin=0, vmax=100)
fig.colorbar(im, ax=ax, pad=0.02).set_label("Accuracy (%)")
ax.set_xticks(range(len(snrs_unique))); ax.set_xticklabels(snrs_unique, fontsize=8)
ax.set_yticks(range(len(active_names))); ax.set_yticklabels(active_names, fontsize=9)
ax.set_xlabel("SNR (dB)"); ax.set_title("Mean Accuracy Heat-map per Configuration & SNR")

for ci in range(len(active_names)):
    for si in range(len(snrs_unique)):
        v = heatmap[ci, si] * 100
        ax.text(si, ci, f"{v:.0f}", ha="center", va="center",
                fontsize=6.5, color="black" if 20 < v < 80 else "white")
plt.tight_layout(); plt.show()

##8. Statistical Significance Testing

In [None]:
from scipy.stats import ttest_ind

baseline = "Full Fusion"
if baseline in results:
    base_accs = results[baseline]["accs"]
    print(f"Pairwise t-tests vs '{baseline}'  (two-sided, α = 0.05)\n")
    print(f"{'Config':<18}  {'Δ Acc':>8}  {'p-value':>10}  {'Sig?':>6}")
    print("-" * 48)
    for name in cfg_names:
        if name == baseline or name not in results: continue
        _, p     = ttest_ind(base_accs, results[name]["accs"])
        delta    = (np.mean(base_accs) - np.mean(results[name]["accs"])) * 100
        print(f"{name:<18}  {delta:>+7.2f}%  {p:>10.4f}  {'YES ✓' if p < 0.05 else 'no':>6}")
    print("\nNote: increase NUM_RUNS to ≥5 for publication-quality significance.")
else:
    print("Full Fusion results not available yet — run all config cells first.")

##9. Load Saved Models
Use the cell below to reload any checkpoint — useful for resuming analysis or running inference without retraining.

In [None]:
def load_checkpoint(cfg_name, seed):
    """
    Reload a saved checkpoint and return a ready-to-use model.

    Parameters
    ----------
    cfg_name : str   e.g. "Full Fusion"
    seed     : int   0, 1, or 2

    Returns
    -------
    model : FusionModel   loaded on `device`, set to eval mode
    ckpt  : dict          full checkpoint dict (includes curves, snr_acc, etc.)
    """
    safe_name = cfg_name.lower().replace(" ", "_").replace("+", "plus")
    path      = os.path.join(SAVE_DIR, f"{safe_name}_seed{seed}.pt")
    ckpt      = torch.load(path, map_location=device)

    model = FusionModel(
        num_classes=len(MOD_NAMES),
        use_iq=ckpt["use_iq"],
        use_const=ckpt["use_const"],
        use_spec=ckpt["use_spec"],
    ).to(device)
    model.load_state_dict(ckpt["state_dict"])
    model.eval()

    print(f"Loaded  : {path}")
    print(f"Config  : {ckpt['config']}  |  Seed: {ckpt['seed']}")
    print(f"Best acc: {ckpt['best_acc']*100:.2f}%")
    return model, ckpt


# ── Example usage ─────────────────────────────────────────────────────────────
# model, ckpt = load_checkpoint("Full Fusion", seed=0)
#
# Replay the learning curves from the saved data:
# plot_learning_curves(ckpt["config"], [ckpt["train_losses"]], [ckpt["val_accs"]])

print("load_checkpoint() is ready.  Uncomment the example lines above to use it.")
print("\nSaved checkpoints:")
for f in sorted(os.listdir(SAVE_DIR)):
    path = os.path.join(SAVE_DIR, f)
    print(f"  {f}  ({os.path.getsize(path)/1024:.0f} KB)")