In [1]:
import joblib, types

def load_pack(path):
    p = joblib.load(path)
    return types.SimpleNamespace(**p) if isinstance(p, dict) else p

h = load_pack("packs/hitter_freeman.joblib")  # whichever hitter you pass into Pitcher(...)
enc = getattr(h, "cluster_encoder", None)
print("encoder type:", type(enc))
print("num classes:", 0 if enc is None else len(getattr(enc, "classes_", [])))
print("sample classes:", None if enc is None else getattr(enc, "classes_", [])[:10])

encoder type: <class 'sklearn.preprocessing._label.LabelEncoder'>
num classes: 36
sample classes: ['L2F1' 'L2F2' 'L2F3' 'L4F1' 'L4F2' 'L4F3' 'L4F4' 'LC1' 'LC2' 'LCF1']


In [1]:
# mc_audit.py — comprehensive, stall-resistant audit for Monte Carlo + AtBatSim
import time, types, joblib, numpy as np, pandas as pd
from collections import Counter, defaultdict
from sklearn.naive_bayes import CategoricalNB
from baseball_utils import *
# ---- Your modules ----
import AtBatSim
from General_Initialization import map_description_to_simple  # uses your updated mapper

# =========================
# Low-level hybrid helpers (audited)
# =========================

# ------ Lookup guardrails ------
MIN_LOOKUP_SUPPORT = 3        # require at least this many total observations for a lookup key
MIN_LOOKUP_DOMINANCE = 0.70   # top count must be at least this fraction of total to trust lookup

# ------ Two-strike foul breaker ------
FOUL_STREAK_CAP = 10          # after this many 2-strike fouls in a single PA, force model backoff

def _lookup_is_trustworthy(counter: Counter, *, min_support=MIN_LOOKUP_SUPPORT, min_dom=MIN_LOOKUP_DOMINANCE):
    total = sum(counter.values())
    if total < min_support:
        return False
    top = max(counter.values())
    return (top / total) >= min_dom

    
def _encode_count(balls, strikes):
    table = {(0,0):0,(0,1):1,(0,2):2,(1,0):3,(1,1):4,(1,2):5,(2,0):6,(2,1):7,(2,2):8,(3,0):9,(3,1):10,(3,2):11}
    return table.get((int(balls), int(strikes)), 0)

def _get_nonempty_df(obj, primary_attr: str, fallback_attr: str):
    """Return obj.<primary_attr> if it exists and is non-empty; else obj.<fallback_attr>;
    return None if both are missing/empty."""
    df = getattr(obj, primary_attr, None)
    if df is None or (hasattr(df, "empty") and df.empty):
        df = getattr(obj, fallback_attr, None)
    if df is None or (hasattr(df, "empty") and df.empty):
        return None
    return df

def _argmax_counter(counter_dict):
    return max(counter_dict.items(), key=lambda kv: (kv[1], kv[0]))[0]

def _hybrid_predict_row_audit(key, lookup_table, nb_model, proba_classes, label, *, audit_list, force_model=False):
    """
    Audit-enabled hybrid:
      - If force_model=True, skip lookup entirely.
      - Else, if key in lookup and trustworthy, use lookup (argmax OR probabilistic, see comment).
      - Else, sample from NB model.
    """
    if (not force_model) and (key in lookup_table) and (len(lookup_table[key]) > 0):
        bucket = lookup_table[key]
        if _lookup_is_trustworthy(bucket):
            # Option A (deterministic): argmax
            pred = _argmax_counter(bucket)
            audit_list.append({
                "stage": label, "source": "lookup", "key": tuple(map(int,key)),
                "pred": int(pred), "counts": dict(bucket), "lookup_trustworthy": True
            })
            return int(pred), "lookup", None

        # If not trustworthy, fall back to model
        audit_list.append({
            "stage": label, "source": "lookup_rejected", "key": tuple(map(int,key)),
            "counts": dict(bucket), "lookup_trustworthy": False
        })

    # Model path
    probs = nb_model.predict_proba([list(key)])[0]
    pred = int(np.random.choice(proba_classes, p=probs))
    audit_list.append({
        "stage": label, "source": "model", "key": tuple(map(int,key)),
        "probs": probs.tolist(), "classes": list(map(int, proba_classes)), "pred": pred
    })
    return pred, "model", probs
def _hybrid_pitch_predict_audit(rows, nb_model, lookup_table, class_labels, audit_list, force_model=False):
    preds = []
    for row in rows:
        key = tuple(int(v) for v in row)
        pred, source, _ = _hybrid_predict_row_audit(
            key, lookup_table, nb_model, class_labels, "pitch", audit_list=audit_list, force_model=force_model
        )
        preds.append(pred)
    return preds

def _predict_zone_hybrid_audit(rows, zone_nb_model, zone_lookup_table, zone_class_labels, audit_list, force_model=False):
    preds = []
    for row in rows:
        key = tuple(int(v) for v in row)
        pred, source, _ = _hybrid_predict_row_audit(
            key, zone_lookup_table, zone_nb_model, zone_class_labels, "zone", audit_list=audit_list, force_model=force_model
        )
        preds.append(pred)
    return preds

def _predict_outcome_audit(*, pitch_cluster_enc, zone_enc, count_enc, nb_model, lookup_table, class_labels, audit_list, force_model=False):
    c, z, k = int(pitch_cluster_enc), int(zone_enc), int(count_enc)

    if not force_model:
        for key in ((c, z, k), (c, z), (c, k), (c,), (z, k), (z,)):
            if key in lookup_table and len(lookup_table[key]) > 0:
                bucket = lookup_table[key]
                if _lookup_is_trustworthy(bucket):
                    pred = _argmax_counter(bucket)
                    audit_list.append({
                        "stage":"outcome","source":"lookup","key": key,
                        "pred_raw": int(pred), "counts": dict(bucket), "lookup_trustworthy": True
                    })
                    return int(pred), "lookup", None
                else:
                    audit_list.append({
                        "stage":"outcome","source":"lookup_rejected","key": key,
                        "counts": dict(bucket), "lookup_trustworthy": False
                    })
                break  # don’t keep backing off keys if the most specific exists but is untrustworthy

    # Model path
    probs = nb_model.predict_proba([[c, z, k]])[0]
    pred = int(np.random.choice(class_labels, p=probs))
    audit_list.append({
        "stage":"outcome","source":"model","key": (c,z,k),
        "probs": probs.tolist(), "classes": list(map(int, class_labels)), "pred_raw": pred
    })
    return int(pred), "model", probs

# =========================
# Hardened, fully-audited PA simulator
# =========================
def simulate_at_bat_between_AUDIT(
    hitter,
    pitcher,
    nb_pitch_model,
    pitch_lookup_table,
    pitch_class_labels,
    nb_zone_model,
    zone_lookup_table,
    zone_class_labels,
    *,
    cluster_encoder=None,  # optional, for pretty labels
    MAX_PITCHES_PER_PA=30,
    MAX_SECONDS_PER_PA=3.0,
    print_every=True
):
    """
    Drop-in audited variant that cannot stall:
    - Caps pitches per PA
    - Caps wall-clock time per PA
    - Enforces progress invariant
    - Logs EVERY step with sources and inputs
    Returns (result, audit_log)
    """
    import time
    t_start = time.perf_counter()

    # Pull outcome stuff from hitter pack
    outcome_encoder      = hitter.outcome_encoder
    nb_outcome_model     = hitter.nb_outcome_model
    outcome_lookup_table = hitter.outcome_lookup_table
    outcome_class_labels = hitter.outcome_class_labels
    xba_lookup_table     = hitter.xba_lookup_table or {}
    global_bip_xba       = float(getattr(hitter, "global_bip_xba", 0.300) or 0.300)

    # Resolve names
    hitter_name_str  = getattr(hitter, "full_lower", None) or f"{getattr(hitter,'first_lower','unknown')} {getattr(hitter,'last_lower','unknown')}"
    pitcher_name_str = getattr(pitcher, "full_lower", None) or f"{getattr(pitcher,'first_lower','unknown')} {getattr(pitcher,'last_lower','unknown')}"

    # --- Handedness / arch (SAFE DataFrame selection) ---
    pitcher_df = _get_nonempty_df(pitcher, "pitcher_data_arch", "pitcher_data")
    if pitcher_df is None:
        raise ValueError("Pitcher object has no non-empty `pitcher_data_arch` or `pitcher_data`.")

    stand_map = getattr(pitcher, "stand_by_batter_lower", None)
    if isinstance(stand_map, dict):
        hitter_hand = stand_map.get(hitter_name_str)
    else:
        mask = pitcher_df["batter_name"].str.contains(hitter_name_str, case=False, na=False)
        if not mask.any():
            raise ValueError(f"No matchup data found between {pitcher_name_str} and {hitter_name_str}")
        hitter_hand = str(pitcher_df.loc[mask.idxmax(), "stand"])

    stand_enc = hitter.stand_encoder.transform([hitter_hand])[0]
    arch_enc  = int(getattr(hitter, "arch_enc"))

    # State
    balls = strikes = 0
    pitch_num = 1
    log = []
    # Two-strike foul breaker state
    FOUL_STREAK_CAP = globals().get("FOUL_STREAK_CAP", 10)
    foul_streak = 0
    force_model_outcome = False  # when True, skip lookup for outcome and use model

    if print_every:
        print(f"\n=== PA AUDIT: {hitter_name_str} vs {pitcher_name_str} ===")
        print(f"hand={hitter_hand} stand_enc={int(stand_enc)} arch_enc={arch_enc}")

    # --- loop ---
    while True:
        # hard stops
        if pitch_num > MAX_PITCHES_PER_PA:
            log.append({"ABORT":"MAX_PITCHES_PER_PA", "pitch_num": pitch_num, "count": (balls,strikes)})
            if print_every: print(f"[ABORT] Exceeded MAX_PITCHES_PER_PA={MAX_PITCHES_PER_PA}")
            return ("ABORT_MAX_PITCHES", log)

        if (time.perf_counter() - t_start) > MAX_SECONDS_PER_PA:
            log.append({"ABORT":"TIMEOUT", "pitch_num": pitch_num, "count": (balls,strikes)})
            if print_every: print(f"[ABORT] TIMEOUT after {MAX_SECONDS_PER_PA:.2f}s")
            return ("ABORT_TIMEOUT", log)

        count_enc = _encode_count(balls, strikes)
        step = {"pitch_num": pitch_num, "count_before": (balls, strikes), "count_enc": int(count_enc)}

        # --- PITCH ---
        pitch_global_id = _hybrid_pitch_predict_audit(
            [[int(stand_enc), int(count_enc), int(arch_enc)]],
            nb_pitch_model, pitch_lookup_table, pitch_class_labels, audit_list=log,
            force_model=False  # no breaker for pitch
        )[0]
        step["pitch_global_id"] = int(pitch_global_id)
        if cluster_encoder is not None:
            try:
                step["pitch_label"] = cluster_encoder.inverse_transform([pitch_global_id])[0]
            except Exception:
                step["pitch_label"] = None

        # --- ZONE ---
        zone_enc = _predict_zone_hybrid_audit(
            [[int(pitch_global_id), int(count_enc), int(stand_enc)]],
            zone_nb_model=nb_zone_model,
            zone_lookup_table=zone_lookup_table,
            zone_class_labels=zone_class_labels,
            audit_list=log,
            force_model=False  # no breaker for zone
        )[0]
        step["zone_enc"] = int(zone_enc)

        # --- OUTCOME (raw → simple) ---
        outcome_enc, src, probs = _predict_outcome_audit(
            pitch_cluster_enc=pitch_global_id,
            zone_enc=zone_enc,
            count_enc=count_enc,
            nb_model=nb_outcome_model,
            lookup_table=outcome_lookup_table,
            class_labels=outcome_class_labels,
            audit_list=log,
            force_model=force_model_outcome  # breaker may force model path
        )
        try:
            raw_label = str(outcome_encoder.inverse_transform([outcome_enc])[0])
        except Exception:
            raw_label = str(outcome_enc)
        simple = map_description_to_simple(raw_label)
        if simple == "unknown":
            # fail-closed so the PA progresses
            simple = "strike"

        step["outcome_src"] = src
        step["outcome_raw"] = raw_label
        step["outcome_simple"] = simple

        prev_count = (balls, strikes)

        # ---- Two-strike foul breaker bookkeeping (before count update) ----
        if simple == "foul" and strikes == 2:
            foul_streak += 1
            if (foul_streak >= FOUL_STREAK_CAP) and (not force_model_outcome):
                force_model_outcome = True
                log.append({"stage":"foul_breaker", "action":"force_model_outcome=True", "streak": foul_streak, "cap": FOUL_STREAK_CAP})
        else:
            if force_model_outcome and (simple != "foul" or strikes < 2):
                log.append({"stage":"foul_breaker", "action":"force_model_outcome=False", "reason":"state_changed"})
            force_model_outcome = False
            foul_streak = 0

        # --- Terminal fast paths ---
        if simple == "bip":
            xba = AtBatSim.predict_xba(pitch_global_id, zone_enc, count_enc, xba_lookup_table, global_fallback=global_bip_xba)
            hit = (np.random.rand() < float(xba))
            step["xBA"] = float(xba); step["bip_hit"] = bool(hit)
            log.append(step)
            if print_every:
                print(f"#{pitch_num} {prev_count} pitch={step.get('pitch_label',step['pitch_global_id'])} zone={zone_enc+1} OUTCOME=bip xBA={float(xba):.3f} -> {'HIT' if hit else 'OUT'}")
            return ("HIT", log) if hit else ("OUT", log)

        if simple == "hbp":
            log.append(step)
            if print_every:
                print(f"#{pitch_num} {prev_count} pitch={step.get('pitch_label',step['pitch_global_id'])} zone={zone_enc+1} OUTCOME=HBP")
            return ("HBP", log)

        # --- Count update ---
        if simple == "ball":
            balls += 1
        elif simple == "strike":
            strikes += 1
        elif simple == "foul":
            if strikes < 2:
                strikes += 1
        else:
            # shouldn't happen after normalization, but force progress
            strikes += 1
            step["forced_progress_on_unknown"] = True

        step["count_after"] = (balls, strikes)
        log.append(step)
        if print_every:
            print(f"#{pitch_num} {prev_count} -> {step['count_after']}  pitch={step.get('pitch_label',step['pitch_global_id'])}  zone={zone_enc+1}  outcome_raw={raw_label}  simple={simple}")

        # --- Terminal counts ---
        if balls >= 4:
            if print_every: print("   ↳ WALK")
            return ("WALK", log)
        if strikes >= 3:
            if print_every: print("   ↳ K")
            return ("K", log)

        # --- Progress invariant ---
        progressed = (prev_count != (balls, strikes)) or (simple in ("bip","hbp")) or (simple == "foul" and prev_count[1] == 2)
        if not progressed:
            log.append({"ABORT":"NO_PROGRESS", "at": (balls,strikes), "simple": simple})
            if print_every: print("[ABORT] NO_PROGRESS invariant broke")
            return ("ABORT_NO_PROGRESS", log)

        pitch_num += 1

# =========================
# Monte Carlo audit wrapper
# =========================
def build_models_from_pitcher_df_AUDIT(pitcher_pack):
    df = getattr(pitcher_pack, "pitcher_data_arch", None)
    if df is None or len(df) == 0:
        raise ValueError("pitcher_data_arch missing from pack. Export packs with include_full_df=True.")
    need_cols = ["stand_enc","count_enc","arch_enc","pitch_cluster_enc"]
    missing = [c for c in need_cols if c not in df.columns]
    if missing:
        raise ValueError(f"pitcher_data_arch missing columns: {missing}")

    Xp = df[["stand_enc","count_enc","arch_enc"]].astype(int).values
    yp = df["pitch_cluster_enc"].astype(int).values
    pitch_lookup = defaultdict(Counter)
    for x,y in zip(Xp,yp):
        pitch_lookup[tuple(x)][int(y)] += 1
    nb_pitch = CategoricalNB().fit(Xp, yp)

    dfz = df[df["zone"].notna() & df["zone"].isin(range(1,15))].copy()
    dfz["zone_enc"] = dfz["zone"].astype(int) - 1
    Xz = dfz[["pitch_cluster_enc","count_enc","stand_enc"]].astype(int).values
    yz = dfz["zone_enc"].astype(int).values
    zone_lookup = defaultdict(Counter)
    for x,y in zip(Xz,yz):
        zone_lookup[tuple(x)][int(y)] += 1
    nb_zone = CategoricalNB().fit(Xz, yz)
    return nb_pitch, pitch_lookup, nb_pitch.classes_, nb_zone, zone_lookup, nb_zone.classes_

def simulate_total_hits_AUDIT(
    hitter, pitcher, num_trials,
    nb_pitch_model, pitch_lookup_table, pitch_class_labels,
    nb_zone_model,  zone_lookup_table,  zone_class_labels,
    *,
    print_every=True,
    seed=123
):
    rng = np.random.default_rng(seed)
    results = []
    # Convenience pulls
    xba = float(getattr(hitter, "xba", 0.300) or 0.300)
    spot = int(getattr(hitter, "most_recent_spot", 3) or 3)
    is_home = True  # same as your current script

    # Recreate the wOBA table (or import if you prefer)
    team_woba = pd.DataFrame({
        "Team": ["CHC","NYY","TOR","LAD","ARI","BOS","DET","NYM","MIL","SEA","PHI","HOU","STL","ATH","ATL","SDP","TBR","BAL","MIN","MIA","TEX","CIN","SFG","CLE","LAA","WSN","KCR","PIT","CHW","COL"],
        "wOBA": [0.333,0.337,0.328,0.334,0.329,0.328,0.322,0.317,0.313,0.319,0.323,0.318,0.312,0.323,0.311,0.307,0.316,0.314,0.312,0.309,0.298,0.313,0.302,0.296,0.311,0.305,0.298,0.285,0.293,0.296]
    })
    team_to_abbr = {
        "Angels":"LAA","Astros":"HOU","Athletics":"OAK","Blue Jays":"TOR","Braves":"ATL","Brewers":"MIL",
        "Cardinals":"STL","Cubs":"CHC","Diamondbacks":"ARI","Dodgers":"LAD","Giants":"SFG","Guardians":"CLE",
        "Mariners":"SEA","Marlins":"MIA","Mets":"NYM","Nationals":"WSN","Orioles":"BAL","Padres":"SDP",
        "Phillies":"PHI","Pirates":"PIT","Rangers":"TEX","Rays":"TBR","Reds":"CIN","Red Sox":"BOS",
        "Rockies":"COL","Royals":"KCR","Tigers":"DET","Twins":"MIN","White Sox":"CWS","Yankees":"NYY"
    }
    hitter_abbr = team_to_abbr[hitter.team_name]
    team_woba_val = float(team_woba.loc[team_woba["Team"] == hitter_abbr, "wOBA"].values[0])

    IP_model = pitcher.IPLinReg
    BF_model = pitcher.poisson_model
    ip_sigma = float(pitcher.ip_std)

    BF_PER_OUT = np.array(getattr(pitcher, "bf_per_out", []), dtype=float)
    if BF_PER_OUT.size == 0: BF_PER_OUT = np.array([1.0], dtype=float)
    HOME_EXTRAS = getattr(pitcher, "home_IP_extras", [])
    AWAY_EXTRAS = getattr(pitcher, "away_IP_extras", [])
    P_EXTRAS    = float(getattr(pitcher, "prob_extra_innings", 0.09) or 0.09)

    def round_to_thirds(ip): return round(ip * 3) / 3

    for t in range(num_trials):
        if print_every:
            print(f"\n=== TRIAL {t+1}/{num_trials} ===")

        # ----- Starter IP -----
        exp_ip = float(IP_model.predict([[team_woba_val]])[0])
        sim_ip = round_to_thirds(rng.normal(exp_ip, ip_sigma))
        sim_ip = float(np.clip(sim_ip, 0.0, 9.0))
        if print_every: print(f"[IP] expected={exp_ip:.3f}  sigma={ip_sigma:.3f}  simulated={sim_ip:.3f}")

        # ----- Starter BF -----
        exp_bf = float(BF_model.predict(pd.DataFrame({"ip":[sim_ip]}))[0])
        sim_bf = int(rng.poisson(exp_bf))
        if print_every: print(f"[BF] expected={exp_bf:.3f}  simulated={sim_bf}")

        # ----- PA allocation vs SP -----
        full_cycles = sim_bf // 9
        remainder   = sim_bf % 9
        pa_vs_sp    = full_cycles + (1 if spot <= remainder else 0)
        if print_every: print(f"[PA vs SP] spot={spot} full_cycles={full_cycles} remainder={remainder} -> pa_vs_sp={pa_vs_sp}")

        # ----- At-bats vs SP (audited) -----
        hits_vs_sp = 0
        for i in range(pa_vs_sp):
            if print_every: print(f"\n-- PA vs SP #{i+1}/{pa_vs_sp} --")
            res, audit = simulate_at_bat_between_AUDIT(
                hitter=hitter, pitcher=pitcher,
                nb_pitch_model=nb_pitch_model, pitch_lookup_table=pitch_lookup_table, pitch_class_labels=pitch_class_labels,
                nb_zone_model=nb_zone_model, zone_lookup_table=zone_lookup_table, zone_class_labels=zone_class_labels,
                cluster_encoder=getattr(hitter, "cluster_encoder", None),
                MAX_PITCHES_PER_PA=30, MAX_SECONDS_PER_PA=3.0, print_every=True
            )
            if res == "HIT": hits_vs_sp += 1
            if res.startswith("ABORT"):
                print("[PA ABORTED]", res)
                # Print tail of audit to pinpoint
                tail = audit[-8:] if audit else []
                print("… Last steps:")
                for step in tail:
                    print(step)

        # ----- Bullpen -----
        hitter_win = float(getattr(hitter, "winning_pct_value", 0.5) or 0.5)
        pitcher_win= float(getattr(pitcher, "winning_pct_value", 0.5) or 0.5)

        if not True:  # is_home fixed True above
            relief_ip = 9 - sim_ip
        else:
            prob_not_hitting_9th = hitter_win / (hitter_win + pitcher_win + 1e-9)
            hits_in_9th = rng.random() > prob_not_hitting_9th
            relief_ip = 9 - sim_ip if hits_in_9th else 8 - sim_ip

        relief_ip = max(0.0, relief_ip)
        n = int(relief_ip); frac = relief_ip - n
        if 0.3 <= frac < 0.5: relief_ip = n + 0.1
        elif 0.5 <= frac < 0.7: relief_ip = n + 0.2
        else: relief_ip = float(n)
        if print_every: print(f"[Bullpen IP] relief_ip={relief_ip:.3f}")

        def outs_from_ip(ip: float) -> int:
            whole, frac = divmod(round(ip*10), 10)
            return whole*3 + (2 if frac==2 else 1 if frac==1 else 0)

        outs_req = outs_from_ip(relief_ip)
        if outs_req <= 0:
            bp_bf = 0
        else:
            samples = rng.choice(BF_PER_OUT, size=outs_req, replace=True)
            if samples[-1] == 0.5:
                while True:
                    new = rng.choice(BF_PER_OUT, size=1)[0]
                    if new != 0.5:
                        samples[-1] = new
                        break
            bp_bf = int(samples.sum())
        if print_every: print(f"[Bullpen BF] outs_req={outs_req}  sampled_bf={bp_bf}")

        next_spot = (sim_bf % 9) + 1
        pa_vs_rp = sum(1 for i in range(bp_bf) if ((next_spot + i - 1) % 9 + 1) == spot)
        if print_every: print(f"[PA vs RP] next_spot={next_spot}  pa_vs_rp={pa_vs_rp}")

        hits_vs_rp = int(rng.binomial(n=pa_vs_rp, p=xba))
        if print_every: print(f"[Hits vs RP] xba={xba:.3f}  hits={hits_vs_rp}")

        # ----- Extras -----
        P_EXTRAS_LOCAL = float(getattr(pitcher, "prob_extra_innings", 0.09) or 0.09)
        extras = {}
        if rng.random() < P_EXTRAS_LOCAL:
            extras_pool = getattr(pitcher, "home_IP_extras", []) if (not True) else getattr(pitcher, "away_IP_extras", [])
            if extras_pool:
                extra_ip = float(rng.choice(extras_pool))
                ip_int = int(extra_ip); ip_frac = extra_ip - ip_int
                if np.isclose(ip_frac, 0.33): extra_ip = ip_int + 0.1
                elif np.isclose(ip_frac, 0.67): extra_ip = ip_int + 0.2
                outs_needed = outs_from_ip(extra_ip)
                if outs_needed > 0:
                    bf_samples = rng.choice(BF_PER_OUT, size=outs_needed, replace=True)
                    while bf_samples[-1] == 0.5:
                        bf_samples[-1] = rng.choice(BF_PER_OUT)
                    total_bf = int(round(bf_samples.sum()))
                    nxt = ((sim_bf + bp_bf) % 9) + 1
                    mc_ab = sum(1 for i in range(total_bf) if ((nxt + i - 1) % 9 + 1) == spot)
                    hits_ex = int(rng.binomial(mc_ab, xba))
                else:
                    mc_ab = 0; hits_ex = 0
                extras = {"extra_happens": True, "extra_ip": extra_ip, "extra_bf": total_bf if outs_needed>0 else 0, "mcneil_ab": mc_ab, "mcneil_hits": hits_ex}
            else:
                extras = {"extra_happens": True, "mcneil_hits": 0}
        else:
            extras = {"extra_happens": False, "mcneil_hits": 0}
        if print_every: print(f"[Extras] {extras}")

        total_hits = hits_vs_sp + hits_vs_rp + int(extras.get("mcneil_hits",0))
        results.append(int(total_hits))
        if print_every: print(f"[TOTAL HITS in trial] {total_hits}")

    return results

# =========================
# Example runner
# =========================
def load_pack(path: str):
    return types.SimpleNamespace(**joblib.load(path))

if __name__ == "__main__":
    # Load packs
    mcneil = load_pack("packs/hitter_mcneil.joblib")
    fried  = load_pack("packs/pitcher_fried.joblib")

    # Build pitcher-side models (same as your script)
    nb_pitch_model, pitch_lookup_table, pitch_class_labels, \
    nb_zone_model,  zone_lookup_table,  zone_class_labels = build_models_from_pitcher_df_AUDIT(fried)

    # Run a tiny MC with full audit printing
    results = simulate_total_hits_AUDIT(
        hitter=mcneil, pitcher=fried, num_trials=1,
        nb_pitch_model=nb_pitch_model, pitch_lookup_table=pitch_lookup_table, pitch_class_labels=pitch_class_labels,
        nb_zone_model=nb_zone_model, zone_lookup_table=zone_lookup_table, zone_class_labels=zone_class_labels,
        print_every=True, seed=42
    )

    print("\nFinal results (hits per trial):", results)


results = simulate_total_hits_AUDIT(
    hitter=mcneil, pitcher=fried, num_trials=1000,
    nb_pitch_model=nb_pitch_model, pitch_lookup_table=pitch_lookup_table, pitch_class_labels=pitch_class_labels,
    nb_zone_model=nb_zone_model, zone_lookup_table=zone_lookup_table, zone_class_labels=zone_class_labels,
    print_every=False, seed=7
)
print("mean hits:", sum(results)/len(results))
print("zero-hit %:", sum(1 for r in results if r==0)/len(results))
print(">=2 hits %:", sum(1 for r in results if r>=2)/len(results))

False
Downloading / reading cached Statcast…
 Extra–inning rows: 4,645

=== TRIAL 1/1 ===
[IP] expected=5.905  sigma=1.073  simulated=6.333
[BF] expected=25.016  simulated=24
[PA vs SP] spot=5 full_cycles=2 remainder=6 -> pa_vs_sp=3

-- PA vs SP #1/3 --

=== PA AUDIT: jeff mcneil vs max fried ===
hand=L stand_enc=0 arch_enc=4
#1 (0, 0) -> (0, 1)  pitch=LS5  zone=13  outcome_raw=strike  simple=strike
#2 (0, 1) -> (1, 1)  pitch=LS5  zone=11  outcome_raw=ball  simple=ball
#3 (1, 1) -> (1, 2)  pitch=LS5  zone=9  outcome_raw=strike  simple=strike
#4 (1, 2) -> (2, 2)  pitch=LC2  zone=14  outcome_raw=ball  simple=ball
#5 (2, 2) pitch=L2F3 zone=6 OUTCOME=bip xBA=0.288 -> OUT

-- PA vs SP #2/3 --

=== PA AUDIT: jeff mcneil vs max fried ===
hand=L stand_enc=0 arch_enc=4
#1 (0, 0) -> (1, 0)  pitch=LC2  zone=13  outcome_raw=ball  simple=ball
#2 (1, 0) -> (1, 1)  pitch=LCF2  zone=1  outcome_raw=strike  simple=strike
#3 (1, 1) -> (1, 2)  pitch=LS5  zone=13  outcome_raw=foul  simple=foul
#4 (1, 2) ->

In [None]:
# mc_audit_v2.py — focused PA & outcome audit for Monte Carlo + AtBatSim
import time, types, joblib, numpy as np, pandas as pd
from collections import Counter, defaultdict
from sklearn.naive_bayes import CategoricalNB
from baseball_utils import *
import AtBatSim
from General_Initialization import map_description_to_simple  # your mapper

# ===== Config =====
MIN_LOOKUP_SUPPORT = 3
MIN_LOOKUP_DOMINANCE = 0.70
FOUL_STREAK_CAP = 10

# ===== Small helpers =====
def _encode_count(balls, strikes):
    table = {(0,0):0,(0,1):1,(0,2):2,(1,0):3,(1,1):4,(1,2):5,(2,0):6,(2,1):7,(2,2):8,(3,0):9,(3,1):10,(3,2):11}
    return table.get((int(balls), int(strikes)), 0)

def _lookup_is_trustworthy(counter: Counter, *, min_support=MIN_LOOKUP_SUPPORT, min_dom=MIN_LOOKUP_DOMINANCE):
    total = sum(counter.values())
    if total < min_support:
        return False
    top = max(counter.values())
    return (top / total) >= min_dom

def _get_nonempty_df(obj, primary_attr: str, fallback_attr: str):
    df = getattr(obj, primary_attr, None)
    if df is None or (hasattr(df, "empty") and df.empty):
        df = getattr(obj, fallback_attr, None)
    if df is None or (hasattr(df, "empty") and df.empty):
        return None
    return df

def _argmax_counter(counter_dict):
    return max(counter_dict.items(), key=lambda kv: (kv[1], kv[0]))[0]

def _hybrid_predict_row_audit(key, lookup_table, nb_model, class_labels, stage_label, audit, force_model=False):
    if (not force_model) and (key in lookup_table) and (len(lookup_table[key]) > 0):
        bucket = lookup_table[key]
        if _lookup_is_trustworthy(bucket):
            pred = _argmax_counter(bucket)
            audit["lookup_hits"][stage_label] += 1
            return int(pred), "lookup"
        audit["lookup_rejects"][stage_label] += 1
    # NB path
    probs = nb_model.predict_proba([list(key)])[0]
    pred = int(np.random.choice(class_labels, p=probs))
    audit["nb_uses"][stage_label] += 1
    return pred, "model"

# ===== Pitch, Zone, Outcome hybrids (audited) =====
def _hybrid_pitch_predict_audit(rows, nb_model, lookup_table, class_labels, audit, force_model=False):
    out = []
    for row in rows:
        key = tuple(int(v) for v in row)
        pred, _ = _hybrid_predict_row_audit(key, lookup_table, nb_model, class_labels, "pitch", audit, force_model)
        out.append(pred)
    return out

def _predict_zone_hybrid_audit(rows, nb_model, lookup_table, class_labels, audit, force_model=False):
    out = []
    for row in rows:
        key = tuple(int(v) for v in row)
        pred, _ = _hybrid_predict_row_audit(key, lookup_table, nb_model, class_labels, "zone", audit, force_model)
        out.append(pred)
    return out

def _predict_outcome_audit(pitch_cluster_enc, zone_enc, count_enc, nb_model, lookup_table, class_labels, audit, force_model=False):
    c, z, k = int(pitch_cluster_enc), int(zone_enc), int(count_enc)
    if not force_model:
        for key in ((c, z, k), (c, z), (c, k), (c,), (z, k), (z,)):
            if key in lookup_table and len(lookup_table[key]) > 0:
                bucket = lookup_table[key]
                if _lookup_is_trustworthy(bucket):
                    audit["lookup_hits"]["outcome"] += 1
                    return int(_argmax_counter(bucket)), "lookup"
                else:
                    audit["lookup_rejects"]["outcome"] += 1
                break
    # NB path
    probs = nb_model.predict_proba([[c, z, k]])[0]
    pred = int(np.random.choice(class_labels, p=probs))
    audit["nb_uses"]["outcome"] += 1
    return pred, "model"

# ===== PA simulator (audited) =====
def simulate_at_bat_between_AUDIT(
    hitter, pitcher,
    nb_pitch_model, pitch_lookup_table, pitch_class_labels,
    nb_zone_model,  zone_lookup_table,  zone_class_labels,
    *,
    cluster_encoder=None,
    MAX_PITCHES_PER_PA=30,
    MAX_SECONDS_PER_PA=3.0,
    print_every=False,
    global_audit=None
):
    t0 = time.perf_counter()

    outcome_encoder      = hitter.outcome_encoder
    nb_outcome_model     = hitter.nb_outcome_model
    outcome_lookup_table = hitter.outcome_lookup_table
    outcome_class_labels = hitter.outcome_class_labels
    xba_lookup_table     = getattr(hitter, "xba_lookup_table", {}) or {}
    global_bip_xba       = float(getattr(hitter, "global_bip_xba", 0.300) or 0.300)

    # Names
    h_name = getattr(hitter, "full_lower", f"{getattr(hitter,'first_lower','?')} {getattr(hitter,'last_lower','?')}")
    p_name = getattr(pitcher, "full_lower", f"{getattr(pitcher,'first_lower','?')} {getattr(pitcher,'last_lower','?')}")

    # Source DF + handedness
    df = _get_nonempty_df(pitcher, "pitcher_data_arch", "pitcher_data")
    if df is None: raise ValueError("Pitcher has no pitcher_data_arch/pitcher_data.")
    stand_map = getattr(pitcher, "stand_by_batter_lower", None)
    if isinstance(stand_map, dict):
        hitter_hand = stand_map.get(h_name)
    else:
        mask = df["batter_name"].str.contains(h_name, case=False, na=False)
        hitter_hand = str(df.loc[mask.idxmax(), "stand"]) if mask.any() else "L"
    # Record stand resolution (for aggregate stats)
    if global_audit is not None:
        global_audit["hand_resolution"][f"{h_name} vs {p_name}"].append(hitter_hand)

    stand_enc = hitter.stand_encoder.transform([hitter_hand])[0]
    arch_enc  = int(getattr(hitter, "arch_enc"))

    # State
    balls = strikes = 0
    pitch_num = 1
    foul_streak = 0
    force_model_outcome = False

    while True:
        if pitch_num > MAX_PITCHES_PER_PA:
            global_audit["aborts"]["max_pitches"] += 1
            return "ABORT_MAX_PITCHES"
        if (time.perf_counter() - t0) > MAX_SECONDS_PER_PA:
            global_audit["aborts"]["timeout"] += 1
            return "ABORT_TIMEOUT"

        count_enc = _encode_count(balls, strikes)

        # PITCH
        pitch_global_id = _hybrid_pitch_predict_audit(
            [[int(stand_enc), int(count_enc), int(arch_enc)]],
            nb_pitch_model, pitch_lookup_table, pitch_class_labels,
            audit=global_audit, force_model=False
        )[0]

        # ZONE
        zone_enc = _predict_zone_hybrid_audit(
            [[int(pitch_global_id), int(count_enc), int(stand_enc)]],
            nb_model=nb_zone_model, lookup_table=zone_lookup_table, class_labels=zone_class_labels,
            audit=global_audit, force_model=False
        )[0]

        # OUTCOME
        outcome_enc, src = _predict_outcome_audit(
            pitch_cluster_enc=pitch_global_id, zone_enc=zone_enc, count_enc=count_enc,
            nb_model=nb_outcome_model, lookup_table=outcome_lookup_table,
            class_labels=outcome_class_labels, audit=global_audit, force_model=force_model_outcome
        )
        try:
            raw_label = str(outcome_encoder.inverse_transform([outcome_enc])[0])
        except Exception:
            raw_label = str(outcome_enc)
        simple = map_description_to_simple(raw_label)
        if simple == "unknown":
            simple = "strike"  # fail-closed

        # BIP terminal
        if simple == "bip":
            xba = AtBatSim.predict_xba(pitch_global_id, zone_enc, count_enc, xba_lookup_table, global_fallback=global_bip_xba)
            xba = float(np.clip(xba, 0.0, 1.0))
            global_audit["xba_samples"].append(xba)
            global_audit["bip_events"] += 1
            if np.random.rand() < xba:
                global_audit["hits_from_bip"] += 1
                return "HIT"
            else:
                return "OUT"

        if simple == "hbp":
            return "HBP"

        # Count updates and foul breaker
        prev = (balls, strikes)
        if simple == "ball":
            balls += 1
        elif simple == "strike":
            strikes += 1
        elif simple == "foul":
            if strikes < 2:
                strikes += 1
            if strikes == 2:
                foul_streak += 1
                if foul_streak >= FOUL_STREAK_CAP and not force_model_outcome:
                    force_model_outcome = True
                    global_audit["foul_breaker_trips"] += 1
        else:
            strikes += 1  # force progress

        if simple != "foul" or strikes < 2:
            foul_streak = 0
            force_model_outcome = False

        # Terminal counts
        if balls >= 4: return "WALK"
        if strikes >= 3: return "K"

        # Progress invariant
        progressed = (prev != (balls, strikes)) or (simple in ("bip","hbp")) or (simple == "foul" and prev[1] == 2)
        if not progressed:
            global_audit["aborts"]["no_progress"] += 1
            return "ABORT_NO_PROGRESS"

        pitch_num += 1

# ===== Build models from pitcher DF =====
def build_models_from_pitcher_df_AUDIT(pitcher_pack):
    df = getattr(pitcher_pack, "pitcher_data_arch", None)
    if df is None or len(df) == 0:
        raise ValueError("pitcher_data_arch missing from pack. Export packs with include_full_df=True.")
    need_cols = ["stand_enc","count_enc","arch_enc","pitch_cluster_enc"]
    missing = [c for c in need_cols if c not in df.columns]
    if missing:
        raise ValueError(f"pitcher_data_arch missing columns: {missing}")

    # Pitch
    Xp = df[["stand_enc","count_enc","arch_enc"]].astype(int).values
    yp = df["pitch_cluster_enc"].astype(int).values
    pitch_lookup = defaultdict(Counter)
    for x,y in zip(Xp,yp): pitch_lookup[tuple(x)][int(y)] += 1
    nb_pitch = CategoricalNB().fit(Xp, yp)

    # Zone
    dfz = df[df["zone"].notna() & df["zone"].isin(range(1,15))].copy()
    dfz["zone_enc"] = dfz["zone"].astype(int) - 1
    Xz = dfz[["pitch_cluster_enc","count_enc","stand_enc"]].astype(int).values
    yz = dfz["zone_enc"].astype(int).values
    zone_lookup = defaultdict(Counter)
    for x,y in zip(Xz,yz): zone_lookup[tuple(x)][int(y)] += 1
    nb_zone = CategoricalNB().fit(Xz, yz)

    return nb_pitch, pitch_lookup, nb_pitch.classes_, nb_zone, zone_lookup, nb_zone.classes_

# ===== Monte Carlo with PA + source auditing =====
def simulate_total_hits_AUDIT_V2(
    hitter, pitcher, num_trials,
    nb_pitch_model, pitch_lookup_table, pitch_class_labels,
    nb_zone_model,  zone_lookup_table,  zone_class_labels,
    *,
    print_every=False,
    seed=123
):
    rng = np.random.default_rng(seed)
    results = []

    # Aggregate audit bins
    global_audit = {
        "lookup_hits": {"pitch":0,"zone":0,"outcome":0},
        "lookup_rejects": {"pitch":0,"zone":0,"outcome":0},
        "nb_uses": {"pitch":0,"zone":0,"outcome":0},
        "bip_events": 0,
        "hits_from_bip": 0,
        "xba_samples": [],
        "foul_breaker_trips": 0,
        "aborts": {"max_pitches":0,"timeout":0,"no_progress":0},
        "hand_resolution": defaultdict(list),
        "pa_tally": {"sp":0,"rp":0,"extras":0,"total":0},
        "bf_tally": {"sp":0,"rp":0,"extras":0,"total":0},
        "ip_sp_samples": [],
        "bf_sp_samples": [],
        "bf_rp_samples": [],
        "bf_extras_samples": [],
    }

    # Convenience pulls
    xba_default = float(getattr(hitter, "xba", 0.300) or 0.300)
    spot = int(getattr(hitter, "most_recent_spot", 3) or 3)
    is_home = True

    # Team wOBA table (fixed)
    team_woba = pd.DataFrame({
        "Team": ["CHC","NYY","TOR","LAD","ARI","BOS","DET","NYM","MIL","SEA","PHI","HOU","STL","ATH","ATL","SDP","TBR","BAL","MIN","MIA","TEX","CIN","SFG","CLE","LAA","WSN","KCR","PIT","CHW","COL"],
        "wOBA": [0.333,0.337,0.328,0.334,0.329,0.328,0.322,0.317,0.313,0.319,0.323,0.318,0.312,0.323,0.311,0.307,0.316,0.314,0.312,0.309,0.298,0.313,0.302,0.296,0.311,0.305,0.298,0.285,0.293,0.296]
    })
    team_to_abbr = {
        "Angels":"LAA","Astros":"HOU","Athletics":"OAK","Blue Jays":"TOR","Braves":"ATL","Brewers":"MIL",
        "Cardinals":"STL","Cubs":"CHC","Diamondbacks":"ARI","Dodgers":"LAD","Giants":"SFG","Guardians":"CLE",
        "Mariners":"SEA","Marlins":"MIA","Mets":"NYM","Nationals":"WSN","Orioles":"BAL","Padres":"SDP",
        "Phillies":"PHI","Pirates":"PIT","Rangers":"TEX","Rays":"TBR","Reds":"CIN","Red Sox":"BOS",
        "Rockies":"COL","Royals":"KCR","Tigers":"DET","Twins":"MIN","White Sox":"CWS","Yankees":"NYY"
    }
    hitter_abbr = team_to_abbr[hitter.team_name]
    team_woba_val = float(team_woba.loc[team_woba["Team"] == hitter_abbr, "wOBA"].values[0])

    # Starter models & bullpen/extras distributions from pitcher pack
    IP_model = pitcher.IPLinReg
    BF_model = pitcher.poisson_model
    ip_sigma = float(pitcher.ip_std)

    BF_PER_OUT = np.array(getattr(pitcher, "bf_per_out", []), dtype=float)
    if BF_PER_OUT.size == 0: BF_PER_OUT = np.array([1.0], dtype=float)
    HOME_EXTRAS = getattr(pitcher, "home_IP_extras", [])
    AWAY_EXTRAS = getattr(pitcher, "away_IP_extras", [])
    P_EXTRAS    = float(getattr(pitcher, "prob_extra_innings", 0.09) or 0.09)

    def round_to_thirds(ip): return round(ip * 3) / 3

    for t in range(num_trials):
        if print_every: print(f"\n=== TRIAL {t+1}/{num_trials} ===")
        # ---- Starter IP ----
        exp_ip = float(IP_model.predict([[team_woba_val]])[0])
        sim_ip = round_to_thirds(rng.normal(exp_ip, ip_sigma))
        sim_ip = float(np.clip(sim_ip, 0.0, 9.0))
        global_audit["ip_sp_samples"].append(sim_ip)
        if print_every: print(f"[IP] expected={exp_ip:.3f} sigma={ip_sigma:.3f} simulated={sim_ip:.3f}")

        # ---- Starter BF ----
        exp_bf = float(BF_model.predict(pd.DataFrame({"ip":[sim_ip]}))[0])
        sim_bf = int(rng.poisson(exp_bf))
        sim_bf = max(0, sim_bf)
        global_audit["bf_sp_samples"].append(sim_bf)
        if print_every: print(f"[BF] expected={exp_bf:.3f} simulated={sim_bf}")

        # ---- PA vs SP ----
        full_cycles = sim_bf // 9
        remainder   = sim_bf % 9
        pa_vs_sp    = full_cycles + (1 if spot <= remainder else 0)
        pa_vs_sp    = int(max(0, pa_vs_sp))
        if print_every: print(f"[PA vs SP] spot={spot} full_cycles={full_cycles} remainder={remainder} -> pa_vs_sp={pa_vs_sp}")

        hits_vs_sp = 0
        for i in range(pa_vs_sp):
            res = simulate_at_bat_between_AUDIT(
                hitter, pitcher,
                nb_pitch_model, pitch_lookup_table, pitch_class_labels,
                nb_zone_model,  zone_lookup_table,  zone_class_labels,
                cluster_encoder=getattr(hitter, "cluster_encoder", None),
                MAX_PITCHES_PER_PA=30, MAX_SECONDS_PER_PA=3.0,
                print_every=False, global_audit=global_audit
            )
            if res == "HIT": hits_vs_sp += 1

        # ---- Bullpen IP & BF ----
        hitter_win = float(getattr(hitter, "winning_pct_value", 0.5) or 0.5)
        pitcher_win= float(getattr(pitcher, "winning_pct_value", 0.5) or 0.5)

        # If home, chance the offense doesn't bat in the 9th; otherwise they do
        if is_home:
            prob_not_hitting_9th = hitter_win / (hitter_win + pitcher_win + 1e-9)
            bat_9th = rng.random() > prob_not_hitting_9th
            relief_ip = (9 - sim_ip) if bat_9th else (8 - sim_ip)
        else:
            relief_ip = 9 - sim_ip

        relief_ip = max(0.0, float(relief_ip))
        # Normalize 0.1/0.2 tenths to baseball .1/.2 thirds
        n = int(relief_ip); frac = relief_ip - n
        if 0.3 <= frac < 0.5: relief_ip = n + 0.1
        elif 0.5 <= frac < 0.7: relief_ip = n + 0.2
        else: relief_ip = float(n)
        if print_every: print(f"[Bullpen IP] relief_ip={relief_ip:.3f}")

        def outs_from_ip(ip: float) -> int:
            whole, frac = divmod(round(ip*10), 10)
            return whole*3 + (2 if frac==2 else 1 if frac==1 else 0)

        outs_req = outs_from_ip(relief_ip)
        if outs_req <= 0:
            bp_bf = 0
        else:
            samples = rng.choice(BF_PER_OUT, size=outs_req, replace=True)
            # Guard weird 0.5 sentinels if present
            if samples.size > 0 and samples[-1] == 0.5:
                new = rng.choice(BF_PER_OUT)
                while new == 0.5:
                    new = rng.choice(BF_PER_OUT)
                samples[-1] = new
            bp_bf = int(max(0, round(samples.sum())))
        global_audit["bf_rp_samples"].append(bp_bf)
        if print_every: print(f"[Bullpen BF] outs_req={outs_req} sampled_bf={bp_bf}")

        next_spot = (sim_bf % 9) + 1
        pa_vs_rp = sum(1 for i in range(bp_bf) if ((next_spot + i - 1) % 9 + 1) == spot)
        pa_vs_rp = int(max(0, pa_vs_rp))
        if print_every: print(f"[PA vs RP] next_spot={next_spot} pa_vs_rp={pa_vs_rp}")

        hits_vs_rp = int(rng.binomial(n=pa_vs_rp, p=xba_default))
        if print_every: print(f"[Hits vs RP] xba={xba_default:.3f} hits={hits_vs_rp}")

        # ---- Extras ----
        hits_ex = 0; pa_vs_ex = 0; ex_bf = 0
        if rng.random() < float(getattr(pitcher, "prob_extra_innings", P_EXTRAS)):
            pool = HOME_EXTRAS if (not is_home) else AWAY_EXTRAS
            if pool:
                extra_ip = float(rng.choice(pool))
                whole, frac = divmod(round(extra_ip*10), 10)
                ex_outs = whole*3 + (2 if frac==2 else 1 if frac==1 else 0)
                if ex_outs > 0:
                    ex_samples = rng.choice(BF_PER_OUT, size=ex_outs, replace=True)
                    if ex_samples.size > 0 and ex_samples[-1] == 0.5:
                        new = rng.choice(BF_PER_OUT)
                        while new == 0.5:
                            new = rng.choice(BF_PER_OUT)
                        ex_samples[-1] = new
                    ex_bf = int(max(0, round(ex_samples.sum())))
                    nxt = ((sim_bf + bp_bf) % 9) + 1
                    pa_vs_ex = sum(1 for i in range(ex_bf) if ((nxt + i - 1) % 9 + 1) == spot)
                    hits_ex = int(rng.binomial(pa_vs_ex, xba_default))
        global_audit["bf_extras_samples"].append(ex_bf)

        # ---- Totals + tallies ----
        total_pa = pa_vs_sp + pa_vs_rp + pa_vs_ex
        total_hits = hits_vs_sp + hits_vs_rp + hits_ex
        results.append(int(total_hits))

        global_audit["pa_tally"]["sp"] += pa_vs_sp
        global_audit["pa_tally"]["rp"] += pa_vs_rp
        global_audit["pa_tally"]["extras"] += pa_vs_ex
        global_audit["pa_tally"]["total"] += total_pa

        global_audit["bf_tally"]["sp"] += sim_bf
        global_audit["bf_tally"]["rp"] += bp_bf
        global_audit["bf_tally"]["extras"] += ex_bf
        global_audit["bf_tally"]["total"] += (sim_bf + bp_bf + ex_bf)

        if print_every:
            print(f"[SUMMARY] PA: SP={pa_vs_sp} RP={pa_vs_rp} EX={pa_vs_ex} | BF: SP={sim_bf} RP={bp_bf} EX={ex_bf} | Hits: {total_hits}")

    return results, global_audit

# ===== Pack loader + aggregate report =====
def load_pack(path: str):
    return types.SimpleNamespace(**joblib.load(path))

def report(global_audit, results):
    n = len(results)
    mean_hits = np.mean(results)
    p0 = np.mean([r==0 for r in results])
    p2p = np.mean([r>=2 for r in results])

    total_pa = global_audit["pa_tally"]["total"]
    bip = global_audit["bip_events"]
    hits_from_bip = global_audit["hits_from_bip"]
    xbas = global_audit["xba_samples"]
    bip_rate = bip / max(1, total_pa)
    hit_rate = sum(results) / max(1, total_pa)
    hit_given_bip = hits_from_bip / max(1, bip)

    print("\n===== AUDIT REPORT =====")
    print(f"Trials: {n}")
    print(f"Hits — mean={mean_hits:.3f}  P(0)={p0:.3f}  P(>=2)={p2p:.3f}")
    print(f"PA — total={total_pa}  (SP={global_audit['pa_tally']['sp']}, RP={global_audit['pa_tally']['rp']}, EX={global_audit['pa_tally']['extras']})")
    print(f"BF — total={global_audit['bf_tally']['total']}  (SP={global_audit['bf_tally']['sp']}, RP={global_audit['bf_tally']['rp']}, EX={global_audit['bf_tally']['extras']})")
    print(f"BIP — count={bip}  rate/PA={bip_rate:.3f}  E[Hit|BIP]={hit_given_bip:.3f}  Hit/PA={hit_rate:.3f}")
    if xbas:
        print(f"xBA — mean={np.mean(xbas):.3f}  min={np.min(xbas):.3f}  max={np.max(xbas):.3f}")
    print(f"Lookup usage — pitch: {global_audit['lookup_hits']['pitch']} hits / {global_audit['nb_uses']['pitch']} NB; "
          f"zone: {global_audit['lookup_hits']['zone']} / {global_audit['nb_uses']['zone']}; "
          f"outcome: {global_audit['lookup_hits']['outcome']} / {global_audit['nb_uses']['outcome']}")
    print(f"Lookup rejects — pitch:{global_audit['lookup_rejects']['pitch']}  zone:{global_audit['lookup_rejects']['zone']}  outcome:{global_audit['lookup_rejects']['outcome']}")
    print(f"Foul breaker trips: {global_audit['foul_breaker_trips']}")
    print(f"Aborts — max_pitches:{global_audit['aborts']['max_pitches']} timeout:{global_audit['aborts']['timeout']} no_progress:{global_audit['aborts']['no_progress']}")
    if global_audit["ip_sp_samples"]:
        print(f"SP IP — mean={np.mean(global_audit['ip_sp_samples']):.2f}  sd={np.std(global_audit['ip_sp_samples']):.2f}")
    if global_audit["bf_sp_samples"]:
        print(f"SP BF — mean={np.mean(global_audit['bf_sp_samples']):.2f}  sd={np.std(global_audit['bf_sp_samples']):.2f}")
    if global_audit["bf_rp_samples"]:
        print(f"RP BF — mean={np.mean(global_audit['bf_rp_samples']):.2f}  sd={np.std(global_audit['bf_rp_samples']):.2f}")
    if global_audit["bf_extras_samples"]:
        print(f"EX BF — mean={np.mean(global_audit['bf_extras_samples']):.2f}  sd={np.std(global_audit['bf_extras_samples']):.2f}")

# ===== Example runner =====
if __name__ == "__main__":
    # Swap these to Freeman/Soriano to test your spike
    hitter_pack = load_pack("packs/hitter_freeman.joblib")
    pitcher_pack = load_pack("packs/pitcher_soriano.joblib")

    nb_pitch_model, pitch_lookup_table, pitch_class_labels, \
    nb_zone_model,  zone_lookup_table,  zone_class_labels = build_models_from_pitcher_df_AUDIT(pitcher_pack)

    results, audit = simulate_total_hits_AUDIT_V2(
        hitter=hitter_pack, pitcher=pitcher_pack, num_trials=200,
        nb_pitch_model=nb_pitch_model, pitch_lookup_table=pitch_lookup_table, pitch_class_labels=pitch_class_labels,
        nb_zone_model=nb_zone_model, zone_lookup_table=zone_lookup_table, zone_class_labels=zone_class_labels,
        print_every=False, seed=7
    )
    report(audit, results)