In [2]:
# Input:
#   PokerData_with_range_features.csv
#
# Output:
#   Poker_semantic_context_models_FINAL.csv
# ============================================================

import time
import math
import numpy as np
import pandas as pd
from dataclasses import dataclass
from typing import Optional, List, Dict

from sklearn.model_selection import train_test_split


# ============================================================
# 0) CONFIG
# ============================================================
DATA_PATH = "PokerData_with_range_features.csv"
OUT_PATH  = "Poker_semantic_context_models_FINAL.csv"

BASE_RANDOM_STATE = 42
np.random.seed(BASE_RANDOM_STATE)

# Filter rule
VIL_CHECKS_COL = "% Villain Checks"
VIL_CHECKS_KEEP_THRESHOLD = 85.0

# Targets (solver strategy)
Y_COL_CHECK = "We Check Back"
Y_COL_BET13 = "We C-Bet 1/3"
Y_COL_BET23 = "We C-Bet 2/3"
Y_COLS = [Y_COL_CHECK, Y_COL_BET13, Y_COL_BET23]

# Context columns
SPR_COL = "SPR"
POS_COL = "Position"
FULL_HOUSES_COL = "Full Houses"
FLUSHES_COL = "Flushes"
STRAIGHTS_COL = "Straights"
CAT_COLS = [POS_COL, FULL_HOUSES_COL, FLUSHES_COL, STRAIGHTS_COL]

# Asymmetric loss settings
LOSS = dict(
    w_over=(0.8, 1.2, 1.8),     # penalize over-betting more than over-check
    w_under=(1.0, 1.0, 1.0),
    over_power=2.0,
    under_power=2.0
)

# Speed knobs
GLOBAL_THRESHOLDS_PER_FEATURE = 16
TUNING_SUBSAMPLE_N = 1200
MAX_CONFIGS = 8
N_SEEDS_FOR_STABILITY = 3

# Teacher forest
FINAL_TEACHER_N_TREES = 240

# Final presentation trees (fixed)
FINAL_DEPTH = 4
FINAL_MIN_LEAF = 60

# Depth sweep grid (for tables)
DEPTH_GRID = list(range(2, 11))          # 2..10
MIN_LEAF_GRID = [60, 90]                # include 60 (final choice) and 90 as a simpler alt
SEED = BASE_RANDOM_STATE

# Misfire thresholds (keep consistent across reports)
BIG_MISFIRE_B23_PRED = 0.50
BIG_MISFIRE_B23_TRUE = 0.10
BIG_MISFIRE_B13_PRED = 0.60
BIG_MISFIRE_B13_TRUE = 0.15


# ============================================================
# 1) Utilities
# ============================================================
def parse_percent_cell(x):
    if pd.isna(x):
        return np.nan
    if isinstance(x, (int, float, np.integer, np.floating)):
        return float(x)
    s = str(x).strip()
    if s == "":
        return np.nan
    if s.endswith("%"):
        try:
            return float(s[:-1])
        except:
            return np.nan
    try:
        return float(s)
    except:
        return np.nan


def normalize_actions(y: np.ndarray) -> np.ndarray:
    y = np.asarray(y, dtype=float)
    y = np.clip(y, 0.0, None)
    s = y.sum(axis=1, keepdims=True)
    s[s == 0] = 1.0
    return y / s


def require_cols(df: pd.DataFrame, cols: List[str], where: str = ""):
    missing = [c for c in cols if c not in df.columns]
    if missing:
        msg = f"Missing required columns {missing}"
        if where:
            msg += f" (needed for {where})"
        raise ValueError(msg)


def safe_col(df: pd.DataFrame, name: str) -> pd.Series:
    if name in df.columns:
        return pd.to_numeric(df[name], errors="coerce").fillna(0.0)
    return pd.Series(0.0, index=df.index)


def mean_vec(arr3):
    return np.mean(normalize_actions(arr3), axis=0)


# ============================================================
# 2) Asymmetric loss
# ============================================================
def asym_loss_from_err(err: np.ndarray,
                       w_over=(0.8, 1.2, 1.8),
                       w_under=(1.0, 1.0, 1.0),
                       over_power=2.0,
                       under_power=2.0) -> float:
    over = np.maximum(err, 0.0)
    under = np.maximum(-err, 0.0)
    w_over = np.array(w_over, dtype=float).reshape(1, 3)
    w_under = np.array(w_under, dtype=float).reshape(1, 3)
    return float(np.mean(w_over * (over ** over_power) + w_under * (under ** under_power)))


def asymmetric_loss(y_true: np.ndarray,
                    y_pred: np.ndarray,
                    w_over=(0.8, 1.2, 1.8),
                    w_under=(1.0, 1.0, 1.0),
                    over_power=2.0,
                    under_power=2.0) -> float:
    y_true = normalize_actions(y_true)
    y_pred = normalize_actions(y_pred)
    return asym_loss_from_err(y_pred - y_true, w_over, w_under, over_power, under_power)


def node_loss_constant_pred(y_true: np.ndarray,
                            pred_vec: np.ndarray,
                            w_over, w_under,
                            over_power, under_power) -> float:
    y_true = normalize_actions(y_true)
    pred_vec = normalize_actions(pred_vec.reshape(1, 3)).ravel()
    err = pred_vec.reshape(1, 3) - y_true
    return asym_loss_from_err(err, w_over, w_under, over_power, under_power)


# ============================================================
# 3) Semantic features (compressed for interpretability)
# ============================================================
def add_semantic_features(df: pd.DataFrame) -> pd.DataFrame:
    out = df.copy()

    out["HERO_NUTS"] = (
        safe_col(out, "HeroMade_straight_flush") +
        safe_col(out, "HeroMade_four_of_a_kind") +
        safe_col(out, "HeroMade_full_house")
    )
    out["VIL_NUTS"] = (
        safe_col(out, "VillainMade_straight_flush") +
        safe_col(out, "VillainMade_four_of_a_kind") +
        safe_col(out, "VillainMade_full_house")
    )
    out["NUTS_ADV"] = out["HERO_NUTS"] - out["VIL_NUTS"]

    out["HERO_STRONG"] = (
        safe_col(out, "HeroMade_straight_flush") +
        safe_col(out, "HeroMade_four_of_a_kind") +
        safe_col(out, "HeroMade_full_house") +
        safe_col(out, "HeroMade_flush") +
        safe_col(out, "HeroMade_straight") +
        safe_col(out, "HeroMade_set") +
        safe_col(out, "HeroMade_trips") +
        safe_col(out, "HeroMade_two_pair")
    )
    out["VIL_STRONG"] = (
        safe_col(out, "VillainMade_straight_flush") +
        safe_col(out, "VillainMade_four_of_a_kind") +
        safe_col(out, "VillainMade_full_house") +
        safe_col(out, "VillainMade_flush") +
        safe_col(out, "VillainMade_straight") +
        safe_col(out, "VillainMade_set") +
        safe_col(out, "VillainMade_trips") +
        safe_col(out, "VillainMade_two_pair")
    )
    out["STRONG_ADV"] = out["HERO_STRONG"] - out["VIL_STRONG"]

    out["HERO_TP_PLUS"] = safe_col(out, "HeroMade_top_pair") + safe_col(out, "HeroMade_overpair")
    out["VIL_TP_PLUS"]  = safe_col(out, "VillainMade_top_pair") + safe_col(out, "VillainMade_overpair")
    out["TP_PLUS_ADV"]  = out["HERO_TP_PLUS"] - out["VIL_TP_PLUS"]

    out["HERO_AIR"] = safe_col(out, "HeroMade_no_made_hand") + safe_col(out, "HeroMade_ace_high")
    out["VIL_AIR"]  = safe_col(out, "VillainMade_no_made_hand") + safe_col(out, "VillainMade_ace_high")
    out["AIR_ADV"]  = out["HERO_AIR"] - out["VIL_AIR"]

    out["DRAW_ADV"] = (
        safe_col(out, "HeroDraw_flush_draw") +
        safe_col(out, "HeroDraw_oesd") +
        safe_col(out, "HeroDraw_gutshot") +
        safe_col(out, "HeroDraw_combo_draw")
    ) - (
        safe_col(out, "VillainDraw_flush_draw") +
        safe_col(out, "VillainDraw_oesd") +
        safe_col(out, "VillainDraw_gutshot") +
        safe_col(out, "VillainDraw_combo_draw")
    )
    out["COMBO_DRAW_ADV"] = safe_col(out, "HeroDraw_combo_draw") - safe_col(out, "VillainDraw_combo_draw")

    out["HERO_MEDIUM"] = (
        safe_col(out, "HeroMade_top_pair") +
        safe_col(out, "HeroMade_second_pair") +
        safe_col(out, "HeroMade_third_pair") +
        safe_col(out, "HeroMade_underpair")
    )
    out["VIL_MEDIUM"] = (
        safe_col(out, "VillainMade_top_pair") +
        safe_col(out, "VillainMade_second_pair") +
        safe_col(out, "VillainMade_third_pair") +
        safe_col(out, "VillainMade_underpair")
    )

    out["HERO_POLAR"] = (
        out["HERO_NUTS"] +
        safe_col(out, "HeroDraw_combo_draw") +
        safe_col(out, "HeroDraw_flush_draw") +
        safe_col(out, "HeroDraw_oesd")
    ) - out["HERO_MEDIUM"]

    out["VIL_POLAR"] = (
        out["VIL_NUTS"] +
        safe_col(out, "VillainDraw_combo_draw") +
        safe_col(out, "VillainDraw_flush_draw") +
        safe_col(out, "VillainDraw_oesd")
    ) - out["VIL_MEDIUM"]

    out["POLAR_ADV"] = out["HERO_POLAR"] - out["VIL_POLAR"]
    return out


# ============================================================
# 4) Precompute thresholds ONCE
# ============================================================
def build_global_thresholds(X: np.ndarray, n_thresholds: int) -> List[np.ndarray]:
    n, d = X.shape
    qs = np.linspace(0.05, 0.95, n_thresholds)
    thr = []
    for j in range(d):
        col = X[:, j]
        if np.all(col == col[0]):
            thr.append(np.array([], dtype=float))
            continue
        ths = np.unique(np.quantile(col, qs))
        if len(ths) > n_thresholds:
            idx = np.linspace(0, len(ths) - 1, n_thresholds).astype(int)
            ths = ths[idx]
        thr.append(ths.astype(float))
    return thr


# ============================================================
# 5) Fast Tree + Forest
# ============================================================
@dataclass
class TreeNode:
    is_leaf: bool
    pred: np.ndarray
    n: int
    feat_idx: Optional[int] = None
    thresh: Optional[float] = None
    left: Optional["TreeNode"] = None
    right: Optional["TreeNode"] = None


class AsymmetricLossTreeFast:
    def __init__(self,
                 max_depth=5,
                 min_leaf=60,
                 max_features=0.7,  # int, fraction, or "sqrt"
                 thresholds_by_feature=None,
                 w_over=(0.8, 1.2, 1.8),
                 w_under=(1.0, 1.0, 1.0),
                 over_power=2.0,
                 under_power=2.0,
                 min_gain=1e-6,
                 random_state=42):
        self.max_depth = int(max_depth)
        self.min_leaf = int(min_leaf)
        self.max_features = max_features
        self.thresholds_by_feature = thresholds_by_feature
        self.w_over = tuple(w_over)
        self.w_under = tuple(w_under)
        self.over_power = float(over_power)
        self.under_power = float(under_power)
        self.min_gain = float(min_gain)
        self.random_state = int(random_state)
        self.root = None

    def _node_pred_mean(self, y: np.ndarray) -> np.ndarray:
        p = np.mean(y, axis=0)
        return normalize_actions(p.reshape(1, 3)).ravel()

    def _choose_feature_subset(self, d: int, rng: np.random.RandomState) -> np.ndarray:
        if isinstance(self.max_features, str) and self.max_features == "sqrt":
            k = max(1, int(math.sqrt(d)))
        elif isinstance(self.max_features, float):
            k = max(1, int(round(self.max_features * d)))
        else:
            k = max(1, min(int(self.max_features), d))
        return rng.choice(d, size=k, replace=False)

    def fit(self, X: np.ndarray, y: np.ndarray):
        if self.thresholds_by_feature is None:
            raise ValueError("thresholds_by_feature required")
        rng = np.random.RandomState(self.random_state)
        self.root = self._build(X, y, depth=0, rng=rng)
        return self

    def predict(self, X: np.ndarray) -> np.ndarray:
        out = np.zeros((X.shape[0], 3), dtype=float)
        for i in range(X.shape[0]):
            node = self.root
            while not node.is_leaf:
                node = node.left if X[i, node.feat_idx] <= node.thresh else node.right
            out[i] = node.pred
        return normalize_actions(out)

    def _build(self, X: np.ndarray, y: np.ndarray, depth: int, rng: np.random.RandomState) -> TreeNode:
        n, d = X.shape
        pred = self._node_pred_mean(y)

        if depth >= self.max_depth or n < 2 * self.min_leaf:
            return TreeNode(True, pred, n)

        parent_loss = node_loss_constant_pred(y, pred, self.w_over, self.w_under, self.over_power, self.under_power)

        feat_candidates = self._choose_feature_subset(d, rng)
        best_gain, best = 0.0, None

        for j in feat_candidates:
            xj = X[:, j]
            if np.all(xj == xj[0]):
                continue
            ths = self.thresholds_by_feature[j]
            if ths is None or len(ths) == 0:
                continue

            for t in ths:
                left_mask = (xj <= t)
                nl = int(left_mask.sum())
                nr = n - nl
                if nl < self.min_leaf or nr < self.min_leaf:
                    continue

                yl, yr = y[left_mask], y[~left_mask]
                pl, pr = self._node_pred_mean(yl), self._node_pred_mean(yr)

                loss_l = node_loss_constant_pred(yl, pl, self.w_over, self.w_under, self.over_power, self.under_power)
                loss_r = node_loss_constant_pred(yr, pr, self.w_over, self.w_under, self.over_power, self.under_power)

                child_loss = (nl / n) * loss_l + (nr / n) * loss_r
                gain = parent_loss - child_loss

                if gain > best_gain:
                    best_gain = gain
                    best = (j, float(t), left_mask)

        if best is None or best_gain < self.min_gain:
            return TreeNode(True, pred, n)

        j, t, left_mask = best
        left = self._build(X[left_mask], y[left_mask], depth + 1, rng)
        right = self._build(X[~left_mask], y[~left_mask], depth + 1, rng)
        return TreeNode(False, pred, n, feat_idx=j, thresh=t, left=left, right=right)


class AsymmetricLossRandomForestFast:
    def __init__(self,
                 n_estimators=200,
                 max_depth=5,
                 min_leaf=60,
                 max_features=0.7,
                 thresholds_by_feature=None,
                 bootstrap=True,
                 w_over=(0.8, 1.2, 1.8),
                 w_under=(1.0, 1.0, 1.0),
                 over_power=2.0,
                 under_power=2.0,
                 random_state=42):
        self.n_estimators = int(n_estimators)
        self.max_depth = int(max_depth)
        self.min_leaf = int(min_leaf)
        self.max_features = max_features
        self.thresholds_by_feature = thresholds_by_feature
        self.bootstrap = bool(bootstrap)
        self.w_over = tuple(w_over)
        self.w_under = tuple(w_under)
        self.over_power = float(over_power)
        self.under_power = float(under_power)
        self.random_state = int(random_state)
        self.trees = []
        self.boot_indices = []

    def fit(self, X: np.ndarray, y: np.ndarray):
        if self.thresholds_by_feature is None:
            raise ValueError("thresholds_by_feature required")

        n = X.shape[0]
        rng = np.random.RandomState(self.random_state)
        self.trees, self.boot_indices = [], []

        for b in range(self.n_estimators):
            idx = rng.randint(0, n, size=n) if self.bootstrap else np.arange(n)
            self.boot_indices.append(idx)

            tree = AsymmetricLossTreeFast(
                max_depth=self.max_depth,
                min_leaf=self.min_leaf,
                max_features=self.max_features,
                thresholds_by_feature=self.thresholds_by_feature,
                w_over=self.w_over,
                w_under=self.w_under,
                over_power=self.over_power,
                under_power=self.under_power,
                random_state=self.random_state + 1000 + b
            )
            tree.fit(X[idx], y[idx])
            self.trees.append(tree)

        return self

    def predict(self, X: np.ndarray) -> np.ndarray:
        preds = np.zeros((X.shape[0], 3), dtype=float)
        for t in self.trees:
            preds += t.predict(X)
        preds /= max(1, len(self.trees))
        return normalize_actions(preds)

    def oob_predict(self, X: np.ndarray) -> np.ndarray:
        n = X.shape[0]
        acc = np.zeros((n, 3), dtype=float)
        cnt = np.zeros(n, dtype=int)

        for tree, idx in zip(self.trees, self.boot_indices):
            inbag = np.zeros(n, dtype=bool)
            inbag[idx] = True
            oob_mask = ~inbag
            if not np.any(oob_mask):
                continue
            acc[oob_mask] += tree.predict(X[oob_mask])
            cnt[oob_mask] += 1

        base = self.predict(X)
        out = base.copy()
        ok = cnt > 0
        out[ok] = acc[ok] / cnt[ok].reshape(-1, 1)
        return normalize_actions(out)


# ============================================================
# 6) Interpretability helpers
# ============================================================
ACTION_NAMES = ["CHECK", "BET13", "BET23"]

def action_label(p: np.ndarray) -> str:
    p = np.asarray(p, dtype=float).ravel()
    top = int(np.argmax(p))
    conf = float(np.max(p))
    if conf >= 0.60:
        return f"Mostly {ACTION_NAMES[top]}"
    if conf >= 0.45:
        return f"Lean {ACTION_NAMES[top]}"
    return "Mix"

def leaf_summary(pred: np.ndarray, baseline: np.ndarray) -> str:
    pred = pred.ravel()
    baseline = baseline.ravel()
    lab = action_label(pred)
    conf = float(np.max(pred))
    d = pred - baseline
    return (f"{lab} (conf={conf:.2f}) | "
            f"Δvs baseline: chk {d[0]:+.2f}, b13 {d[1]:+.2f}, b23 {d[2]:+.2f}")

def print_tree_rules_pretty(node: TreeNode, feature_names: List[str], baseline: np.ndarray, indent=""):
    if node.is_leaf:
        p = node.pred
        msg = leaf_summary(p, baseline)
        print(f"{indent}LEAF n={node.n} pred=[{p[0]:.2f}, {p[1]:.2f}, {p[2]:.2f}] :: {msg}")
        return
    fname = feature_names[node.feat_idx]
    print(f"{indent}IF {fname} <= {node.thresh:.4f}:")
    print_tree_rules_pretty(node.left, feature_names, baseline, indent + "  ")
    print(f"{indent}ELSE:  # {fname} > {node.thresh:.4f}")
    print_tree_rules_pretty(node.right, feature_names, baseline, indent + "  ")

def count_leaves(node: TreeNode) -> int:
    if node.is_leaf:
        return 1
    return count_leaves(node.left) + count_leaves(node.right)


# ============================================================
# 7) Mistake analysis (presentation-friendly metrics)
# ============================================================
def mistake_report(y_true: np.ndarray,
                   y_pred: np.ndarray,
                   name: str,
                   big_misfire_b23_pred=0.50,
                   big_misfire_b23_true=0.10,
                   big_misfire_b13_pred=0.60,
                   big_misfire_b13_true=0.15) -> pd.DataFrame:
    y_true = normalize_actions(y_true)
    y_pred = normalize_actions(y_pred)
    err = y_pred - y_true

    over_b23_amt = float(np.mean(np.maximum(err[:, 2], 0.0)))
    over_b13_amt = float(np.mean(np.maximum(err[:, 1], 0.0)))
    over_chk_amt = float(np.mean(np.maximum(err[:, 0], 0.0)))

    misfire_b23 = float(np.mean((y_pred[:, 2] >= big_misfire_b23_pred) & (y_true[:, 2] <= big_misfire_b23_true)))
    misfire_b13 = float(np.mean((y_pred[:, 1] >= big_misfire_b13_pred) & (y_true[:, 1] <= big_misfire_b13_true)))

    mae_chk = float(np.mean(np.abs(err[:, 0])))
    mae_b13 = float(np.mean(np.abs(err[:, 1])))
    mae_b23 = float(np.mean(np.abs(err[:, 2])))

    out = pd.DataFrame([{
        "model": name,
        "asym_loss": asymmetric_loss(y_true, y_pred, **LOSS),
        "mean_over_chk": over_chk_amt,
        "mean_over_b13": over_b13_amt,
        "mean_over_b23": over_b23_amt,
        "big_misfire_b23_rate": misfire_b23,
        "big_misfire_b13_rate": misfire_b13,
        "MAE_chk": mae_chk,
        "MAE_b13": mae_b13,
        "MAE_b23": mae_b23
    }])
    return out


# ============================================================
# 8) Load + filter + feature build
# ============================================================
print("=== LOAD DATA ===")
df0 = pd.read_csv(DATA_PATH)
print("Rows / Cols:", df0.shape)

require_cols(df0, [VIL_CHECKS_COL] + Y_COLS, where="filter/targets")
df0[VIL_CHECKS_COL] = df0[VIL_CHECKS_COL].apply(parse_percent_cell)

print(f"\n=== FILTER: keep ({VIL_CHECKS_COL} == -1) OR >= {VIL_CHECKS_KEEP_THRESHOLD} ===")
before = df0.shape[0]
mask_keep = (df0[VIL_CHECKS_COL] == -1) | (df0[VIL_CHECKS_COL] >= VIL_CHECKS_KEEP_THRESHOLD)
df = df0.loc[mask_keep].copy()
after = df.shape[0]
print(f"Before: {before} | After: {after} | Dropped: {before-after}")

for c in Y_COLS:
    df[c] = df[c].apply(parse_percent_cell)

y_raw = df[Y_COLS].values.astype(float)
if np.nanmax(y_raw) > 1.5:
    print("Targets look like percents. Dividing by 100.")
    y_raw = y_raw / 100.0
y = normalize_actions(y_raw)

print("\nTargets sanity check (first 5):")
tmp = pd.DataFrame(y[:5], columns=["CHECK","BET13","BET23"])
tmp["sum"] = tmp.sum(axis=1)
print(tmp)

print("\n=== ADD SEMANTIC FEATURES ===")
df = add_semantic_features(df)

semantic_cols = [
    "HERO_NUTS","VIL_NUTS","NUTS_ADV",
    "HERO_STRONG","VIL_STRONG","STRONG_ADV",
    "HERO_TP_PLUS","VIL_TP_PLUS","TP_PLUS_ADV",
    "HERO_AIR","VIL_AIR","AIR_ADV",
    "DRAW_ADV","COMBO_DRAW_ADV",
    "HERO_MEDIUM","VIL_MEDIUM",
    "HERO_POLAR","VIL_POLAR","POLAR_ADV"
]
require_cols(df, semantic_cols, where="semantic features")

feature_df = df[semantic_cols].copy()

# SPR
if SPR_COL in df.columns:
    feature_df[SPR_COL] = pd.to_numeric(df[SPR_COL], errors="coerce")
    med = float(feature_df[SPR_COL].median()) if np.isfinite(feature_df[SPR_COL].median()) else 0.0
    feature_df[SPR_COL] = feature_df[SPR_COL].fillna(med)
else:
    feature_df[SPR_COL] = 0.0

# One-hot categories
present_cat_cols = [c for c in CAT_COLS if c in df.columns]
if present_cat_cols:
    cats = df[present_cat_cols].astype("string").fillna("nan")
    feature_df = pd.concat([feature_df, pd.get_dummies(cats, prefix=present_cat_cols, drop_first=False)], axis=1)

feature_df = feature_df.replace([np.inf, -np.inf], np.nan).fillna(0.0)

X = feature_df.values.astype(float)
feature_names = list(feature_df.columns)

print("\n=== FINAL FEATURE MATRIX ===")
print("X shape:", X.shape)
print("First 25 features:", feature_names[:25])


# ============================================================
# 9) Clean split: Train / Val / Test (test never used for choices)
# ============================================================
print("\n=== SPLITS (train/val/test) ===")
X_trainval, X_test, y_trainval, y_test, df_trainval, df_test = train_test_split(
    X, y, df, test_size=0.20, random_state=BASE_RANDOM_STATE
)

X_train, X_val, y_train, y_val, df_train, df_val = train_test_split(
    X_trainval, y_trainval, df_trainval, test_size=0.25, random_state=BASE_RANDOM_STATE
)
# => 60% train, 20% val, 20% test

print("Train:", X_train.shape, "Val:", X_val.shape, "Test:", X_test.shape)

print("\n=== PRECOMPUTE GLOBAL THRESHOLDS (train only) ===")
t0 = time.time()
thresholds_by_feature = build_global_thresholds(X_train, GLOBAL_THRESHOLDS_PER_FEATURE)
print(f"Done. seconds={time.time()-t0:.2f} | per-feature thresholds={GLOBAL_THRESHOLDS_PER_FEATURE}")

# tuning subset from TRAIN only
rng = np.random.RandomState(BASE_RANDOM_STATE)
if X_train.shape[0] > TUNING_SUBSAMPLE_N:
    idx = rng.choice(X_train.shape[0], size=TUNING_SUBSAMPLE_N, replace=False)
else:
    idx = np.arange(X_train.shape[0])
X_tune, y_tune = X_train[idx], y_train[idx]
print(f"\n=== TUNING SUBSET ===\nUsing {X_tune.shape[0]} / {X_train.shape[0]} training rows for tuning.")


# ============================================================
# 10) Tuning Teacher: choose best by VAL loss (stability across seeds)
# ============================================================
print("\n=== TUNING TEACHER (pick by VAL loss; report seed variance) ===")

CANDIDATES = [
    dict(n_estimators=140, max_depth=5, min_leaf=60, max_features=0.7, bootstrap=True, name="A"),
    dict(n_estimators=180, max_depth=5, min_leaf=60, max_features=0.7, bootstrap=True, name="B"),
    dict(n_estimators=200, max_depth=5, min_leaf=60, max_features=0.7, bootstrap=True, name="C"),
    dict(n_estimators=180, max_depth=4, min_leaf=60, max_features=0.7, bootstrap=True, name="D"),
    dict(n_estimators=180, max_depth=5, min_leaf=80, max_features=0.7, bootstrap=True, name="E"),
    dict(n_estimators=220, max_depth=5, min_leaf=60, max_features=0.8, bootstrap=True, name="F"),
    dict(n_estimators=160, max_depth=5, min_leaf=60, max_features=0.7, bootstrap=True, name="G"),
    dict(n_estimators=160, max_depth=4, min_leaf=60, max_features=0.7, bootstrap=True, name="H"),
]
GRID = CANDIDATES[:MAX_CONFIGS]

def fit_eval_once(cfg: Dict, seed: int) -> Dict:
    forest = AsymmetricLossRandomForestFast(
        n_estimators=cfg["n_estimators"],
        max_depth=cfg["max_depth"],
        min_leaf=cfg["min_leaf"],
        max_features=cfg["max_features"],
        thresholds_by_feature=thresholds_by_feature,
        bootstrap=cfg["bootstrap"],
        random_state=seed,
        **LOSS
    )
    t0 = time.time()
    forest.fit(X_tune, y_tune)
    fit_sec = time.time() - t0

    pred_tune = forest.predict(X_tune)
    pred_val  = forest.predict(X_val)
    pred_oob  = forest.oob_predict(X_tune)

    return dict(
        cfg_name=cfg["name"],
        seed=seed,
        fit_seconds=fit_sec,
        tune_loss=asymmetric_loss(y_tune, pred_tune, **LOSS),
        val_loss=asymmetric_loss(y_val, pred_val, **LOSS),
        oob_loss=asymmetric_loss(y_tune, pred_oob, **LOSS),
    )

tune_rows = []
seeds = [BASE_RANDOM_STATE + 10*s for s in range(N_SEEDS_FOR_STABILITY)]

for i, cfg in enumerate(GRID, start=1):
    print("\n" + "-"*70)
    print(f"CONFIG {i}/{len(GRID)} {cfg['name']}: {cfg}")
    per_seed = []
    for seed in seeds:
        r = fit_eval_once(cfg, seed)
        per_seed.append(r)
        print(f"  seed={seed} | fit={r['fit_seconds']:.2f}s | tune={r['tune_loss']:.6f} | val={r['val_loss']:.6f} | oob={r['oob_loss']:.6f}")
    val_losses = [r["val_loss"] for r in per_seed]
    oob_losses = [r["oob_loss"] for r in per_seed]
    fit_secs   = [r["fit_seconds"] for r in per_seed]
    tune_rows.append(dict(
        name=cfg["name"],
        n_estimators=cfg["n_estimators"],
        max_depth=cfg["max_depth"],
        min_leaf=cfg["min_leaf"],
        max_features=cfg["max_features"],
        bootstrap=cfg["bootstrap"],
        val_loss_mean=float(np.mean(val_losses)),
        val_loss_std=float(np.std(val_losses)),
        oob_loss_mean=float(np.mean(oob_losses)),
        fit_seconds_mean=float(np.mean(fit_secs)),
    ))

tune_df = pd.DataFrame(tune_rows).sort_values(["val_loss_mean","val_loss_std"]).reset_index(drop=True)
print("\n=== TEACHER TUNING SUMMARY (sorted by mean val loss, then stability) ===")
print(tune_df.to_string(index=False))

best_row = tune_df.iloc[0].to_dict()
print("\n=== BEST TEACHER CONFIG (by VAL mean + stability) ===")
print(best_row)


# ============================================================
# 11) Fit FINAL Teacher on TRAIN+VAL, evaluate on VAL + TEST
# ============================================================
print("\n" + "="*70)
print("=== FINAL TEACHER (fit TRAIN+VAL) ===")

X_tv = np.vstack([X_train, X_val])
y_tv = np.vstack([y_train, y_val])

teacher_params = dict(
    n_estimators=FINAL_TEACHER_N_TREES,
    max_depth=int(best_row["max_depth"]),
    min_leaf=int(best_row["min_leaf"]),
    max_features=float(best_row["max_features"]),
    bootstrap=bool(best_row["bootstrap"]),
    thresholds_by_feature=thresholds_by_feature,
    random_state=BASE_RANDOM_STATE,
    **LOSS
)

print("Teacher params:", teacher_params)

t0 = time.time()
teacher = AsymmetricLossRandomForestFast(**teacher_params)
teacher.fit(X_tv, y_tv)
print(f"Teacher fit seconds: {time.time()-t0:.2f}")

teacher_pred_val  = teacher.predict(X_val)
teacher_pred_test = teacher.predict(X_test)

teacher_val_loss  = asymmetric_loss(y_val,  teacher_pred_val,  **LOSS)
teacher_test_loss = asymmetric_loss(y_test, teacher_pred_test, **LOSS)

print("\n=== TEACHER SUMMARY ===")
print(f"Teacher asym_loss (VAL) : {teacher_val_loss:.6f}")
print(f"Teacher asym_loss (TEST): {teacher_test_loss:.6f}")

teach_mean_pred_test = mean_vec(teacher_pred_test)
true_mean_test = mean_vec(y_test)
print("\nTeacher mean strategy on TEST (pred vs true):")
print(f"  pred: check={teach_mean_pred_test[0]:.4f}, b13={teach_mean_pred_test[1]:.4f}, b23={teach_mean_pred_test[2]:.4f}")
print(f"  true: check={true_mean_test[0]:.4f}, b13={true_mean_test[1]:.4f}, b23={true_mean_test[2]:.4f}")

teacher_mist_test = mistake_report(
    y_test, teacher_pred_test, name="TEACHER_FOREST",
    big_misfire_b23_pred=BIG_MISFIRE_B23_PRED,
    big_misfire_b23_true=BIG_MISFIRE_B23_TRUE,
    big_misfire_b13_pred=BIG_MISFIRE_B13_PRED,
    big_misfire_b13_true=BIG_MISFIRE_B13_TRUE
)
print("\n=== TEACHER mistake report (TEST) ===")
print(teacher_mist_test.to_string(index=False))


# ============================================================
# 12) Depth sweep (NO TEST LEAKAGE):
#     - Students fit on TRAIN only (to mimic teacher on TRAIN)
#     - Fidelity measured on VAL vs teacher preds on VAL
#     - Performance measured on VAL vs TRUE y_val
# ============================================================
print("\n" + "="*70)
print("=== DEPTH SWEEP (VALIDATION ONLY for choosing depth) ===")
print("Depths:", DEPTH_GRID, "| min_leaf:", MIN_LEAF_GRID)

teacher_pred_train = teacher.predict(X_train)
teacher_pred_val   = teacher.predict(X_val)

rows_fid = []
rows_val = []

for min_leaf in MIN_LEAF_GRID:
    for depth in DEPTH_GRID:
        t0 = time.time()
        stud = AsymmetricLossTreeFast(
            max_depth=depth,
            min_leaf=min_leaf,
            max_features=1.0,
            thresholds_by_feature=thresholds_by_feature,
            random_state=SEED,
            **LOSS
        )
        stud.fit(X_train, teacher_pred_train)
        fit_sec = time.time() - t0

        n_leaves = count_leaves(stud.root)

        stud_pred_val = stud.predict(X_val)

        # Fidelity (VAL): student vs teacher
        fid_mse_val = float(np.mean((stud_pred_val - teacher_pred_val) ** 2))

        # Performance (VAL): student vs TRUE solver
        val_loss_true = asymmetric_loss(y_val, stud_pred_val, **LOSS)

        # mean strategy alignment (VAL) as a sanity check
        stud_mean_val = mean_vec(stud_pred_val)
        true_mean_val = mean_vec(y_val)
        mean_L1_gap = float(np.sum(np.abs(stud_mean_val - true_mean_val)))

        rows_fid.append(dict(
            depth=depth,
            min_leaf=min_leaf,
            n_leaves=n_leaves,
            fit_seconds=round(fit_sec, 3),
            fidelity_MSE_val=fid_mse_val
        ))

        rows_val.append(dict(
            depth=depth,
            min_leaf=min_leaf,
            n_leaves=n_leaves,
            val_asym_loss_vs_TRUE=val_loss_true,
            mean_pred_CHECK_val=stud_mean_val[0],
            mean_pred_BET13_val=stud_mean_val[1],
            mean_pred_BET23_val=stud_mean_val[2],
            mean_true_CHECK_val=true_mean_val[0],
            mean_true_BET13_val=true_mean_val[1],
            mean_true_BET23_val=true_mean_val[2],
            mean_L1_gap_pred_vs_true_val=mean_L1_gap,
            fidelity_hint=fid_mse_val
        ))

fid_df = pd.DataFrame(rows_fid).sort_values(["fidelity_MSE_val","depth","min_leaf"]).reset_index(drop=True)
val_df = pd.DataFrame(rows_val).sort_values(["val_asym_loss_vs_TRUE","depth","min_leaf"]).reset_index(drop=True)

print("\n=== TABLE 1: Fidelity vs Depth on VALIDATION ===")
print(fid_df.to_string(index=False))

print("\n=== TABLE 2: Validation Performance vs Depth ===")
print(val_df.to_string(index=False))

best_fid = float(fid_df["fidelity_MSE_val"].min())
close_cut = best_fid * 1.05
cands = fid_df[fid_df["fidelity_MSE_val"] <= close_cut].copy()
cands = cands.sort_values(["n_leaves","depth","min_leaf","fidelity_MSE_val"]).reset_index(drop=True)

print("\n=== DEPTH PICK HELPER (VALIDATION fidelity only) ===")
print(f"Best fidelity_MSE_val = {best_fid:.6f} | +5% cutoff = {close_cut:.6f}")
if len(cands) == 0:
    print("No candidate met +5% rule (unexpected). Consider +10%.")
else:
    rec = cands.iloc[0].to_dict()
    print(f"Smallest tree within +5% fidelity: depth={int(rec['depth'])}, min_leaf={int(rec['min_leaf'])}, n_leaves={int(rec['n_leaves'])}, fidelity_MSE_val={rec['fidelity_MSE_val']:.6f}")


# ============================================================
# 13) FINAL Born-Again tree:
#     depth=4, min_leaf=60
#     Train on TRAIN+VAL to mimic teacher preds on TRAIN+VAL
# ============================================================
print("\n" + "="*70)
print("=== FINAL BORN-AGAIN TREE (depth=4, min_leaf=60) ===")

teacher_pred_tv = teacher.predict(X_tv)

t0 = time.time()
born_again = AsymmetricLossTreeFast(
    max_depth=FINAL_DEPTH,
    min_leaf=FINAL_MIN_LEAF,
    max_features=1.0,
    thresholds_by_feature=thresholds_by_feature,
    random_state=BASE_RANDOM_STATE,
    **LOSS
)
born_again.fit(X_tv, teacher_pred_tv)
fit_sec = time.time() - t0

born_again_pred_test = born_again.predict(X_test)
born_again_test_loss = asymmetric_loss(y_test, born_again_pred_test, **LOSS)
born_again_fid_test  = float(np.mean((born_again_pred_test - teacher_pred_test) ** 2))
born_again_leaves    = count_leaves(born_again.root)

print(f"Born-again fit seconds: {fit_sec:.3f} | n_leaves={born_again_leaves}")
print(f"Born-again asym_loss vs TRUE (TEST): {born_again_test_loss:.6f}")
print(f"Born-again fidelity MSE vs TEACHER (TEST): {born_again_fid_test:.6f}")

ba_mean_pred_test = mean_vec(born_again_pred_test)
print("\nBorn-again mean strategy on TEST (pred vs true):")
print(f"  pred: check={ba_mean_pred_test[0]:.4f}, b13={ba_mean_pred_test[1]:.4f}, b23={ba_mean_pred_test[2]:.4f}")
print(f"  true: check={true_mean_test[0]:.4f}, b13={true_mean_test[1]:.4f}, b23={true_mean_test[2]:.4f}")

born_again_mist_test = mistake_report(
    y_test, born_again_pred_test, name="BORN_AGAIN_TREE_d4_ml60",
    big_misfire_b23_pred=BIG_MISFIRE_B23_PRED,
    big_misfire_b23_true=BIG_MISFIRE_B23_TRUE,
    big_misfire_b13_pred=BIG_MISFIRE_B13_PRED,
    big_misfire_b13_true=BIG_MISFIRE_B13_TRUE
)
print("\n=== Born-Again mistake report (TEST) ===")
print(born_again_mist_test.to_string(index=False))

print("\n=== BORN-AGAIN TREE RULES (for slides) ===")
baseline = mean_vec(y_tv)
print(f"Baseline (avg solver on TRAIN+VAL) = [check={baseline[0]:.2f}, b13={baseline[1]:.2f}, b23={baseline[2]:.2f}]")
print_tree_rules_pretty(born_again.root, feature_names, baseline)


# ============================================================
# 14) Simple Decision Tree baseline (same size as born-again):
#     depth=4, min_leaf=60
#     Train on TRAIN+VAL to predict TRUE solver y
# ============================================================
print("\n" + "="*70)
print("=== SIMPLE TREE (depth=4, min_leaf=60) trained on TRUE y ===")

t0 = time.time()
simple_tree = AsymmetricLossTreeFast(
    max_depth=FINAL_DEPTH,
    min_leaf=FINAL_MIN_LEAF,
    max_features=1.0,
    thresholds_by_feature=thresholds_by_feature,
    random_state=BASE_RANDOM_STATE,
    **LOSS
)
simple_tree.fit(X_tv, y_tv)
fit_sec = time.time() - t0

simple_pred_test = simple_tree.predict(X_test)
simple_test_loss = asymmetric_loss(y_test, simple_pred_test, **LOSS)
simple_leaves    = count_leaves(simple_tree.root)

print(f"Simple tree fit seconds: {fit_sec:.3f} | n_leaves={simple_leaves}")
print(f"Simple tree asym_loss vs TRUE (TEST): {simple_test_loss:.6f}")

simp_mean_pred_test = mean_vec(simple_pred_test)
print("\nSimple tree mean strategy on TEST (pred vs true):")
print(f"  pred: check={simp_mean_pred_test[0]:.4f}, b13={simp_mean_pred_test[1]:.4f}, b23={simp_mean_pred_test[2]:.4f}")
print(f"  true: check={true_mean_test[0]:.4f}, b13={true_mean_test[1]:.4f}, b23={true_mean_test[2]:.4f}")

simple_mist_test = mistake_report(
    y_test, simple_pred_test, name="SIMPLE_TREE_d4_ml60",
    big_misfire_b23_pred=BIG_MISFIRE_B23_PRED,
    big_misfire_b23_true=BIG_MISFIRE_B23_TRUE,
    big_misfire_b13_pred=BIG_MISFIRE_B13_PRED,
    big_misfire_b13_true=BIG_MISFIRE_B13_TRUE
)
print("\n=== Simple tree mistake report (TEST) ===")
print(simple_mist_test.to_string(index=False))

print("\n=== SIMPLE TREE RULES ===")
print_tree_rules_pretty(simple_tree.root, feature_names, baseline)


# ============================================================
# 15) Save CSV with predictions (full filtered dataset)
# ============================================================
print("\n" + "="*70)
print("=== SAVE CSV WITH PREDICTIONS (FULL FILTERED DATASET) ===")

teacher_pred_all   = teacher.predict(X)
born_again_pred_all = born_again.predict(X)
simple_pred_all     = simple_tree.predict(X)

df_out = df.copy()
df_out["Y_CHECK"] = y[:, 0]
df_out["Y_BET13"] = y[:, 1]
df_out["Y_BET23"] = y[:, 2]

df_out["PRED_TEACHER_CHECK"] = teacher_pred_all[:, 0]
df_out["PRED_TEACHER_BET13"] = teacher_pred_all[:, 1]
df_out["PRED_TEACHER_BET23"] = teacher_pred_all[:, 2]

df_out["PRED_BORN_AGAIN_CHECK"] = born_again_pred_all[:, 0]
df_out["PRED_BORN_AGAIN_BET13"] = born_again_pred_all[:, 1]
df_out["PRED_BORN_AGAIN_BET23"] = born_again_pred_all[:, 2]

df_out["PRED_SIMPLE_CHECK"] = simple_pred_all[:, 0]
df_out["PRED_SIMPLE_BET13"] = simple_pred_all[:, 1]
df_out["PRED_SIMPLE_BET23"] = simple_pred_all[:, 2]

df_out.to_csv(OUT_PATH, index=False)
print("Saved:", OUT_PATH)

print("\n=== QUICK NUMBERS ===")
print(f"Teacher asym_loss TEST: {teacher_test_loss:.6f}")
print(f"Born-again d4/ml60 asym_loss TEST: {born_again_test_loss:.6f} | fidelity MSE TEST: {born_again_fid_test:.6f} | leaves: {born_again_leaves}")
print(f"Simple tree d4/ml60 asym_loss TEST: {simple_test_loss:.6f} | leaves: {simple_leaves}")


=== LOAD DATA ===
Rows / Cols: (4416, 72)

=== FILTER: keep (% Villain Checks == -1) OR >= 85.0 ===
Before: 4416 | After: 4267 | Dropped: 149
Targets look like percents. Dividing by 100.

Targets sanity check (first 5):
      CHECK     BET13     BET23  sum
0  0.079208  0.792079  0.128713  1.0
1  0.030000  0.900000  0.070000  1.0
2  0.030000  0.920000  0.050000  1.0
3  0.030000  0.790000  0.180000  1.0
4  0.040000  0.780000  0.180000  1.0

=== ADD SEMANTIC FEATURES ===

=== FINAL FEATURE MATRIX ===
X shape: (4267, 32)
First 25 features: ['HERO_NUTS', 'VIL_NUTS', 'NUTS_ADV', 'HERO_STRONG', 'VIL_STRONG', 'STRONG_ADV', 'HERO_TP_PLUS', 'VIL_TP_PLUS', 'TP_PLUS_ADV', 'HERO_AIR', 'VIL_AIR', 'AIR_ADV', 'DRAW_ADV', 'COMBO_DRAW_ADV', 'HERO_MEDIUM', 'VIL_MEDIUM', 'HERO_POLAR', 'VIL_POLAR', 'POLAR_ADV', 'SPR', 'Position_IP', 'Position_OOP', 'Full Houses_PAIRED', 'Full Houses_TRIPS', 'Full Houses_UNPAIRED']

=== SPLITS (train/val/test) ===
Train: (2559, 32) Val: (854, 32) Test: (854, 32)

=== PRECOM