# 11 · COREL Calibration Demo (SpectraMind V50)

Mission‑grade **uncertainty calibration** walkthrough for SpectraMind V50 using **CLI + Hydra** outputs only (no ad‑hoc pipeline code).

This notebook:
1) Locates pre‑calibration predictions (`μ, σ`) and held‑out targets `y` under `outputs/`.
2) Computes baseline **z‑score** diagnostics and coverage before calibration.
3) (Optional) Invokes the COREL conformal/GNN calibrator via the CLI to produce calibrated `σ` or prediction intervals.
4) Loads calibrated artifacts and recomputes coverage/quantile diagnostics and reliability plots.
5) Writes a compact **calibration_report.json** and detail CSV + PNGs under `outputs/notebooks/11_corel_calibration/`.

**Contract**: Notebooks are thin orchestration. We only read CLI/Hydra artifacts and, when requested, *call* the CLI. All results are saved in `outputs/` and are DVC‑friendly.

In [None]:
import os, sys, json, shutil, subprocess, platform
from pathlib import Path
from datetime import datetime
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_context('notebook'); sns.set_style('whitegrid')

ROOT = Path.cwd().resolve()
NB_OUT = ROOT / 'outputs' / 'notebooks' / '11_corel_calibration'
NB_OUT.mkdir(parents=True, exist_ok=True)

CLI = shutil.which('spectramind') or (f"{sys.executable} {ROOT/'spectramind.py'}" if (ROOT/'spectramind.py').exists() else f"{sys.executable} -m spectramind")
print('ROOT:', ROOT)
print('NB_OUT:', NB_OUT)
print('CLI  :', CLI)

(NB_OUT/'env_snapshot.json').write_text(json.dumps({'python': platform.python_version(), 'platform': platform.platform()}, indent=2))
print('Saved env snapshot.')

## Parameters
Tell the notebook where to look for predictions/labels. If left to `None`, we try to auto‑discover under `outputs/`.

- **Before‑calibration**: csv with columns like `planet_id, wavelength_index, mu, sigma` (and optionally `y`). If `y` is not present, set `LABELS_HINT`.
- **Labels**: long table `planet_id, wavelength_index, y` or wide per‑planet rows `mu_000..` for ground truth.
- **After‑calibration**: csv with same structure but calibrated `sigma_cal` or conformal intervals `lo,hi`.

You can also toggle the CLI calibration cell below to generate the calibrated artifacts.

In [None]:
PRED_HINTS = [
    ROOT/'outputs'/'predictions'/'predictions.csv',
    ROOT/'outputs'/'predictions.csv',
    ROOT/'outputs'/'runs',  # scan
]
LABELS_HINT = None  # e.g., ROOT/'outputs'/'val_labels.csv' or ROOT/'data'/'labels'/'val_labels.csv'
CAL_HINTS = [
    ROOT/'outputs'/'calibration'/'corel_calibrated.csv',
    ROOT/'outputs'/'runs',
]

# Coverage target (e.g., 90%) for conformal intervals if produced by the CLI
TARGET_ALPHA = 0.1  # 1 - coverage

print('PRED_HINTS:', [str(p) for p in PRED_HINTS])
print('LABELS_HINT:', LABELS_HINT)
print('CAL_HINTS :', [str(p) for p in CAL_HINTS])

## 1) Locate artifacts
We search for the newest CSVs under the hints. If predictions do not include labels, we load labels separately and merge on `(planet_id, wavelength_index)`.

In [None]:
def newest_csv(paths):
    cands = []
    for h in paths:
        if not h.exists():
            continue
        if h.is_file() and h.suffix.lower()=='.csv':
            cands.append(h)
        elif h.is_dir():
            cands += list(h.rglob('*.csv'))
    if not cands:
        return None
    return sorted(cands, key=lambda p: p.stat().st_mtime)[-1]

PRED_CSV = newest_csv(PRED_HINTS)
CAL_CSV  = newest_csv(CAL_HINTS)
print('PRED_CSV:', PRED_CSV)
print('CAL_CSV :', CAL_CSV)
if PRED_CSV is None:
    raise FileNotFoundError('No predictions CSV found in PRED_HINTS; generate predictions first (see 04_predict_v50_demo).')

pred_df = pd.read_csv(PRED_CSV)
print('pred_df:', pred_df.shape)
pred_df.head(3)

In [None]:
# Pull labels: from predictions if included, otherwise from LABELS_HINT.
need_labels = 'y' not in {c.lower() for c in pred_df.columns}
labels_df = None
if need_labels:
    if LABELS_HINT is None:
        # Try a few common places
        for guess in [ROOT/'outputs'/'val_labels.csv', ROOT/'data'/'labels'/'val_labels.csv']:
            if guess.exists():
                LABELS_HINT = guess; break
    if LABELS_HINT and LABELS_HINT.exists():
        labels_df = pd.read_csv(LABELS_HINT)
        print('labels_df:', labels_df.shape, 'from', LABELS_HINT)
    else:
        print('WARNING: No labels found; coverage diagnostics will be limited.')

# Normalize column names
def normcols(df):
    return df.rename(columns={c: c.lower() for c in df.columns})

pred_df = normcols(pred_df)
if labels_df is not None:
    labels_df = normcols(labels_df)

# Expect long format: planet_id, wavelength_index, mu, sigma, (y optional)
required = {'planet_id','wavelength_index','mu'}
if not required.issubset(set(pred_df.columns)):
    raise ValueError(f'predictions missing required columns {required}; got {list(pred_df.columns)}')
has_sigma = 'sigma' in pred_df.columns
print('has_sigma:', has_sigma)

if labels_df is not None:
    keep = {'planet_id','wavelength_index','y'} & set(labels_df.columns)
    labels_df = labels_df[list(keep)]
    df = pred_df.merge(labels_df, on=['planet_id','wavelength_index'], how='inner')
else:
    df = pred_df.copy()
print('merged df:', df.shape)
df.head(3)

## 2) Pre‑calibration diagnostics (baseline)
We compute z‑scores `z = (μ−y)/σ` where available, histogram them, and estimate empirical coverage for nominal Gaussian intervals `μ±kσ`.

In [None]:
def zscores(mu, sigma, y):
    s = np.asarray(sigma, float)
    s = np.where(s<=0, np.nan, s)
    return (np.asarray(mu, float) - np.asarray(y, float)) / s

pre = {}
if ('y' in df.columns) and has_sigma:
    z = zscores(df['mu'], df['sigma'], df['y'])
    pre['z_mean'] = float(np.nanmean(z))
    pre['z_std']  = float(np.nanstd(z))
    # Empirical coverage for 1σ/2σ
    pre['cov_1sigma'] = float(np.nanmean(np.abs(z)<=1.0))
    pre['cov_2sigma'] = float(np.nanmean(np.abs(z)<=2.0))
else:
    z = None
    print('Sigma and/or y missing; z‑score baseline limited.')

# Plot histogram if possible
if z is not None:
    plt.figure(figsize=(9,3))
    sns.histplot(z[~np.isnan(z)], bins=60, kde=True, stat='density', color='tab:blue')
    plt.title(f'Pre‑calibration z distribution (mean={pre.get("z_mean",np.nan):.3f}, std={pre.get("z_std",np.nan):.3f})')
    plt.xlabel('z'); plt.ylabel('density')
    plt.tight_layout(); plt.savefig(NB_OUT/'pre_z_hist.png', dpi=150); plt.close()
    print('Saved pre_z_hist.png')

# Per‑bin coverage heatmap (optional)
bin_cov = None
if z is not None:
    tmp = df[['wavelength_index']].copy()
    tmp['in1'] = np.abs(z)<=1
    bin_cov = tmp.groupby('wavelength_index')['in1'].mean().reset_index()
    plt.figure(figsize=(10,2.5))
    plt.plot(bin_cov['wavelength_index'], bin_cov['in1'], lw=1)
    plt.ylim(-0.05,1.05)
    plt.title('Pre‑calibration per‑bin 1σ coverage')
    plt.xlabel('wavelength index'); plt.ylabel('coverage')
    plt.tight_layout(); plt.savefig(NB_OUT/'pre_bin_coverage.png', dpi=150); plt.close()
    print('Saved pre_bin_coverage.png')

pre

## 3) (Optional) Run COREL calibrator via CLI
If your repository exposes a COREL or conformal calibration command, enable the cell below to generate calibrated artifacts.

Common patterns:
- `spectramind calibrate-uncertainty corel ...` (example)
- `spectramind corel-calibrate ...`

Adjust flags to point at your predictions/labels and target coverage (`1−α`).

In [None]:
RUN_COREL_CLI = False  # set True to enable
if RUN_COREL_CLI:
    try:
        cmd = [
            CLI,
            'corel-calibrate',                   # <-- change to your actual subcommand
            f'inputs={str(PRED_CSV)}',           # predictions with mu/sigma
            f'labels={str(LABELS_HINT)}',        # labels csv (if needed by CLI)
            f'alpha={TARGET_ALPHA}',             # desired miscoverage
            f'outdir={str(ROOT/"outputs"/"calibration")}'
        ]
        print('Running:', ' '.join(map(str,cmd)))
        subprocess.run(list(map(str,cmd)), check=True)
    except Exception as e:
        print('COREL CLI failed (non‑blocking):', e)
else:
    print('COREL CLI disabled; set RUN_COREL_CLI=True to run it here.')

## 4) Load calibrated artifacts
We look for `sigma_cal` or prediction intervals `lo,hi` in the calibrated file; if missing, we stay with baseline.

In [None]:
cal_df = None
if CAL_CSV and CAL_CSV.exists():
    cal_df = pd.read_csv(CAL_CSV)
    cal_df = cal_df.rename(columns={c: c.lower() for c in cal_df.columns})
    print('cal_df:', cal_df.shape, 'from', CAL_CSV)
else:
    print('No calibrated CSV found; proceeding with baseline only.')

# Merge calibrated columns if structure matches
merged_df = df.copy()
has_sigma_cal = False
has_intervals = False
if cal_df is not None:
    join_keys = [k for k in ['planet_id','wavelength_index'] if k in cal_df.columns and k in merged_df.columns]
    if join_keys:
        merged_df = merged_df.merge(cal_df, on=join_keys, how='left', suffixes=('','_cal'))
        has_sigma_cal = 'sigma_cal' in merged_df.columns
        has_intervals = {'lo','hi'}.issubset(set(merged_df.columns))
        print('has_sigma_cal:', has_sigma_cal, 'has_intervals:', has_intervals)
    else:
        print('WARNING: Could not align calibrated file with predictions on keys; skipping merge.')

merged_df.head(3)

## 5) Post‑calibration diagnostics
We recompute z‑scores using `σ_cal` if provided, and/or coverage using intervals `lo,hi`. We also plot reliability and per‑bin coverage.

In [None]:
post = {}
z_cal = None
if ('y' in merged_df.columns) and has_sigma_cal:
    z_cal = zscores(merged_df['mu'], merged_df['sigma_cal'], merged_df['y'])
    post['z_mean'] = float(np.nanmean(z_cal))
    post['z_std']  = float(np.nanstd(z_cal))
    post['cov_1sigma'] = float(np.nanmean(np.abs(z_cal)<=1.0))
    post['cov_2sigma'] = float(np.nanmean(np.abs(z_cal)<=2.0))

if z_cal is not None:
    plt.figure(figsize=(9,3))
    sns.histplot(z_cal[~np.isnan(z_cal)], bins=60, kde=True, stat='density', color='tab:green')
    plt.title(f'Post‑calibration z distribution (mean={post.get("z_mean",np.nan):.3f}, std={post.get("z_std",np.nan):.3f})')
    plt.xlabel('z_cal'); plt.ylabel('density')
    plt.tight_layout(); plt.savefig(NB_OUT/'post_z_hist.png', dpi=150); plt.close()
    print('Saved post_z_hist.png')

# Interval coverage (if conformal intervals lo,hi available)
int_cov = None
if ('y' in merged_df.columns) and has_intervals:
    inside = (merged_df['y']>=merged_df['lo']) & (merged_df['y']<=merged_df['hi'])
    int_cov = float(np.mean(inside))
    post['interval_coverage'] = int_cov
    plt.figure(figsize=(10,2.5))
    # per‑bin interval coverage
    perbin = merged_df[['wavelength_index']].copy()
    perbin['inside'] = inside
    gb = perbin.groupby('wavelength_index')['inside'].mean().reset_index()
    plt.plot(gb['wavelength_index'], gb['inside'], lw=1, color='tab:purple')
    plt.ylim(-0.05,1.05)
    plt.title(f'Per‑bin interval coverage (target ~ {1-TARGET_ALPHA:.0%})')
    plt.xlabel('wavelength index'); plt.ylabel('coverage')
    plt.tight_layout(); plt.savefig(NB_OUT/'post_bin_interval_coverage.png', dpi=150); plt.close()
    print('Saved post_bin_interval_coverage.png')

post

### Reliability: empirical vs nominal
For Gaussian σ, nominal central mass for `k` is `erf(k/√2)`. For conformal intervals, the nominal is `1−α`. We overlay empirical points.

In [None]:
from math import erf, sqrt

def nominal_gauss_mass(k):
    return erf(k/sqrt(2.0))

def empirical_mass_from_z(z, ks=(0.5,1.0,1.5,2.0,2.5)):
    out = []
    z = np.asarray(z)
    for k in ks:
        out.append(np.nanmean(np.abs(z)<=k))
    return np.array(out), np.array(ks)

# Pre (if available)
if z is not None:
    emp, ks = empirical_mass_from_z(z)
    nom = np.array([nominal_gauss_mass(k) for k in ks])
    plt.figure(figsize=(6,4))
    plt.plot(nom, emp, 'o-', label='pre (empirical vs nominal)')
    plt.plot([0,1],[0,1], 'k--', lw=1)
    plt.xlabel('Nominal mass'); plt.ylabel('Empirical mass')
    plt.title('Reliability (pre)'); plt.legend(); plt.tight_layout()
    plt.savefig(NB_OUT/'reliability_pre.png', dpi=150); plt.close()
    print('Saved reliability_pre.png')

# Post (if z_cal)
if z_cal is not None:
    emp2, ks2 = empirical_mass_from_z(z_cal)
    nom2 = np.array([nominal_gauss_mass(k) for k in ks2])
    plt.figure(figsize=(6,4))
    plt.plot(nom2, emp2, 'o-', color='tab:green', label='post (empirical vs nominal)')
    plt.plot([0,1],[0,1], 'k--', lw=1)
    plt.xlabel('Nominal mass'); plt.ylabel('Empirical mass')
    plt.title('Reliability (post)'); plt.legend(); plt.tight_layout()
    plt.savefig(NB_OUT/'reliability_post.png', dpi=150); plt.close()
    print('Saved reliability_post.png')

# Interval reliability (conformal)
if int_cov is not None:
    plt.figure(figsize=(4,3))
    plt.bar(['empirical','target'], [int_cov, 1-TARGET_ALPHA], color=['tab:purple','tab:gray'])
    plt.ylim(0,1)
    plt.title('Interval coverage (post)')
    plt.tight_layout(); plt.savefig(NB_OUT/'reliability_interval.png', dpi=150); plt.close()
    print('Saved reliability_interval.png')

## 6) Report bundle
We store a machine‑readable `calibration_report.json` and detail CSV with per‑bin coverage (if computed).

In [None]:
report = {
    'pred_file': str(PRED_CSV) if PRED_CSV else None,
    'cal_file': str(CAL_CSV) if CAL_CSV else None,
    'target_alpha': float(TARGET_ALPHA),
    'pre': pre,
    'post': post,
}
(NB_OUT/'calibration_report.json').write_text(json.dumps(report, indent=2))
print('Wrote calibration_report.json')

if bin_cov is not None:
    bin_cov.to_csv(NB_OUT/'pre_bin_coverage.csv', index=False)
    print('Wrote pre_bin_coverage.csv')

# Optionally, per‑bin post coverage if z_cal is present
if (z_cal is not None) and ('wavelength_index' in merged_df.columns):
    tmp2 = merged_df[['wavelength_index']].copy()
    tmp2['in1'] = np.abs(z_cal)<=1
    post_bin = tmp2.groupby('wavelength_index')['in1'].mean().reset_index()
    post_bin.to_csv(NB_OUT/'post_bin_coverage.csv', index=False)
    print('Wrote post_bin_coverage.csv')

## 7) (Optional) DVC add
Register outputs for full reproducibility.

In [None]:
if shutil.which('dvc'):
    try:
        subprocess.run(['dvc','add', str(NB_OUT)], check=False)
        subprocess.run(['git','add', f'{NB_OUT}.dvc', '.gitignore'], check=False)
        subprocess.run(['dvc','status'], check=False)
        print('DVC add done (non‑blocking).')
    except Exception as e:
        print('DVC step failed (non‑blocking):', e)
else:
    print('DVC not found; skipping.')