
# ML-JET — Loss-Weight Sweep Aggregator (Notebook)

This notebook scans your sweep outputs under:

```
experiments/exp_loss_weight_sweep/training_output/
```

It collects per-run metrics from `training_summary.json`, builds a table with the columns needed for your paper, and then reports the **best run per model family** (ConvNeXt, EfficientNet, ViT, Swin, Mamba).

**Columns generated:**

- **Model (Initialization)** — inferred from `model_tag` / folder name
- **Batch** — `batch_size`
- **LR** — `learning_rate`
- **Total_Acc** — `best_accuracy`
- **Energy (Acc, Prec, Rec, F1)**
- **αs (Acc, Prec, Rec, F1)**
- **Q0 (Acc, Prec, Rec, F1)**

You can customize the ranking logic if needed.


In [None]:

# --- CONFIG ---
ROOT = "experiments/exp_loss_weight_sweep/training_output"   # path to run folders
OUT_DIR = "experiments/exp_loss_weight_sweep"                # where to write CSVs

# If running from a different CWD, you can set absolute paths:
# ROOT = "/wsu/home/.../experiments/exp_loss_weight_sweep/training_output"
# OUT_DIR = "/wsu/home/.../experiments/exp_loss_weight_sweep"


In [None]:

import os, json, re
from pathlib import Path
import pandas as pd

def load_json(p):
    try:
        with open(p, "r") as f:
            return json.load(f)
    except Exception as e:
        print(f"[WARN] Failed to read {p}: {e}")
        return None

FAMILY_PATTERNS = [
    (re.compile(r'efficientnet', re.I), 'EfficientNet'),
    (re.compile(r'convnext', re.I),     'ConvNeXt'),
    (re.compile(r'\bvit\b', re.I),      'ViT'),
    (re.compile(r'comer', re.I),        'ViT'),      # ViT-CoMer
    (re.compile(r'swin', re.I),         'Swin'),
    (re.compile(r'mamba', re.I),        'Mamba'),
]

def infer_family(model_tag: str, fallback: str) -> str:
    key = (model_tag or fallback or "").lower()
    for pat, fam in FAMILY_PATTERNS:
        if pat.search(key):
            return fam
    return "Other"

def get_nested(d, path, default=None):
    cur = d
    for k in path.split("."):
        if not isinstance(cur, dict) or k not in cur:
            return default
        cur = cur[k]
    return cur

def tuple4(d, base):
    """Return (acc, prec, rec, f1) for a head or (None,..) if missing."""
    return (
        get_nested(d, f"{base}.accuracy"),
        get_nested(d, f"{base}.precision"),
        get_nested(d, f"{base}.recall"),
        get_nested(d, f"{base}.f1"),
    )

ROOT = Path(ROOT)
OUT_DIR = Path(OUT_DIR)
OUT_DIR.mkdir(parents=True, exist_ok=True)


In [None]:

rows = []
for run_dir in sorted(ROOT.iterdir()):
    if not run_dir.is_dir():
        continue
    summary_path = run_dir / "training_summary.json"
    if not summary_path.exists():
        continue

    js = load_json(summary_path)
    if js is None:
        continue

    model_tag = js.get("model_tag") or run_dir.name
    family = infer_family(model_tag, run_dir.name)

    # Scalars
    batch = js.get("batch_size")
    lr    = js.get("learning_rate")
    total_acc = js.get("best_accuracy")

    # Per-head tuples
    e_acc, e_prec, e_rec, e_f1 = tuple4(js, "best_model_metrics.energy")
    a_acc, a_prec, a_rec, a_f1 = tuple4(js, "best_model_metrics.alpha")
    q_acc, q_prec, q_rec, q_f1 = tuple4(js, "best_model_metrics.q0")

    rows.append({
        "family": family,
        "model_tag": model_tag,
        "run_dir": str(run_dir),
        "batch": batch,
        "lr": lr,
        "Total_Acc": total_acc,
        "Energy_Acc": e_acc, "Energy_Prec": e_prec, "Energy_Rec": e_rec, "Energy_F1": e_f1,
        "Alpha_Acc": a_acc,  "Alpha_Prec": a_prec,  "Alpha_Rec": a_rec,  "Alpha_F1": a_f1,
        "Q0_Acc": q_acc,     "Q0_Prec": q_prec,     "Q0_Rec": q_rec,     "Q0_F1": q_f1,
    })

import pandas as pd
df = pd.DataFrame(rows)
print(f"[OK] Collected {len(df)} runs from {ROOT}")
df.head()


In [None]:

full_csv = OUT_DIR / "aggregate_results_detailed.csv"
df.to_csv(full_csv, index=False)
print(f"[OK] wrote {full_csv}")

# Display (if running in a notebook UI that supports DataFrames)
try:
    from caas_jupyter_tools import display_dataframe_to_user
    display_dataframe_to_user("Loss-Weight Sweep — Full Results", df)
except Exception as e:
    print("[INFO] display_dataframe_to_user not available in this environment.")


In [None]:

# Rank by Total_Acc, tie-break on Q0_F1 (most challenging head)
df_numeric = df.copy()
for c in ["Total_Acc","Q0_F1"]:
    df_numeric[c] = pd.to_numeric(df_numeric[c], errors="coerce")

best_rows = []
for fam, sub in df_numeric.groupby("family"):
    sub = sub.copy()
    sub["_rank_key"] = list(zip(sub["Total_Acc"].fillna(-1), sub["Q0_F1"].fillna(-1)))
    best_idx = sub["_rank_key"].idxmax()
    best_rows.append(df.loc[best_idx])

best_df = pd.DataFrame(best_rows).sort_values("family")
best_csv = OUT_DIR / "best_per_family_detailed.csv"
best_df.to_csv(best_csv, index=False)
print(f"[OK] wrote {best_csv}")
best_df[["family","model_tag","batch","lr","Total_Acc",
         "Energy_Acc","Energy_Prec","Energy_Rec","Energy_F1",
         "Alpha_Acc","Alpha_Prec","Alpha_Rec","Alpha_F1",
         "Q0_Acc","Q0_Prec","Q0_Rec","Q0_F1"]]


In [None]:

def latex_tuple(a,p,r,f):
    def fmt(x):
        import math
        if x is None or (isinstance(x, float) and (x != x)):
            return "-"
        try:
            return f"{float(x):.4f}"
        except Exception:
            return "-"
    return f"({fmt(a)}, {fmt(p)}, {fmt(r)}, {fmt(f)})"

lines = []
lines.append("\begin{table}[ht]")
lines.append("\centering")
lines.append("\caption{Best result per model family on loss-weight sweep.}")
lines.append("\label{tab:best_family}")
lines.append("\scriptsize")
lines.append("\begin{tabular}{llccccc}")
lines.append("\toprule")
lines.append("\textbf{Model (Family)} & \textbf{Batch} & \textbf{LR} & Total$_{\mathrm{Acc}}$ & "
             "\makecell{E-loss \\ (Acc,~Prec,~Rec,~F1)} & "
             "\makecell{$\alpha_s$ \\ (Acc,~Prec,~Rec,~F1)} & "
             "\makecell{$Q_0$ \\ (Acc,~Prec,~Rec,~F1)} \\")
lines.append("\midrule")

for _, r in best_df.iterrows():
    model_name = r["model_tag"]
    fam = r["family"]
    batch = r["batch"]
    lr = r["lr"]
    tot = float(r["Total_Acc"]) if r["Total_Acc"] == r["Total_Acc"] else 0.0
    e = latex_tuple(r["Energy_Acc"], r["Energy_Prec"], r["Energy_Rec"], r["Energy_F1"])
    a = latex_tuple(r["Alpha_Acc"],  r["Alpha_Prec"],  r["Alpha_Rec"],  r["Alpha_F1"])
    q = latex_tuple(r["Q0_Acc"],     r["Q0_Prec"],     r["Q0_Rec"],     r["Q0_F1"])

    lines.append(f"{model_name} ({fam}) & {batch} & {lr} & {tot:.4f} & {e} & {a} & {q} \\")

lines.append("\bottomrule")
lines.append("\end{tabular}")
lines.append("\end{table}")

latex_str = "\n".join(lines)
print(latex_str)
