In [None]:
# diag_stress.py — stress the slow paths until they squeal
import time, types, joblib, inspect, sys, concurrent.futures as cf
import numpy as np
import pandas as pd
from collections import Counter
from sklearn.naive_bayes import CategoricalNB
import AtBatSim

CALL_LIMIT = 100.0  # seconds per call

def load_pack(path: str):
    return types.SimpleNamespace(**joblib.load(path))

def need(df, cols, name):
    miss = [c for c in cols if c not in df.columns]
    if miss: raise RuntimeError(f"{name} missing {miss}")

def run_with_timeout(func, timeout, desc, *args, **kwargs):
    with cf.ThreadPoolExecutor(max_workers=1) as ex:
        fut = ex.submit(func, *args, **kwargs)
        try:
            return fut.result(timeout=timeout)
        except cf.TimeoutError:
            raise TimeoutError(f"CALL TIMEOUT ({timeout:.0f}s): {desc}")

def predict_safe(model, **features):
    if hasattr(model, "model") and hasattr(model.model, "exog_names"):
        exog = [n for n in model.model.exog_names if n != "Intercept"]
        miss = [c for c in exog if c not in features]
        if miss:
            aliases = {"team_woba":["woba","team_wOBA","opp_woba","team_woba_val"],
                       "ip":["innings","IP","simulated_ip"]}
            for t, alts in aliases.items():
                if t in exog and t in miss:
                    for a in alts:
                        if a in features: features[t] = features[a]; break
            miss = [c for c in exog if c not in features]
            if miss: raise ValueError(f"predict_safe missing {miss}")
        df = pd.DataFrame({c:[features[c]] for c in exog})
        yhat = model.predict(df)
        arr = np.asarray(yhat)
        return float(arr.ravel()[0] if arr.size else yhat.iloc[0])
    if not features: raise ValueError("predict_safe: no features")
    n_expected = getattr(model, "n_features_in_", None)
    keys = sorted(features.keys())
    if n_expected is not None: keys = keys[:n_expected]
    X = [[features[k] for k in keys]]
    yhat = model.predict(X)
    return float(np.asarray(yhat).ravel()[0])

def build_models_from_pitcher_df(df):
    need(df, ["stand_enc","count_enc","arch_enc","pitch_cluster_enc","zone"], "pitcher_data_arch")
    pitch_tbl = (df.assign(_one=1)
                   .pivot_table(index=["stand_enc","count_enc","arch_enc"],
                                columns="pitch_cluster_enc", values="_one",
                                aggfunc="sum", fill_value=0))
    pitch_lookup = {
        tuple(map(int, idx)): Counter({int(k): int(v) for k, v in row.items() if v})
        for idx, row in pitch_tbl.iterrows()
    }
    nb_pitch = CategoricalNB().fit(
        df[["stand_enc","count_enc","arch_enc"]].astype(int).values,
        df["pitch_cluster_enc"].astype(int).values
    )
    dfz = df[df["zone"].between(1,14)].copy()
    if dfz.empty: raise RuntimeError("No valid zones 1..14")
    dfz["zone_enc"] = dfz["zone"].astype(int) - 1
    zone_tbl = (dfz.assign(_one=1)
                   .pivot_table(index=["pitch_cluster_enc","count_enc","stand_enc"],
                                columns="zone_enc", values="_one",
                                aggfunc="sum", fill_value=0))
    zone_lookup = {
        tuple(map(int, idx)): Counter({int(k): int(v) for k, v in row.items() if v})
        for idx, row in zone_tbl.iterrows()
    }
    nb_zone = CategoricalNB().fit(
        dfz[["pitch_cluster_enc","count_enc","stand_enc"]].astype(int).values,
        dfz["zone_enc"].astype(int).values
    )
    return nb_pitch, pitch_lookup, nb_pitch.classes_, nb_zone, zone_lookup, nb_zone.classes_

def call_kwargs(nbp, plook, pcl, nbz, zlook, zcl):
    params = set(inspect.signature(AtBatSim.simulate_at_bat_between).parameters)
    base = dict(
        nb_pitch_model=nbp, pitch_lookup_table=plook, pitch_class_labels=pcl,
        nb_zone_model=nbz,  zone_lookup_table=zlook,  zone_class_labels=zcl,
        verbose=False, verbose_audit=False
    )
    return {k:v for k,v in base.items() if k in params}, ("rng" in params)

# --------- 1) AB Stress ---------
def stress_ab(hitter, pitcher, nbp, plook, pcl, nbz, zlook, zcl, n_ab=2000, seed=7):
    rng = np.random.default_rng(seed)
    kwargs, has_rng = call_kwargs(nbp, plook, pcl, nbz, zlook, zcl)
    times = []
    n_hits = 0
    for i in range(n_ab):
        t0 = time.perf_counter()
        desc = f"AB stress (i={i+1}/{n_ab})"
        if has_rng:
            result, audit = run_with_timeout(
                lambda: AtBatSim.simulate_at_bat_between(hitter, pitcher, rng=rng, **kwargs),
                CALL_LIMIT, desc
            )
        else:
            result, audit = run_with_timeout(
                lambda: AtBatSim.simulate_at_bat_between(hitter, pitcher, **kwargs),
                CALL_LIMIT, desc
            )
        times.append(time.perf_counter() - t0)
        if result == "HIT": n_hits += 1
        if (i+1) % 200 == 0:
            print(f"[AB STRESS] {i+1}/{n_ab} done; last 200 avg={np.mean(times[-200:])*1000:.2f} ms")
    arr = np.array(times)
    print(f"\n[AB STRESS SUMMARY] n={n_ab} | hit_rate={n_hits/n_ab:.3f} | "
          f"mean={arr.mean()*1000:.2f} ms | p95={np.percentile(arr,95)*1000:.2f} ms | max={arr.max()*1000:.2f} ms")

# --------- 2) Multi-Game Sweep ---------
def sweep_games(hitter, pitcher, nbp, plook, pcl, nbz, zlook, zcl, trials=200, seed=11):
    rng = np.random.default_rng(seed)
    kwargs, has_rng = call_kwargs(nbp, plook, pcl, nbz, zlook, zcl)
    def round_to_thirds(ip): return round(ip*3)/3
    def outs_from_ip(ip):
        whole, frac = divmod(round(ip*10), 10)
        return whole*3 + (2 if frac==2 else 1 if frac==1 else 0)

    BF_PER_OUT = np.asarray(getattr(pitcher, "bf_per_out", []), dtype=float)
    if BF_PER_OUT.size == 0: BF_PER_OUT = np.array([1.0])

    HOME_EXTRAS = getattr(pitcher, "home_IP_extras", [])
    AWAY_EXTRAS = getattr(pitcher, "away_IP_extras", [])
    P_EXTRAS    = float(getattr(pitcher, "prob_extra_innings", 0.09) or 0.09)

    team_to_abbr = {"Mets":"NYM","Tigers":"DET","Braves":"ATL","Yankees":"NYY"}
    team_woba = {"NYM":0.317,"DET":0.322,"ATL":0.311,"NYY":0.337}
    abbr = team_to_abbr.get(hitter.team_name, "NYM")
    team_woba_val = team_woba.get(abbr, 0.320)

    IP_model = getattr(pitcher, "IPLinReg", None)
    BF_model = getattr(pitcher, "poisson_model", None)
    if IP_model is None or BF_model is None:
        raise RuntimeError("Pitcher pack missing IP/BF models (IPLinReg/poisson_model).")
    ip_sigma = float(getattr(pitcher, "ip_std", 1.0))
    xba = float(getattr(hitter, "xba", 0.300) or 0.300)
    spot = int(getattr(hitter, "most_recent_spot", 3) or 3)
    is_home = True

    totals = dict(total_time=0.0, sp_time=0.0, ab_time=0.0, rp_time=0.0, ex_time=0.0)
    worst = dict(section="", trial=None, seconds=0.0)

    for t in range(trials):
        t_total0 = time.perf_counter()

        # SP segment
        t_sp0 = time.perf_counter()
        expected_ip = predict_safe(IP_model, team_woba=team_woba_val, team_woba_val=team_woba_val)
        simulated_ip = float(np.clip(round_to_thirds(rng.normal(expected_ip, ip_sigma)), 0.0, 9.0))
        expected_bf = predict_safe(BF_model, ip=simulated_ip, simulated_ip=simulated_ip)
        simulated_bf = int(rng.poisson(expected_bf))
        full_cycles = simulated_bf // 9
        remainder   = simulated_bf % 9
        pa_vs_sp    = full_cycles + (1 if spot <= remainder else 0)

        # ABs vs SP
        t_ab0 = time.perf_counter()
        for _ in range(pa_vs_sp):
            desc = f"SWEEP: AB (trial={t+1}, pa_vs_sp={pa_vs_sp})"
            if has_rng:
                run_with_timeout(lambda: AtBatSim.simulate_at_bat_between(hitter, pitcher, rng=rng, **kwargs),
                                 CALL_LIMIT, desc)
            else:
                run_with_timeout(lambda: AtBatSim.simulate_at_bat_between(hitter, pitcher, **kwargs),
                                 CALL_LIMIT, desc)
        t_ab = time.perf_counter() - t_ab0

        t_sp = time.perf_counter() - t_sp0

        # RP segment
        t_rp0 = time.perf_counter()
        rel_ip = 9 - simulated_ip if (not is_home or rng.random() > 0.5) else 8 - simulated_ip
        rel_ip = max(0.0, rel_ip)
        n = int(rel_ip); frac = rel_ip - n
        if 0.3 <= frac < 0.5: rel_ip = n + 0.1
        elif 0.5 <= frac < 0.7: rel_ip = n + 0.2
        outs_req = outs_from_ip(rel_ip)
        bp_bf = 0
        if outs_req > 0:
            s = rng.choice(BF_PER_OUT, size=outs_req, replace=True)
            if s[-1] == 0.5:
                while True:
                    new = rng.choice(BF_PER_OUT, size=1)[0]
                    if new != 0.5: s[-1] = new; break
            bp_bf = int(s.sum())
        # count PA vs RP
        pa_vs_rp = 0
        next_spot = (simulated_bf % 9) + 1
        for i in range(bp_bf):
            if ((next_spot + i - 1) % 9 + 1) == spot:
                pa_vs_rp += 1
        t_rp = time.perf_counter() - t_rp0

        # Extras segment
        t_ex0 = time.perf_counter()
        if (HOME_EXTRAS or AWAY_EXTRAS) and (np.random.random() < P_EXTRAS):
            pool = HOME_EXTRAS if not is_home else AWAY_EXTRAS
            if pool:
                extra_ip = float(np.random.choice(pool))
                ip_int = int(extra_ip); ip_frac = extra_ip - ip_int
                if np.isclose(ip_frac, 1/3): extra_ip = ip_int + 0.1
                elif np.isclose(ip_frac, 2/3): extra_ip = ip_int + 0.2
                outs_e = outs_from_ip(extra_ip)
                if outs_e > 0:
                    s = np.random.choice(BF_PER_OUT, size=outs_e, replace=True)
                    while s[-1] == 0.5:
                        s[-1] = np.random.choice(BF_PER_OUT)
        t_ex = time.perf_counter() - t_ex0

        # totals
        t_total = time.perf_counter() - t_total0
        totals["total_time"] += t_total
        totals["sp_time"] += t_sp
        totals["ab_time"] += t_ab
        totals["rp_time"] += t_rp
        totals["ex_time"] += t_ex

        for name, sec in [("SP", t_sp), ("SP_AB", t_ab), ("RP", t_rp), ("EX", t_ex)]:
            if sec > worst["seconds"]:
                worst.update(section=name, trial=t+1, seconds=sec)

        if (t+1) % 20 == 0:
            print(f"[SWEEP] {t+1}/{trials} games | avg total={(totals['total_time']/(t+1)):.4f}s "
                  f"(SP={(totals['sp_time']/(t+1)):.4f}s, AB={(totals['ab_time']/(t+1)):.4f}s, "
                  f"RP={(totals['rp_time']/(t+1)):.4f}s, EX={(totals['ex_time']/(t+1)):.4f}s)")

    n = float(trials)
    print("\n[SWEEP SUMMARY]")
    print(f"games={trials} | avg_total={totals['total_time']/n:.4f}s "
          f"| avg_SP={totals['sp_time']/n:.4f}s | avg_SP_AB={totals['ab_time']/n:.4f}s "
          f"| avg_RP={totals['rp_time']/n:.4f}s | avg_EX={totals['ex_time']/n:.4f}s")
    print(f"worst_section={worst['section']} | trial={worst['trial']} | seconds={worst['seconds']:.3f}")

if __name__ == "__main__":
    try:
        t0 = time.perf_counter()
        hitter = load_pack("packs/hitter_mcneil.joblib")
        pitcher = load_pack("packs/pitcher_fried.joblib")
        print(f"[TIME] load packs: {time.perf_counter()-t0:.3f}s")

        dfp = pitcher.pitcher_data_arch
        nbp, plook, pcl, nbz, zlook, zcl = build_models_from_pitcher_df(dfp)

        # 1) Stress ABs — try to surface pitch-level bottlenecks
        stress_ab(hitter, pitcher, nbp, plook, pcl, nbz, zlook, zcl, n_ab=3000, seed=97)

        # 2) Sweep many games — catch rare slow branches
        sweep_games(hitter, pitcher, nbp, plook, pcl, nbz, zlook, zcl, trials=200, seed=13)

    except TimeoutError as e:
        print(f"\n[HALT due to TIMEOUT] {e}", file=sys.stderr)
        sys.exit(2)
    except Exception as e:
        print(f"\n[ERROR] {e}", file=sys.stderr)
        sys.exit(1)


[TIME] load packs: 0.052s
