
# SpectraMind V50 — Model Ablation Study

This notebook runs **systematic ablations** over the SpectraMind V50 architecture and loss stack to quantify each component’s contribution.

**Scope**
- Encoders: FGS1 (temporal) / AIRS (spectral)
- Fusion: concatenation vs. cross-attn (or SSM gating)
- Decoder heads: μ/σ joint vs. decoupled
- Loss stack: Gaussian NLL (base) + smoothness + non-negativity + band coherence + calibration
- Precision / mixed precision
- Training-time toggles (early stopping, lr scheduler)

**Outputs**
- CSV of runs & metrics
- Plots comparing ablated configs
- A short, printable summary

> Tip: This notebook calls the project CLI (`spectramind`) or Python entrypoints. Ensure your repo is installed (editable mode) in this kernel.


In [None]:

# --- Environment sanity check (customize paths if needed) ---
import sys, subprocess, json, os, shutil, time
import pandas as pd
import numpy as np
from pathlib import Path

print('Python:', sys.version)
print('CWD:', os.getcwd())

# Try to import project; fall back to CLI calls if not importable
try:
    import spectramind
    HAS_PACKAGE = True
    print('spectramind package found:', spectramind.__version__ if hasattr(spectramind, '__version__') else 'OK')
except Exception as e:
    HAS_PACKAGE = False
    print('spectramind package not importable:', e)

# CLI resolver
def run_cli(args, capture_output=True, check=False):
    """Run a spectramind CLI command. Returns (rc, stdout, stderr)."""
    if shutil.which('spectramind'):
        cmd = ['spectramind'] + args
    else:
        cmd = [sys.executable, '-m', 'spectramind'] + args
    print('> ', ' '.join(map(str, cmd)))
    p = subprocess.run(cmd, capture_output=capture_output, text=True)
    if check and p.returncode != 0:
        raise RuntimeError(p.stderr or p.stdout)
    return p.returncode, p.stdout, p.stderr


## 1) Define ablation grid

In [None]:

# Define a compact ablation grid. Keep runs short by using debug/truncated configs.
# Each entry is a dict of Hydra overrides.
ABLATIONS = [
    # Baseline
    {'name':'baseline', 'overrides':[
        '+env=local', '+data=nominal', '+calib=nominal', '+model=v50',
        '+training=lightning', '+loss=composite', '+logger=csv'
    ]},
    # Remove smoothness
    {'name':'no_smooth', 'overrides':[
        '+loss.composite.enable_smoothness=false'
    ]},
    # Remove non-negativity constraint
    {'name':'no_nonneg', 'overrides':[
        '+loss.composite.enable_nonneg=false'
    ]},
    # Remove band coherence
    {'name':'no_bandcoh', 'overrides':[
        '+loss.composite.enable_band_coherence=false'
    ]},
    # FGS1 encoder ablation: simpler MLP
    {'name':'fgs1_mlp', 'overrides':[
        '+model.fgs1_encoder=mlp_small'
    ]},
    # AIRS encoder ablation: shallow CNN
    {'name':'airs_cnn_shallow', 'overrides':[
        '+model.airs_encoder=cnn_shallow'
    ]},
    # Fusion change: concat -> cross-attn (if available)
    {'name':'fusion_xattn', 'overrides':[
        '+model.fusion=attn'
    ]},
    # Decoder heads decoupled μ/σ
    {'name':'decoupled_heads', 'overrides':[
        '+model.decoder.decoupled=true'
    ]},
    # Precision ablation
    {'name':'amp_off', 'overrides':[
        '+training.precision.precision=32'
    ]},
]

# Optionally shrink training budget for quick sweeps
GLOBAL_OVERRIDES = [
    '+training.max_epochs=3',
    '+training.trainer.limit_train_batches=0.2',
    '+training.trainer.limit_val_batches=1.0',
    '+logger.csv.dir=artifacts/ablations'
]

print(f'{len(ABLATIONS)} ablations configured.')


## 2) Run ablations

In [None]:

from datetime import datetime

RESULTS = []
ARTIFACT_DIR = Path('artifacts/ablations')
ARTIFACT_DIR.mkdir(parents=True, exist_ok=True)

def run_one(ablation):
    name = ablation['name']
    overrides = ablation.get('overrides', [])
    hydra_args = ['train', '--config-name', 'train'] + [*GLOBAL_OVERRIDES, *overrides]
    started = time.time()
    try:
        rc, out, err = run_cli(hydra_args, capture_output=True, check=False)
        elapsed = time.time() - started
        # Try to parse metrics if the CLI prints JSON lines or write a metrics file
        # Here we heuristically search the CSV logger directory for the latest metrics.csv
        metrics_path = None
        if ARTIFACT_DIR.exists():
            candidates = sorted(ARTIFACT_DIR.rglob('metrics*.csv'), key=lambda p: p.stat().st_mtime, reverse=True)
            metrics_path = str(candidates[0]) if candidates else None
        RESULTS.append({
            'name': name,
            'returncode': rc,
            'elapsed_sec': round(elapsed, 2),
            'metrics_csv': metrics_path,
            'stdout_tail': (out or '')[-800:],
            'stderr_tail': (err or '')[-800:]
        })
    except Exception as ex:
        elapsed = time.time() - started
        RESULTS.append({
            'name': name,
            'returncode': -1,
            'elapsed_sec': round(elapsed, 2),
            'metrics_csv': None,
            'error': str(ex)
        })

# Toggle this to actually run (set to False when editing)
EXECUTE = False

if EXECUTE:
    for abl in ABLATIONS:
        print(f'\n=== Running: {abl["name"]} ===')
        run_one(abl)

pd.DataFrame(RESULTS)


## 3) Aggregate metrics

In [None]:

def load_metrics_row(row):
    path = row.get('metrics_csv')
    if not path or not os.path.exists(path):
        return {}
    try:
        dfm = pd.read_csv(path)
        # assume last row is the final epoch/validation metrics
        last = dfm.iloc[-1].to_dict()
        # keep only common keys
        keep = {k: v for k, v in last.items() if any(x in k for x in ['val_', 'epoch', 'gll', 'loss'])}
        return keep
    except Exception as e:
        return {'metrics_error': str(e)}

if RESULTS:
    df = pd.DataFrame(RESULTS)
    metric_cols = []
    expanded = []
    for _, r in df.iterrows():
        m = load_metrics_row(r)
        rr = dict(r)
        rr.update(m)
        expanded.append(rr)
        metric_cols += list(m.keys())
    df_exp = pd.DataFrame(expanded)
    metric_cols = sorted(set(metric_cols))
    display(df_exp[['name','returncode','elapsed_sec'] + metric_cols])
else:
    print('No RESULTS yet. Set EXECUTE=True to run.')


## 4) Plot comparisons

In [None]:

import matplotlib.pyplot as plt

if 'df_exp' in globals():
    key = None
    # Heuristic: choose a known metric name if present
    for k in ['val_gll', 'val_loss', 'val_metric']:
        if k in df_exp.columns:
            key = k
            break
    if key:
        plt.figure(figsize=(8,4))
        df_plot = df_exp.sort_values(key)
        plt.barh(df_plot['name'], df_plot[key])
        plt.xlabel(key)
        plt.title('Ablation Comparison')
        plt.gca().invert_yaxis()
        plt.tight_layout()
        plt.show()
    else:
        print('No recognized validation metric column found to plot.')
else:
    print('No aggregated DataFrame available.')


## 5) Generate summary

In [None]:

SUMMARY_PATH = ARTIFACT_DIR / 'ablation_summary.csv'
if 'df_exp' in globals():
    df_exp.to_csv(SUMMARY_PATH, index=False)
    print('Saved summary to', SUMMARY_PATH)
else:
    print('Nothing to summarize yet.')



---

### Appendix: Notes

- The ablation overrides assume corresponding Hydra config groups:
  - `+model.fgs1_encoder=mlp_small`, `+model.airs_encoder=cnn_shallow`, `+model.fusion=attn`, etc.
  - `+loss.composite.enable_*` boolean flags within the loss stack.
- Swap `+env=local` for `+env=kaggle` when running on Kaggle.
- Increase `+training.max_epochs` for high-fidelity results; this notebook defaults to quick smoke tests.
- If you don’t have the package importable, ensure the CLI works (`pip install -e .` or `python -m spectramind --help`).

**Suggested metrics to track** (and ensure your logger emits them as CSV columns):
- `val_gll` (Gaussian log-likelihood, higher is better or negate if using loss convention)
- `val_loss` (lower is better)
- `val_smooth_penalty`, `val_nonneg_violations`, `val_band_coherence` (diagnostic)

