In [None]:
#!/usr/bin/env python
"""
STAGE 07 · RISE-PROBABILITY EXPLORATION & LEADERBOARDS
───────────────────────────────────────────────────────
For every prediction file produced in Stage-06 this notebook writes

    Stage7_Rank_<flavour>_<YEAR>.csv
    Stage7_HistGrid_<flavour>_<YEAR>.png
    Stage7_Boxplot_<flavour>_<YEAR>.png

If **both** speed-domain and depth-domain are present it also writes

    Stage7_Speed_vs_Depth_<YEAR>.csv
    Stage7_SpeedDepth_DiffBar_<YEAR>.png
"""
from __future__ import annotations
from pathlib import Path
from typing  import Dict, List
import logging, os, warnings, math

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from pipeline_utils import load_cfg, resolve_run_dir
warnings.filterwarnings("ignore", category=RuntimeWarning)
plt.rcParams["figure.dpi"] = 110
sns.set_style("whitegrid")

# ══════════════════════ 0 · CONFIG / PATHS ════════════════════════
CFG        = load_cfg()
EVENTS     = {str(k): v for k, v in CFG["events"].items()}

SWAN_YEAR  = str(os.getenv("SWAN_YEAR") or next(iter(EVENTS)))
RUN_DIR    = resolve_run_dir(
                swan_year = SWAN_YEAR,
                run_tag   = os.getenv("RUN_TAG"),
                must_have = f"stage06/Stage6_RISE_Predictions_{SWAN_YEAR}.csv",
)
STAGE_DIR  = RUN_DIR / "stage07"; STAGE_DIR.mkdir(exist_ok=True)

logging.basicConfig(level=logging.INFO,
                    format="%(asctime)s | %(levelname)-7s | %(message)s",
                    handlers=[logging.FileHandler(STAGE_DIR/"stage07.log","w","utf-8"),
                              logging.StreamHandler()])
log = logging.getLogger(__name__)
log.info("==========  STAGE 07 (%s) ==========", SWAN_YEAR)

# ══════════════════════ 1 · PREDICTION FLAVOURS ══════════════════
# key → (file-path,  prob-suffix,  pretty-title)
FLAV_INFO: Dict[str, Dict] = {
    # DOMAIN families
    "speed" : {"file": RUN_DIR/f"stage06/Stage6_RISE_Predictions_{SWAN_YEAR}.csv",
               "suffix":"_RISE_prob",
               "title" :"Speed – domain"},
    "depth" : {"file": RUN_DIR/f"stage06/Stage6C_Depth_RISE_Predictions_{SWAN_YEAR}.csv",
               "suffix":"_DepthRISE_prob",
               "title" :"Depth – domain"},
    "blend" : {"file": RUN_DIR/f"stage06/Stage6E_Blend_RISE_Predictions_{SWAN_YEAR}.csv",
               "suffix":"_blendRISE_prob",
               "title" :"Blend (domain)"},
    # STAGE families (kept separate so downstream authors can study them)
    "speedStage" : {"file": RUN_DIR/f"stage06/Stage6B_Stage_RISE_Predictions_{SWAN_YEAR}.csv",
                    "suffix":"_StageRISE_prob",
                    "title" :"Speed – stage"},
    "depthStage" : {"file": RUN_DIR/f"stage06/Stage6D_DepthStage_RISE_Predictions_{SWAN_YEAR}.csv",
                    "suffix":"_DepthStageRISE_prob",
                    "title" :"Depth – stage"},
    "blendStage" : {"file": RUN_DIR/f"stage06/Stage6F_BlendStage_RISE_Predictions_{SWAN_YEAR}.csv",
                    "suffix":"_blendStageRISE_prob",
                    "title" :"Blend (stage)"},
}

ID_COL   = "Symbol"
DATE_COL = "ReportDate"

# ══════════════════════ 2 · HELPERS ═══════════════════════════════
def make_leaderboard(df: pd.DataFrame, probs: List[str], suf: str) -> pd.DataFrame:
    lb = (df[probs].mean()
            .rename("MeanProb")
            .sort_values(ascending=False)
            .reset_index()
            .rename(columns={"index":"Metric"}))
    lb["Metric"] = lb["Metric"].str.replace(suf, "", regex=False)
    return lb

def hist_grid(df: pd.DataFrame, cols: List[str], title: str, out: Path) -> None:
    n, c = len(cols), 4
    r    = math.ceil(n / c)
    fig, axs = plt.subplots(r, c, figsize=(4*c, 3*r))
    for ax, col in zip(axs.flatten(), cols):
        df[col].dropna().hist(ax=ax, bins=20, edgecolor="k")
        ax.set_xlim(0,1)
        ax.set_title(col.replace("_".join(col.split("_")[-3:]), ""))
    for ax in axs.flatten()[n:]:
        ax.axis("off")
    plt.suptitle(title, y=1.02)
    plt.tight_layout()
    plt.savefig(out, dpi=110)
    plt.close()

def box_plot(df: pd.DataFrame, cols: List[str], title: str, out: Path) -> None:
    plt.figure(figsize=(max(6, .6*len(cols)), 5))
    df[cols].boxplot(rot=45)
    plt.title(title); plt.ylabel("Predicted probability")
    plt.tight_layout(); plt.savefig(out, dpi=110); plt.close()

# ══════════════════════ 3 · MAIN LOOP ════════════════════════════
leaderboards: Dict[str, pd.DataFrame] = {}

for flav, info in FLAV_INFO.items():
    fp = info["file"]
    if not fp.exists() or fp.stat().st_size == 0:
        log.warning("%s missing/empty – %s flavour skipped", fp.name, flav)
        continue

    df   = pd.read_csv(fp)
    suf  = info["suffix"]
    cols = [c for c in df.columns if c.endswith(suf)]
    if not cols:
        log.warning("No columns with suffix %s in %s", suf, fp.name); continue

    # 1. leaderboard
    lb = make_leaderboard(df, cols, suf)
    lb.to_csv(STAGE_DIR/f"Stage7_Rank_{flav}_{SWAN_YEAR}.csv", index=False)

    # 2. plots
    hist_grid(df, cols, f"{info['title']} – distribution by metric",
              STAGE_DIR/f"Stage7_HistGrid_{flav}_{SWAN_YEAR}.png")
    box_plot(df, cols, f"{info['title']} – box-plot",
             STAGE_DIR/f"Stage7_Boxplot_{flav}_{SWAN_YEAR}.png")

    leaderboards[flav] = lb
    log.info("✓ outputs for %s flavour written", flav)

# ══════════════════════ 4 · SPEED vs DEPTH  (domain only) ═════════
if {"speed", "depth"} <= leaderboards.keys():
    comp = (leaderboards["speed"]
              .merge(leaderboards["depth"],
                     on="Metric", suffixes=("_Speed","_Depth"))
              .assign(Diff=lambda d: d["MeanProb_Speed"] - d["MeanProb_Depth"])
              .sort_values("MeanProb_Speed", ascending=False))
    comp.to_csv(STAGE_DIR/f"Stage7_Speed_vs_Depth_{SWAN_YEAR}.csv", index=False)

    plt.figure(figsize=(10,4))
    sns.barplot(data=comp, x="Metric", y="Diff", palette="vlag")
    plt.title("Speed minus Depth (mean probability) – domain family")
    plt.ylabel("Δ (Speed – Depth)"); plt.axhline(0,color="k")
    plt.xticks(rotation=45, ha="right"); plt.tight_layout()
    plt.savefig(STAGE_DIR/f"Stage7_SpeedDepth_DiffBar_{SWAN_YEAR}.png", dpi=110)
    plt.close()
    log.info("✓ speed-vs-depth comparison written")

log.info("🎉  Stage 07 complete – artefacts in %s", STAGE_DIR)