# Summarize ablation results (small models)

Loads `history.json` from each run of `run_ablation_small_models.sh` (dirs `checkpoints_tcn_ddp_original/ablation_L*_H*/`) and builds a summary table and plots.

**Requires:** Ablation runs completed; checkpoint dirs named `ablation_L{levels}_H{nhid}`.

In [None]:
import json
import re
from pathlib import Path

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

# Project root and checkpoint base
ROOT = Path.cwd()
CKPT_BASE = ROOT / "checkpoints_tcn_ddp_original"

# Pattern for ablation dirs: ablation_L4_H80, ablation_L3_H40, etc.
ABLATION_PATTERN = re.compile(r"^ablation_L(\d+)_H(\d+)$")

def find_ablation_dirs():
    """Return list of (dir_path, levels, nhid) for each ablation run."""
    if not CKPT_BASE.exists():
        return []
    out = []
    for d in CKPT_BASE.iterdir():
        if not d.is_dir():
            continue
        m = ABLATION_PATTERN.match(d.name)
        if m and (d / "history.json").exists():
            out.append((d, int(m.group(1)), int(m.group(2))))
    return sorted(out, key=lambda x: (x[1], x[2]))

dirs = find_ablation_dirs()
print(f"Found {len(dirs)} ablation run(s) under {CKPT_BASE}")
for d, L, H in dirs:
    print(f"  {d.name}")

In [None]:
# Param count (same formula as ablation_model_sizes.py)
def count_params(levels: int, nhid: int, input_channels: int = 160, kernel_size: int = 15) -> int:
    k, in_ch = kernel_size, input_channels
    n0 = 2 * in_ch + (in_ch * nhid * k + nhid) + 2 * nhid + (nhid * nhid * k + nhid)
    if in_ch != nhid:
        n0 += in_ch * nhid * 1 + nhid
    n_block = 2 * nhid + (nhid * nhid * k + nhid) + 2 * nhid + (nhid * nhid * k + nhid)
    n_lin = nhid * 1 + 1
    return n0 + (levels - 1) * n_block + n_lin

def load_summary(ckpt_dir: Path, levels: int, nhid: int) -> dict:
    with open(ckpt_dir / "history.json") as f:
        hist = json.load(f)
    n_epochs = len(hist.get("val_f1", []))
    if n_epochs == 0:
        return None
    val_f1 = hist["val_f1"]
    best_idx = int(np.argmax(val_f1))
    return {
        "levels": levels,
        "nhid": nhid,
        "n_params": count_params(levels, nhid),
        "best_val_f1": float(val_f1[best_idx]),
        "best_epoch": best_idx + 1,
        "n_epochs": n_epochs,
        "final_train_loss": float(hist["train_loss"][-1]),
        "final_val_loss": float(hist["val_loss"][-1]),
        "final_val_f1": float(val_f1[-1]),
        "final_val_acc": float(hist["val_acc"][-1]),
        "dir": str(ckpt_dir.name),
    }

rows = []
for ckpt_dir, L, H in dirs:
    row = load_summary(ckpt_dir, L, H)
    if row:
        rows.append(row)

df = pd.DataFrame(rows)
if df.empty:
    print("No history found. Run run_ablation_small_models.sh first.")
else:
    df = df.sort_values("n_params", ascending=False).reset_index(drop=True)
    display(df)

## Summary table (by param count)

Key columns: `n_params`, `best_val_f1`, `best_epoch`, `final_*`.

In [None]:
if not df.empty:
    print("Best val F1 per config:")
    print(df[["levels", "nhid", "n_params", "best_val_f1", "best_epoch"]].to_string(index=False))
    print()
    best_overall = df.loc[df["best_val_f1"].idxmax()]
    print(f"Best overall val F1: {best_overall['best_val_f1']:.4f} (L={best_overall['levels']}, H={best_overall['nhid']}, {best_overall['n_params']:,} params)")

## Plots

In [None]:
if df.empty:
    pass
else:
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))

    # Val F1 vs n_params (log scale)
    ax = axes[0]
    ax.scatter(df["n_params"], df["best_val_f1"], s=80, alpha=0.8)
    for _, r in df.iterrows():
        ax.annotate(f"L{r['levels']}H{r['nhid']}", (r["n_params"], r["best_val_f1"]),
                    textcoords="offset points", xytext=(0, 6), ha="center", fontsize=8)
    ax.set_xscale("log")
    ax.set_xlabel("Number of parameters")
    ax.set_ylabel("Best val F1")
    ax.set_title("Best validation F1 vs model size")
    ax.grid(True, alpha=0.3)

    # Bar: best_val_f1 per config
    ax = axes[1]
    x_labels = [f"L{int(r['levels'])}H{int(r['nhid'])}" for _, r in df.iterrows()]
    ax.bar(range(len(df)), df["best_val_f1"], color="steelblue", alpha=0.8)
    ax.set_xticks(range(len(df)))
    ax.set_xticklabels(x_labels, rotation=45, ha="right")
    ax.set_ylabel("Best val F1")
    ax.set_title("Best validation F1 by config")
    ax.grid(True, axis="y", alpha=0.3)

    plt.tight_layout()
    plt.show()

In [None]:
# Learning curves: val_f1 vs epoch for each config (optional)
if not df.empty and len(dirs) <= 12:
    fig, ax = plt.subplots(figsize=(10, 5))
    for ckpt_dir, L, H in dirs:
        with open(ckpt_dir / "history.json") as f:
            hist = json.load(f)
        epochs = range(1, len(hist["val_f1"]) + 1)
        ax.plot(epochs, hist["val_f1"], label=f"L{L} H{H}", alpha=0.8)
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Val F1")
    ax.set_title("Validation F1 over training")
    ax.legend(loc="lower right", fontsize=8)
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()