# 09 · Ablation Study (SpectraMind V50)

Mission‑grade **ablation orchestrator** for V50 using the **CLI + Hydra** only (no ad‑hoc training code). This notebook:

1) Defines a small **grid of Hydra overrides** for symbolic/architecture flags (ablation candidates).
2) Invokes the official CLI ablation runner (e.g., `spectramind ablate ...`) to execute multiple runs.
3) Collects all **run metrics** from `outputs/` and compiles a **leaderboard** (CSV/MD/HTML).
4) Renders comparison plots (e.g., Val GLL, coverage metrics) and exports them under `outputs/notebooks/09_ablation_study/`.
5) Optionally registers artifacts with **DVC** for reproducibility.

_Notebook contract_: **Thin orchestration** over CLI/Hydra; all artifacts written under `outputs/` and tracked as appropriate.

In [None]:
import os, sys, json, subprocess, platform, shutil, time
from pathlib import Path
from datetime import datetime
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_context('notebook'); sns.set_style('whitegrid')

ROOT = Path.cwd()
NB_OUT = ROOT / 'outputs' / 'notebooks' / '09_ablation_study'
NB_OUT.mkdir(parents=True, exist_ok=True)
OUT_ROOT = ROOT / 'outputs'
LOGS = ROOT / 'logs'
LOGS.mkdir(exist_ok=True, parents=True)

print('ROOT:', ROOT)
print('NB_OUT:', NB_OUT)
print('OUT_ROOT:', OUT_ROOT)
print('Python:', platform.python_version())

def sh(cmd, check=True, cwd=None, env=None):
    print('\n$', cmd)
    r = subprocess.run(cmd, shell=True, cwd=cwd or ROOT, env=env)
    if check and r.returncode != 0:
        raise RuntimeError(f'Command failed ({r.returncode}): {cmd}')
    return r.returncode

# Try to find the CLI
CLI = shutil.which('spectramind')
if CLI is None:
    if (ROOT/'spectramind.py').exists():
        CLI = f"{sys.executable} {ROOT/'spectramind.py'}"
    else:
        CLI = f"{sys.executable} -m spectramind"
print('CLI:', CLI)

## Parameters — Ablation Grid
Edit these lists to control the set of overrides. The ablation runner will combine them (Cartesian product) or accept explicit lists depending on your CLI.

**Examples**
- Toggle symbolic losses or weights
- Switch AIRS encoder (e.g., GAT vs baseline)
- Adjust smoothness weight, nonneg cap, batch size/optimizer minor tweaks

> Keep the grid **small** when testing interactively (e.g., 3–6 configs). For a larger grid, run headless via CI.

In [None]:
RUN_TS = datetime.utcnow().strftime('%Y%m%d_%H%M%S')
ABLATE_TAG = f"ablate_{RUN_TS}"
ABLATE_OUTDIR = OUT_ROOT / 'ablate' / ABLATE_TAG
ABLATE_OUTDIR.mkdir(parents=True, exist_ok=True)

# Grid candidates (Hydra overrides); keep short for demo
SYMBOLIC_ON = [True, False]                      # turn symbolic loss on/off
SMOOTH_W = [0.0, 0.1]                            # smoothness weight
ENCODER = ['airs_gnn_gat', 'airs_gnn_baseline']  # encoder variants (example names)

# Build override strings for each point
def build_override(symbolic_on, smooth_w, encoder):
    return [
        f"symbolic.enabled={'true' if symbolic_on else 'false'}",
        f"symbolic.smooth_w={smooth_w}",
        f"model.airs_encoder={encoder}",
        # Demo: fast run settings (adapt to your code)
        "training.fast_dev_run=true",
        # Tag & output dir
        f"run.tag={ABLATE_TAG}"
    ]

GRID = [build_override(a,b,c) for a in SYMBOLIC_ON for b in SMOOTH_W for c in ENCODER]
print(f"Grid size: {len(GRID)}")
for ex in GRID[:3]:
    print('example overrides:', ex)

## Run Ablations via CLI
This cell calls the official ablation entry point. Your repository may expose one of these shapes:

1. `spectramind ablate` with arguments like `--grid-json` or repeated `+override` flags.
2. `spectramind train ...` executed in a loop (fallback mode shown below if no ablate subcommand exists).

All runs should write into `outputs/` with a consistent naming/tagging so we can harvest metrics afterward.

In [None]:
# Try a native 'ablate' subcommand first; if not, fall back to looping 'train'.
def run_ablation_native(overrides_grid):
    # If your CLI supports passing a JSON/CSV of overrides, prepare it here.
    # For demo, we'll fall back immediately.
    return False

def run_ablation_loop(overrides_grid):
    for i, ov in enumerate(overrides_grid, start=1):
        tag = f"{ABLATE_TAG}_cfg{i:02d}"
        cmd = [
            CLI, 'train',
            f"hydra.run.dir=outputs/runs/{tag}",
            f"hydra.job.name={tag}",
        ] + ov
        sh(" ".join(cmd), check=True)
    return True

ok = run_ablation_native(GRID)
if not ok:
    print("Native ablate not available; running looped trains...")
    run_ablation_loop(GRID)

print("\nAblation batch complete. Tag:", ABLATE_TAG)

## Harvest Metrics & Build Leaderboard
We walk `outputs/` for runs carrying this ablation tag, parse metrics/configs, and assemble a leaderboard.

In [None]:
def find_runs_by_tag(tag):
    # Accept layouts like outputs/runs/<tag>*, outputs/YYYY-MM-DD/<time>_<tag>
    found = []
    # layout A: outputs/runs/
    runs_root = OUT_ROOT / 'runs'
    if runs_root.exists():
        for p in runs_root.glob(f"{tag}*"):
            if p.is_dir():
                found.append(p)
    # layout B: date-based
    for date_dir in OUT_ROOT.glob("20*"):
        if date_dir.is_dir():
            for run_dir in date_dir.iterdir():
                if run_dir.is_dir() and tag in run_dir.name:
                    found.append(run_dir)
    # unique
    uniq = []
    seen = set()
    for p in found:
        if p not in seen:
            uniq.append(p); seen.add(p)
    return sorted(uniq)

def load_config_snapshot(run_dir: Path):
    candidates = [run_dir/'config.yaml', run_dir/'.hydra'/'config.yaml']
    for c in candidates:
        if c.exists():
            return c.read_text(errors='ignore')
    return None

def load_metrics_csv(run_dir: Path):
    for p in [run_dir/'metrics.csv', run_dir/'training_metrics.csv']:
        if p.exists():
            try:
                return pd.read_csv(p)
            except Exception:
                pass
    return None

RUNS = find_runs_by_tag(ABLATE_TAG)
print(f"Found {len(RUNS)} runs for tag {ABLATE_TAG}")

rows = []
for r in RUNS:
    cfg_txt = load_config_snapshot(r)
    met = load_metrics_csv(r)
    # Heuristic final metrics (adapt to your schema)
    if met is not None and len(met):
        last = met.iloc[-1]
        row = {
            'run_dir': str(r.relative_to(ROOT)),
            'val_loss': last.get('val_loss', np.nan),
            'val_gll': last.get('val_gll', np.nan),
            'val_coverage': last.get('val_coverage', np.nan),
            'config_head': ("\n".join(cfg_txt.splitlines()[:25]) if cfg_txt else None)
        }
        rows.append(row)

LB = pd.DataFrame(rows)
if LB.empty:
    print('No metrics harvested. Check your metrics file names/columns.')
else:
    # rank by Val GLL (descending: higher better); fall back to -val_loss
    if 'val_gll' in LB:
        LB = LB.sort_values(['val_gll'], ascending=False)
    elif 'val_loss' in LB:
        LB = LB.sort_values(['val_loss'], ascending=True)
    LB.reset_index(drop=True, inplace=True)
    display(LB.head())

lb_csv = NB_OUT / 'ablation_leaderboard.csv'
LB.to_csv(lb_csv, index=False)
print('Wrote leaderboard CSV:', lb_csv)

## Plots: Leaderboard & Metric Trends
Simple visualizations for top‑k configurations.

In [None]:
if not LB.empty:
    TOPK = min(10, len(LB))
    df = LB.head(TOPK).copy()
    plt.figure(figsize=(10,4))
    if 'val_gll' in df:
        sns.barplot(x='val_gll', y='run_dir', data=df, orient='h', color='tab:blue')
        plt.title('Top Val GLL (higher is better)')
        plt.xlabel('Val GLL'); plt.ylabel('Run')
    elif 'val_loss' in df:
        sns.barplot(x='val_loss', y='run_dir', data=df, orient='h', color='tab:orange')
        plt.title('Top Val Loss (lower is better)')
        plt.xlabel('Val Loss'); plt.ylabel('Run')
    plt.tight_layout()
    plt.savefig(NB_OUT/'leaderboard_bar.png', dpi=150)
    plt.close()
    print('Saved:', NB_OUT/'leaderboard_bar.png')
else:
    print('Leaderboard empty; skipping plots.')

## Export Markdown & HTML Leaderboard
Convenient artifacts for PRs/Wikis/Kaggle writeups.

In [None]:
def to_markdown_table(df: pd.DataFrame) -> str:
    cols = [c for c in ['run_dir','val_gll','val_loss','val_coverage'] if c in df.columns]
    md = ['| ' + ' | '.join(cols) + ' |', '| ' + ' | '.join(['---']*len(cols)) + ' |']
    for _,row in df.iterrows():
        md.append('| ' + ' | '.join([str(row.get(c,'')) for c in cols]) + ' |')
    return '\n'.join(md)

if not LB.empty:
    md_text = '# Ablation Leaderboard\n\n' + to_markdown_table(LB.head(min(25,len(LB))))
    (NB_OUT/'ablation_leaderboard.md').write_text(md_text)
    print('Wrote:', NB_OUT/'ablation_leaderboard.md')

    # Minimal HTML wrapper
    html = ['<html><head><meta charset="utf-8"><title>Ablation Leaderboard</title></head><body>',
            '<h1>Ablation Leaderboard</h1>',
            '<pre>', md_text, '</pre>',
            '</body></html>']
    (NB_OUT/'ablation_leaderboard.html').write_text('\n'.join(html), encoding='utf-8')
    print('Wrote:', NB_OUT/'ablation_leaderboard.html')
else:
    print('No rows to export.')

## Optional: DVC Registration
If your project uses DVC, add the notebook outputs so runs are fully tracked.

In [None]:
if shutil.which('dvc'):
    try:
        sh(f"dvc add {NB_OUT}", check=False)
        sh(f"git add {NB_OUT}.dvc .gitignore", check=False)
        sh("dvc status", check=False)
    except Exception as e:
        print('DVC step failed (non-blocking):', e)
else:
    print('DVC not found; skipping.')

---
### Tips
- Scale the grid and **remove `training.fast_dev_run=true`** when running actual ablations.
- Use CI to run larger grids and publish the generated **MD/HTML leaderboard** as artifacts.
- To include **symbolic diagnostics** in the leaderboard, ensure your training logs export the corresponding metrics (e.g., average violation norms) and add columns here.

**Artifacts written to**: `outputs/notebooks/09_ablation_study/`