# KTND-Finance: Full Experiment Pipeline (v1.5.3)

### How to run:
1. **Cell 1** (Setup) -- Install deps + clone repo (~2 min)
2. **Cell 2** (Pipeline) -- Full pipeline + multi-seed + figures (~2-3 hours)
3. **Cell 3** (Ablations) -- Ablation study + Brownian gyrator (~6 hours, optional)
4. **Cell 4** (View) -- Display all figures
5. **Cell 5** (Download) -- Zip all results

### What changed (v1.5.3):
- **Lower regularization** -- beta_orth 0.001 (was 0.005), gamma_reg 1e-6 (was 1e-5)
- **More univariate modes** -- n_modes 8 (was 5) for narrower entropy gap
- **Multiasset robustness** -- full 7-test statistical battery on multiasset model too
- **Mode-aware results** -- robustness saves per-mode (statistical_tests.json / _multiasset.json)

### Previous (v1.5.2):
- 7 statistical tests (added time-reversal asymmetry)
- Cohen's d effect sizes on permutation test
- 1000 permutations, BIC model selection for HMM
- LR 3e-4, 800 epochs, 15 multiasset modes

### Notes:
- Set runtime to **GPU (T4)**: Runtime -> Change runtime type -> T4 GPU
- Cell 2 and Cell 3 are **independent** -- if Colab disconnects during Cell 3, Cell 2 outputs are preserved
- Cell 3 is **resume-safe** -- ablations save after each variant, so re-running picks up where it left off


In [None]:
#@title 1. Setup (install + clone + verify) - ~2 min

# Install missing dependencies (torch/numpy/pandas/scipy/sklearn/matplotlib are pre-installed)
!pip install -q yfinance>=1.0.0 hmmlearn>=0.3.0 statsmodels>=0.14.0 arch>=6.0.0 pyyaml>=6.0

# Clone repo
import os, sys
REPO_URL = "https://github.com/keshavkrishnan08/kind_finance.git"
REPO_DIR = "/content/ktnd_finance"

if os.path.exists(REPO_DIR):
    !cd {REPO_DIR} && git pull
else:
    !git clone {REPO_URL} {REPO_DIR}

os.chdir(REPO_DIR)
sys.path.insert(0, REPO_DIR)

# Verify
import torch, numpy as np
from src.model.vampnet import NonEquilibriumVAMPNet
print(f"Python {sys.version.split()[0]} | PyTorch {torch.__version__} | "
      f"CUDA: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
print("Setup complete.")

In [None]:
#@title 2. Run pipeline + multi-seed + figures (~2-3 hours)

import subprocess, time, json, os, sys, glob

# ==========================================================================
# PATHS
# ==========================================================================
REPO_DIR = "/content/ktnd_finance"
OUTPUT_DIR = "/content/ktnd_finance/outputs"
RESULTS_DIR = "/content/ktnd_finance/outputs/results"
MODELS_DIR = "/content/ktnd_finance/outputs/models"
FIGURES_DIR = "/content/ktnd_finance/outputs/figures"
DATA_DIR = "/content/ktnd_finance/data"

for d in [OUTPUT_DIR, RESULTS_DIR, MODELS_DIR, FIGURES_DIR]:
    os.makedirs(d, exist_ok=True)

os.chdir(REPO_DIR)
python = sys.executable

print(f"Python: {python}")
print(f"CWD: {os.getcwd()}")
print(f"Repo dir exists: {os.path.exists(REPO_DIR)}")
print(f"src/ exists: {os.path.isdir(os.path.join(REPO_DIR, 'src'))}")
print(f"experiments/ exists: {os.path.isdir(os.path.join(REPO_DIR, 'experiments'))}")
print(f"config/ exists: {os.path.isdir(os.path.join(REPO_DIR, 'config'))}")

def run(name, cmd, check_files=None):
    """Run a stage, print output, verify files."""
    print(f"\n{'='*70}")
    print(f"  STAGE: {name}")
    print(f"  CMD: {cmd}")
    print(f"{'='*70}")
    t0 = time.time()
    result = subprocess.run(cmd, shell=True, cwd=REPO_DIR,
                            capture_output=True, text=True)
    elapsed = time.time() - t0

    if result.stdout:
        for line in result.stdout.strip().split('\n'):
            print(f"  {line}")

    if result.returncode != 0:
        print(f"\n  === STDERR ===")
        if result.stderr:
            for line in result.stderr.strip().split('\n'):
                print(f"  ! {line}")
        print(f"  >> {name}: FAILED (exit code {result.returncode}, {elapsed/60:.1f} min)")
        return False

    if result.stderr:
        stderr_lines = result.stderr.strip().split('\n')
        error_lines = [l for l in stderr_lines if 'Error' in l or 'Exception' in l or 'Traceback' in l]
        if error_lines:
            print(f"  === STDERR (errors) ===")
            for line in error_lines:
                print(f"  ! {line}")

    if check_files:
        missing = [f for f in check_files if not os.path.exists(f)]
        if missing:
            print(f"  WARNING: Missing expected output files:")
            for f in missing:
                print(f"    MISSING: {f}")
            if result.stderr:
                for line in result.stderr.strip().split('\n')[-30:]:
                    print(f"  ! {line}")
            print(f"  >> {name}: INCOMPLETE ({elapsed/60:.1f} min)")
            return False
        for f in check_files:
            sz = os.path.getsize(f)
            print(f"  OK: {os.path.basename(f)} ({sz:,} bytes)")

    print(f"  >> {name}: OK ({elapsed/60:.1f} min)")
    return True

pipeline_start = time.time()
results = {}

# ======================================================================
# PART A: FULL PIPELINE (~2-3 hours)
# ======================================================================
print(f"\n{'#'*70}")
print(f"#  PART A: FULL PIPELINE")
print(f"#  Order: tests -> download -> train(uni+multi) -> baselines ->")
print(f"#         rolling -> robustness(needs rolling output) -> figures")
print(f"{'#'*70}")

# --- Stage 1: Quick tests ---
results['tests'] = run('Quick tests',
    f'{python} -m pytest tests/ -q --tb=short -k "not test_synthetic"')

# --- Stage 2: Download data ---
results['download'] = run('Download data',
    f'{python} {REPO_DIR}/data/download.py --mode all',
    check_files=[f'{DATA_DIR}/prices.csv', f'{DATA_DIR}/vix.csv'])

# --- Stage 3: Train univariate ---
# NOTE: --config default.yaml ensures all loss weights/training params are loaded.
# Mode-specific overrides (n_modes=5, hidden_dims, batch_size) merge from univariate.yaml.
results['train_uni'] = run('Train univariate (SPY)',
    f'{python} {REPO_DIR}/experiments/run_main.py'
    f' --config config/default.yaml --mode univariate --seed 42'
    f' --output-dir {OUTPUT_DIR}',
    check_files=[
        f'{RESULTS_DIR}/analysis_results.json',
        f'{RESULTS_DIR}/eigenvalues.csv',
        f'{RESULTS_DIR}/entropy_decomposition.csv',
        f'{RESULTS_DIR}/irreversibility_field.npy',
        f'{MODELS_DIR}/vampnet_univariate.pt',
    ])

# --- Stage 4: Train multiasset ---
results['train_multi'] = run('Train multiasset (11 ETFs)',
    f'{python} {REPO_DIR}/experiments/run_main.py'
    f' --config config/default.yaml --mode multiasset --seed 42'
    f' --output-dir {OUTPUT_DIR}',
    check_files=[f'{RESULTS_DIR}/analysis_results_multiasset.json',
                 f'{MODELS_DIR}/vampnet_multiasset.pt'])

# --- Stage 5: Baselines ---
results['baselines'] = run('Baselines',
    f'{python} {REPO_DIR}/experiments/run_baselines.py'
    f' --config config/default.yaml --output-dir {RESULTS_DIR}',
    check_files=[f'{RESULTS_DIR}/baseline_comparison.csv'])

# --- Stage 6: Rolling (BEFORE robustness -- Granger needs spectral_gap_timeseries.csv) ---
results['rolling'] = run('Rolling spectral analysis',
    f'{python} {REPO_DIR}/experiments/run_rolling.py'
    f' --config config/default.yaml --mode univariate'
    f' --checkpoint {MODELS_DIR}/vampnet_univariate.pt'
    f' --output-dir {RESULTS_DIR}',
    check_files=[f'{RESULTS_DIR}/spectral_gap_timeseries.csv'])

# --- Stage 7: Robustness (AFTER rolling -- Granger needs spectral_gap_timeseries.csv) ---
results['robustness'] = run('Robustness tests',
    f'{python} {REPO_DIR}/experiments/run_robustness.py'
    f' --config config/default.yaml --mode univariate'
    f' --checkpoint {MODELS_DIR}/vampnet_univariate.pt'
    f' --output-dir {RESULTS_DIR}',
    check_files=[f'{RESULTS_DIR}/statistical_tests.json'])

# --- Stage 7b: Robustness multiasset (stronger non-eq signal, 12/15 complex modes) ---
results['robustness_multi'] = run('Robustness tests (multiasset)',
    f'{python} {REPO_DIR}/experiments/run_robustness.py'
    f' --config config/default.yaml --mode multiasset'
    f' --checkpoint {MODELS_DIR}/vampnet_multiasset.pt'
    f' --output-dir {RESULTS_DIR}',
    check_files=[f'{RESULTS_DIR}/statistical_tests_multiasset.json'])

# --- Stage 8: Figures (script) ---
results['figures'] = run('Generate figures (script)',
    f'{python} {REPO_DIR}/experiments/run_figures.py'
    f' --results-dir {RESULTS_DIR} --figures-dir {FIGURES_DIR}')

# --- Stage 9: Figures (inline fallback) ---
print(f"\n{'='*70}")
print(f"  GENERATING FIGURES INLINE (FALLBACK)")
print(f"{'='*70}")

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

n_figs = 0

for mode_tag, label in [("univariate", "Univariate (SPY)"), ("multiasset", "Multiasset")]:
    ap = f"{RESULTS_DIR}/analysis_results_{mode_tag}.json"
    if not os.path.exists(ap):
        continue
    with open(ap) as f:
        ar = json.load(f)

    er, ei = ar.get("eigenvalues_real"), ar.get("eigenvalues_imag")
    if er and ei:
        er, ei = np.array(er), np.array(ei)
        mags = np.sqrt(er**2 + ei**2)
        fig, ax = plt.subplots(figsize=(7,7))
        th = np.linspace(0, 2*np.pi, 300)
        ax.plot(np.cos(th), np.sin(th), "k--", lw=0.8, alpha=0.5)
        sc = ax.scatter(er, ei, c=mags, cmap="viridis", edgecolors="k", linewidths=0.4, s=80, zorder=3)
        plt.colorbar(sc, ax=ax, label="|$\\lambda$|")
        for i, idx in enumerate(np.argsort(-mags)[:5]):
            ax.annotate(f"$\\lambda_{i}$", (er[idx], ei[idx]), textcoords="offset points", xytext=(8,8), fontsize=9)
        ax.set_xlabel("Re($\\lambda$)"); ax.set_ylabel("Im($\\lambda$)")
        ax.set_title(f"Koopman Eigenvalue Spectrum -- {label}"); ax.set_aspect("equal"); ax.grid(True, alpha=0.3)
        fig.savefig(f"{FIGURES_DIR}/fig1_eigenvalue_spectrum_{mode_tag}.png", dpi=300, bbox_inches="tight"); plt.close(fig); n_figs += 1

for csv_name, title, ycol in [
    ("eigenvalues.csv", "Eigenvalue Magnitudes", "magnitude"),
    ("entropy_decomposition.csv", "Entropy Decomposition", "entropy_production"),
]:
    p = f"{RESULTS_DIR}/{csv_name}"
    if os.path.exists(p):
        df = pd.read_csv(p)
        if ycol in df.columns:
            fig, ax = plt.subplots(figsize=(8,5))
            ax.bar(df["mode"], df[ycol], color="coral" if "entropy" in csv_name else "steelblue", edgecolor="black", lw=0.3)
            ax.set_xlabel("Mode"); ax.set_ylabel(ycol); ax.set_title(title); ax.grid(True, alpha=0.3, axis="y")
            fig.savefig(f"{FIGURES_DIR}/fig_{csv_name.replace('.csv','')}.png", dpi=300, bbox_inches="tight"); plt.close(fig); n_figs += 1

irp = f"{RESULTS_DIR}/irreversibility_field.npy"
if os.path.exists(irp):
    ir = np.load(irp, allow_pickle=True)
    fig, ax = plt.subplots(figsize=(14,4))
    ax.fill_between(range(len(ir)), ir, alpha=0.4, color="darkorange"); ax.plot(ir, lw=0.5, color="darkorange")
    ax.set_xlabel("Time"); ax.set_ylabel("$I(x)$"); ax.set_title("Irreversibility Field"); ax.grid(True, alpha=0.3)
    fig.savefig(f"{FIGURES_DIR}/fig_irreversibility_field.png", dpi=300, bbox_inches="tight"); plt.close(fig); n_figs += 1

rcp = f"{RESULTS_DIR}/spectral_gap_timeseries.csv"
if os.path.exists(rcp):
    rdf = pd.read_csv(rcp)
    if "spectral_gap" in rdf.columns:
        fig, ax = plt.subplots(figsize=(14,5))
        x = pd.to_datetime(rdf["center_date"]) if "center_date" in rdf.columns else range(len(rdf))
        ax.plot(x, rdf["spectral_gap"], color="steelblue", lw=1.0)
        ax.set_xlabel("Date"); ax.set_ylabel("Spectral Gap"); ax.set_title("Rolling Spectral Gap"); ax.grid(True, alpha=0.3)
        fig.savefig(f"{FIGURES_DIR}/fig_spectral_gap.png", dpi=300, bbox_inches="tight"); plt.close(fig); n_figs += 1

bcp = f"{RESULTS_DIR}/baseline_comparison.csv"
if os.path.exists(bcp):
    bdf = pd.read_csv(bcp)
    ms = [m for m in ["nber_accuracy","nber_f1","nber_precision","nber_recall"] if m in bdf.columns]
    if ms and "method" in bdf.columns:
        fig, ax = plt.subplots(figsize=(10,6))
        x = np.arange(len(bdf)); w = 0.8/len(ms)
        for i, m in enumerate(ms):
            ax.bar(x+i*w, bdf[m].astype(float), w, label=m.replace("nber_","").title(),
                   color=["steelblue","coral","seagreen","orchid"][i%4], edgecolor="black", lw=0.3)
        ax.set_xticks(x+w*(len(ms)-1)/2); ax.set_xticklabels(bdf["method"], rotation=15, ha="right")
        ax.set_ylabel("Score"); ax.set_title("Baseline Comparison"); ax.legend(); ax.set_ylim(0,1.05)
        fig.savefig(f"{FIGURES_DIR}/fig_baseline_comparison.png", dpi=300, bbox_inches="tight"); plt.close(fig); n_figs += 1

print(f"  Generated {n_figs} figures inline")

# ======================================================================
# PART A2: MULTI-SEED ERROR BARS (5 seeds total, ~1.5 hours)
# PRE requires error bars on main results. Run 4 additional seeds
# for training + analysis only (baselines/rolling/figures use seed 42).
# ======================================================================
N_MAIN_SEEDS = 5
EXTRA_SEEDS = [0, 1, 2, 3]  # + seed 42 from above = 5 total

print(f"\n{'#'*70}")
print(f"#  PART A2: MULTI-SEED ERROR BARS ({N_MAIN_SEEDS} seeds)")
print(f"#  Running {len(EXTRA_SEEDS)} additional seeds + seed 42 from above")
print(f"#  Only re-runs training + analysis (not baselines/rolling/figures)")
print(f"{'#'*70}")

multi_seed_results = {}

# Collect seed=42 results from the primary run
for mode_tag in ["univariate", "multiasset"]:
    ap = f"{RESULTS_DIR}/analysis_results_{mode_tag}.json"
    if os.path.exists(ap):
        with open(ap) as f:
            multi_seed_results.setdefault(mode_tag, {})[42] = json.load(f)

# Run extra seeds
for seed in EXTRA_SEEDS:
    seed_dir = f"{OUTPUT_DIR}/seed_{seed}"
    seed_results = f"{seed_dir}/results"
    seed_models = f"{seed_dir}/models"
    os.makedirs(seed_results, exist_ok=True)
    os.makedirs(seed_models, exist_ok=True)

    for mode_tag in ["univariate", "multiasset"]:
        # Check if already completed (resume-safe)
        seed_ap = f"{seed_results}/analysis_results_{mode_tag}.json"
        if os.path.exists(seed_ap):
            print(f"\n  --- Seed {seed}, mode={mode_tag}: ALREADY DONE (resuming) ---")
            with open(seed_ap) as f:
                multi_seed_results.setdefault(mode_tag, {})[seed] = json.load(f)
            continue

        print(f"\n  --- Seed {seed}, mode={mode_tag} ---", flush=True)
        ok = run(f'Seed {seed} {mode_tag}',
            f'{python} {REPO_DIR}/experiments/run_main.py'
            f' --config config/default.yaml --mode {mode_tag} --seed {seed}'
            f' --output-dir {seed_dir}')

        if os.path.exists(seed_ap):
            with open(seed_ap) as f:
                multi_seed_results.setdefault(mode_tag, {})[seed] = json.load(f)
        else:
            print(f"  WARNING: No results for seed {seed} {mode_tag}")

# Aggregate and report
METRICS = [
    'vamp2_score', 'spectral_gap', 'entropy_empirical', 'entropy_total',
    'mean_irreversibility', 'detailed_balance_violation',
    'fluctuation_theorem_ratio', 'n_complex_modes', 'complex_fraction',
    'ktnd_nber_accuracy', 'ktnd_nber_f1',
]

multi_seed_summary = {}
for mode_tag in ["univariate", "multiasset"]:
    if mode_tag not in multi_seed_results:
        continue
    seed_data = multi_seed_results[mode_tag]
    seeds_present = sorted(seed_data.keys())
    print(f"\n  === {mode_tag.title()}: {len(seeds_present)} seeds ({seeds_present}) ===")

    summary = {'n_seeds': len(seeds_present), 'seeds': seeds_present}
    for metric in METRICS:
        vals = [seed_data[s].get(metric) for s in seeds_present
                if seed_data[s].get(metric) is not None]
        if vals:
            vals = [float(v) for v in vals]
            mean_val = np.mean(vals)
            std_val = np.std(vals, ddof=1) if len(vals) > 1 else 0.0
            summary[f'{metric}_mean'] = float(mean_val)
            summary[f'{metric}_std'] = float(std_val)
            print(f"    {metric:35s}  {mean_val:.4f} +/- {std_val:.4f}  (n={len(vals)})")

    multi_seed_summary[mode_tag] = summary

# Save aggregated results
ms_path = f"{RESULTS_DIR}/multi_seed_summary.json"
with open(ms_path, 'w') as f:
    json.dump(multi_seed_summary, f, indent=2, default=str)
print(f"\n  Saved: {ms_path}")
results['multi_seed'] = os.path.exists(ms_path)


# ======================================================================
# FINAL REPORT
# ======================================================================

print(f"\n{'='*70}")
print(f"  ALL OUTPUT FILES")
print(f"{'='*70}")
for dirpath, dirnames, filenames in os.walk(OUTPUT_DIR):
    for f in sorted(filenames):
        fp = os.path.join(dirpath, f)
        sz = os.path.getsize(fp)
        rel = os.path.relpath(fp, OUTPUT_DIR)
        print(f"  {sz:>10,} bytes  {rel}")

total_min = (time.time() - pipeline_start) / 60
n_ok = sum(v for v in results.values() if isinstance(v, bool) and v)
n_total = len(results)

print(f"\n{'='*70}")
print(f"  COMPLETE: {n_ok}/{n_total} stages passed ({total_min:.1f} min total)")
print(f"  Version: v1.5.3")
print(f"{'='*70}")
for name, ok in results.items():
    print(f"  {'OK' if ok else 'FAIL':6s}  {name}")

# Print single-seed results for BOTH modes
for mode_tag, label in [("univariate", "Univariate (SPY)"), ("multiasset", "Multiasset (11 ETFs)")]:
    ap = f"{RESULTS_DIR}/analysis_results_{mode_tag}.json"
    if not os.path.exists(ap):
        continue
    with open(ap) as f:
        r = json.load(f)
    print(f"\n  === {label} (seed 42) ===")
    print(f"    VAMP-2 score:         {r.get('vamp2_score', 'N/A')}")
    print(f"    Spectral gap:         {r.get('spectral_gap', 'N/A')}")
    print(f"    Entropy (empirical):  {r.get('entropy_empirical', 'N/A')} "
          f"[{r.get('entropy_ci_lower', '?')}, {r.get('entropy_ci_upper', '?')}] 95% CI")
    print(f"    Spectral entropy:     {r.get('entropy_total', 'N/A')}")
    print(f"    Mean irreversibility: {r.get('mean_irreversibility', 'N/A')}")
    print(f"    Irrev method:         {r.get('irrev_method', 'N/A')}")
    print(f"    DB violation:         {r.get('detailed_balance_violation', 'N/A')}")
    print(f"    Complex modes:        {r.get('n_complex_modes', 'N/A')}/{r.get('n_modes', 'N/A')}")
    print(f"    FT ratio:             {r.get('fluctuation_theorem_ratio', 'N/A')}")

    ktnd_acc = r.get('ktnd_nber_accuracy')
    if ktnd_acc is not None:
        print(f"    KTND NBER accuracy:   {ktnd_acc:.3f}")
        print(f"    KTND NBER F1:         {r.get('ktnd_nber_f1', 'N/A'):.3f}")
        print(f"    KTND naive accuracy:  {r.get('ktnd_naive_accuracy', 'N/A'):.3f}")
        print(f"    Mean regime duration: {r.get('ktnd_mean_regime_duration', 'N/A'):.1f} days")

# Print multi-seed aggregated results
ms_path = f"{RESULTS_DIR}/multi_seed_summary.json"
if os.path.exists(ms_path):
    with open(ms_path) as f:
        ms = json.load(f)
    print(f"\n  === MULTI-SEED SUMMARY (mean +/- std) ===")
    for mode_tag in ["univariate", "multiasset"]:
        if mode_tag not in ms:
            continue
        s = ms[mode_tag]
        label = "Univariate (SPY)" if mode_tag == "univariate" else "Multiasset (11 ETFs)"
        print(f"\n  {label} ({s.get('n_seeds', '?')} seeds):")
        for metric in METRICS:
            mk, sk = f'{metric}_mean', f'{metric}_std'
            if mk in s:
                print(f"    {metric:35s}  {s[mk]:.4f} +/- {s[sk]:.4f}")

stat_path = f"{RESULTS_DIR}/statistical_tests.json"
if os.path.exists(stat_path):
    with open(stat_path) as f:
        st = json.load(f)
    print(f"\n  === Statistical Tests ===")
    for k, v in st.items():
        if isinstance(v, dict):
            if v.get('skipped'):
                print(f"    {k}: SKIPPED ({v.get('reason', '')})")
            else:
                pval = v.get('p_value', v.get('pvalue', None))
                if pval is not None:
                    print(f"    {k}: p={pval:.4f}")
        elif isinstance(v, (int, float)):
            print(f"    {k}: {v}")

print(f"\n{'='*70}")
print(f"  DONE. Total wall time: {total_min:.1f} min")
print(f"  Next: Cell 3 (ablations, optional) -> Cell 4 (view figures) -> Cell 5 (download)")
print(f"{'='*70}")


In [None]:
#@title 3. Ablations + Brownian gyrator (~6 hours, resume-safe, optional)

import subprocess, time, json, os, sys
import numpy as np
import pandas as pd

REPO_DIR = "/content/ktnd_finance"
OUTPUT_DIR = "/content/ktnd_finance/outputs"
RESULTS_DIR = "/content/ktnd_finance/outputs/results"
N_ABLATION_SEEDS = 10

os.chdir(REPO_DIR)
python = sys.executable

def run_streaming(name, cmd, check_files=None):
    """Run a stage with LIVE output streaming."""
    print(f"\n{'='*70}")
    print(f"  STAGE: {name}")
    print(f"  CMD: {cmd}")
    print(f"{'='*70}", flush=True)
    t0 = time.time()
    proc = subprocess.Popen(
        cmd, shell=True, cwd=REPO_DIR,
        stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
        text=True, bufsize=1,
    )
    for line in proc.stdout:
        print(f"  {line}", end="", flush=True)
    proc.wait()
    elapsed = time.time() - t0
    if proc.returncode != 0:
        print(f"  >> {name}: FAILED (exit code {proc.returncode}, {elapsed/60:.1f} min)")
        return False
    if check_files:
        missing = [f for f in check_files if not os.path.exists(f)]
        if missing:
            for f in missing:
                print(f"    MISSING: {f}")
            print(f"  >> {name}: INCOMPLETE ({elapsed/60:.1f} min)")
            return False
        for f in check_files:
            sz = os.path.getsize(f)
            print(f"  OK: {os.path.basename(f)} ({sz:,} bytes)")
    print(f"  >> {name}: OK ({elapsed/60:.1f} min)")
    return True

def run(name, cmd, check_files=None):
    """Run a stage, print output, verify files."""
    print(f"\n{'='*70}")
    print(f"  STAGE: {name}")
    print(f"  CMD: {cmd}")
    print(f"{'='*70}")
    t0 = time.time()
    result = subprocess.run(cmd, shell=True, cwd=REPO_DIR, capture_output=True, text=True)
    elapsed = time.time() - t0
    if result.stdout:
        for line in result.stdout.strip().split('\n'):
            print(f"  {line}")
    if result.returncode != 0:
        if result.stderr:
            for line in result.stderr.strip().split('\n')[-20:]:
                print(f"  ! {line}")
        print(f"  >> {name}: FAILED ({elapsed/60:.1f} min)")
        return False
    if check_files:
        missing = [f for f in check_files if not os.path.exists(f)]
        if missing:
            for f in missing:
                print(f"    MISSING: {f}")
            return False
    print(f"  >> {name}: OK ({elapsed/60:.1f} min)")
    return True

t_start = time.time()
results = {}

# ======================================================================
# PART 1: ABLATION STUDY (10 seeds x ~32 variants, ~6 hours)
# Resume-safe: saves after each variant. Re-run to pick up where it left off.
# ======================================================================
print(f"\n{'#'*70}")
print(f"#  ABLATION STUDY ({N_ABLATION_SEEDS} seeds x ~32 variants)")
print(f"#  Resume-safe: skips completed variants")
print(f"{'#'*70}")

results['ablations'] = run_streaming(f'Ablations ({N_ABLATION_SEEDS} seeds)',
    f'{python} -u experiments/run_ablations.py --config config/default.yaml'
    f' --n-seeds {N_ABLATION_SEEDS} --n-jobs 1'
    f' --output-dir {RESULTS_DIR}',
    check_files=[f'{RESULTS_DIR}/ablation_summary.csv'])

summary_path = f"{RESULTS_DIR}/ablation_summary.csv"
if os.path.exists(summary_path):
    abl_df = pd.read_csv(summary_path)
    print(f"\n  {len(abl_df)} ablation variants ({N_ABLATION_SEEDS} seeds each):")
    cols = ['name', 'n_valid', 'vamp2_mean', 'vamp2_std',
            'spectral_gap_mean', 'spectral_gap_std',
            'entropy_total_mean', 'entropy_total_std']
    cols = [c for c in cols if c in abl_df.columns]
    print(abl_df[cols].to_string(index=False))

    if 'vamp2_mean' in abl_df.columns:
        baseline = abl_df[abl_df['name'] == 'baseline']
        if len(baseline) > 0:
            bl_vamp2 = baseline['vamp2_mean'].values[0]
            print(f"\n  Baseline VAMP-2: {bl_vamp2:.4f}")
            diff = abl_df.copy()
            diff['vamp2_delta'] = ((diff['vamp2_mean'] - bl_vamp2) / abs(bl_vamp2) * 100)
            notable = diff[abs(diff['vamp2_delta']) > 5].sort_values('vamp2_delta')
            if len(notable) > 0:
                print(f"\n  Variants with >5% VAMP-2 change from baseline:")
                for _, row in notable.iterrows():
                    print(f"    {row['name']:40s}  {row['vamp2_delta']:+.1f}%")

# ======================================================================
# PART 2: BROWNIAN GYRATOR BENCHMARK (~5 min)
# ======================================================================
print(f"\n{'#'*70}")
print(f"#  BROWNIAN GYRATOR -- analytical EP benchmark")
print(f"#  2D coupled OU, T1!=T2 breaks detailed balance")
print(f"{'#'*70}")

from scipy.linalg import solve_continuous_lyapunov

def analytical_ep(T1, T2, k=1.0, kappa=0.5):
    A = np.array([[k, -kappa], [-kappa, k]])
    D = np.array([[T1, 0.0], [0.0, T2]])
    Sigma = solve_continuous_lyapunov(A, 2.0 * D)
    Q = A - D @ np.linalg.inv(Sigma)
    D_inv = np.diag([1.0/T1, 1.0/T2])
    return np.trace(Q @ Sigma @ Q.T @ D_inv)

print("\n  Analytical EP rates:")
for T2 in [1.0, 1.5, 3.0, 5.0]:
    ep = analytical_ep(1.0, T2)
    print(f"    T1=1.0, T2={T2:.1f}:  EP = {ep:.6f}  {'(equilibrium)' if T2 == 1.0 else ''}")

results['gyrator'] = run('Brownian gyrator tests',
    f'{python} -m pytest tests/test_synthetic.py::TestBrownianGyrator -v')

# ======================================================================
# SUMMARY
# ======================================================================
total_min = (time.time() - t_start) / 60
print(f"\n{'='*70}")
print(f"  ABLATIONS + GYRATOR COMPLETE ({total_min:.1f} min)")
print(f"{'='*70}")
for name, ok in results.items():
    print(f"  {'OK' if ok else 'FAIL':6s}  {name}")
print(f"\n  Next: Cell 4 (view figures) -> Cell 5 (download)")
print(f"{'='*70}")


In [None]:
#@title 4. View figures (run after Cell 2 finishes)

import glob, os
from IPython.display import Image, display

FIGURES_DIR = "/content/ktnd_finance/outputs/figures"

pngs = sorted(glob.glob(f"{FIGURES_DIR}/*.png"))
sup_dir = os.path.join(FIGURES_DIR, "supplemental")
if os.path.exists(sup_dir):
    pngs += sorted(glob.glob(f"{sup_dir}/*.png"))

if pngs:
    print(f"Found {len(pngs)} figures:\n")
    for p in pngs:
        print(f"--- {os.path.basename(p)} ---")
        display(Image(filename=p, width=800))
        print()
else:
    print("No figures found. Make sure Cell 2 has finished running first.")
    print(f"Checked: {FIGURES_DIR}")
    results_dir = "/content/ktnd_finance/outputs/results"
    if os.path.exists(results_dir):
        files = os.listdir(results_dir)
        print(f"Result files available ({len(files)}): {files}")
    else:
        print("No results directory found - Cell 2 needs to run first.")


In [None]:
#@title 5. Download all results as zip

!cd /content/ktnd_finance && zip -rq /content/ktnd_results.zip outputs/
from google.colab import files
files.download('/content/ktnd_results.zip')
print("Download started.")