In [None]:
#!/usr/bin/env python
"""
STAGE-06B · STAGE-WEIGHTED RISE PREDICTIONS
v5.1 – 2025-07-04
────────────────────────────────────────────────────────
Runs the stage-score logits (from 05C) on FY-(SWAN-1) data.
Outputs one CSV + diagnostics per flavour (Temporal / Impact / Dynamic).
"""

from __future__ import annotations
import logging, os, warnings
from pathlib import Path
from typing import Dict, List

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from pipeline_utils import load_cfg, resolve_run_dir
warnings.filterwarnings("ignore", category=RuntimeWarning)
plt.rcParams["figure.dpi"] = 110

# ═════════════ 0 · PATHS / LOGGING ═════════════════════════════
CFG     = load_cfg()
EVENTS  = {str(k): v for k, v in CFG.get("events", {}).items()}

SWAN       = os.getenv("SWAN_YEAR") or next(iter(EVENTS))
SNAP_YEAR  = int(SWAN) - 1

RUN_DIR = resolve_run_dir(
    swan_year = SWAN,
    run_tag   = os.getenv("RUN_TAG"),
    must_have = f"stage05b/05B_AllScores_{SWAN}.csv",
)

ST05B = RUN_DIR / "stage05b"
ST05C = RUN_DIR / "stage05c"
OUT   = RUN_DIR / "stage06b";  OUT.mkdir(exist_ok=True)

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s | %(levelname)-7s | %(message)s",
    handlers=[logging.FileHandler(OUT / "stage06b.log", "w", "utf-8"),
              logging.StreamHandler()],
)
log = logging.getLogger(__name__)
log.info("==========  STAGE-06B  (SWAN %s) ==========", SWAN)

# ═════════════ 1 · SNAPSHOT DATA ═══════════════════════════════
DATE_COL, ID_COL = "ReportDate", "Symbol"
snap = pd.read_csv(ST05B / f"05B_AllScores_{SWAN}.csv")
if "Year" not in snap.columns:
    snap["Year"] = pd.to_datetime(snap[DATE_COL], errors="coerce").dt.year
snap = snap[snap["Year"] == SNAP_YEAR].copy()
if snap.empty:
    raise RuntimeError(f"No FY-{SNAP_YEAR} rows in 05B_AllScores_{SWAN}.csv")

# flip scores so 5 = best
for c in [c for c in snap.columns if c.endswith("_Score") or c.endswith("_Q")]:
    snap[c] = 6 - snap[c]
log.info("Snapshot FY-%d  (%d rows × %d cols)", SNAP_YEAR, *snap.shape)

# ═════════════ 2 · CONSTANTS ═══════════════════════════════════
METRICS = [
    "NetIncome","EarningBeforeInterestAndTax","OperatingIncome","EBITDA",
    "OperatingCashFlow","FreeCashFlow","Cash","CashAndCashEquivalents",
    "TotalRevenue","GrossProfit"
]
FLAVOURS: Dict[str, Dict] = {
    "Temporal": dict(code="temporal",
                     suf="_StageTempRISE_prob",
                     out=f"Stage6B_TemporalStage_RISE_Predictions_{SWAN}.csv"),
    "Impact":   dict(code="impact",
                     suf="_StageImpactRISE_prob",
                     out=f"Stage6B_ImpactStage_RISE_Predictions_{SWAN}.csv"),
    "Dynamic":  dict(code="dynamic",
                     suf="_StageDynRISE_prob",
                     out=f"Stage6B_DynamicStage_RISE_Predictions_{SWAN}.csv"),
}
STG_COLS = ["Prepare_Score", "Absorb_Score", "Recover_Score", "Adapt_Score"]

# ═════════════ 3 · HELPERS ═════════════════════════════════════
def coef_path(code: str, metric: str) -> Path | None:
    p = ST05C / f"Stage05C_Stage_{code}_{metric}_{SWAN}_Coefficients.csv"
    return p if p.exists() else None

def zscore(df: pd.DataFrame, cols: List[str]) -> pd.DataFrame:
    Z = df[cols].copy()
    for c in cols:
        mu, sd = Z[c].mean(), Z[c].std(ddof=0)
        Z[c] = (Z[c] - mu) / (sd + 1e-9)
    return Z

def diagnostics(df: pd.DataFrame, prob_cols: List[str], tag: str) -> None:
    """Write summary, bar, hist-grid, top/bottom-10 for one flavour."""
    if not prob_cols:
        log.warning("No probability columns for %s – diagnostics skipped", tag)
        return

    mean_col = f"Mean{tag}Prob"
    df[mean_col] = df[prob_cols].mean(axis=1)

    # summary CSV
    (df[prob_cols].describe(percentiles=[.25,.5,.75]).T.round(3)
        .to_csv(OUT / f"Stage6{tag}_Summary_Probs_{SWAN}.csv"))

    # mean bar
    (df[prob_cols].mean().sort_values(ascending=False)
       .plot(kind="bar", figsize=(10,4)))
    plt.title(f"{tag} – mean predicted probability by metric")
    plt.ylim(0,1); plt.ylabel("Probability"); plt.tight_layout()
    plt.savefig(OUT / f"Stage6{tag}_MeanBar_{SWAN}.png"); plt.close()

    # histogram grid
    r, c = -(-len(prob_cols)//4), 4
    fig, axarr = plt.subplots(r, c, figsize=(4*c, 3*r))
    for ax, col in zip(axarr.flatten(), prob_cols):
        ax.hist(df[col].dropna(), bins=20, edgecolor="k")
        ax.set_title(col.replace("_prob","")); ax.set_xlim(0,1)
    for ax in axarr.flatten()[len(prob_cols):]:
        ax.axis("off")
    plt.tight_layout()
    fig.savefig(OUT / f"Stage6{tag}_HistGrid_{SWAN}.png"); plt.close()

    # top / bottom-10 CSVs
    df.nlargest(10, mean_col)[[ID_COL, mean_col]] \
      .to_csv(OUT / f"Stage6{tag}_Top10_{SWAN}.csv", index=False)
    df.nsmallest(10, mean_col)[[ID_COL, mean_col]] \
      .to_csv(OUT / f"Stage6{tag}_Bottom10_{SWAN}.csv", index=False)

# ═════════════ 4 · MAIN LOOP ═══════════════════════════════════
for tag, cfg in FLAVOURS.items():
    flav_code, suf, out_name = cfg["code"], cfg["suf"], cfg["out"]
    out_df   = snap.copy()
    prob_cols: List[str] = []

    for met in METRICS:
        cp = coef_path(flav_code, met)
        if cp is None:                           # model not trained
            continue
        beta = pd.read_csv(cp).set_index("Term")["Coefficient"]

        Z   = zscore(out_df, STG_COLS)
        lin = np.full(len(out_df), beta.get("const", 0.0))
        for term in STG_COLS:
            if term not in beta:
                log.warning("%s: %s missing for %s", tag, term, met)
                break
            lin += beta[term] * Z[term]
        else:                                    # all predictors OK
            lin_col  = f"{met}{suf.replace('prob','linpred')}"
            prob_col = f"{met}{suf}"
            out_df[lin_col]  = lin
            out_df[prob_col] = 1 / (1 + np.exp(-lin))
            prob_cols.append(prob_col)

    if prob_cols:
        out_df.to_csv(OUT / out_name, index=False)
        log.info("✓ %-8s CSV written (%d metrics)", tag, len(prob_cols))
        diagnostics(out_df, prob_cols, tag)
    else:
        log.warning("⚠️  No probabilities generated for %s flavour", tag)

log.info("🎉  STAGE-06B complete — artefacts in %s", OUT)
