# Setup

The first steps to get started are:
1. Get the setup command
2. Execute it in the cell below

### >> https://hub.crunchdao.com/competitions/structural-break/submit/notebook

In [None]:
# %pip install crunch-cli --upgrade --quiet --progress-bar off
# !crunch setup-notebook structural-break nRmrHs5mYNH8RZ2ksRNw1qLL

In [None]:
import crunch

# Load the Crunch Toolings
crunch = crunch.load_notebook()

loaded inline runner with module: <module '__main__'>

cli version: 8.0.0
available ram: 8.00 gb
available cpu: 8 core
----


# Import

In [None]:
# Standard library
import json
from pathlib import Path
from datetime import datetime
from joblib import dump, load
import re

# Typing
from typing import Iterable

# Data handling
import numpy as np
import pandas as pd

# FOr features
import math
import pywt  # for wavelet transforms
from scipy.special import (
    gammaincc,
)  # for chi-square survival function (ARCH-LM p-values)

# Machine learning frameworks
import lightgbm as lgb
import xgboost as xgb
from catboost import CatBoostClassifier
from xgboost.callback import EarlyStopping as XGBEarlyStop
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import StratifiedGroupKFold, BaseCrossValidator
from sklearn.base import ClassifierMixin, BaseEstimator, TransformerMixin
from sklearn.metrics import roc_auc_score
from scipy.stats import rankdata

# Hyperparameter optimization
import optuna
from optuna import Trial


# Config

In [None]:
# @crunch/keep:on

# ─────────────────────────────────────────────────────────────────────
# GLOBAL CONFIG
# ─────────────────────────────────────────────────────────────────────

RANDOM_STATE = 69
np.random.seed(RANDOM_STATE)

EPS = 1e-9

# ─────────────────────────────────────────────────────────────────────
# INFER CONFIG
# ─────────────────────────────────────────────────────────────────────

INFERENCE_MODE = "full"

# ─────────────────────────────────────────────────────────────────────
# STACKING CONFIG
# ─────────────────────────────────────────────────────────────────────

BASE_LEARNERS = (
    "xgb_main",
    "xgb_lite",
    "lgb_main",
    "cat_main",
)

N_SEEDS = 20
TOP_SEEDS = 2

FULL_REFIT = True
FULL_HP_SELECTION = "consensus"

# ─────────────────────────────────────────────────────────────────────
# TRAIN_MODEL CONFIG
# ─────────────────────────────────────────────────────────────────────

N_OPTUNA_TRIALS = 32
optuna.logging.set_verbosity(optuna.logging.ERROR)

K_OUTER = 5
K_MAX_INNER = 5
K_STOP_INNER = 1

TOPK_MIN_AUC = 0.52
TOPK_FEATURES = 60
TOPK_ALWAYS_KEEP = []

EXCLUDE_FEATURE_KEYWORDS = ["logit", "fisher", "logratio", "scaled"]


MAX_BIN = 64
EARLY_STOPPING = 100


MODEL_DIR = Path("resources/model")
MODEL_DIR.mkdir(exist_ok=True)

# ─────────────────────────────────────────────────────────────────────
# FEATURE EXTRACTION CONFIG
# ─────────────────────────────────────────────────────────────────────

# Moments
QUANTILE_COARSE_GRID = [0.10, 0.25, 0.50, 0.75, 0.90]

# Quantiles
QUANTILE_FINE_GRID = [0.05, 0.10, 0.25, 0.40, 0.50, 0.60, 0.75, 0.90, 0.95]
TOP_K = 5

# Crossing rates
W_FANO = 50
CROSSING_RATE_DEADBAND = 0.1

# Autocorrelation
ACF_MAX_LAG = 10
LBQ_M = 20

# Tests & Distances
JS_QUANTILE_BINS = np.linspace(0.0, 1.0, 33)
MMD_MAX_N = 512

# Frequency
FREQ_BANDS = ((0.00, 0.05), (0.05, 0.15), (0.15, 0.30), (0.30, 0.50))
DWT_WAVELET = "db2"
DWT_LEVEL = 3
ENTROPY_M1, ENTROPY_M2 = 3, 5
ENTROPY_TAU = 1

# Boundary
BOUND_EDGE = 5
BOUND_WINDOW_SIZES = [32, 128]
BOUND_SKIP_AFTER = 0
BOUND_ACF_MAX_LAG = 6
ARCH_L = 5
BOUND_OFFSETS = (0, 8, 16, 32)

# Rolling
ROLL_WINDOWS = (10, 20, 50, 100, 200)
ROLL_MIN_POS_PER_HALF = 20
ROLL_TOPK = 3
EWVAR_HALFLIVES = (200, 400)

# AR
AR_ORDER = 1
AR_RIDGE_LAMBDA = 1.0
AR_SCORE_CAP = 256

# ─────────────────────────────────────────────────────────────────────
# PREPROCESS CONFIG
# ─────────────────────────────────────────────────────────────────────


# Floors computed as the 5% bottom quantile of s0, s_dz, s_abs, s_dd
S0_FLOOR = 0.0008216
S_DZ_FLOOR = 1.0909
S_ABS_FLOOR = 0.5555
S_DD_FLOOR = 1.6383
CLIP_QLOW, CLIP_QHIGH = 0.002, 0.998
CLIP_MIN_WIDTH = 1.0
CLIP_DEFAULT_BAND = 7.0  # computed as the 0.2% and 99.8% quantiles of z_before
FEAT_CACHE_DIR = Path("resources/features")
FEAT_CACHE_DIR.mkdir(exist_ok=True)

# @crunch/keep:off


# Preprocessing

In [None]:
# ─────────────────────────────────────────────────────────────────────
# Helper functions
# ─────────────────────────────────────────────────────────────────────


def pad_periods(X, period, max_len):
    """Return padded 2D array for one period (before/after)."""
    grouped = X.loc[X["period"] == period].groupby("id")["value"].apply(np.array)
    n_series = len(grouped)
    arr = np.full((n_series, max_len), np.nan, dtype=np.float32)
    for i, g in enumerate(grouped):
        arr[i, : len(g)] = g
    return arr, grouped.index


def mad(arr, axis=1):
    """Median absolute deviation along axis."""
    med = np.nanmedian(arr, axis=axis, keepdims=True)
    return np.nanmedian(np.abs(arr - med), axis=axis, keepdims=True)


def winsorize_pair(
    before,
    after,
    qlow=CLIP_QLOW,
    qhigh=CLIP_QHIGH,
    min_width=CLIP_MIN_WIDTH,  # required (qh-ql) span in standardized units
    default_band=CLIP_DEFAULT_BAND,
):
    """
    Winsorize both segments using BEFORE per-ID cutoffs; fallback to default_band if BEFORE span collapses.
    Inputs should already be standardized (z).
    """
    # per-ID BEFORE quantiles
    ql_id = np.nanquantile(before, qlow, axis=1, keepdims=True)
    qh_id = np.nanquantile(before, qhigh, axis=1, keepdims=True)

    width = qh_id - ql_id
    wide_enough = np.isfinite(width) & (width >= min_width)

    # choose fallback band
    ql_default, qh_default = -float(default_band), float(default_band)

    ql = np.where(wide_enough, ql_id, ql_default)
    qh = np.where(wide_enough, qh_id, qh_default)

    wb = np.where(np.isnan(before), np.nan, np.clip(before, ql, qh))
    wa = np.where(np.isnan(after), np.nan, np.clip(after, ql, qh))
    return wb.astype(np.float32), wa.astype(np.float32)


def fast_detrend_ols(arr, mean_center=True):
    """
    Vectorized linear detrend per row with NaN masks.
    Returns mean-centered residuals (float32).
    """
    y = arr.astype(np.float32, copy=False)
    _, T = y.shape

    # x-axis: 0..T-1, then center per row to improve numerics
    t = np.arange(T, dtype=np.float32)[None, :]  # (1, T)
    mask = ~np.isnan(y)  # (n, T)
    cnt = mask.sum(axis=1, keepdims=True).astype(np.float32)

    t_sum = (mask * t).sum(axis=1, keepdims=True)  # Σ t_i
    y_sum = np.nansum(y, axis=1, keepdims=True)  # Σ y_i

    t_bar = t_sum / cnt  # \bar t
    y_bar = y_sum / cnt  # \bar y

    tc = t - t_bar  # center x
    yc = np.where(mask, y - y_bar, 0.0)  # center y where valid

    num = (tc * yc * mask).sum(axis=1, keepdims=True)  # Σ (tc * yc)
    den = (tc * tc * mask).sum(axis=1, keepdims=True)  # Σ (tc^2)
    den = np.where(den <= EPS, EPS, den)  # floor

    slope = num / den
    intercept = y_bar - slope * t_bar

    yhat = intercept + slope * t
    resid = np.where(mask, y - yhat, np.nan)

    if mean_center:
        resid = resid - np.nanmean(resid, axis=1, keepdims=True)

    return resid.astype(np.float32)


# ─────────────────────────────────────────────────────────────────────
# Main pipeline
# ─────────────────────────────────────────────────────────────────────


def process_series(X_train):
    """Heavy NumPy preprocessing for the entire dataset."""

    # --- Build padded arrays
    max_before = X_train.loc[X_train["period"] == 0].groupby("id").size().max()
    max_after = X_train.loc[X_train["period"] == 1].groupby("id").size().max()
    before_arr, ids = pad_periods(X_train, 0, max_before)
    after_arr, _ = pad_periods(X_train, 1, max_after)

    # --- Standardization by before stats (robust)
    m0 = np.nanmedian(before_arr, axis=1, keepdims=True)
    s0 = 1.4826 * mad(before_arr)
    s0 = np.maximum(s0, S0_FLOOR)
    z_before = (before_arr - m0) / s0
    z_after = (after_arr - m0) / s0

    # --- Winsorized standardized (replaces fixed np.clip)
    zc_before, zc_after = winsorize_pair(z_before, z_after)

    # --- Detrended (per segment) on winsorized standardized
    zd_before = fast_detrend_ols(zc_before)
    zd_after = fast_detrend_ols(zc_after)

    # --- Segment-aware diffs on standardized (pre-winsor)
    dz_before = np.diff(z_before, axis=1, prepend=z_before[:, [0]])
    dz_after = np.diff(z_after, axis=1, prepend=z_after[:, [0]])

    m_dz = np.nanmedian(dz_before, axis=1, keepdims=True)
    s_dz = 1.4826 * mad(dz_before)
    s_dz = np.maximum(s_dz, S_DZ_FLOOR)

    d_before = (dz_before - m_dz) / s_dz
    d_after = (dz_after - m_dz) / s_dz

    # --- Winsorized diffs (replaces fixed clip on diffs)
    dc_before, dc_after = winsorize_pair(d_before, d_after)

    # --- Detrend of winsorized diffs
    dm_before = fast_detrend_ols(dc_before)
    dm_after = fast_detrend_ols(dc_after)

    # --- Absolute diffs (from normalized diffs)
    a_before = np.abs(d_before)
    a_after = np.abs(d_after)

    # Winsorize absolute diffs by before’s absolute-diff quantiles
    ac_before, ac_after = winsorize_pair(a_before, a_after)

    # Detrend of winsorized absolute diffs
    am_before = fast_detrend_ols(ac_before)
    am_after = fast_detrend_ols(ac_after)

    # --- Absolute values of z (pre-winsor)
    absz_before = np.abs(z_before)
    absz_after = np.abs(z_after)

    m_abs = np.nanmedian(absz_before, axis=1, keepdims=True)
    s_abs = 1.4826 * mad(absz_before)
    s_abs = np.maximum(s_abs, S_ABS_FLOOR)
    abs_std_before = (absz_before - m_abs) / s_abs
    abs_std_after = (absz_after - m_abs) / s_abs

    # Winsorize normalized |z|
    abs_c_before, abs_c_after = winsorize_pair(abs_std_before, abs_std_after)

    # Detrend of winsorized |z| normalized
    abs_m_before = fast_detrend_ols(abs_c_before)
    abs_m_after = fast_detrend_ols(abs_c_after)

    # Build z^2, winsorize using BEFORE quantiles, then detrend per segment
    sq_before = np.square(z_before)
    sq_after = np.square(z_after)

    sqc_before, sqc_after = winsorize_pair(sq_before, sq_after)
    sqm_before = fast_detrend_ols(sqc_before)
    sqm_after = fast_detrend_ols(sqc_after)

    # --- Second differences (curvature) on standardized z, per segment (no cross-period leakage)
    dd_raw_before = np.diff(dz_before, axis=1, prepend=dz_before[:, [0]])
    dd_raw_after = np.diff(dz_after, axis=1, prepend=dz_after[:, [0]])

    # Robust standardization by BEFORE stats (median/MAD), reuse the diff floor to be conservative
    m_dd = np.nanmedian(dd_raw_before, axis=1, keepdims=True)
    s_dd = 1.4826 * mad(dd_raw_before)
    s_dd = np.maximum(s_dd, S_DD_FLOOR)

    dd_before = (dd_raw_before - m_dd) / s_dd
    dd_after = (dd_raw_after - m_dd) / s_dd

    # Winsorize second differences using BEFORE-based cutoffs
    ddc_before, ddc_after = winsorize_pair(dd_before, dd_after)

    # Detrend per segment (OLS) on winsorized second differences
    ddm_before = fast_detrend_ols(ddc_before)
    ddm_after = fast_detrend_ols(ddc_after)

    # --- Stitch back into MultiIndex DataFrame
    out_list = []
    for i, id_val in enumerate(ids):
        Lb = np.count_nonzero(~np.isnan(before_arr[i]))
        La = np.count_nonzero(~np.isnan(after_arr[i]))
        n_total = Lb + La
        time_index = np.arange(n_total)

        df = pd.DataFrame(
            {
                "original": np.r_[before_arr[i, :Lb], after_arr[i, :La]],
                "period": np.r_[np.zeros(Lb, dtype=int), np.ones(La, dtype=int)],
                "standardized": np.r_[z_before[i, :Lb], z_after[i, :La]],
                "clipped": np.r_[zc_before[i, :Lb], zc_after[i, :La]],
                "detrended": np.r_[zd_before[i, :Lb], zd_after[i, :La]],
                "diff_standardized": np.r_[d_before[i, :Lb], d_after[i, :La]],
                "diff_detrended": np.r_[dm_before[i, :Lb], dm_after[i, :La]],
                "absdiff_detrended": np.r_[am_before[i, :Lb], am_after[i, :La]],
                "absval_detrended": np.r_[abs_m_before[i, :Lb], abs_m_after[i, :La]],
                "squared_detrended": np.r_[sqm_before[i, :Lb], sqm_after[i, :La]],
                "diff2_standardized": np.r_[dd_before[i, :Lb], dd_after[i, :La]],
                "diff2_detrended": np.r_[ddm_before[i, :Lb], ddm_after[i, :La]],
            },
            index=pd.MultiIndex.from_product(
                [[id_val], time_index], names=["id", "time"]
            ),
        )
        out_list.append(df)

    return pd.concat(out_list, axis=0)


# ─────────────────────────────────────────────────────────────────────
# File-handling wrapper
# ─────────────────────────────────────────────────────────────────────
def _latest_cache(prefix: str):
    files = sorted(
        FEAT_CACHE_DIR.glob(f"{prefix}_*.parquet"),
        key=lambda f: f.stat().st_mtime,
        reverse=True,
    )
    latest = files[0] if files else None
    return latest


def _save_cache(df: pd.DataFrame, prefix: str) -> Path:
    ts = datetime.now().strftime("%m%d_%H%M")
    path = FEAT_CACHE_DIR / f"{prefix}_{ts}.parquet"
    df.to_parquet(path)
    return path


def detect_non_finite(feats: pd.DataFrame):
    arr = feats.to_numpy(dtype=np.float32, copy=False)
    mask = ~np.isfinite(arr)
    if mask.any():
        r, c = np.where(mask)
        for i in range(min(5, len(r))):
            print(
                f"  at row={feats.index[r[i]]}, col={feats.columns[c[i]]}, val={arr[r[i], c[i]]}"
            )
    return


def build_preprocessed(X_train, force=False, inference=False):
    """Check cache, run process_series if needed, and save."""
    prefix = "preprocess"
    cache = _latest_cache(prefix)
    if cache and not force and not inference:
        print(f"Loading cached data from {cache}")
        return pd.read_parquet(cache)

    out = process_series(X_train)

    # Sanity check
    detect_non_finite(out)

    if not inference:
        _save_cache(out, prefix)
    return out


def save_essentials():
    """Create a lightweight preprocessed cache to study it in a notebook
    because the original 1GB preprocess file cannot be handled by my computer."""
    path = _latest_cache("preprocess")
    cols = ["original", "period", "standardized", "clipped", "detrended"]
    X_prep = pd.read_parquet(path)[cols]
    _save_cache(X_prep, "preprocess_essentials")


# Feature Extraction

In [None]:
# ─────────────────────────────────────────────────────────────────────
# MOMENTS BLOCK
# ─────────────────────────────────────────────────────────────────────


def _group_series(
    df: pd.DataFrame, col: str, period: int
) -> pd.core.groupby.SeriesGroupBy:
    return df.loc[df["period"] == period, col].groupby(level="id")


def _log(x):
    return np.log(np.maximum(x, EPS))


def _std(s):
    return np.sqrt(np.nanmean((s - np.nanmean(s)) ** 2))


def _skew_kurt(s):
    # classic (population) skew and excess kurtosis; NaN-safe
    m = np.nanmean(s)
    v = np.nanmean((s - m) ** 2)
    std = np.sqrt(v)
    z = (s - m) / std
    skew = np.nanmean(z**3)
    kurt = np.nanmean(z**4) - 3.0
    return float(skew), float(kurt)


def _mad(s):
    med = np.nanmedian(s)
    return np.nanmedian(np.abs(s - med))


def _ols_slope(y):
    # slope of y vs t (0..n-1), NaN-safe
    n = len(y)
    t = np.arange(n, dtype=np.float32)
    t_bar = t.mean()
    y_bar = y.mean()
    num = np.sum((t - t_bar) * (y - y_bar))
    den = np.sum((t - t_bar) ** 2)
    if den <= EPS:
        return 0.0
    return float(num / den)


def _topk_mean(arr, k=TOP_K):
    if arr.size == 0:
        return np.nan
    k = min(k, arr.size)
    # np.partition is O(n); take largest k, then mean
    part = np.partition(arr, -k)[-k:]
    return float(np.nanmean(part))


def _bottomk_mean(arr, k=TOP_K):
    k = min(k, arr.size)
    part = np.partition(arr, k - 1)[:k]
    return float(np.nanmean(part))


def compute_moments_block(
    X_prep: pd.DataFrame,
    force: bool = False,
    inference: bool = False,
    quantile_coarse_grid: list[float] = QUANTILE_COARSE_GRID,
) -> pd.DataFrame:
    # Load cache
    prefix = "moments"
    cache = _latest_cache(prefix)
    if cache and not force and not inference:
        print(f"Loading cached data from {cache}")
        return pd.read_parquet(cache)

    # Group series
    z_col = "standardized"
    clip_col = "clipped"

    z_b = _group_series(X_prep, z_col, 0)
    z_a = _group_series(X_prep, z_col, 1)
    zc_b = _group_series(X_prep, clip_col, 0)
    zc_a = _group_series(X_prep, clip_col, 1)
    ids = z_b.size().index

    # FEATURE BUILDING BLOCKS

    # Base moments
    mean_b = z_b.mean()
    mean_a = z_a.mean()
    std_b = z_b.apply(_std)
    std_a = z_a.apply(_std)
    skew_b = z_b.apply(lambda s: _skew_kurt(s.values)[0])
    skew_a = z_a.apply(lambda s: _skew_kurt(s.values)[0])
    kurt_b = z_b.apply(lambda s: _skew_kurt(s.values)[1])
    kurt_a = z_a.apply(lambda s: _skew_kurt(s.values)[1])

    # Medians & MADs
    med_b = z_b.median()
    med_a = z_a.median()
    mad_b = z_b.apply(lambda s: _mad(s.values))
    mad_a = z_a.apply(lambda s: _mad(s.values))

    # Base quantiles
    qs = quantile_coarse_grid
    Q10_b, Q25_b, Q50_b, Q75_b, Q90_b = [z_b.quantile(q) for q in qs]
    Q10_a, Q25_a, Q50_a, Q75_a, Q90_a = [z_a.quantile(q) for q in qs]

    # Robust skew via central asymmetry
    rob_skew_b = _log((Q75_b - Q50_b) / (Q50_b - Q25_b + EPS))
    rob_skew_a = _log((Q75_a - Q50_a) / (Q50_a - Q25_a + EPS))

    # Robust kurtosis pieces: IQR, IDR
    iqr_b = Q75_b - Q25_b
    iqr_a = Q75_a - Q25_a
    idr_b = Q90_b - Q10_b
    idr_a = Q90_a - Q10_a

    # Trend (slope) on clipped/winsorized standardized series
    slope_b = zc_b.apply(lambda s: _ols_slope(s.values))
    slope_a = zc_a.apply(lambda s: _ols_slope(s.values))

    # FEATURE COMPUTATION

    # Robust location shift, Δmedian/MAD (on original) = median_after (on standardized)
    med_delta = med_a - med_b  # med_b should be 0

    # Robust scale shift
    mad_logratio = _log((mad_a + EPS) / (mad_b + EPS))

    # Classic-vs-robust contrasts (per segment)
    mean_vs_med = mean_a - med_a - (mean_b - med_b)
    std_vs_mad = _log(std_a / (1.4826 * mad_a + EPS)) - _log(
        std_b / (1.4826 * mad_b + EPS)
    )
    skew_contrast = skew_a - rob_skew_a - skew_b - rob_skew_b
    kurt_contrast = (
        kurt_a - _log(idr_a / (iqr_a + EPS)) - (kurt_b - _log(idr_b / (iqr_b + EPS)))
    )

    # Slope
    slope_delta = slope_a - slope_b

    # Assemble
    df = pd.DataFrame(
        {
            # robust moments
            "med_delta": med_delta,
            "mad_logratio": mad_logratio,
            # contrasts
            "mean_vs_med": mean_vs_med,
            "std_vs_mad": std_vs_mad,
            "skew_contrast": skew_contrast,
            "kurt_contrast": kurt_contrast,
            # slope
            "slope_delta": slope_delta,
        },
        index=ids,
        dtype=np.float32,
    )
    if not inference:
        _save_cache(df, prefix)
    return df


# ─────────────────────────────────────────────────────────────────────
# QUANTILES BLOCK
# ─────────────────────────────────────────────────────────────────────


def compute_quantiles_block(
    X_prep: pd.DataFrame,
    force: bool = False,
    inference: bool = False,
    quantile_fine_grid: list[float] = QUANTILE_FINE_GRID,
    top_k: int = TOP_K,
) -> pd.DataFrame:
    # Load cache
    prefix = "quantiles"
    cache = _latest_cache(prefix)
    if cache and not force and not inference:
        print(f"Loading cached data from {cache}")
        return pd.read_parquet(cache)

    # Group series
    z_col = "standardized"
    z_b = _group_series(X_prep, z_col, 0)
    z_a = _group_series(X_prep, z_col, 1)
    ids = z_b.size().index

    # FEATURE BUILDING BLOCKS

    # Base quantiles
    qs = quantile_fine_grid
    qb = z_b.quantile(q=qs).unstack(level=-1)
    qa = z_a.quantile(q=qs).unstack(level=-1)
    Q5_b, Q10_b, Q25_b, Q40_b, Q50_b, Q60_b, Q75_b, Q90_b, Q95_b = [qb[q] for q in qs]
    Q5_a, Q10_a, Q25_a, Q40_a, Q50_a, Q60_a, Q75_a, Q90_a, Q95_a = [qa[q] for q in qs]

    IQR_b = Q75_b - Q25_b
    IQR_a = Q75_a - Q25_a
    IDR_b = Q90_b - Q10_b
    IDR_a = Q90_a - Q10_a

    IDR_logratio = _log((IDR_a + EPS) / (IDR_b + EPS))

    # Tail asymmetry (robust skew)
    central_asym_b = _log((Q75_b - Q50_b) / (Q50_b - Q25_b + EPS))
    central_asym_a = _log((Q75_a - Q50_a) / (Q50_a - Q25_a + EPS))

    shoulder_asym_b = _log((Q90_b - Q75_b) / (Q25_b - Q10_b + EPS))
    shoulder_asym_a = _log((Q90_a - Q75_a) / (Q25_a - Q10_a + EPS))

    tail_asym_b = _log((Q95_b - Q90_b) / (Q10_b - Q5_b + EPS))
    tail_asym_a = _log((Q95_a - Q90_a) / (Q10_a - Q5_a + EPS))

    # Tail thickness / peakedness (robust kurtosis)
    central_weight_b = _log(IQR_b / (Q60_b - Q40_b + EPS))
    central_weight_a = _log(IQR_a / (Q60_a - Q40_a + EPS))

    shoulder_weight_b = _log(IDR_b / (IQR_b + EPS))
    shoulder_weight_a = _log(IDR_a / (IQR_a + EPS))

    tail_weight_b = _log((Q95_b - Q5_b) / (IDR_b + EPS))
    tail_weight_a = _log((Q95_a - Q5_a) / (IDR_a + EPS))

    # Moors-like variant: ((Q90 - Q60) - (Q40 - Q10)) / (Q75 - Q25)
    moors_b = ((Q90_b - Q60_b) - (Q40_b - Q10_b)) / (IQR_b + EPS)
    moors_a = ((Q90_a - Q60_a) - (Q40_a - Q10_a)) / (IQR_a + EPS)

    # Tail decay rates
    central_decay_b = _log((Q75_b - Q60_b) / (Q60_b - Q50_b + EPS))
    central_decay_a = _log((Q75_a - Q60_a) / (Q60_a - Q50_a + EPS))

    shoulder_decay_b = _log((Q90_b - Q75_b) / (Q75_b - Q60_b + EPS))
    shoulder_decay_a = _log((Q90_a - Q75_a) / (Q75_a - Q60_a + EPS))

    tail_decay_b = _log((Q95_b - Q90_b) / (Q90_b - Q75_b + EPS))
    tail_decay_a = _log((Q95_a - Q90_a) / (Q90_a - Q75_a + EPS))

    # Tail extremes (top/bottom k means) relative to IDR
    topk_b = z_b.apply(lambda s: _topk_mean(s.values, top_k))
    topk_a = z_a.apply(lambda s: _topk_mean(s.values, top_k))
    botk_b = z_b.apply(lambda s: _bottomk_mean(s.values, top_k))
    botk_a = z_a.apply(lambda s: _bottomk_mean(s.values, top_k))

    upper_ext_b = _log(np.maximum(topk_b, EPS) / (IDR_b + EPS))
    upper_ext_a = _log(np.maximum(topk_a, EPS) / (IDR_a + EPS))

    # For lower extension we expect botk to be negative; use -mean_bottom_k
    lower_ext_b = _log(np.maximum(-botk_b, EPS) / (IDR_b + EPS))
    lower_ext_a = _log(np.maximum(-botk_a, EPS) / (IDR_a + EPS))

    # Extreme tail asymmetry: log(mean_top_k / (-mean_bottom_k))
    # Guard denominator sign; if mean_bottom_k >= 0, push to EPS to avoid invalid log.
    eta_b = _log(np.maximum(topk_b, EPS) / np.maximum(-botk_b, EPS))
    eta_a = _log(np.maximum(topk_a, EPS) / np.maximum(-botk_a, EPS))

    # FEATURE COMPUTATION

    # Base quantiles
    Q5_delta = Q5_a - Q5_b
    Q10_delta = Q10_a - Q10_b
    Q25_delta = Q25_a - Q25_b
    Q40_delta = Q40_a - Q40_b
    Q60_delta = Q60_a - Q60_b
    Q75_delta = Q75_a - Q75_b
    Q90_delta = Q90_a - Q90_b
    Q95_delta = Q95_a - Q95_b

    # Tail asymmetry (robust skew)
    central_asym = central_asym_a - central_asym_b
    shoulder_asym = shoulder_asym_a - shoulder_asym_b
    tail_asym = tail_asym_a - tail_asym_b

    # Tail thickness / peakedness (robust kurtosis)
    central_weight = central_weight_a - central_weight_b
    shoulder_weight = shoulder_weight_a - shoulder_weight_b
    tail_weight = tail_weight_a - tail_weight_b
    moors = moors_a - moors_b

    # Tail decay rates
    central_decay = central_decay_a - central_decay_b
    shoulder_decay = shoulder_decay_a - shoulder_decay_b
    tail_decay = tail_decay_a - tail_decay_b

    # Tail extremes
    upper_ext = upper_ext_a - upper_ext_b
    lower_ext = lower_ext_a - lower_ext_b
    extreme_tail_asym = eta_a - eta_b

    # Assemble
    q_df = pd.DataFrame(
        {
            # base quantiles
            "Q5_delta": Q5_delta,
            "Q10_delta": Q10_delta,
            "Q25_delta": Q25_delta,
            "Q40_delta": Q40_delta,
            "Q60_delta": Q60_delta,
            "Q75_delta": Q75_delta,
            "Q90_delta": Q90_delta,
            "Q95_delta": Q95_delta,
            "IDR_logratio": IDR_logratio,
            # tail asym
            "central_asym": central_asym,
            "shoulder_asym": shoulder_asym,
            "tail_asym": tail_asym,
            # tail thickness
            "central_weight": central_weight,
            "shoulder_weight": shoulder_weight,
            "tail_weight": tail_weight,
            "moors": moors,
            # decay rates
            "central_decay": central_decay,
            "shoulder_decay": shoulder_decay,
            "tail_decay": tail_decay,
            # extremes
            "upper_ext": upper_ext,
            "lower_ext": lower_ext,
            "extreme_tail_asym": extreme_tail_asym,
        },
        index=ids,
        dtype=np.float32,
    )
    if not inference:
        _save_cache(q_df, prefix)
    return q_df


# ─────────────────────────────────────────────────────────────────────
# RATES BLOCK
# ─────────────────────────────────────────────────────────────────────


def jeffreys_logit(k: int, m: int) -> float:
    """Logit of Jeffreys-smoothed rate: p~ = (k+0.5)/(m+1)."""
    if m <= 0:
        raise ValueError
    p = (k + 0.5) / (m + 1.0)
    p = np.clip(p, EPS, 1 - EPS)
    return _log(p / (1.0 - p))


def exceedance_logits(
    z: np.ndarray, k_abs: float, k_fixed: float
) -> tuple[float, float]:
    """P(|z|>k_abs) and P(|z|>k_fixed) → Jeffreys logits."""
    m = z.size
    k1 = int(np.sum(np.abs(z) > k_abs))
    k2 = int(np.sum(np.abs(z) > k_fixed))
    return jeffreys_logit(k1, m), jeffreys_logit(k2, m)


def upper_lower_logits(
    z: np.ndarray, q_low: float, q_high: float
) -> tuple[float, float]:
    """P(z > q_high), P(z < q_low) → Jeffreys logits."""
    m = z.size
    ku = int(np.sum(z > q_high))
    kl = int(np.sum(z < q_low))
    return jeffreys_logit(ku, m), jeffreys_logit(kl, m)


def fano_burstiness(z: np.ndarray, thr_abs: float, w: int) -> float:
    """Fano factor of windowed counts of |z|>thr_abs over fixed window size w."""
    n = z.size
    n_w = (n // w) * w
    if n_w == 0:
        return 0.0
    blocks = z[:n_w].reshape(-1, w)
    hits = np.sum(np.abs(blocks) > thr_abs, axis=1).astype(np.float32)
    mu = hits.mean()
    if mu <= 0:
        return 0.0
    var = hits.var(ddof=0)
    return float(var / (mu + EPS))


def crossing_rates_logits(
    z: np.ndarray, Q: float, eps: float
) -> tuple[float, float, float]:
    """
    Up/Down/Total crossing rates (per transition), with deadband:
        up:   z_t <= Q - eps and z_{t+1} >= Q + eps
        down: z_t >= Q + eps and z_{t+1} <= Q - eps
    Returns Jeffreys-smoothed logits for up, down, total.
    """
    a = z[:-1]
    b = z[1:]
    m = a.size
    up = np.sum((a <= (Q - eps)) & (b >= (Q + eps)))
    dn = np.sum((a >= (Q + eps)) & (b <= (Q - eps)))
    tot = int(up + dn)
    return (
        jeffreys_logit(int(up), m),
        jeffreys_logit(int(dn), m),
        jeffreys_logit(tot, m),
    )


def median_crossing_asym(z: np.ndarray, Q: float, eps: float) -> float:
    """(UP_CR50 - DOWN_CR50) / (UP_CR50 + DOWN_CR50 + 1), using raw (unsmoothed) rates."""
    a = z[:-1]
    b = z[1:]
    up = np.sum((a <= (Q - eps)) & (b >= (Q + eps)))
    dn = np.sum((a >= (Q + eps)) & (b <= (Q - eps)))
    denom = up + dn
    if denom == 0:
        return 0.0
    return float((up - dn) / (denom + 1))  # +1 prevents +1 or -1 for actanh


def mean_log_res_time(z: np.ndarray, Q: float, eps: float) -> float:
    """
    Average time between *median* crossings (using deadband).
    Compute distances between consecutive (up or down) events; log of mean.
    """
    a, b = z[:-1], z[1:]
    m = a.size
    cross_idx = np.where(
        ((a <= Q - eps) & (b >= Q + eps)) | ((a >= Q + eps) & (b <= Q - eps))
    )[0]
    tot = cross_idx.size
    if tot == 0:
        return 0.0
    if tot == 1:
        # approx mean gap ≈ (m)/(tot+1)
        return _log(m / (tot + 1.0) + EPS)
    gaps = np.diff(cross_idx)
    return _log(np.mean(gaps) + EPS)


def fisher_delta(a: float, b: float, eps: float = EPS) -> float:
    """atanh(a) - atanh(b) with safe clamping for inputs in [-1,1]."""
    lo, hi = -1 + eps, 1 - eps
    return float(np.arctanh(np.clip(a, lo, hi)) - np.arctanh(np.clip(b, lo, hi)))


def compute_rates_block(
    X_prep: pd.DataFrame,
    force: bool = False,
    inference: bool = False,
    quantile_fine_grid: list[float] = QUANTILE_FINE_GRID,
    w_fano: int = W_FANO,
    crossing_rate_deadband: float = CROSSING_RATE_DEADBAND,
) -> pd.DataFrame:
    """
    Rates block (after − before or log-ratios), computed per id via a single groupby-apply.
    Uses BEFORE quantiles as thresholds for both segments.
    """
    prefix = "rates"
    cache = _latest_cache(prefix)
    if cache and not force and not inference:
        print(f"Loading cached data from {cache}")
        return pd.read_parquet(cache)

    # keep only what we need
    cols = ["standardized", "original", "period"]
    df = X_prep[cols].copy()

    def _one_id(g: pd.DataFrame) -> pd.Series:
        # split
        gb = g[g["period"] == 0]
        ga = g[g["period"] == 1]

        zb = gb["standardized"].to_numpy(np.float32, copy=False)
        za = ga["standardized"].to_numpy(np.float32, copy=False)
        xb = gb["original"].to_numpy(np.float32, copy=False)
        xa = ga["original"].to_numpy(np.float32, copy=False)

        # BEFORE thresholds
        qs = quantile_fine_grid
        Q5_b, _, Q25_b, _, Q50_b, _, Q75_b, _, Q95_b = np.quantile(zb, qs)

        # Extreme events (Jeffreys-smoothed logits)
        b1, b2 = exceedance_logits(zb, k_abs=Q95_b, k_fixed=3.0)
        a1, a2 = exceedance_logits(za, k_abs=Q95_b, k_fixed=3.0)

        u_b, l_b = upper_lower_logits(zb, q_low=Q5_b, q_high=Q95_b)
        u_a, l_a = upper_lower_logits(za, q_low=Q5_b, q_high=Q95_b)

        # Burstiness: Fano on |z|>Q95_before, then log-ratio
        fb = fano_burstiness(zb, thr_abs=Q95_b, w=w_fano)
        fa = fano_burstiness(za, thr_abs=Q95_b, w=w_fano)
        fano_logratio = _log((fa + EPS) / (fb + EPS))

        # 2) Crossing rates (deadband ε), keep totals at Q25/Q50/Q75
        _, _, cr25_b_t = crossing_rates_logits(zb, Q25_b, crossing_rate_deadband)
        _, _, cr25_a_t = crossing_rates_logits(za, Q25_b, crossing_rate_deadband)

        _, _, cr50_b_t = crossing_rates_logits(zb, Q50_b, crossing_rate_deadband)
        _, _, cr50_a_t = crossing_rates_logits(za, Q50_b, crossing_rate_deadband)

        _, _, cr75_b_t = crossing_rates_logits(zb, Q75_b, crossing_rate_deadband)
        _, _, cr75_a_t = crossing_rates_logits(za, Q75_b, crossing_rate_deadband)

        # Median crossing asymmetry (Fisher-z delta) & log mean residence time
        med_asym_b = median_crossing_asym(zb, Q50_b, crossing_rate_deadband)
        med_asym_a = median_crossing_asym(za, Q50_b, crossing_rate_deadband)
        med_cross_asym = fisher_delta(med_asym_a, med_asym_b)

        # CR decay
        cr_decay = (cr75_a_t - cr25_a_t) - (cr75_b_t - cr25_b_t)

        # 3) % zeros (original) → Jeffreys-logit delta
        kb, mb = int(np.isclose(xb, 0.0).sum()), xb.size
        ka, ma = int(np.isclose(xa, 0.0).sum()), xa.size
        pct_zeros = jeffreys_logit(ka, ma) - jeffreys_logit(kb, mb)

        return pd.Series(
            {
                "abs_excd_q95_logit_delta": a1 - b1,
                "abs_excd_3_logit_delta": a2 - b2,
                "up_excd_q95_logit_delta": u_a - u_b,
                "low_excd_q5_logit_delta": l_a - l_b,
                "fano_logratio": fano_logratio,
                "signflips_logit_delta": cr50_a_t - cr50_b_t,
                "CR25_logit_delta": cr25_a_t - cr25_b_t,
                "CR75_logit_delta": cr75_a_t - cr75_b_t,
                "med_cross_asym_fisher_delta": med_cross_asym,
                "cr_decay_logit_delta": cr_decay,
                "pct_zeros_logit_delta": pct_zeros,
            },
            dtype=np.float32,
        )

    out = (
        df.groupby(level="id", sort=False, group_keys=False)
        .apply(_one_id)
        .astype(np.float32)
    )

    if not inference:
        _save_cache(out, prefix)
    return out


# ─────────────────────────────────────────────────────────────────────
# AUTOCORRELATION BLOCK
# ─────────────────────────────────────────────────────────────────────


def acf_1d(x: np.ndarray, max_lag: int) -> np.ndarray:
    """Sample ACF r[1..max_lag]; x is 1D float32."""
    n = x.size
    # mean-center (detrended is near zero-mean, but do it anyway)
    x = x - x.mean()
    var = x.var()
    if var <= 0:
        return np.zeros(max_lag, dtype=np.float32)
    r = np.empty(max_lag, dtype=np.float32)
    # O(nK) naive is fine for K<=20
    for k in range(1, max_lag + 1):
        num = np.dot(x[k:], x[:-k])
        r[k - 1] = num / ((n - k) * var)  # consistent scaling across lags
    return r


def pacf_yw(x: np.ndarray, max_lag: int) -> np.ndarray:
    """PACF via Yule–Walker / Durbin–Levinson; returns pacf[1..max_lag]."""
    r = acf_1d(x, max_lag)  # r[1..K]
    r0 = 1.0
    # build autocorr sequence r_full[0..K]
    r_full = np.concatenate(([r0], r))
    pacf = np.zeros(max_lag, dtype=np.float32)
    # Durbin–Levinson
    phi = np.zeros((max_lag + 1, max_lag + 1), dtype=np.float32)
    sig = np.empty(max_lag + 1, dtype=np.float32)
    phi[1, 1] = r_full[1]
    sig[1] = 1 - r_full[1] ** 2
    pacf[0] = phi[1, 1]
    for k in range(2, max_lag + 1):
        num = r_full[k] - np.dot(phi[1:k, k - 1], r_full[1:k][::-1])
        den = sig[k - 1] if sig[k - 1] > 0 else EPS
        phi[k, k] = num / den
        for j in range(1, k):
            phi[j, k] = phi[j, k - 1] - phi[k, k] * phi[k - j, k - 1]
        sig[k] = sig[k - 1] * (1 - phi[k, k] ** 2)
        pacf[k - 1] = phi[k, k]
    return pacf


def ljung_box_z(x: np.ndarray, m: int) -> float:
    """Length/m-invariant LBQ z-score using unbiased ACF and common m."""
    n = x.size
    r = acf_1d(x.astype(np.float32), m)  # unbiased ACF (your fixed version)
    ks = np.arange(1, m + 1, dtype=np.float32)
    Q = n * (n + 2.0) * np.sum((r**2) / (n - ks))
    z = (Q - m) / np.sqrt(2.0 * m)
    return float(0.0 if not np.isfinite(z) else z)


def compute_autocorr_block(
    X_prep: pd.DataFrame,
    force: bool = False,
    inference: bool = False,
    acf_max_lag: int = ACF_MAX_LAG,  # ACF/PACF lags 1..acf_max_lag for shape/summaries
    lbq_m: int = LBQ_M,  # Ljung–Box Q statistic uses lbq_m
) -> pd.DataFrame:
    """
    Autocorrelation features on the *detrended* standardized series.

    Outputs (one row per id):
      - acf1_delta, acf2_delta                   : ACF lag-1/2 (after − before)
      - pacf1_delta, pacf2_delta                 : PACF lag-1/2 (after − before)
      - short_lag_acf_dep_delta                  : Σ_{ℓ=1..K} |ρ(ℓ)|/ℓ  (after − before)
      - lbq_stat_delta                           : Ljung–Box Q_m (after − before), m = m_lbq
      - alt_signed_sum_delta                     : (Σ_{ℓ=1..K} (−1)^{ℓ−1} ρ_a(ℓ)) − same_before

    Notes:
      • Crashes if a segment contains non-finite values (enforce data hygiene).
      • Fast O(nK) per id (K small), no heavy FFT needed.
    """
    # Caching
    prefix = "autocorr"
    cache = _latest_cache(prefix)
    if cache and not force and not inference:
        print(f"Loading cached data from {cache}")
        return pd.read_parquet(cache)

    # Group series
    zcol = "detrended"
    zd_b = _group_series(X_prep, zcol, 0)
    zd_a = _group_series(X_prep, zcol, 1)

    # Materialize arrays once (NumPy, float32 for stability)
    zd_b = zd_b.apply(lambda s: s.to_numpy(dtype=np.float32, copy=False))
    zd_a = zd_a.apply(lambda s: s.to_numpy(dtype=np.float32, copy=False))
    ids = zd_b.index

    # ---------- compute per id ----------
    acf1_d, acf2_d = [], []
    pacf2_d = []
    shortlag_d = []
    lbq_d = []
    alt_sum_d = []

    for i in ids:
        xb = zd_b.loc[i].astype(np.float32, copy=False)
        xa = zd_a.loc[i].astype(np.float32, copy=False)

        # ACF/PACF up to K
        acf_b = acf_1d(xb, acf_max_lag)  # r[1..acf_max_lag]
        acf_a = acf_1d(xa, acf_max_lag)
        pacf_b = pacf_yw(xb, acf_max_lag)
        pacf_a = pacf_yw(xa, acf_max_lag)

        # lag-1/2 deltas
        acf1_d.append(acf_a[0] - acf_b[0] if acf_max_lag >= 1 else 0.0)
        acf2_d.append(acf_a[1] - acf_b[1] if acf_max_lag >= 2 else 0.0)
        pacf2_d.append(pacf_a[1] - pacf_b[1] if acf_max_lag >= 2 else 0.0)

        # short-lag dependence Σ |ρ(ℓ)|/ℓ, starting from l=2
        weights = 1.0 / np.arange(2, acf_max_lag + 1, dtype=np.float32)
        s_b = np.sum(np.abs(acf_b[1:]) * weights)
        s_a = np.sum(np.abs(acf_a[1:]) * weights)
        shortlag_d.append(s_a - s_b)

        # Ljung–Box Q (m = lbq_m)
        lbq_d.append(ljung_box_z(xa, lbq_m) - ljung_box_z(xb, lbq_m))

        # Alternative signed sum Σ (-1)^{ℓ-1} ρ(ℓ)
        signs = np.where((np.arange(1, acf_max_lag + 1) % 2) == 1, 1.0, -1.0)
        alt_b = float(np.sum(signs * acf_b))
        alt_a = float(np.sum(signs * acf_a))
        alt_sum_d.append(alt_a - alt_b)

    autoc = pd.DataFrame(
        {
            "acf1_delta": acf1_d,
            "acf2_delta": acf2_d,
            "pacf2_delta": pacf2_d,
            "shortlag_l1_delta": shortlag_d,
            "lbq_stat_delta": lbq_d,
            "alt_signed_sum_delta": alt_sum_d,
        },
        index=ids,
        dtype=np.float32,
    )

    if not inference:
        _save_cache(autoc, prefix)
    return autoc


# ─────────────────────────────────────────────────────────────────────
# TESTS & DISTANCES BLOCK
# ─────────────────────────────────────────────────────────────────────


def _idr(x):
    q = np.quantile(x, [0.25, 0.75])
    return float(q[1] - q[0])


def _ks_normalized(before: np.ndarray, after: np.ndarray) -> float:
    """
    Kolmogorov two-sample normalized statistic:
      KS_norm = (sqrt(n_eff) + 0.12 + 0.11/sqrt(n_eff)) * D,
    where n_eff = n0*n1/(n0+n1) and D is the sup ECDF distance.
    Returns 0.0 if either segment is empty.
    """
    n0, n1 = before.size, after.size
    if n0 == 0 or n1 == 0:
        return 0.0

    # compute D (two-sample KS) inline (merge-walk)
    xa = np.sort(before)
    xb = np.sort(after)
    ia = ib = 0
    cdfa = cdfb = 0.0
    D = 0.0
    while ia < n0 and ib < n1:
        if xa[ia] <= xb[ib]:
            ia += 1
            cdfa = ia / n0
        else:
            ib += 1
            cdfb = ib / n1
        D = max(D, abs(cdfa - cdfb))
    if ia < n0:
        D = max(D, abs(1.0 - cdfb))
    if ib < n1:
        D = max(D, abs(cdfa - 1.0))

    n_eff = (n0 * n1) / (n0 + n1)
    s = np.sqrt(max(n_eff, 1.0))
    return float((s + 0.12 + 0.11 / s) * D)


def _css_normalized(before: np.ndarray, after: np.ndarray) -> float:
    """
    Inclán–Tiao CUSUM-of-squares with Brownian-bridge scaling:
      CSS_norm = sqrt(n) * max_t | S_t / S_T - t/T |,
    computed on the concatenated series (before || after).
    """
    if before.size + after.size <= 1:
        return 0.0
    y = np.concatenate([before, after])
    n = y.size
    s2 = np.cumsum(y * y)
    ST = s2[-1]
    if ST <= 0:
        return 0.0
    t = np.arange(1, n + 1, dtype=np.float32)
    D = s2 / ST - t / n
    css = float(np.max(np.abs(D)))
    return float(np.sqrt(n) * css)


def _wasserstein_quant(a, b, qs, scale):
    Qa = np.quantile(a, qs)
    Qb = np.quantile(b, qs)
    w1 = np.mean(np.abs(Qa - Qb))
    denom = max(scale, 1e-8)
    return float(w1 / denom)


def _js_divergence(a, b, q_edges):
    # Quantile-based bin edges from BEFORE segment
    edges = np.quantile(a, q_edges)
    # ensure strictly increasing edges (collapse-safe)
    edges = np.unique(edges)
    # histograms
    pa, _ = np.histogram(a, bins=edges, density=False)
    pb, _ = np.histogram(b, bins=edges, density=False)
    # Jeffreys smoothing (0.5) → probabilities
    pa = pa.astype(np.float32) + 0.5
    pb = pb.astype(np.float32) + 0.5
    pa /= pa.sum()
    pb /= pb.sum()
    m = 0.5 * (pa + pb)
    # JS in nats
    with np.errstate(divide="ignore", invalid="ignore"):
        KL_am = np.sum(pa * (np.log(pa) - np.log(m)))
        KL_bm = np.sum(pb * (np.log(pb) - np.log(m)))
    JS = 0.5 * (KL_am + KL_bm)
    if not np.isfinite(JS):
        return 0.0
    return float(JS)


def _mmd2_rbf(a, b, sigma, mmd_max_n):
    # Unbiased MMD^2; cap sample sizes for speed
    na = a.size
    nb = b.size
    if na > mmd_max_n:
        idx = np.linspace(0, na - 1, mmd_max_n, dtype=int)
        a = a[idx]
        na = a.size
    if nb > mmd_max_n:
        idx = np.linspace(0, nb - 1, mmd_max_n, dtype=int)
        b = b[idx]
        nb = b.size
    gamma = 1.0 / (2.0 * sigma * sigma)

    def _kxx(x):
        # exclude diagonal for unbiased estimator
        d2 = (x[:, None] - x[None, :]) ** 2
        np.fill_diagonal(d2, 0.0)
        K = np.exp(-gamma * d2)
        return K.sum() / (x.size * (x.size - 1) + EPS)

    def _kxy(x, y):
        d2 = (x[:, None] - y[None, :]) ** 2
        K = np.exp(-gamma * d2)
        return K.mean()

    kxx = _kxx(a) if na > 1 else 0.0
    kyy = _kxx(b) if nb > 1 else 0.0
    kxy = _kxy(a, b)
    return float(kxx + kyy - 2.0 * kxy)


def _gaussian_glr_per(before: np.ndarray, after: np.ndarray) -> float:
    """
    Per-sample Gaussian GLR for joint mean+variance change (length-invariant).
    LRT = n*log(s2_pooled) - n0*log(s2_b) - n1*log(s2_a); return LRT / n.
    s2_* are MLE variances about their own means (ddof=0).
    """
    n0 = before.size
    n1 = after.size
    n = n0 + n1
    s2_b = float(np.var(before, ddof=0))
    s2_a = float(np.var(after, ddof=0))
    # guard against degenerate variance
    s2_b = max(s2_b, EPS)
    s2_a = max(s2_a, EPS)
    s2_p = (n0 * s2_b + n1 * s2_a) / n
    s2_p = max(s2_p, EPS)
    lrt = n * np.log(s2_p) - n0 * np.log(s2_b) - n1 * np.log(s2_a)
    return float(lrt / n)


def _arch_neglogp(x: np.ndarray, L: int) -> float:
    """
    ARCH-LM: compute LM = n*R^2, then -log p with df=L (p from chi-square upper tail).
    Uses SciPy's gammaincc (regularized upper incomplete gamma).
    """
    # You already have _arch_lm_LM(x, L) elsewhere; if not, substitute here.
    LM = _arch_lm_LM(x, L=L)
    # p = P[Chi2_L >= LM] = gammaincc(L/2, LM/2)
    p = float(gammaincc(0.5 * L, 0.5 * max(LM, 0.0)))
    return float(-np.log(max(p, EPS)))


def compute_tests_distances_block(
    X_prep: pd.DataFrame,
    force: bool = False,
    inference: bool = False,
    js_quantile_bins: np.ndarray = JS_QUANTILE_BINS,  # e.g., np.linspace(0,1,33)
    mmd_max_n: int = MMD_MAX_N,  # e.g., 512
    arch_L: int = ARCH_L,  # e.g., 5
) -> pd.DataFrame:
    """
    Length-stable tests & distances between BEFORE and AFTER on standardized z.

    Tests (length-invariant forms):
      - gauss_glr_per_sample      : per-sample Gaussian GLR (mean+variance)
      - css_norm                  : √n * CSS on concatenated series
      - archlm_neglogp_delta      : (-log p)_after - (-log p)_before   [ARCH-LM, df=arch_L]

    Distances (length-agnostic):
      - ks_norm                   : Kolmogorov D with Massey normalization
      - js_divergence             : quantile-binned JS with Jeffreys smoothing
      - mmd2_rbf_idrband_equal    : unbiased MMD² (RBF), equal-size subsample, σ = IDR_before/2
    """
    prefix = "tests_distances"
    cache = _latest_cache(prefix)
    if cache and not force and not inference:
        print(f"Loading cached data from {cache}")
        return pd.read_parquet(cache)

    # standardized segments
    z_b = _group_series(X_prep, "standardized", 0).apply(
        lambda s: s.to_numpy(np.float32, copy=False)
    )
    z_a = _group_series(X_prep, "standardized", 1).apply(
        lambda s: s.to_numpy(np.float32, copy=False)
    )
    ids = z_b.index

    # accumulators
    glr_per_L, css_norm_L, arch_neglogp_delta_L = [], [], []
    ks_norm_L, js_div_L, mmd2_equal_L = [], [], []

    for i in ids:
        b = z_b.loc[i]
        a = z_a.loc[i]

        # --- Tests (normalized) ---
        glr_per = _gaussian_glr_per(b, a)
        css_norm = _css_normalized(b, a)

        arch_b = _arch_neglogp(b, arch_L)
        arch_a = _arch_neglogp(a, arch_L)
        arch_delta = float(arch_a - arch_b)

        # --- Distances (no caps) ---
        ks_norm = _ks_normalized(b, a)
        js = _js_divergence(b, a, js_quantile_bins)

        # MMD² with equal-size subsampling (deterministic linspace indices)
        m = int(min(b.size, a.size, mmd_max_n))
        if m >= 2:
            idx_b = np.linspace(0, b.size - 1, m, dtype=int)
            idx_a = np.linspace(0, a.size - 1, m, dtype=int)
            bb = b[idx_b]
            aa = a[idx_a]
            sigma = max(_idr(b), EPS) / 2.0
            mmd2 = _mmd2_rbf(
                bb, aa, sigma, mmd_max_n=m
            )  # mmd_max_n=m since already subsampled equally
        else:
            mmd2 = 0.0

        # collect
        glr_per_L.append(glr_per)
        css_norm_L.append(css_norm)
        arch_neglogp_delta_L.append(arch_delta)
        ks_norm_L.append(ks_norm)
        js_div_L.append(js)
        mmd2_equal_L.append(mmd2)

    out = pd.DataFrame(
        {
            "gauss_glr_per_sample": glr_per_L,
            "css_norm": css_norm_L,
            "archlm_neglogp_delta": arch_neglogp_delta_L,
            "ks_norm": ks_norm_L,
            "js_divergence": js_div_L,
            "mmd2_rbf_idrband_equal": mmd2_equal_L,
        },
        index=ids,
        dtype=np.float32,
    )

    if not inference:
        _save_cache(out, prefix)
    return out


# ─────────────────────────────────────────────────────────────────────
# FREQUENCY BLOCK
# ─────────────────────────────────────────────────────────────────────


def _psd_rfft(x: np.ndarray):
    """Return (freq in [0,0.5], power spectrum). Mean-centered; no window for speed."""
    n = x.size
    y = x - x.mean()
    fft = np.fft.rfft(y)
    P = (fft.real**2 + fft.imag**2) / max(n, 1)
    f = np.fft.rfftfreq(n, d=1.0)  # normalized to sampling step 1 → Nyquist 0.5
    return f, P


def _spectral_centroid(f, P):
    tot = P.sum()
    return float((f * P).sum() / (tot + EPS)) if tot > 0 else 0.0


def _log_flatness(P):
    """log( geometric_mean / arithmetic_mean )."""
    Pp = P + EPS
    return float(np.exp(np.mean(np.log(Pp))) / np.mean(Pp) + 0.0)  # flatness in (0,1]
    # we’ll convert to log-domain delta below


def _bandpower_logratio(f, P, bands):
    tot = P.sum()
    if tot <= 0:
        return [0.0] * len(bands)
    out = []
    for lo, hi in bands:
        m = (f >= lo) & (f < hi)
        frac = P[m].sum() / (tot + EPS)
        out.append(frac)
    return out  # we’ll take log-ratio a/b later


def _dwt_l3_ratio(x, dwt_wavelet, dwt_level):
    coeffs = pywt.wavedec(x, dwt_wavelet, level=dwt_level, mode="symmetric")
    # coeffs = [cA_L, cD_L, ..., cD_1]
    details = coeffs[1:]  # list of arrays cD_L..cD_1
    if len(details) < 3:
        return 0.0
    cD3 = details[-3]  # level-3 detail
    e3 = float(np.sum(cD3 * cD3))
    etot = float(sum(np.sum(d * d) for d in details)) + EPS
    return e3 / etot


def _perm_entropy(x: np.ndarray, m: int = 3, tau: int = 1) -> float:
    """
    Normalized permutation entropy in [0,1] using ordinal patterns.
    Deterministic tie-breaking via stable argsort.
    """
    x = np.asarray(x, dtype=np.float64)
    n = x.size
    span = (m - 1) * tau
    N = n - span
    if m < 2 or tau < 1 or N <= 0:
        return 0.0
    # counts
    counts = {}
    # stable tie-breaking: argsort twice to get ranks
    for i in range(N):
        window = x[i : i + span + 1 : tau]
        # rank vector (0..m-1) with stable tie handling
        order = np.argsort(window, kind="mergesort")
        ranks = np.empty(m, dtype=np.int64)
        ranks[order] = np.arange(m, dtype=np.int64)
        key = tuple(ranks.tolist())
        counts[key] = counts.get(key, 0) + 1
    total = float(sum(counts.values()))
    if total <= 0:
        return 0.0
    probs = np.fromiter((c / total for c in counts.values()), dtype=np.float64)
    H = -np.sum(probs * np.log(probs + EPS))
    Hmax = np.log(math.factorial(m))
    return float(H / (Hmax + EPS))


def compute_frequency_block(
    X_prep: pd.DataFrame,
    force: bool = False,
    inference: bool = False,
    bands: tuple = FREQ_BANDS,  # normalized freq bands
    dwt_wavelet: str = DWT_WAVELET,
    dwt_level: int = DWT_LEVEL,
    entropy_m1: int = ENTROPY_M1,
    entropy_m2: int = ENTROPY_M2,
    entropy_tau: int = ENTROPY_TAU,
) -> pd.DataFrame:
    """
    Frequency & entropy features per id.

    Spectral (on detrended):
      - spectral centroid (delta)
      - spectral flatness (delta of log-flatness)
      - band power fractions over 'bands' (log-ratio after/before per band)
      - DWT level-3 detail energy / total detail energy (log-ratio after/before)
    """
    prefix = "frequency"
    cache = _latest_cache(prefix)
    if cache and not force and not inference:
        print(f"Loading cached data from {cache}")
        return pd.read_parquet(cache)

    # --- choose series ---
    z_std = "standardized"
    z_det = "detrended"

    z_b = _group_series(X_prep, z_std, 0)
    z_a = _group_series(X_prep, z_std, 1)
    zd_b = _group_series(X_prep, z_det, 0)
    zd_a = _group_series(X_prep, z_det, 1)

    z_b = z_b.apply(lambda s: s.to_numpy(dtype=np.float32, copy=False))
    z_a = z_a.apply(lambda s: s.to_numpy(dtype=np.float32, copy=False))
    zd_b = zd_b.apply(lambda s: s.to_numpy(dtype=np.float32, copy=False))
    zd_a = zd_a.apply(lambda s: s.to_numpy(dtype=np.float32, copy=False))
    ids = zd_b.index

    # ---- accumulators ----
    log_flatness_d = []
    band_logratios = [[] for _ in bands]  # list of lists
    dwt_l3_logratio = []
    perm_m3_delta = []
    perm_m5_delta = []

    # ---- compute per id ----
    for i in ids:
        xb = zd_b.loc[i]
        xa = zd_a.loc[i]

        # FFT stats (detrended)
        fb, Pb = _psd_rfft(xb)
        fa, Pa = _psd_rfft(xa)

        # Spectral flatness (use log-flatness then delta)
        sf_b = _log_flatness(Pb)
        sf_a = _log_flatness(Pa)
        # store as log of flatness (so delta is log-ratio); equivalently take log here
        log_flatness_d.append(np.log(sf_a + EPS) - np.log(sf_b + EPS))

        # Band powers → fractions → log-ratios
        frac_b = _bandpower_logratio(fb, Pb, bands)
        frac_a = _bandpower_logratio(fa, Pa, bands)
        for bi, (fb_i, fa_i) in enumerate(zip(frac_b, frac_a)):
            band_logratios[bi].append(np.log((fa_i + EPS) / (fb_i + EPS)))

        # DWT L3 energy ratio (detrended) → log-ratio
        r_b = _dwt_l3_ratio(xb, dwt_wavelet, dwt_level)
        r_a = _dwt_l3_ratio(xa, dwt_wavelet, dwt_level)
        dwt_l3_logratio.append(np.log((r_a + EPS) / (r_b + EPS)))

        # permutation entropy deltas (m=3 and m=5, tau=1)
        zb = z_b.loc[i]
        za = z_a.loc[i]
        pe_b_m3 = _perm_entropy(zb, m=entropy_m1, tau=entropy_tau)
        pe_a_m3 = _perm_entropy(za, m=entropy_m1, tau=entropy_tau)
        pe_b_m5 = _perm_entropy(zb, m=entropy_m2, tau=entropy_tau)
        pe_a_m5 = _perm_entropy(za, m=entropy_m2, tau=entropy_tau)
        perm_m3_delta.append(pe_a_m3 - pe_b_m3)
        perm_m5_delta.append(pe_a_m5 - pe_b_m5)

    # ---- assemble ----
    cols = {
        "spec_flatness_logratio": log_flatness_d,
        "dwt_l3_energy_logratio": dwt_l3_logratio,
        "perm_entropy_m3_delta": perm_m3_delta,
        "perm_entropy_m5_delta": perm_m5_delta,
    }
    # add band logratios with names
    for bi, _ in enumerate(bands):
        cols[f"bandpower_b{bi + 1}_logratio"] = band_logratios[bi]

    out = pd.DataFrame(cols, index=ids, dtype=np.float32)

    if not inference:
        _save_cache(out, prefix)
    return out


# ─────────────────────────────────────────────────────────────────────
# DIFFERENCES BLOCK
# ─────────────────────────────────────────────────────────────────────


def _shortlag_L1(x: np.ndarray, K: int) -> float:
    r = acf_1d(x.astype(np.float32), K)  # r[1..K]
    w = 1.0 / np.arange(2, K + 1, dtype=np.float32)
    return float(np.sum(np.abs(r[1:]) * w))


def compute_differences_block(
    X_prep: pd.DataFrame,
    force: bool = False,
    inference: bool = False,
    quantile_fine_grid: list[float] = QUANTILE_FINE_GRID,
    crossing_rate_deadband: float = CROSSING_RATE_DEADBAND,
    lbq_m: int = LBQ_M,
    acf_max_lag: int = ACF_MAX_LAG,
) -> pd.DataFrame:
    """
    Returns per-id features (float32):
      - diff_mad_logratio                   : log(MAD(d))_after − log(MAD(d))_before
      - diff_w1_scaled                      : Wasserstein-1(d_a,d_b) via quantile grid / IDR_before
      - diff_signflips_logit_delta          : logit(sign-flip rate)_after − logit(...)_before
      - diff_shortlag_l1_delta              : Σ_{ℓ=1..K} |ACF_dm(ℓ)| / ℓ  (delta)
      - diff_lbq_stat_delta                 : Ljung–Box Q(m) (delta)
      - diff_spec_centroid_delta            : spectral centroid (delta)
      - absdiff_up_excd_q95_logit_delta     : logit P(a > T95_before) (delta)
      - absdiff_w1_scaled                   : Wasserstein-1(a_a,a_b) via quantile grid / IDR_before
      - absdiff_acf1_delta                  : lag-1 ACF(am) (delta)
    """
    prefix = "differences"
    cache = _latest_cache(prefix)
    if cache and not force and not inference:
        print(f"Loading cached data from {cache}")
        return pd.read_parquet(cache)

    # keep only needed cols to speed up groupby.apply
    cols = [
        "period",
        "diff_standardized",
        "diff_detrended",
        "absdiff_detrended",
    ]
    df = X_prep[cols].copy()

    # ---------- per-id computation ----------
    def _one_id(g: pd.DataFrame) -> pd.Series:
        gb = g[g["period"] == 0]
        ga = g[g["period"] == 1]

        # fetch arrays
        d_b = gb["diff_standardized"].to_numpy(np.float32, copy=False)
        d_a = ga["diff_standardized"].to_numpy(np.float32, copy=False)
        dm_b = gb["diff_detrended"].to_numpy(np.float32, copy=False)
        dm_a = ga["diff_detrended"].to_numpy(np.float32, copy=False)
        am_b = gb["absdiff_detrended"].to_numpy(np.float32, copy=False)
        am_a = ga["absdiff_detrended"].to_numpy(np.float32, copy=False)

        # Median-crossing rate change (deadband) on standardized; Q=0
        _, _, cr_b = crossing_rates_logits(d_b, Q=0.0, eps=crossing_rate_deadband)
        _, _, cr_a = crossing_rates_logits(d_a, Q=0.0, eps=crossing_rate_deadband)
        signflips_logit_delta = float(cr_a - cr_b)

        # Extreme events (Jeffreys-smoothed logits)
        qs = quantile_fine_grid
        Qb = np.quantile(d_b, qs)
        Qb5, Qb95 = Qb[0], Qb[8]
        _, b2 = exceedance_logits(d_b, k_abs=Qb95, k_fixed=3.0)
        _, a2 = exceedance_logits(d_a, k_abs=Qb95, k_fixed=3.0)

        u_b, l_b = upper_lower_logits(d_b, q_low=Qb5, q_high=Qb95)
        u_a, l_a = upper_lower_logits(d_a, q_low=Qb5, q_high=Qb95)

        # --- On dm: acf1_delta, short-lag L1, Ljung–Box Q, spectral centroid (all deltas) ---
        diff_shortlag_l1_delta = _shortlag_L1(dm_a, acf_max_lag) - _shortlag_L1(
            dm_b, acf_max_lag
        )
        diff_lbq_delta = ljung_box_z(dm_a, lbq_m) - ljung_box_z(dm_b, lbq_m)

        # Spectral centroid
        fa, Pa = _psd_rfft(dm_a)
        fb, Pb = _psd_rfft(dm_b)
        diff_spec_centroid_delta = _spectral_centroid(fa, Pa) - _spectral_centroid(
            fb, Pb
        )

        # --- On am: volatility clustering → lag-1 ACF delta ---
        r1_b = acf_1d(am_b.astype(np.float32), 1)[0]
        r1_a = acf_1d(am_a.astype(np.float32), 1)[0]
        absdiff_acf1_delta = float(r1_a - r1_b)

        return pd.Series(
            {
                "diff_abs_excd_3_logit_delta": a2 - b2,
                "diff_up_excd_q95_logit_delta": u_a - u_b,
                "diff_low_excd_q5_logit_delta": l_a - l_b,
                "diff_signflips_logit_delta": signflips_logit_delta,
                "diff_shortlag_l1_delta": diff_shortlag_l1_delta,
                "diff_lbq_stat_delta": diff_lbq_delta,
                "diff_spec_centroid_delta": diff_spec_centroid_delta,
                "absdiff_acf1_delta": absdiff_acf1_delta,
            },
            dtype=np.float32,
        )

    out = (
        df.groupby(level="id", sort=False, group_keys=False)
        .apply(_one_id)
        .astype(np.float32)
    )

    if not inference:
        _save_cache(out, prefix)
    return out


# ─────────────────────────────────────────────────────────────────────
# ABSOLUTE BLOCK
# ─────────────────────────────────────────────────────────────────────


def compute_absolute_block(
    X_prep: pd.DataFrame,
    force: bool = False,
    inference: bool = False,
    acf_max_lag: int = ACF_MAX_LAG,
    lbq_m: int = LBQ_M,
) -> pd.DataFrame:
    """
    Absolute magnitudes (on 'absval_detrended'):
      - absval_acf1_delta       : lag-1 ACF (after − before)
      - absval_lbq_stat_delta   : Ljung–Box Q(m) (after − before), m=20 by default
    """
    prefix = "absolute"
    cache = _latest_cache(prefix)
    if cache and not force and not inference:
        print(f"Loading cached data from {cache}")
        return pd.read_parquet(cache)

    col = "absval_detrended"
    df = X_prep[[col, "period"]].copy()

    def _one_id(g: pd.DataFrame) -> pd.Series:
        gb = g[g["period"] == 0][col].to_numpy(np.float32, copy=False)
        ga = g[g["period"] == 1][col].to_numpy(np.float32, copy=False)

        # lag-1 ACF
        r1_b = acf_1d(gb, 1)[0]
        r1_a = acf_1d(ga, 1)[0]
        acf1_delta = float(r1_a - r1_b)

        # Shortlag L1
        shortlag_l1_delta = _shortlag_L1(ga, acf_max_lag) - _shortlag_L1(
            gb, acf_max_lag
        )

        # LBQ stat delta
        lbq_delta = ljung_box_z(ga, lbq_m) - ljung_box_z(gb, lbq_m)

        return pd.Series(
            {
                "absval_acf1_delta": acf1_delta,
                "absval_shortlag_L1_delta": shortlag_l1_delta,
                "absval_lbq_stat_delta": lbq_delta,
            },
            dtype=np.float32,
        )

    out = (
        df.groupby(level="id", sort=False, group_keys=False)
        .apply(_one_id)
        .astype(np.float32)
    )

    if not inference:
        _save_cache(out, prefix)
    return out


# ─────────────────────────────────────────────────────────────────────
# SQUARED BLOCK
# ─────────────────────────────────────────────────────────────────────


def compute_squared_block(
    X_prep: pd.DataFrame,
    force: bool = False,
    inference: bool = False,
    acf_max_lag: int = ACF_MAX_LAG,
    lbq_m: int = LBQ_M,
) -> pd.DataFrame:
    """
    Absolute magnitudes (on 'squared_detrended'):
      - squared_acf1_delta       : lag-1 ACF (after − before)
      -
      - squared_lbq_stat_delta   : Ljung–Box Q(m) (after − before), m=20 by default
    """
    prefix = "squared"
    cache = _latest_cache(prefix)
    if cache and not force and not inference:
        print(f"Loading cached data from {cache}")
        return pd.read_parquet(cache)

    col = "squared_detrended"
    df = X_prep[[col, "period"]].copy()

    def _one_id(g: pd.DataFrame) -> pd.Series:
        gb = g[g["period"] == 0][col].to_numpy(np.float32, copy=False)
        ga = g[g["period"] == 1][col].to_numpy(np.float32, copy=False)

        # lag-1 ACF
        r1_b = acf_1d(gb, 1)[0]
        r1_a = acf_1d(ga, 1)[0]
        acf1_delta = float(r1_a - r1_b)

        # Shortlag L1
        shortlag_l1_delta = _shortlag_L1(ga, acf_max_lag) - _shortlag_L1(
            gb, acf_max_lag
        )

        # LBQ stat delta
        lbq_delta = ljung_box_z(ga, lbq_m) - ljung_box_z(gb, lbq_m)

        return pd.Series(
            {
                "squared_acf1_delta": acf1_delta,
                "squared_shortlag_L1_delta": shortlag_l1_delta,
                "squared_lbq_stat_delta": lbq_delta,
            },
            dtype=np.float32,
        )

    out = (
        df.groupby(level="id", sort=False, group_keys=False)
        .apply(_one_id)
        .astype(np.float32)
    )

    if not inference:
        _save_cache(out, prefix)
    return out


# ─────────────────────────────────────────────────────────────────────
# BOUNDARY LOCAL BLOCK
# ─────────────────────────────────────────────────────────────────────


def _window_around_boundary(
    arr_before: np.ndarray,
    arr_after: np.ndarray,
    w_before: int,
    w_after: int,
    skip_after: int = 0,
) -> tuple[np.ndarray, np.ndarray]:
    """
    Return boundary-local windows:
      before: last w_before samples,
      after : first w_after samples (optionally skip 'skip_after' transient).
    """
    if arr_before.size == 0 or arr_after.size == 0:
        return arr_before, arr_after
    wb = int(min(w_before, arr_before.size))
    sa = int(min(skip_after, max(0, arr_after.size - 1)))
    wa = int(min(w_after, max(0, arr_after.size - sa)))
    b = arr_before[-wb:] if wb > 0 else arr_before[:0]
    a = arr_after[sa : sa + wa] if wa > 0 else arr_after[:0]
    return b.astype(np.float32, copy=False), a.astype(np.float32, copy=False)


def _ecdf_logit_against_before(v: float, before_sorted: np.ndarray) -> float:
    """Jeffreys-smoothed logit rank of value v against BEFORE sample."""
    n = before_sorted.size
    lt = np.searchsorted(before_sorted, v, side="left")
    le = np.searchsorted(before_sorted, v, side="right")
    rank = lt + 0.5 * (le - lt)
    p = (rank + 0.5) / (n + 1.0)
    p = np.clip(p, EPS, 1 - EPS)
    return float(np.log(p / (1.0 - p)))


def _cliffs_delta(a: np.ndarray, b: np.ndarray) -> float:
    """
    Cliff's delta in [-1, 1]: P(a>b) - P(a<b).
    O(n log n) using searchsorted counts.
    """
    na, nb = a.size, b.size
    if na == 0 or nb == 0:
        return 0.0
    b_sorted = np.sort(b)
    # counts for each a: how many b are < a, and > a
    lt_counts = np.searchsorted(b_sorted, a, side="left")  # b < a
    gt_counts = nb - np.searchsorted(b_sorted, a, side="right")  # b > a
    gt = int(np.sum(lt_counts))  # P(a>b) numerator
    lt = int(np.sum(gt_counts))  # P(a<b) numerator
    return float((gt - lt) / (na * nb))


def _chow_F(y_b: np.ndarray, y_a: np.ndarray) -> float:
    """
    Chow F test at the boundary for y ~ a + b t.
    Returns the F-statistic. Uses clipped/detrended to limit outliers.
    """

    def _ols_ssr(y):
        n = y.size
        if n < 2:
            return 0.0, 0
        t = np.arange(n, dtype=np.float32)
        X = np.c_[np.ones(n), t]
        # OLS via normal equations (2x2)
        XtX = X.T @ X
        Xty = X.T @ y
        beta = np.linalg.solve(XtX, Xty)
        resid = y - X @ beta
        return float(np.dot(resid, resid)), 2  # ssr, k

    ssr_b, k = _ols_ssr(y_b)
    ssr_a, _ = _ols_ssr(y_a)
    y = np.concatenate([y_b, y_a])
    ssr_pooled, _ = _ols_ssr(y)
    n_b, n_a = max(0, y_b.size), max(0, y_a.size)
    n = n_b + n_a
    # F = ((SSR_pooled - (SSR_b+SSR_a)) / k) / ((SSR_b+SSR_a) / (n - 2k))
    num = (ssr_pooled - (ssr_b + ssr_a)) / max(k, 1)
    den = (ssr_b + ssr_a) / max(n - 2 * k, 1)
    F = num / den if den > 0 else 0.0
    return float(max(F, 0.0)) if np.isfinite(F) else 0.0


def _cusum_signed_stat(y_b: np.ndarray, y_a: np.ndarray) -> float:
    """
    Signed CUSUM across boundary: build combined series (mean-centered),
    compute cumulative sum; report jump as end_after_cusum - last_before_cusum,
    normalized by std * sqrt(n) for rough scale invariance.
    """
    y_b = y_b.astype(np.float32)
    y_a = y_a.astype(np.float32)
    if y_b.size + y_a.size < 3:
        return 0.0
    y = np.concatenate([y_b, y_a])
    y = y - y.mean()
    s = np.cumsum(y)
    j = y_b.size
    num = s[-1] - s[j - 1] if j > 0 else s[-1]
    denom = (y.std(ddof=0) + EPS) * np.sqrt(y.size)
    return float(num / denom)


def _gaussian_glr(zb: np.ndarray, za: np.ndarray) -> float:
    """
    Gaussian log-likelihood ratio at the given split:
      LRT = n*log(var_all) - n0*log(var_b) - n1*log(var_a)
    Uses MLE variances (ddof=0) about each segment's own mean.
    """
    n0 = int(zb.size)
    n1 = int(za.size)
    n = n0 + n1
    if n0 <= 1 or n1 <= 1:
        return 0.0
    m0 = float(np.mean(zb))
    v0 = float(np.mean((zb - m0) ** 2))
    m1 = float(np.mean(za))
    v1 = float(np.mean((za - m1) ** 2))
    y = np.concatenate([zb, za], axis=0)
    m = float(np.mean(y))
    v = float(np.mean((y - m) ** 2))
    v0 = max(v0, EPS)
    v1 = max(v1, EPS)
    v = max(v, EPS)
    return float(n * np.log(v) - n0 * np.log(v0) - n1 * np.log(v1))


def _arch_lm_LM(z: np.ndarray, L: int = 5) -> float:
    """
    Engle's ARCH LM statistic on a single segment z (standardized):
      1) center z, set e2 = (z - mean(z))^2
      2) regress e2[L:] on [1, e2_{t-1},...,e2_{t-L}]
      3) LM = n_eff * R^2
    """
    z = z.astype(np.float64)
    n = z.size
    if n <= L + 1:
        return 0.0
    e2 = (z - z.mean()) ** 2
    Y = e2[L:]
    Xcols = [np.ones_like(Y)]
    for j in range(1, L + 1):
        Xcols.append(e2[L - j : n - j])
    X = np.column_stack(Xcols)  # shape (n_eff, 1+L)
    beta, *_ = np.linalg.lstsq(X, Y, rcond=None)
    Yhat = X @ beta
    ssr = float(np.sum((Yhat - Y.mean()) ** 2))  # regression sum of squares
    sst = float(np.sum((Y - Y.mean()) ** 2)) + EPS  # total sum of squares
    R2 = ssr / sst
    n_eff = Y.shape[0]
    return float(n_eff * R2)


def compute_boundary_local_block(
    X_prep: pd.DataFrame,
    force: bool = False,
    inference: bool = False,
    w_local: list[int] = BOUND_WINDOW_SIZES,
    skip_after: int = BOUND_SKIP_AFTER,
    acf_K: int = BOUND_ACF_MAX_LAG,
    eps_deadband: float = CROSSING_RATE_DEADBAND,
    quantile_coarse_grid: list[float] = QUANTILE_COARSE_GRID,
) -> pd.DataFrame:
    """
    Boundary-local features (windowed around the split): jumps, local scale/trend,
    crossings/residence, short-lag ACF/L1, spectral centroid on detrended, local W1,
    RMS logratio, diff median jump, Cliff's delta, Chow F, signed CUSUM.

    If `w_local` is a list (e.g., [32, 64, 128]), compute the full feature set
    for each window size and suffix columns with `_w{size}`.
    """
    w_list = sorted(set(int(w) for w in w_local))

    prefix = "boundary_local"
    cache = _latest_cache(prefix)
    if cache and not force and not inference:
        print(f"Loading cached data from {cache}")
        return pd.read_parquet(cache)

    require_cols = [
        "period",
        "standardized",
        "clipped",
        "detrended",
        "diff_standardized",
        "diff2_standardized",
        "diff2_detrended",
    ]
    g = X_prep[require_cols].groupby(level="id", sort=False)

    def _one_id(df: pd.DataFrame, w_curr: int) -> pd.Series:
        b = df[df["period"] == 0]
        a = df[df["period"] == 1]

        z_b = b["standardized"].to_numpy(np.float32, copy=False)
        z_a = a["standardized"].to_numpy(np.float32, copy=False)
        zc_b = b["clipped"].to_numpy(np.float32, copy=False)
        zc_a = a["clipped"].to_numpy(np.float32, copy=False)
        zd_b = b["detrended"].to_numpy(np.float32, copy=False)
        zd_a = a["detrended"].to_numpy(np.float32, copy=False)
        dz_b = b["diff_standardized"].to_numpy(np.float32, copy=False)
        dz_a = a["diff_standardized"].to_numpy(np.float32, copy=False)
        dd_b = b["diff2_standardized"].to_numpy(np.float32, copy=False)
        dd_a = a["diff2_standardized"].to_numpy(np.float32, copy=False)
        ddm_b = b["diff2_detrended"].to_numpy(np.float32, copy=False)
        ddm_a = a["diff2_detrended"].to_numpy(np.float32, copy=False)

        # Local windows (use current w)
        z_bw, z_aw = _window_around_boundary(z_b, z_a, w_curr, w_curr, skip_after)
        zc_bw, zc_aw = _window_around_boundary(zc_b, zc_a, w_curr, w_curr, skip_after)
        zd_bw, zd_aw = _window_around_boundary(zd_b, zd_a, w_curr, w_curr, skip_after)
        dz_bw, dz_aw = _window_around_boundary(dz_b, dz_a, w_curr, w_curr, skip_after)
        dd_bw, dd_aw = _window_around_boundary(dd_b, dd_a, w_curr, w_curr, skip_after)
        ddmw, ddaw = _window_around_boundary(ddm_b, ddm_a, w_curr, w_curr, skip_after)

        # Local median jump
        local_median_jump = (
            float(np.median(z_aw) - np.median(z_bw))
            if (z_aw.size and z_bw.size)
            else 0.0
        )

        # Local scale logratio (MAD)
        mad_b = 1.4826 * _mad(z_bw) + EPS
        mad_a = 1.4826 * _mad(z_aw) + EPS
        local_scale_logratio = float(np.log(mad_a) - np.log(mad_b))

        # Local slope jump on detrended
        slope_jump = _ols_slope(zd_aw) - _ols_slope(zd_bw)

        # Local IDR logratio (Q10..Q90 on standardized)
        qs = quantile_coarse_grid
        Qb = np.quantile(z_bw, qs) if z_bw.size else np.zeros(len(qs), dtype=np.float32)
        Qa = np.quantile(z_aw, qs) if z_aw.size else np.zeros(len(qs), dtype=np.float32)
        Q10_b, Q90_b = Qb[0], Qb[-1]
        Q10_a, Q90_a = Qa[0], Qa[-1]
        idr_b = Q90_b - Q10_b + EPS
        idr_a = Q90_a - Q10_a + EPS
        IDR_logratio = float(np.log(idr_a) - np.log(idr_b))

        # Exceedances using BEFORE thresholds
        _, Qb95 = np.quantile(z_b, [0.05, 0.95]) if z_b.size else (0.0, 0.0)
        b1, b2 = exceedance_logits(z_bw, k_abs=Qb95, k_fixed=3.0)
        a1, a2 = exceedance_logits(z_aw, k_abs=Qb95, k_fixed=3.0)

        # Crossing rates & asymmetry @ median (Q=0) with deadband
        _, _, cr_b = crossing_rates_logits(z_bw, Q=0.0, eps=eps_deadband)
        _, _, cr_a = crossing_rates_logits(z_aw, Q=0.0, eps=eps_deadband)
        signflips_logit_delta = float(cr_a - cr_b)
        asym_b = median_crossing_asym(z_bw, Q=0.0, eps=eps_deadband)
        asym_a = median_crossing_asym(z_aw, Q=0.0, eps=eps_deadband)
        med_cross_asym_fisher_delta = fisher_delta(asym_a, asym_b)

        # Residence-time logratio
        res_time_logratio = float(
            mean_log_res_time(z_aw, Q=0.0, eps=eps_deadband)
            - mean_log_res_time(z_bw, Q=0.0, eps=eps_deadband)
        )

        # ACF/short-lag on detrended local windows
        r1_b = acf_1d(zd_bw, 1)[0] if zd_bw.size > 1 else 0.0
        r1_a = acf_1d(zd_aw, 1)[0] if zd_aw.size > 1 else 0.0
        acf1_local_delta = float(r1_a - r1_b)
        shortlag_local_delta = _shortlag_L1(zd_aw, acf_K) - _shortlag_L1(zd_bw, acf_K)

        # Spectral centroid delta on detrended local windows
        fa, Pa = _psd_rfft(zd_aw)
        fb, Pb = _psd_rfft(zd_bw)
        spec_centroid_local_delta = _spectral_centroid(fa, Pa) - _spectral_centroid(
            fb, Pb
        )

        # Local W1 on standardized (scaled by local IQR_b)
        iqr_b = (Qb[-2] - Qb[1]) if len(Qb) >= 4 else (Q90_b - Q10_b)
        w1_scale = max(iqr_b, EPS)
        w1_scaled = _wasserstein_quant(z_aw, z_bw, quantile_coarse_grid, w1_scale)

        # Local RMS logratio on detrended
        rms_b = float(np.sqrt(np.mean(zd_bw**2))) if zd_bw.size else 0.0
        rms_a = float(np.sqrt(np.mean(zd_aw**2))) if zd_aw.size else 0.0
        rms_logratio = np.log(rms_a + EPS) - np.log(rms_b + EPS)

        # Cliff's delta (standardized windows)
        cd = _cliffs_delta(z_aw, z_bw)
        cliffs_fisher_delta = fisher_delta(cd, 0.0)

        # Chow F on clipped local windows
        chow_F = _chow_F(zc_bw, zc_aw)

        # Signed CUSUM on detrended local windows
        cusum_signed = _cusum_signed_stat(zd_bw, zd_aw)

        # Gauss GLR
        gauss_glr = _gaussian_glr(z_bw, z_aw)

        # Diff median jump
        local_diff_median_jump = float(np.median(dz_aw) - np.median(dz_bw))

        # Diff2 MAD logratio
        mad_b2 = 1.4826 * _mad(dd_bw) + EPS
        mad_a2 = 1.4826 * _mad(dd_aw) + EPS
        local_dd_MAD_logratio = float(np.log(mad_a2) - np.log(mad_b2))

        # Diff2 ACF1 delta on detrended local windows
        r1_b2 = acf_1d(ddmw, 1)[0] if ddmw.size > 1 else 0.0
        r1_a2 = acf_1d(ddaw, 1)[0] if ddaw.size > 1 else 0.0
        local_dd_acf1_delta = float(r1_a2 - r1_b2)

        # Diff2 Local W1 scaled by local IQR_before (use coarse grid)
        Qb2 = np.quantile(dd_bw, quantile_coarse_grid)
        idr_local_b = float(Qb2[-1] - Qb2[0]) + EPS
        local_dd_w1_scaled = _wasserstein_quant(
            dd_aw, dd_bw, quantile_coarse_grid, idr_local_b
        )

        return pd.Series(
            {
                "local_median_jump": local_median_jump,
                "local_MAD_logratio": local_scale_logratio,
                "local_slope_jump": slope_jump,
                "local_IDR_logratio": IDR_logratio,
                "local_abs_excd_q95_logit_delta": a1 - b1,
                "local_abs_excd_3_logit_delta": a2 - b2,
                "local_signflips_logit_delta": signflips_logit_delta,
                "local_med_cross_asym_fisher_delta": med_cross_asym_fisher_delta,
                "local_res_time_logratio": res_time_logratio,
                "local_acf1_delta": acf1_local_delta,
                "local_shortlag_L1_delta": shortlag_local_delta,
                "local_spec_centroid_delta": spec_centroid_local_delta,
                "local_w1_scaled": w1_scaled,
                "local_rms_logratio": rms_logratio,
                "local_cliffs_fisher_delta": cliffs_fisher_delta,
                "local_chow_F": chow_F,
                "local_cusum_signed": cusum_signed,
                "local_gauss_glr": gauss_glr,
                "local_diff_median_jump": local_diff_median_jump,
                "local_dd_MAD_logratio": local_dd_MAD_logratio,
                "local_dd_acf1_delta": local_dd_acf1_delta,
                "local_dd_w1_scaled": local_dd_w1_scaled,
            },
            dtype=np.float32,
        )

    # compute per-window, suffix columns, then concat horizontally
    pieces = []
    for w in w_list:
        out_w = g.apply(lambda df: _one_id(df, w)).astype(np.float32)
        out_w = out_w.add_suffix(f"_w{w}")  # suffix all columns for clarity
        pieces.append(out_w)

    out = pd.concat(pieces, axis=1)

    if not inference:
        _save_cache(out, prefix)
    return out


# ─────────────────────────────────────────────────────────────────────
# BOUNDARY EDGE BLOCK
# ─────────────────────────────────────────────────────────────────────


def compute_boundary_edge_block(
    X_prep: pd.DataFrame,
    force: bool = False,
    inference: bool = False,
    w_edge: int = BOUND_EDGE,  # edge window length
    offsets: tuple[int, ...] = BOUND_OFFSETS,  # after-side offsets to test
) -> pd.DataFrame:
    """
    Compute edge-Window median jump and rank-logit delta at multiple after-side offsets.
    For each id:
      • boundary_dz_o{off}           : median(after[off:off+w_edge]) − median(before[-w_edge:])
      • ranklogit_delta_o{off}       : logit(rank(v_a vs BEFORE)) − logit(rank(v_b vs BEFORE))
      • edge_dz_max_abs, edge_ranklogit_max_abs

    Notes:
      - Uses standardized z (robust to scale/shift).
      - Independent of w_local; only relies on w_edge and offsets.
      - If a window is empty, uses 0.0 for that value.
    """
    prefix = "boundary_edge"
    cache = _latest_cache(prefix)
    if cache and not force and not inference:
        print(f"Loading cached data from {cache}")
        return pd.read_parquet(cache)

    # pull standardized segments as numpy
    z_b = _group_series(X_prep, "standardized", 0).apply(
        lambda s: s.to_numpy(np.float32, copy=False)
    )
    z_a = _group_series(X_prep, "standardized", 1).apply(
        lambda s: s.to_numpy(np.float32, copy=False)
    )
    ids = z_b.index

    # prebuild column names
    dz_cols = [f"edge_dz_o{off}" for off in offsets]
    rk_cols = [f"edge_ranklogit_delta_o{off}" for off in offsets]

    rows = []
    for i in ids:
        zb = z_b.loc[i]
        za = z_a.loc[i]

        # BEFORE edge median (one-time)
        wb = int(min(w_edge, zb.size))
        z_be = zb[-wb:] if wb > 0 else zb[:0]
        v_b = float(np.median(z_be)) if z_be.size else 0.0

        # Jeffreys-smoothed logit rank of v_b vs BEFORE ECDF (one-time)
        zb_sorted = np.sort(zb) if zb.size else np.array([], dtype=np.float32)
        logit_vb = _ecdf_logit_against_before(v_b, zb_sorted) if zb_sorted.size else 0.0

        # per-offset computations
        dz_vals = []
        rk_vals = []

        for off in offsets:
            # AFTER edge median at offset
            start = int(min(max(off, 0), max(za.size - 1, 0)))
            wa = int(min(w_edge, max(za.size - start, 0)))
            z_ae = za[start : start + wa] if wa > 0 else za[:0]
            v_a = float(np.median(z_ae)) if z_ae.size else 0.0

            # signed jump
            dz = v_a - v_b
            dz_vals.append(dz)

            # rank-logit delta vs BEFORE ECDF
            logit_va = _ecdf_logit_against_before(v_a, zb_sorted)
            rk_vals.append(float(logit_va - logit_vb))

        dz_arr = np.asarray(dz_vals, dtype=np.float32)
        rk_arr = np.asarray(rk_vals, dtype=np.float32)

        # summaries (max by absolute magnitude, keep signed value)
        dz_arg = int(np.argmax(np.abs(dz_arr)))
        rk_arg = int(np.argmax(np.abs(rk_arr)))

        row = {
            **{c: v for c, v in zip(dz_cols, dz_arr.tolist())},
            **{c: v for c, v in zip(rk_cols, rk_arr.tolist())},
            "edge_dz_max_abs": float(dz_arr[dz_arg]),
            "edge_ranklogit_max_abs": float(rk_arr[rk_arg]),
        }
        rows.append(row)

    out = pd.DataFrame(rows, index=ids, dtype=np.float32)

    if not inference:
        _save_cache(out, prefix)
    return out


# ─────────────────────────────────────────────────────────────────────
# CURVATURE (SECOND DIFFERENCES) BLOCK
# ─────────────────────────────────────────────────────────────────────


def compute_curvature_block(
    X_prep: pd.DataFrame,
    force: bool = False,
    inference: bool = False,
    acf_max_lag: int = ACF_MAX_LAG,
    lbq_m: int = LBQ_M,
) -> pd.DataFrame:
    """
    Global features on second differences (curvature) built from:
      - diff2_standardized (dd) and diff2_detrended (ddm)
      - diff_standardized (d) for curv_vs_slope ratio
    Outputs per id:
      curv_energy_logratio,
      curv_vs_slope_logratio, dd_posrate_logit_delta,
      dd_acf1_delta, dd_shortlag_L1_delta, dd_lbq_z_delta,
      dd_spec_centroid_delta
    """
    prefix = "curvature"
    cache = _latest_cache(prefix)
    if cache and not force and not inference:
        print(f"Loading cached data from {cache}")
        return pd.read_parquet(cache)

    cols = ["period", "diff2_standardized", "diff2_detrended", "diff_standardized"]
    g = X_prep[cols].groupby(level="id", sort=False)

    def _lbq_z(x: np.ndarray, m: int) -> float:
        m_eff = int(min(m, max(1, x.size - 1)))
        r = acf_1d(x.astype(np.float32), m_eff)
        n = x.size
        ks = np.arange(1, m_eff + 1, dtype=np.float32)
        Q = n * (n + 2.0) * np.sum((r**2) / (n - ks))
        return float((Q - m_eff) / np.sqrt(2.0 * m_eff)) if np.isfinite(Q) else 0.0

    def _one_id(df: pd.DataFrame) -> pd.Series:
        b = df[df["period"] == 0]
        a = df[df["period"] == 1]

        dd_b = b["diff2_standardized"].to_numpy(np.float32, copy=False)
        dd_a = a["diff2_standardized"].to_numpy(np.float32, copy=False)
        ddm_b = b["diff2_detrended"].to_numpy(np.float32, copy=False)
        ddm_a = a["diff2_detrended"].to_numpy(np.float32, copy=False)
        d_b = b["diff_standardized"].to_numpy(np.float32, copy=False)
        d_a = a["diff_standardized"].to_numpy(np.float32, copy=False)

        # Energies
        e_b = float(np.mean(dd_b * dd_b)) if dd_b.size else 0.0
        e_a = float(np.mean(dd_a * dd_a)) if dd_a.size else 0.0
        curv_energy_logratio = float(np.log(e_a + EPS) - np.log(e_b + EPS))

        s_b = float(np.mean(d_b * d_b)) if d_b.size else 0.0
        s_a = float(np.mean(d_a * d_a)) if d_a.size else 0.0
        curv_vs_slope_logratio = float(
            np.log((e_a / (s_a + EPS)) + EPS) - np.log((e_b / (s_b + EPS)) + EPS)
        )

        # Sign dynamics (P(dd > 0) Jeffreys-logit delta)
        pos_b = int(np.sum(dd_b > 0.0))
        m_b = dd_b.size if dd_b.size else 1
        pos_a = int(np.sum(dd_a > 0.0))
        m_a = dd_a.size if dd_a.size else 1
        dd_posrate_logit_delta = jeffreys_logit(pos_a, m_a) - jeffreys_logit(pos_b, m_b)

        # Dependence on detrended curvature
        dd_acf1_delta = float(
            (acf_1d(ddm_a, 1)[0] if ddm_a.size > 1 else 0.0)
            - (acf_1d(ddm_b, 1)[0] if ddm_b.size > 1 else 0.0)
        )
        dd_shortlag_L1_delta = _shortlag_L1(ddm_a, acf_max_lag) - _shortlag_L1(
            ddm_b, acf_max_lag
        )

        m_id = int(min(lbq_m, max(1, ddm_b.size - 1), max(1, ddm_a.size - 1)))
        dd_lbq_z_delta = _lbq_z(ddm_a, m_id) - _lbq_z(ddm_b, m_id)

        # Spectral centroid on ddm
        fa, Pa = _psd_rfft(ddm_a)
        fb, Pb = _psd_rfft(ddm_b)
        dd_spec_centroid_delta = _spectral_centroid(fa, Pa) - _spectral_centroid(fb, Pb)

        return pd.Series(
            {
                "curv_energy_logratio": curv_energy_logratio,
                "curv_vs_slope_logratio": curv_vs_slope_logratio,
                "dd_posrate_logit_delta": dd_posrate_logit_delta,
                "dd_acf1_delta": dd_acf1_delta,
                "dd_shortlag_L1_delta": dd_shortlag_L1_delta,
                "dd_lbq_z_delta": dd_lbq_z_delta,
                "dd_spec_centroid_delta": dd_spec_centroid_delta,
            },
            dtype=np.float32,
        )

    out = g.apply(_one_id).astype(np.float32)

    if not inference:
        _save_cache(out, prefix)
    return out


# ─────────────────────────────────────────────────────────────────────
# ROLLING BLOCK
# ─────────────────────────────────────────────────────────────────────


def _cumsums_z(z):
    S = np.empty(z.size + 1, dtype=np.float32)
    S[0] = 0.0
    np.cumsum(z, out=S[1:])
    S2 = np.empty(z.size + 1, dtype=np.float32)
    S2[0] = 0.0
    np.cumsum(z * z, out=S2[1:])
    return S, S2


def _cumsums_dm2(dm):
    Sd = np.empty(dm.size + 1, dtype=np.float32)
    Sd[0] = 0.0
    np.cumsum(dm * dm, out=Sd[1:])
    return Sd


def _cross_prefix(z, eps):
    c = ((z[:-1] <= -eps) & (z[1:] >= eps)) | ((z[:-1] >= eps) & (z[1:] <= -eps))
    C = np.empty(c.size + 1, dtype=np.int32)
    C[0] = 0
    np.cumsum(c.astype(np.int32), out=C[1:])
    return C


# --- small helpers (fast top-k mean on |x|; no full sort) ---------------------
def _topk_mean_abs(x: np.ndarray, k: int) -> float:
    n = x.size
    if n == 0:
        return 0.0
    kk = min(k, n)
    ax = np.abs(x)
    # take k largest by abs via partition (O(n))
    # np.partition keeps the k largest in the last kk positions (unordered)
    idx = np.argpartition(ax, n - kk)[-kk:]
    return float(ax[idx].mean())


# --- rolling jumps from precomputed prefixes ----------------------------------
def _roll_logstd_jump_stats_from_cumsums(
    S: np.ndarray, S2: np.ndarray, w: int, min_pos: int, topk: int
) -> tuple[float, float, float]:
    """Returns (maxpos, minneg, topkabs_mean) for Δ log-std between adjacent windows of size w."""
    n = S.size - 1  # since S is cumsum with S[0]=0 of length n+1
    pos = n - 2 * w + 1
    if n < 2 * w or pos < min_pos:
        return 0.0, 0.0, 0.0
    # rolling mean and var (population) for all windows
    # for t in [0..n-w]: sum = S[t+w]-S[t], mean = sum/w
    sum_w = S[w:] - S[:-w]  # length n-w+1
    mean_w = sum_w / float(w)
    sumsq_w = S2[w:] - S2[:-w]  # length n-w+1
    var_w = np.maximum(sumsq_w / float(w) - mean_w * mean_w, 0.0)  # numerical safety
    logstd_w = np.log(np.sqrt(var_w) + EPS)  # length n-w+1

    # adjacent jump J[t] = R[t] - L[t] with R = logstd[t+w], L = logstd[t], t in [0..n-2w]
    L = logstd_w[:-w]
    R = logstd_w[w:]
    J = R - L
    if J.size == 0:
        return 0.0, 0.0, 0.0
    return float(np.max(J)), float(np.min(J)), _topk_mean_abs(J, topk)


def _roll_logrms_jump_stats_from_cumsums(
    Sd: np.ndarray, w: int, min_pos: int, topk: int
) -> tuple[float, float, float]:
    """Returns (maxpos, minneg, topkabs_mean) for Δ log-RMS of diff_detrended between adjacent windows of size w."""
    n = Sd.size - 1  # Sd is cumsum of dm^2 with Sd[0]=0
    pos = n - 2 * w + 1
    if n < 2 * w or pos < min_pos:
        return 0.0, 0.0, 0.0
    sumsq_w = Sd[w:] - Sd[:-w]  # length n-w+1
    rms_w = np.sqrt(np.maximum(sumsq_w / float(w), 0.0))
    logrms_w = np.log(rms_w + EPS)

    L = logrms_w[:-w]
    R = logrms_w[w:]
    J = R - L
    if J.size == 0:
        return 0.0, 0.0, 0.0
    return float(np.max(J)), float(np.min(J)), _topk_mean_abs(J, topk)


def _roll_crossrate_logit_jump_from_prefix(
    C: np.ndarray, w: int, min_pos: int, topk: int
) -> tuple[float, float, float]:
    """
    C is prefix sum over 'cross' of length n-1 (C[0]=0, C[t]=sum_{u< t} cross[u]).
    Window [i, i+w-1] has (w-1) transitions: count = C[i+w-1]-C[i].
    p_hat = (count + 0.5) / ((w-1) + 1.0)   (Jeffreys), logit(p_hat) then adjacent jump.
    """
    # C length = n_trans+1 where n_trans = n-1; valid start i: 0..(n-w)
    n_trans = C.size - 1  # = (n-1)
    n = n_trans + 1
    pos = n - 2 * w + 1
    if n < 2 * w or pos < min_pos or w < 2:
        return 0.0, 0.0, 0.0

    # for i in [0..n-w], transitions idx span [i .. i+w-2] ⇒ count = C[i+w-1]-C[i]
    cnt_w = C[w - 1 :] - C[: -w + 1]  # length n-w+1
    denom = (w - 1) + 1.0  # Jeffreys (w-1 transitions + 1.0)
    p = (cnt_w + 0.5) / denom
    p = np.clip(p, EPS, 1.0 - EPS)
    logit = np.log(p) - np.log(1.0 - p)

    L = logit[:-w]
    R = logit[w:]
    J = R - L
    if J.size == 0:
        return 0.0, 0.0, 0.0
    return float(np.max(J)), float(np.min(J)), _topk_mean_abs(J, topk)


def _ewvar_last(x, hl):
    """
    EWMA of x^2 with half-life hl; return last value.
    alpha = 1 - 2^(-1/hl)
    """
    if x.size == 0:
        return 0.0
    alpha = 1.0 - 2.0 ** (-1.0 / max(hl, 1.0))
    v = float(x[0] * x[0])
    for xi in x[1:]:
        v = (1 - alpha) * v + alpha * float(xi * xi)
    return v


def compute_rolling_block(
    X_prep: pd.DataFrame,
    force: bool = False,
    inference: bool = False,
    windows: tuple = ROLL_WINDOWS,
    min_positions_per_half: int = ROLL_MIN_POS_PER_HALF,
    ewvar_half_lives: tuple = EWVAR_HALFLIVES,
    crossing_rate_deadband: float = CROSSING_RATE_DEADBAND,
    topk: int = ROLL_TOPK,  # <— add this
) -> pd.DataFrame:
    """
    Fast rolling / localized-change features.
      • Rolling log-std jumps on 'clipped' (winsorized standardized)
      • Rolling log-RMS jumps on 'diff_detrended'
      • Rolling crossing-rate (Jeffreys-logit) jumps on 'clipped'
      • EWVAR logratios at the boundary (two half-lives)
    All per-id; deltas = AFTER − BEFORE. Returns float32 DataFrame.
    """
    prefix = "rolling"
    cache = _latest_cache(prefix)
    if cache and not force and not inference:
        print(f"Loading cached data from {cache}")
        return pd.read_parquet(cache)

    zc_b = _group_series(X_prep, "clipped", 0).apply(
        lambda s: s.to_numpy(np.float32, copy=False)
    )
    zc_a = _group_series(X_prep, "clipped", 1).apply(
        lambda s: s.to_numpy(np.float32, copy=False)
    )
    dm_b = _group_series(X_prep, "diff_detrended", 0).apply(
        lambda s: s.to_numpy(np.float32, copy=False)
    )
    dm_a = _group_series(X_prep, "diff_detrended", 1).apply(
        lambda s: s.to_numpy(np.float32, copy=False)
    )
    ids = zc_b.index

    rows = []
    for i in ids:
        zb, za = zc_b.loc[i], zc_a.loc[i]
        dmb, dma = dm_b.loc[i], dm_a.loc[i]
        nb, na = zb.size, za.size

        feats = {}

        # precompute prefixes once per half
        S_b, S2_b = _cumsums_z(zb)
        S_a, S2_a = _cumsums_z(za)
        Sd_b = _cumsums_dm2(dmb)
        Sd_a = _cumsums_dm2(dma)
        C_b = _cross_prefix(zb, crossing_rate_deadband)
        C_a = _cross_prefix(za, crossing_rate_deadband)

        for w in windows:
            # initialize columns for every id (avoid NaNs when invalid)
            feats[f"roll_logstd_jump_w{w}_maxpos_delta"] = np.float32(0.0)
            feats[f"roll_logstd_jump_w{w}_maxneg_delta"] = np.float32(0.0)
            feats[f"roll_logstd_jump_w{w}_topkabs_mean_delta"] = np.float32(0.0)

            feats[f"roll_rms_jump_w{w}_maxpos_delta"] = np.float32(0.0)
            feats[f"roll_rms_jump_w{w}_maxneg_delta"] = np.float32(0.0)
            feats[f"roll_rms_jump_w{w}_topkabs_mean_delta"] = np.float32(0.0)

            feats[f"roll_crossrate_logit_w{w}_maxpos_delta"] = np.float32(0.0)
            feats[f"roll_crossrate_logit_w{w}_maxneg_delta"] = np.float32(0.0)
            feats[f"roll_crossrate_logit_w{w}_topkabs_mean_delta"] = np.float32(0.0)

            # quick feasibility check for both halves
            if (nb - 2 * w + 1) < min_positions_per_half or (
                na - 2 * w + 1
            ) < min_positions_per_half:
                continue

            # log-std (zc)
            smax_b, smin_b, stop_b = _roll_logstd_jump_stats_from_cumsums(
                S_b, S2_b, w, min_positions_per_half, topk
            )
            smax_a, smin_a, stop_a = _roll_logstd_jump_stats_from_cumsums(
                S_a, S2_a, w, min_positions_per_half, topk
            )
            feats[f"roll_logstd_jump_w{w}_maxpos_delta"] = np.float32(smax_a - smax_b)
            feats[f"roll_logstd_jump_w{w}_maxneg_delta"] = np.float32(smin_a - smin_b)
            feats[f"roll_logstd_jump_w{w}_topkabs_mean_delta"] = np.float32(
                stop_a - stop_b
            )

            # log-RMS (diff_detrended)
            rmax_b, rmin_b, rtop_b = _roll_logrms_jump_stats_from_cumsums(
                Sd_b, w, min_positions_per_half, topk
            )
            rmax_a, rmin_a, rtop_a = _roll_logrms_jump_stats_from_cumsums(
                Sd_a, w, min_positions_per_half, topk
            )
            feats[f"roll_rms_jump_w{w}_maxpos_delta"] = np.float32(rmax_a - rmax_b)
            feats[f"roll_rms_jump_w{w}_maxneg_delta"] = np.float32(rmin_a - rmin_b)
            feats[f"roll_rms_jump_w{w}_topkabs_mean_delta"] = np.float32(
                rtop_a - rtop_b
            )

            # crossing-rate (zc, Jeffreys-logit)
            cmax_b, cmin_b, ctop_b = _roll_crossrate_logit_jump_from_prefix(
                C_b, w, min_positions_per_half, topk
            )
            cmax_a, cmin_a, ctop_a = _roll_crossrate_logit_jump_from_prefix(
                C_a, w, min_positions_per_half, topk
            )
            feats[f"roll_crossrate_logit_w{w}_maxpos_delta"] = np.float32(
                cmax_a - cmax_b
            )
            feats[f"roll_crossrate_logit_w{w}_maxneg_delta"] = np.float32(
                cmin_a - cmin_b
            )
            feats[f"roll_crossrate_logit_w{w}_topkabs_mean_delta"] = np.float32(
                ctop_a - ctop_b
            )

        # EWVAR boundary logratios (unchanged)
        for hl in ewvar_half_lives:
            vb = _ewvar_last(zb, hl)
            va = _ewvar_last(za, hl)
            feats[f"ewvar_hl{hl}_logratio"] = np.float32(
                np.log(va + EPS) - np.log(vb + EPS)
            )

        rows.append((i, feats))

    out = pd.DataFrame({idx: f for idx, f in rows}).T
    out.index = ids
    out = out.astype(np.float32)

    if not inference:
        _save_cache(out, prefix)
    return out


# ─────────────────────────────────────────────────────────────────────
# ROLLING BLOCK
# ─────────────────────────────────────────────────────────────────────


def _design(y: np.ndarray, p: int):
    n = y.size
    if n <= p:
        return None, None
    Y = y[p:]
    X = np.column_stack(
        [np.ones(n - p, dtype=np.float32)] + [y[p - j : n - j] for j in range(1, p + 1)]
    ).astype(np.float32, copy=False)
    return X, Y


def _ridge_fit(y: np.ndarray, p: int, lam: float):
    X, Y = _design(y, p)
    if X is None:
        return None
    # (X^T X + λ*I_p)^{-1} X^T Y, but do NOT penalize intercept
    XtX = X.T @ X
    reg = np.eye(p + 1, dtype=np.float32) * lam
    reg[0, 0] = 0.0  # intercept unpenalized
    beta = np.linalg.solve(XtX + reg, X.T @ Y).astype(np.float32)
    c = float(beta[0])
    phi = beta[1:].astype(np.float32)
    # train residual variance
    mu = X @ beta
    e = (Y - mu).astype(np.float32)
    sigma2 = float(np.mean(e * e) + EPS)
    return phi, c, sigma2


def _mean_nll_window(y: np.ndarray, p: int, phi: np.ndarray, c: float, sigma2: float):
    X, Y = _design(y, p)
    if X is None:
        return 0.0, np.zeros(0, dtype=np.float32)
    mu = c + X[:, 1:] @ phi
    inv_s2 = 1.0 / (sigma2 + 1e-12)
    resid = (Y - mu).astype(np.float32)
    nll = 0.5 * (np.log(2.0 * np.pi * (sigma2 + 1e-12)) + (resid * resid) * inv_s2)
    return float(np.mean(nll)), resid


def compute_ar_block(
    X_prep: pd.DataFrame,
    force: bool = False,
    inference: bool = False,
    p: int = AR_ORDER,  # AR order
    ridge_lambda: float = AR_RIDGE_LAMBDA,
    score_cap: int = AR_SCORE_CAP,  # equal-length scoring cap H
) -> pd.DataFrame:
    """
    Ridge AR(p) on 'detrended', emit ONLY:
        ar_ridge_nll_logratio = log( meanNLL(AFTER_head) / meanNLL(BEFORE_hold) )

    Train on BEFORE (excluding a length-H holdout at the end), score on:
      • BEFORE_hold = last H of BEFORE
      • AFTER_head  = first H of AFTER
    with H = min(score_cap, floor(nb/2), na).

    No guards: if series are too short, this will raise (as requested).
    """
    prefix = "ar"
    cache = _latest_cache(prefix)
    if cache and not force and not inference:
        print(f"Loading cached data from {cache}")
        return pd.read_parquet(cache)

    EPS = globals().get("EPS", 1e-8)

    # pull detrended series
    zb = _group_series(X_prep, "detrended", 0).apply(
        lambda s: s.to_numpy(np.float32, copy=False)
    )
    za = _group_series(X_prep, "detrended", 1).apply(
        lambda s: s.to_numpy(np.float32, copy=False)
    )
    ids = zb.index

    rows = []
    for i in ids:
        b = zb.loc[i]
        a = za.loc[i]
        nb, na = b.size, a.size

        # equal-length scoring window
        H = int(min(score_cap, nb // 2, na))

        # split BEFORE into train + hold (no guards)
        n_train_b = nb - H
        b_train = b[:n_train_b]
        b_hold = b[nb - H :]

        # fit on BEFORE-train (no guards)
        phi_b, c_b, s2_b = _ridge_fit(b_train, p, ridge_lambda)

        # mean NLL on BEFORE_hold and AFTER_head
        nll_bh, _ = _mean_nll_window(b_hold, p, phi_b, c_b, s2_b)
        nll_af, _ = _mean_nll_window(a[:H], p, phi_b, c_b, s2_b)

        nll_logratio = float(np.log((nll_af + EPS) / (nll_bh + EPS)))

        rows.append((i, {"ar_ridge_nll_logratio": np.float32(nll_logratio)}))

    out = pd.DataFrame({idx: feats for idx, feats in rows}).T
    out.index = ids
    out = out.astype(np.float32)

    if not inference:
        _save_cache(out, prefix)
    return out


# ─────────────────────────────────────────────────────────────────────
# Build features wrapper
# ─────────────────────────────────────────────────────────────────────

# @crunch/keep:on

FEATURE_BLOCKS = {
    "moments": compute_moments_block,
    "quantiles": compute_quantiles_block,
    "rates": compute_rates_block,
    "autocorrelation": compute_autocorr_block,
    "tests_distances": compute_tests_distances_block,
    "frequency": compute_frequency_block,
    "differences": compute_differences_block,
    "absolute": compute_absolute_block,
    "squared": compute_squared_block,
    "boundary_local": compute_boundary_local_block,
    "boundary_edge": compute_boundary_edge_block,
    "curvature": compute_curvature_block,
    "rolling": compute_rolling_block,
    "ar": compute_ar_block,
}

# @crunch/keep:off


def build_features(
    X_train: pd.DataFrame,
    force_prep: bool = False,
    force_all: bool = False,
    force: dict[str, bool] | None = None,
    inference: bool = False,
    feature_blocks: dict = FEATURE_BLOCKS
) -> pd.DataFrame:
    """
    Build and return a wide per-id feature table by:
      1) Preprocessing the raw, long-format input (delegated to `build_preprocessed`).
      2) Computing each registered feature block in `BLOCKS` (each block caches itself).
      3) Merging all block outputs on the index (id).

    Parameters
    ----------
    X : pd.DataFrame
        Raw input in the expected MultiIndex (id, time) format for preprocessing.
    force_prep : bool, optional
        If True, recompute preprocessing and ignore the upstream preprocess cache.
    force_all : bool, optional
        If True, recompute *all* feature blocks (overrides per-block cache).
    force : dict[str, bool] | None, optional
        Per-block override, e.g. {'moments': True, 'quantiles': False}.
        Only used when `force_all` is False. Keys must match `BLOCKS` names.

    Returns
    -------
    pd.DataFrame
        Wide feature matrix with one row per id and columns from all blocks.
    """
    # Decide which blocks to recompute. If force_all is True, force all blocks.
    force = force or {}
    if inference:
        force_prep = True
        force_all = True
    if force_all:
        force = {name: True for name in feature_blocks}

    # Do not recompute anything if everything in force is False
    prefix = "all"
    cache = _latest_cache(prefix)
    if cache and not force_prep and not any(force.values()):
        print(f"Loading cached data from {cache}")
        return pd.read_parquet(cache)

    # 1) Preprocess raw X
    X_prep = build_preprocessed(X_train, force_prep, inference)

    # Sanity check 1
    detect_non_finite(X_prep)

    # 2) Compute each block (respect per-block force flags)
    parts = []
    for name, fn in feature_blocks.items():
        part = fn(X_prep, force.get(name, False), inference)
        parts.append(part)

    # 3) Merge all blocks on id (inner join ensures only ids present in all blocks remain)
    #    If you add optional blocks later, consider 'outer' join + fillna for flexibility.
    feats = parts[0].join(parts[1:], how="outer")

    # Sanity check 2
    ids = X_train.index.get_level_values("id").unique().size
    n = feats.index.size
    if ids != n:
        raise ValueError(f"Feature table has {n} ids, but input had {ids} ids")
    detect_non_finite(feats)

    # Save features
    if not inference:
        _save_cache(feats, prefix)

    return feats


# Train one model

In [None]:
# ─────────────────────────────────────────────────────────────────────
#  1. build estimator
# ─────────────────────────────────────────────────────────────────────


def make_estimator(
    model_type: str,
    params: dict,
) -> ClassifierMixin:
    """
    Return an untrained binary classifier with default settings updated by `params`.
    """
    params = params.copy()  # don't mutate caller's dict

    # ---------------- XGBoost ----------------
    if "xgb" in model_type:
        if xgb is None:
            raise ImportError("XGBoost not installed in the runner.")

        defaults = dict(
            objective="binary:logistic",
            eval_metric="auc",
            random_state=RANDOM_STATE,
            verbosity=0,
            max_bin=MAX_BIN,
            n_jobs=-1,
            use_label_encoder=False,
            tree_method="hist",  # faster histogram optimized
        )
        defaults.update(params)

        return xgb.XGBClassifier(**defaults)

    # ---------------- LightGBM -------------
    elif "lgb" in model_type:
        if lgb is None:
            raise ImportError("LightGBM not installed in the runner.")
        defaults = dict(
            objective="binary",
            metric="auc",
            random_state=RANDOM_STATE,
            verbosity=-1,
            max_bin=MAX_BIN,  # controls histogram granularity
            n_jobs=-1,  # controls how many threads you use, -1 is for max
            force_row_wise=True,  # parallelizes the computation of the histogram along rows instead of columns
        )
        defaults.update(params)

        return lgb.LGBMClassifier(**defaults)

    # ---------------- CatBoost -------------------------------
    elif "cat" in model_type:
        if CatBoostClassifier is None:
            raise ImportError("CatBoost not installed in the runner.")
        defaults = dict(
            loss_function="Logloss",
            eval_metric="AUC",
            random_seed=RANDOM_STATE,
            verbose=False,
            allow_writing_files=False,
        )
        defaults.update(params)
        return CatBoostClassifier(**defaults)

    else:
        raise NotImplementedError


# ─────────────────────────────────────────────────────────────────────
#  2. build pipeline
# ─────────────────────────────────────────────────────────────────────


class DataFrameScaler(BaseEstimator, TransformerMixin):
    """
    Wrap a StandardScaler (or any sklearn transformer) so that both fit_transform
    and transform return a DataFrame with the original columns & index.

    Set `exclude` to a list of feature names that should be passed through
    without scaling (if present in X).
    """

    def __init__(self, scaler=None, exclude=None):
        self.scaler = scaler if scaler is not None else StandardScaler()
        self.exclude = set(exclude) if exclude is not None else set()

    def fit(self, X, y=None):
        self.columns_ = list(X.columns)

        # self.exclude is now a list of *keywords*, not exact column names
        kw = [k.lower() for k in self.exclude]

        # columns to exclude = any keyword appears as a substring of the column name
        self.exclude_ = [c for c in self.columns_ if any(k in c.lower() for k in kw)]
        self.scale_cols_ = [c for c in self.columns_ if c not in self.exclude_]

        # fit on the subset to be scaled (if any)
        if len(self.scale_cols_) > 0:
            self.scaler.fit(X[self.scale_cols_], y)

        return self

    def transform(self, X):
        # align columns to training order; missing columns will raise KeyError
        X = X.loc[:, self.columns_]
        if len(self.scale_cols_) > 0:
            Xt_scaled = self.scaler.transform(X[self.scale_cols_])
            X_out = X.copy()
            X_out.loc[:, self.scale_cols_] = Xt_scaled
        else:
            X_out = X.copy()
        return pd.DataFrame(X_out.values, columns=self.columns_, index=X.index)

    def get_feature_names_out(self):
        return np.asarray(self.columns_)


def _fast_auc_binary(x: np.ndarray, y: np.ndarray) -> float:
    """Mann–Whitney U / (n_pos*n_neg). NaN-safe (drops NaNs)."""
    mask = np.isfinite(x)
    x = x[mask]
    yb = y[mask]
    n1 = int((yb == 1).sum())
    n0 = int((yb == 0).sum())
    if n1 == 0 or n0 == 0:
        return 0.5
    r = rankdata(x)  # average ranks for ties
    R1 = r[yb == 1].sum()
    U1 = R1 - n1 * (n1 + 1) / 2.0
    return float(U1 / (n1 * n0))


class TopKUnivariateAUC(BaseEstimator, TransformerMixin):
    """
    Keep top-K features by univariate AUC with the label.
    Always keep any columns listed in `always_keep` (if present in X).
    """

    def __init__(
        self,
        k: int = TOPK_FEATURES,
        min_auc: float = TOPK_MIN_AUC,
        always_keep=None,
    ):
        self.k = int(k)
        self.min_auc = float(min_auc)
        self.always_keep = list(always_keep) if always_keep is not None else []

    def fit(self, X: pd.DataFrame, y: pd.Series):
        self.original_n_features_ = X.shape[1]

        aucs = np.array(
            [
                _fast_auc_binary(X[c].to_numpy(dtype=float), y.to_numpy())
                for c in X.columns
            ]
        )
        score = np.abs(aucs - 0.5)

        order = np.argsort(score)[::-1]
        keep_idx = order[: self.k]
        keep_idx = keep_idx[score[keep_idx] >= (self.min_auc - 0.5)]
        selected = list(X.columns[keep_idx])

        always = [c for c in self.always_keep if c in X.columns]
        selected_set = set(selected) | set(always)
        rest = sorted(
            [c for c in selected_set if c not in always],
            key=lambda c: (-score[X.columns.get_loc(c)], c),
        )
        self.keep_cols_ = always + rest

        # bookkeeping
        self.n_features_selected_ = len(self.keep_cols_)
        self.selection_rate_ = self.n_features_selected_ / max(
            1, self.original_n_features_
        )
        self.feature_scores_ = pd.Series(
            score, index=X.columns
        )  # optional: for inspection
        return self

    def transform(self, X: pd.DataFrame):
        cols = [c for c in self.keep_cols_ if c in X.columns]
        return X.loc[:, cols]

    def get_feature_names_out(self):
        return np.asarray(self.keep_cols_)


def make_pipeline(model_type: str, params: dict) -> Pipeline:
    """
    Construct a sklearn Pipeline that standardizes features then applies
    the specified model with given params.

    [ TopK Filter ] → [ StandardScaler ] → [ model ]
    """
    # There is one feature selection step I did which isn't shown here:
    # After computing features I computed the correlation for each feature pair
    # and deleted the ones with >0.98 in this file directly.
    topk_auc = TopKUnivariateAUC(
        k=TOPK_FEATURES, min_auc=TOPK_MIN_AUC, always_keep=TOPK_ALWAYS_KEEP
    )
    scaler = DataFrameScaler(exclude=EXCLUDE_FEATURE_KEYWORDS)
    model = make_estimator(model_type, params)
    return Pipeline(
        [
            ("topk_auc", topk_auc),
            ("scaler", scaler),
            ("model", model),
        ]
    )


# ─────────────────────────────────────────────────────────────────────
#  4. one-fold fit util
# ─────────────────────────────────────────────────────────────────────


def fit_one_fold(
    model_type: str,
    pipeline: Pipeline,
    Xtr: pd.DataFrame,
    ytr: pd.Series,
    Xva: pd.DataFrame,
    yva: pd.Series,
    early_stopping: int = EARLY_STOPPING,
) -> Pipeline:
    """
    Fit the pipeline on one CV fold, but transform eval_set through the
    *fitted preprocessors* before passing it to the final estimator.
    Prevents feature-mismatch with steps like TopK selectors.
    """
    # split pipeline into [preprocessors] and [final model]
    pre = Pipeline(pipeline.steps[:-1])  # e.g. ("topk_auc","scaler")
    model = pipeline.steps[-1][1]  # the estimator instance

    # fit preprocessors on training split only (CV-safe) and transform both sets
    Xtr_p = pre.fit_transform(Xtr, ytr)
    Xva_p = pre.transform(Xva)

    # Define model specific kwargs
    fit_kwargs = {}

    if "xgb" in model_type:
        fit_kwargs["eval_set"] = [(Xva_p, yva)]
        (
            model.set_params(
                callbacks=[
                    XGBEarlyStop(rounds=early_stopping, save_best=True, maximize=True)
                ]
            ),
        )
        fit_kwargs["verbose"] = 0  # silence XGBoost logging
    elif "lgb" in model_type:
        fit_kwargs["eval_set"] = [(Xva_p, yva)]
        fit_kwargs["callbacks"] = [lgb.early_stopping(early_stopping, verbose=False)]

    elif "cat" in model_type:
        fit_kwargs["eval_set"] = (Xva_p, yva)
        fit_kwargs["use_best_model"] = True
        fit_kwargs["early_stopping_rounds"] = early_stopping

    # Fit final model with early stopping on the *transformed* eval set
    model.fit(Xtr_p, ytr, **fit_kwargs)

    # rebuild a single pipeline carrying the fitted preprocessors + model
    fitted = Pipeline(pre.steps + [("model", model)])
    return fitted


# ─────────────────────────────────────────────────────────────────────
#  5. hyper-parameter tuning
# ─────────────────────────────────────────────────────────────────────


def default_param_grid(model_type: str):
    """Optuna search space for each model family."""
    if model_type == "xgb_main":
        return {
            "learning_rate": (0.02, 0.06, True),
            "max_depth": (3, 5),
            "min_child_weight": (70.0, 180.0, True),
            "gamma": (1.0, 3.0, True),
            "subsample": (0.65, 0.90),
            "colsample_bytree": (0.50, 0.70),
            "colsample_bylevel": (0.70, 0.90),
            "colsample_bynode": (0.70, 0.90),
            "reg_alpha": (0.5, 4.0, True),
            "reg_lambda": (8.0, 40.0, True),
            "scale_pos_weight": (0.9, 1.2),
        }

    elif model_type == "xgb_lite":
        return {
            "learning_rate": (0.035, 0.07, True),
            "max_depth": (3, 3),  # not tuned, keep truly shallow
            "min_child_weight": (120.0, 250.0, True),
            "gamma": (1.4, 3.5, True),
            "subsample": (0.60, 0.80),
            "colsample_bytree": (0.45, 0.62),
            "colsample_bylevel": (0.70, 0.90),
            "colsample_bynode": (0.70, 0.90),
            "reg_alpha": (1.0, 4.0, True),
            "reg_lambda": (15.0, 40.0, True),
            "scale_pos_weight": (0.9, 1.2),
        }

    elif model_type == "lgb_main":
        return {
            "learning_rate": (0.02, 0.06, True),
            "max_depth": (4, 6),
            "num_leaves": (24, 48),  # consistent with depth 4–6
            "min_child_samples": (40, 120),
            "min_split_gain": (0.0, 0.20),
            "feature_fraction": (0.60, 0.85),
            "bagging_fraction": (0.60, 0.85),
            "lambda_l1": (0.2, 2.5, True),
            "lambda_l2": (5.0, 25.0, True),
            "scale_pos_weight": (0.9, 1.2),
            # ⬇️ constants (not tuned)
            "extra_trees": True,
            "bagging_freq": 1,
        }

    elif model_type == "cat_main":
        return {
            "learning_rate": (0.02, 0.05, True),
            "depth": (3, 5),
            "l2_leaf_reg": (15.0, 60.0, True),
            "bootstrap_type": ["Bayesian"],  # fix scheme; tune temperature
            "bagging_temperature": (1.2, 2.6),
            "rsm": (0.50, 0.70),
            "random_strength": (1.2, 2.2),
            "min_data_in_leaf": (150, 260),
        }
    else:
        raise NotImplementedError


def optuna_objective(
    trial: Trial,
    model_type: str,
    X_train: pd.DataFrame,
    y_train: pd.Series,
    groups: pd.Index,
    cv: BaseCrossValidator,
) -> tuple[float, dict]:
    """
    Optuna objective for each model family (and variants).
    Supports:
      • XGBoost
      • LightGBM
      • CatBoost
    Returns (mean OOF AUC, sampled_hp_dict).
    """
    grid = default_param_grid(model_type)

    # ----- strict helpers: no fallbacks -----
    def ffloat(name):
        rng = grid[name]
        if isinstance(rng, tuple) and len(rng) == 3:
            return trial.suggest_float(name, rng[0], rng[1], log=rng[2])
        return trial.suggest_float(name, rng[0], rng[1])

    def fint(name):
        rng = grid[name]
        return trial.suggest_int(name, rng[0], rng[1])

    def fcat(name):
        return trial.suggest_categorical(name, grid[name])

    is_xgb = "xgb" in model_type
    is_lgb = "lgb" in model_type
    is_cat = "cat" in model_type

    # ----------------- XGBoost -----------------
    if is_xgb:
        hp = dict(
            learning_rate=ffloat("learning_rate"),
            max_depth=fint("max_depth"),
            min_child_weight=ffloat("min_child_weight"),
            gamma=ffloat("gamma"),
            reg_alpha=ffloat("reg_alpha"),
            reg_lambda=ffloat("reg_lambda"),
            subsample=ffloat("subsample"),
            colsample_bytree=ffloat("colsample_bytree"),
            colsample_bylevel=ffloat("colsample_bylevel"),
            colsample_bynode=ffloat("colsample_bynode"),
            scale_pos_weight=ffloat("scale_pos_weight"),
        )

    # ----------------- LightGBM -----------------
    elif is_lgb:
        hp = dict(
            learning_rate=ffloat("learning_rate"),
            max_depth=fint("max_depth"),
            num_leaves=fint("num_leaves"),
            min_child_samples=fint("min_child_samples"),
            min_split_gain=ffloat("min_split_gain"),
            feature_fraction=ffloat("feature_fraction"),
            bagging_fraction=ffloat("bagging_fraction"),
            lambda_l1=ffloat("lambda_l1"),
            lambda_l2=ffloat("lambda_l2"),
            scale_pos_weight=ffloat("scale_pos_weight"),
            extra_trees=grid["extra_trees"],
            bagging_freq=grid["bagging_freq"],
        )

    # ----------------- CatBoost -----------------
    elif is_cat:
        hp = dict(
            learning_rate=ffloat("learning_rate"),
            depth=fint("depth"),
            l2_leaf_reg=ffloat("l2_leaf_reg"),
            rsm=ffloat("rsm"),
            random_strength=ffloat("random_strength"),
            min_data_in_leaf=fint("min_data_in_leaf"),
            bootstrap_type=fcat("bootstrap_type"),
        )
        if hp["bootstrap_type"] == "Bayesian":
            hp["bagging_temperature"] = ffloat("bagging_temperature")
        elif hp["bootstrap_type"] == "Bernoulli":
            hp["subsample"] = ffloat("subsample")

    else:
        raise NotImplementedError(f"Unknown model_type: {model_type}")

    # ensure every grid key is represented either as a sampled HP or a constant we injected
    expected = set(grid.keys())
    present = set(hp.keys())
    missing = expected - present
    if missing:
        raise KeyError(f"Optuna grid keys not used for {model_type}: {sorted(missing)}")

    # ----------------- CV training -----------------
    pipe = make_pipeline(model_type, hp)
    oof = np.zeros(len(X_train), dtype=float)

    for tr_idx, va_idx in cv.split(X_train, y_train, groups):
        pipe_fold = fit_one_fold(
            model_type,
            pipe,
            X_train.iloc[tr_idx],
            y_train.iloc[tr_idx],
            X_train.iloc[va_idx],
            y_train.iloc[va_idx],
        )
        oof[va_idx] = pipe_fold.predict_proba(X_train.iloc[va_idx])[:, 1]

    return roc_auc_score(y_train, oof), hp


def default_params(model_type: str):
    """Default parameters if no optuna tuning."""
    if model_type == "xgb_main":
        return {
            "booster": "gbtree",
            "n_estimators": 2400,
            "learning_rate": 0.035,
            "max_depth": 4,
            "min_child_weight": 110.0,
            "gamma": 1.8,
            "subsample": 0.75,
            "colsample_bytree": 0.58,
            "colsample_bylevel": 0.80,
            "colsample_bynode": 0.80,
            "reg_alpha": 1.5,
            "reg_lambda": 18.0,
            "scale_pos_weight": 1.0,
        }

    elif model_type == "xgb_lite":
        return {
            "booster": "gbtree",
            "n_estimators": 1800,
            "learning_rate": 0.050,
            "max_depth": 3,
            "min_child_weight": 170.0,
            "gamma": 2.2,
            "subsample": 0.68,
            "colsample_bytree": 0.52,
            "colsample_bylevel": 0.80,
            "colsample_bynode": 0.80,
            "reg_alpha": 2.0,
            "reg_lambda": 25.0,
            "scale_pos_weight": 1.0,
        }

    elif model_type == "lgb_main":
        return {
            "boosting_type": "gbdt",
            "n_estimators": 2400,
            "learning_rate": 0.035,
            "max_depth": 5,
            "num_leaves": 32,
            "min_child_samples": 64,
            "min_sum_hessian_in_leaf": 5.0,
            "min_split_gain": 0.1,
            "feature_fraction": 0.75,
            "bagging_freq": 1,
            "bagging_fraction": 0.75,
            "lambda_l1": 1.0,
            "lambda_l2": 10.0,
            "extra_trees": True,
            "scale_pos_weight": 1.0,
            "force_row_wise": True,
        }

    elif model_type == "cat_main":
        return {
            "n_estimators": 2200,
            "learning_rate": 0.035,
            "depth": 4,
            "l2_leaf_reg": 30.0,
            "bootstrap_type": "Bayesian",
            "bagging_temperature": 1.8,
            "rsm": 0.58,
            "random_strength": 1.7,
            "min_data_in_leaf": 200,
            "od_type": "Iter",
            "od_wait": 100,
        }


def tune_params(
    model_type: str,
    X_train: pd.DataFrame,
    y_train: pd.Series,
    groups: pd.Index,
    cv: BaseCrossValidator,
    n_optuna_trials: int,
) -> dict:
    """Run Optuna and return best hyper-param dict."""
    if n_optuna_trials == 0 or "dart" in model_type:
        return default_params(model_type)

    def obj(trial):
        score, _ = optuna_objective(trial, model_type, X_train, y_train, groups, cv)
        return score

    study = optuna.create_study(
        direction="maximize", sampler=optuna.samplers.TPESampler(seed=RANDOM_STATE)
    )
    study.optimize(obj, n_trials=n_optuna_trials, show_progress_bar=False)

    _, best_hp = optuna_objective(
        study.best_trial, model_type, X_train, y_train, groups, cv
    )
    return best_hp


# ─────────────────────────────────────────────────────────────────────
#  6. main train_model() entry point
# ─────────────────────────────────────────────────────────────────────


def _outer_cv(K, seed):
    return StratifiedGroupKFold(n_splits=K, shuffle=True, random_state=seed)


def _inner_cv(
    X_tr,
    y_tr,
    g_tr,
    seed: int,
    fold_idx: int,
    K_max_inner: int = K_MAX_INNER,
    K_stop_inner: int = K_STOP_INNER,
) -> BaseCrossValidator:
    """
    Build an inner CV that *generates only the first K_INNER_STOP folds*
    from a StratifiedGroupKFold constructed with n_splits=K_INNER_MAX.

    Examples:
      - K_INNER_MAX=2, K_INNER_STOP=1  -> single 50/50 holdout
      - K_INNER_MAX=5, K_INNER_STOP=1  -> single ~80/20 holdout
      - K_INNER_MAX=5, K_INNER_STOP=3  -> use 3 of the 5 folds (early-stop CV)
    """
    if K_stop_inner < 1 or K_stop_inner > K_max_inner:
        raise ValueError("Require 1 <= K_stop_inner <= K_max_inner")

    rng = seed + 69 + fold_idx

    # Fast path: single holdout, precompute once and reuse (like your old _One)
    if K_stop_inner == 1:
        sgk = StratifiedGroupKFold(n_splits=K_max_inner, shuffle=True, random_state=rng)
        tr_idx, va_idx = next(sgk.split(X_tr, y_tr, g_tr))

        class _One(BaseCrossValidator):
            def get_n_splits(self, *_, **__):
                return 1
            def split(self, X=None, y=None, groups=None):
                # Reuse the precomputed indices
                yield tr_idx, va_idx

        return _One()

    # General case: lazily yield only the first K_stop_inner folds (no list(...))
    class _LimitedCV(BaseCrossValidator):
        def __init__(self, n_splits, n_take, random_state):
            self.n_splits = int(n_splits)
            self.n_take = int(n_take)
            self.random_state = int(random_state)

        def get_n_splits(self, *_, **__):
            return self.n_take

        def split(self, X=None, y=None, groups=None):
            sgk = StratifiedGroupKFold(
                n_splits=self.n_splits,
                shuffle=True,
                random_state=self.random_state,
            )
            # yield only the first n_take folds on demand
            for i, (tr, va) in enumerate(sgk.split(X, y, groups)):
                if i >= self.n_take:
                    break
                yield tr, va

    return _LimitedCV(K_max_inner, K_stop_inner, rng)


def _fold_dir(root: Path) -> Path:
    d = Path(root) / "fold_models"
    d.mkdir(parents=True, exist_ok=True)
    return d


def _fold_paths(root: Path, k: int):
    d = _fold_dir(root)
    return {
        "pipe": d / f"fold_{k}.joblib",
        "hp": d / f"fold_{k}_hp.json",
        "met": d / f"fold_{k}_metrics.json",
    }


def _extract_best_iteration(pipe: Pipeline, model_type: str) -> int | None:
    est = pipe.named_steps.get("model")
    if est is None:
        return None
    if "xgb" in model_type:
        bi = getattr(est, "best_iteration", None)
        if bi is None:
            booster = getattr(est, "get_booster", lambda: None)()
            bi = getattr(booster, "best_ntree_limit", None)
        return int(bi) if bi else None
    if "lgb" in model_type:
        bi = getattr(est, "best_iteration_", None)
        return int(bi) if bi else None
    if "cat" in model_type:
        bi = est.get_best_iteration()
        return int(bi) if bi else None
    return None


def _save_fold(
    root: Path,
    k: int,
    pipe: Pipeline,
    hp: dict,
    train_auc: float,
    val_auc: float,
    best_iter: int | None,
):
    p = _fold_paths(root, k)
    dump(pipe, p["pipe"])
    with open(p["hp"], "w") as f:
        json.dump(hp, f)
    with open(p["met"], "w") as f:
        json.dump(
            {
                "train_auc": float(train_auc),
                "val_auc": float(val_auc),
                "best_iter": best_iter,
            },
            f,
        )


def _has_full_cv_cache(root: Path, model_type: str, K: int, N: int) -> bool:
    oof_path = Path(root) / f"{model_type}_oof.npy"
    if not oof_path.exists():
        return False
    if np.load(oof_path).shape[0] != N:
        return False
    for k in range(K):
        p = _fold_paths(root, k)
        if not (p["pipe"].exists() and p["hp"].exists() and p["met"].exists()):
            return False
    return True


def _load_cv_cache(root: Path, model_type: str, K: int):
    oof = np.load(Path(root) / f"{model_type}_oof.npy")
    hps, val_aucs, best_iters = [], [], []
    for k in range(K):
        p = _fold_paths(root, k)
        with open(p["hp"], "r") as f:
            hps.append(json.load(f))
        with open(p["met"], "r") as f:
            m = json.load(f)
            val_aucs.append(float(m["val_auc"]))
            bi = m.get("best_iter", None)
            if isinstance(bi, (int, float)) and int(bi) > 0:
                best_iters.append(int(bi))
    return oof, hps, val_aucs, best_iters


def _select_full_hps(
    mode: str, model_type: str, fold_hps: list[dict], val_aucs: list[float]
) -> dict:
    if mode == "best_outer":
        j = int(np.nanargmax(val_aucs))
        return fold_hps[j]
    if mode == "consensus":
        # median for numerics, majority for bools
        out = {}
        for k in fold_hps[0].keys():
            vals = [hp[k] for hp in fold_hps]
            if all(isinstance(v, (bool, np.bool_)) for v in vals):
                out[k] = sum(vals) >= (len(vals) - sum(vals))
            elif all(
                isinstance(v, (int, float, np.integer, np.floating)) for v in vals
            ):
                med = float(np.median(vals))
                out[k] = (
                    int(round(med))
                    if all(isinstance(v, (int, np.integer)) for v in vals)
                    else med
                )
        return out
    return default_params(model_type)


# ---------- main training ----------


def train_model(
    X_train: pd.DataFrame,
    y_train: pd.Series,
    model_type: str,
    K_outer: int = K_OUTER,
    K_max_inner: int = K_MAX_INNER,
    K_stop_inner: int = K_STOP_INNER,
    n_optuna_trials: int = N_OPTUNA_TRIALS,
    seed: int = RANDOM_STATE,
    model_dir: str = MODEL_DIR,
    full_refit: bool = FULL_REFIT,
    full_hp_selection: str = FULL_HP_SELECTION,
):
    """
    Simple, cache-aware training:
      • If complete CV cache exists → load OOF/hps/metrics.
      • Else run outer CV: tune per fold, early-stop, save each fold’s model+hp+metrics, save OOF.
      • If full_refit: set capacity to median(best_iter) and fit/load cached full model.
    Returns OOF predictions (np.ndarray).
    """
    model_dir = Path(model_dir)
    model_dir.mkdir(parents=True, exist_ok=True)

    N = len(X_train)
    groups = X_train.index  # id per row

    # --- CV phase (load or train) ---
    if _has_full_cv_cache(model_dir, model_type, K_outer, N):
        print(f"[{model_type}] Using cached CV.")
        oof, fold_hps, val_aucs, best_iters = _load_cv_cache(
            model_dir, model_type, K_outer
        )
    else:
        oof = np.zeros(N, dtype=float)
        fold_hps, val_aucs, best_iters = [], [], []

        for fold_idx, (tr_idx, va_idx) in enumerate(
            _outer_cv(K_outer, seed).split(X_train, y_train, groups)
        ):
            X_tr, y_tr, g_tr = (
                X_train.iloc[tr_idx],
                y_train.iloc[tr_idx],
                groups[tr_idx],
            )
            X_va, y_va = X_train.iloc[va_idx], y_train.iloc[va_idx]

            inner = _inner_cv(X_tr, y_tr, g_tr, seed, fold_idx, K_max_inner, K_stop_inner)
            hp = tune_params(model_type, X_tr, y_tr, g_tr, inner, n_optuna_trials)

            # fit this outer fold with early stopping on (X_va, y_va)
            pipe = make_pipeline(model_type, hp)
            pipe = fit_one_fold(model_type, pipe, X_tr, y_tr, X_va, y_va)

            # metrics + OOF
            p_tr = pipe.predict_proba(X_tr)[:, 1]
            p_va = pipe.predict_proba(X_va)[:, 1]
            auc_tr = float(roc_auc_score(y_tr, p_tr))
            auc_va = float(roc_auc_score(y_va, p_va))
            oof[va_idx] = p_va

            # save fold
            bi = _extract_best_iteration(pipe, model_type)
            _save_fold(model_dir, fold_idx, pipe, hp, auc_tr, auc_va, bi)

            fold_hps.append(hp)
            val_aucs.append(auc_va)

            print(f"[{model_type}] Fold {fold_idx}: TRAIN {auc_tr:.4f} | VAL {auc_va:.4f}")

        # save OOF
        np.save(model_dir / f"{model_type}_oof.npy", oof)
        print(f"[{model_type}] OOF AUC = {roc_auc_score(y_train, oof):.4f}")

        # collect best_iters from what we just wrote
        for k in range(K_outer):
            with open(_fold_paths(model_dir, k)["met"], "r") as f:
                bi = json.load(f).get("best_iter", None)
                if isinstance(bi, (int, float)) and int(bi) > 0:
                    best_iters.append(int(bi))

    # --- Full refit (optional) ---
    if full_refit:
        hp_final = _select_full_hps(full_hp_selection, model_type, fold_hps, val_aucs)

        # align capacity with early-stopped folds
        hp_final["n_estimators"] = int(np.median(best_iters))

        # propagate seed
        hp_final = hp_final.copy()
        if "cat" in model_type:
            hp_final.setdefault("random_seed", seed)
        else:
            hp_final.setdefault("random_state", seed)

        full_path = model_dir / f"{model_type}_full.joblib"
        if full_path.exists():
            pipe_full = load(full_path)
            print(f"[{model_type}] Loaded full refit from cache.")
        else:
            pipe_full = make_pipeline(model_type, hp_final)
            pipe_full.fit(X_train, y_train)
            dump(pipe_full, full_path)

        # quick sanity
        auc_full = roc_auc_score(y_train, pipe_full.predict_proba(X_train)[:, 1])
        print(f"[{model_type}] Full-data TRAIN AUC = {auc_full:.4f}")

    return oof


# Stacking / The actual `train()` Function

In [None]:
# ─────────────────────────────────────────────────────────────────────
# Banners
# ─────────────────────────────────────────────────────────────────────


def _print_base_header(model_type: str, width: int = 72):
    line = "═" * width
    title = f" TRAINING BASE — {model_type} "
    print("\n" + line)
    print(title.center(width, "═"))
    print(line)


def _print_seed_header(
    model_type: str, seed: int, idx: int, total: int, width: int = 72
):
    # one thin line above, left-aligned label
    line = "─" * width
    title = f"[{model_type}] SEED {seed} ({idx}/{total}) — START"
    print("\n" + line)
    print(title)


def _print_auc_highlight(scope: str, auc: float, top_k: int, width: int = 72):
    bar = "█" * width
    msg = f"{scope} — TOP{top_k} avg OOF AUC = {auc:.4f}"
    print("\n" + bar)
    print(msg.center(width))
    print(bar)


def _print_meta_header(model_type: str, width: int = 72):
    line = "═" * width
    title = f" TRAINING META — {model_type} "
    print("\n" + line)
    print(title.center(width, "═"))
    print(line)


# ─────────────────────────────────────────────────────────────────────
# Small utilities
# ─────────────────────────────────────────────────────────────────────


def _seed_list(seed0: int, n: int) -> list[int]:
    return [int(seed0 + 17 * i) for i in range(n)]


def _avg_topk(records: list[dict], k: int) -> tuple[np.ndarray, list[dict], float]:
    ranked = sorted(records, key=lambda r: r["auc"], reverse=True)[: max(1, k)]
    oofs = np.vstack([r["oof"] for r in ranked])
    avg = np.mean(oofs, axis=0)
    return avg, ranked, float(np.mean([r["auc"] for r in ranked]))


def _ensure_dir(p: Path) -> Path:
    p.mkdir(parents=True, exist_ok=True)
    return p


# ─────────────────────────────────────────────────────────────────────
# Base learners (multi-seed)
# ─────────────────────────────────────────────────────────────────────


def _train_base_one_seed(
    X_feat: pd.DataFrame,
    y: pd.Series,
    model_type: str,
    seed: int,
    root: Path,
    K_outer: int = K_OUTER,
    K_max_inner: int = K_MAX_INNER,
    K_stop_inner: int = K_STOP_INNER,
    n_optuna_trials: int = N_OPTUNA_TRIALS,
    full_refit: bool = FULL_REFIT,
    full_hp_selection: str = FULL_HP_SELECTION,
) -> dict:
    """Train one seed for a base learner under: root/base/<model_type>/seed_<seed>/"""
    seed_dir = _ensure_dir(root / "base" / model_type / f"seed_{seed}")
    oof = train_model(
        X_feat,
        y,
        model_type=model_type,
        K_outer=K_outer,
        K_max_inner=K_max_inner,
        K_stop_inner=K_stop_inner,
        n_optuna_trials=n_optuna_trials,
        seed=seed,
        model_dir=str(seed_dir),
        full_refit=full_refit,
        full_hp_selection=full_hp_selection,
    )
    np.save(seed_dir / "oof.npy", oof.astype(np.float32))
    auc = float(roc_auc_score(y, oof))
    return {"seed": seed, "oof": oof, "auc": auc, "dir": str(seed_dir)}


def _train_base_multi_seeds(
    X_feat: pd.DataFrame,
    y: pd.Series,
    model_type: str,
    root: Path,
    seed0: int,
    n_seeds: int = N_SEEDS,
    top_seeds: int = TOP_SEEDS,
) -> dict:
    """Train N_SEEDS seeds; keep top_seeds by OOF AUC; save avg OOF + a small summary."""
    model_root = _ensure_dir(root / "base" / model_type)
    records = []
    seeds = _seed_list(seed0, n_seeds)
    for idx, s in enumerate(seeds, 1):
        _print_seed_header(model_type, s, idx, len(seeds))
        rec = _train_base_one_seed(X_feat, y, model_type, s, root)
        records.append(rec)

    avg_oof, top, avg_auc = _avg_topk(records, top_seeds)
    np.save(model_root / f"avg_top{top_seeds}_oof.npy", avg_oof.astype(np.float32))

    best_txt = "\n".join(
        f"seed={r['seed']} auc={r['auc']:.6f} dir={r['dir']}" for r in top
    )
    (model_root / f"best_top{top_seeds}.txt").write_text(
        best_txt + f"\nAVG_AUC={avg_auc:.6f}\n"
    )

    _print_auc_highlight(f"[BASE {model_type}]", avg_auc, top_seeds)
    return {
        "model_type": model_type,
        "avg_oof": avg_oof,
        "top": top,
        "root": str(model_root),
    }


# ─────────────────────────────────────────────────────────────────────
# XGBoost meta-learner (multi-seed)
# ─────────────────────────────────────────────────────────────────────


def _stack_features(S_base: pd.DataFrame) -> pd.DataFrame:
    """Build meta features: logits + simple stats."""

    def _logit_clip(p: np.ndarray, eps: float = 1e-5) -> np.ndarray:
        p = np.clip(p.astype(np.float32), eps, 1 - eps)
        return np.log(p) - np.log(1.0 - p)

    L = pd.DataFrame(
        {f"{c}_logit": _logit_clip(S_base[c].values) for c in S_base.columns},
        index=S_base.index,
    )
    F = pd.concat(
        [
            L,
            pd.DataFrame(
                {
                    "logit_mean": L.mean(axis=1).astype(np.float32),
                    "logit_std": L.std(axis=1).astype(np.float32),
                },
                index=S_base.index,
            ),
        ],
        axis=1,
    )
    return F


def _train_xgb(F, y, params, num_round):
    dtrain = xgb.DMatrix(F.values, label=y.values, feature_names=F.columns.tolist())
    return xgb.train(
        params=params, dtrain=dtrain, num_boost_round=num_round, verbose_eval=False
    )


def _meta_xgb_oof(
    F: pd.DataFrame,
    y: pd.Series,
    seed: int,
    params: dict,
    K_outer: int = K_OUTER,
) -> tuple[float, np.ndarray]:
    # Do cross validation
    yv = y.astype(int).to_numpy()
    groups = F.index.to_numpy()
    skf = StratifiedGroupKFold(n_splits=K_outer, shuffle=True, random_state=seed)

    oof = np.zeros(len(F), dtype=np.float32)
    cols = F.columns.tolist()

    for k, (tr_idx, va_idx) in enumerate(skf.split(F, yv, groups), 1):
        # Split
        Ftr, Fva = F.iloc[tr_idx], F.iloc[va_idx]
        ytr, yva = y.iloc[tr_idx], y.iloc[va_idx]

        # Train xgb for this meta fold
        hp = params.copy()
        num_round = int(hp.pop("n_estimators", 200))
        dtr = xgb.DMatrix(Ftr.values, label=ytr.values, feature_names=cols)
        dva = xgb.DMatrix(Fva.values, label=yva.values, feature_names=cols)
        booster = xgb.train(
            params=hp, dtrain=dtr, num_boost_round=num_round, verbose_eval=False
        )

        # Predict + metrics
        p_tr = booster.predict(dtr).astype(np.float32)
        p_va = booster.predict(dva).astype(np.float32)
        oof[va_idx] = p_va

        auc_tr = roc_auc_score(ytr, p_tr)
        auc_va = roc_auc_score(yva, p_va)
        print(
            f"[META-XGB] fold {k}/{K_outer}: TRAIN AUC={auc_tr:.4f} | VAL AUC={auc_va:.4f}"
        )

    # Log perf
    auc_oof = roc_auc_score(y, oof)
    print(f"[META-XGB] OOF AUC = {auc_oof:.4f}")
    return auc_oof, oof


def _train_meta_xgb_one_seed(
    S_base: pd.DataFrame,
    y: pd.Series,
    seed: int,
    out_dir: Path,
    params: dict | None = None,
    K_outer: int = K_OUTER,
) -> dict:
    """Fit one XGB meta; compute **meta OOF** via K-fold; save final booster + OOF."""
    # 1) meta features
    F = _stack_features(S_base)

    # 2) default params
    hp = dict(
        objective="binary:logistic",
        eval_metric="auc",
        max_depth=2,
        learning_rate=0.05,
        n_estimators=200,  # used as num_boost_round
        subsample=0.7,
        colsample_bytree=1.0,
        min_child_weight=20.0,
        gamma=2.0,
        reg_lambda=20.0,
        reg_alpha=0.0,
        max_bin=64,
        tree_method="hist",
        verbosity=0,
        seed=seed,
    )
    if params:
        hp.update(params)

    # 3) meta OOF
    auc_oof, oof = _meta_xgb_oof(F, y, seed=seed, params=hp, K_outer=K_outer)

    # 4) train final meta on all data (for inference)
    hp_full = hp.copy()
    num_round = int(hp_full.pop("n_estimators", 200))
    booster = _train_xgb(F, y, hp_full, num_round)

    # 5) persist
    out_dir.mkdir(parents=True, exist_ok=True)
    booster.save_model(str(out_dir / "model.json"))
    np.save(out_dir / "oof.npy", oof)

    # 6) Log perf
    dfull = xgb.DMatrix(F.values, feature_names=F.columns.tolist())
    p_full = booster.predict(dfull).astype(np.float32)  # probabilities for class 1
    auc_full = roc_auc_score(y, p_full)
    print(f"[META-XGB] Full-data TRAIN AUC = {auc_full:.4f}")

    return {"seed": seed, "auc": auc_oof, "oof": oof, "dir": str(out_dir)}


def _train_meta_xgb_multi_seeds(
    S_base: pd.DataFrame,
    y: pd.Series,
    root: Path,
    seed0: int,
    n_seeds: int = N_SEEDS,
    top_seeds: int = TOP_SEEDS,
) -> dict:
    """Train multiple meta seeds; rank by **meta OOF AUC**; save artifact with top paths."""
    base_dir = root / "meta" / "xgb"
    base_dir.mkdir(parents=True, exist_ok=True)

    records = []
    seeds = _seed_list(seed0, n_seeds)
    for idx, s in enumerate(seeds, 1):
        _print_seed_header("META-XGB", s, idx, len(seeds))
        rec = _train_meta_xgb_one_seed(S_base, y, s, base_dir / f"seed_{s}")
        records.append(rec)

    avg_oof, top, avg_auc = _avg_topk(records, top_seeds)
    np.save(base_dir / f"avg_top{top_seeds}_oof.npy", avg_oof.astype(np.float32))

    artifact = {
        "type": "xgb_meta",
        "top_seeds": [r["seed"] for r in top],
        "model_paths": [
            str(base_dir / f"seed_{r['seed']}" / "model.json") for r in top
        ],
    }
    dump(artifact, root / "meta_artifact.joblib")

    _print_auc_highlight("[META-XGB]", avg_auc, top_seeds)
    return {"avg_oof": avg_oof, "auc": avg_auc, "artifact": artifact, "top": top}


# ─────────────────────────────────────────────────────────────────────
# Orchestrator
# ─────────────────────────────────────────────────────────────────────


def train(
    X_train: pd.DataFrame,
    y_train: pd.Series,
    base_models: tuple[str] = BASE_LEARNERS,
    seed: int = RANDOM_STATE,
    model_dir: str = MODEL_DIR,
) -> dict:
    """
    1) Build per-id features once.
    2) Train each base across N_SEEDS; keep TOP_SEEDS; average their OOFs.
    3) Stack averaged OOF columns → train XGB meta (multi-seed); keep TOP_SEEDS; save artifact.
    """
    root = _ensure_dir(Path(model_dir))

    # 1) features
    X_feat = build_features(X_train, force_prep=False, force_all=False)

    # Sanity check
    train_n = X_train.index.get_level_values("id").unique().size
    feat_n = len(X_feat)
    if train_n != feat_n:
        raise ValueError(
            "train_n != feat_n, you should probably delete resources/features which was computed on a different X_train"
        )

    # 2) bases
    oof_cols: dict[str, np.ndarray] = {}
    base_info: dict[str, dict] = {}
    for i, m in enumerate(base_models):
        _print_base_header(m)
        res = _train_base_multi_seeds(
            X_feat, y_train, model_type=m, root=root, seed0=seed + 100 * (i + 1)
        )
        oof_cols[f"{m}_oof"] = res["avg_oof"]
        base_info[m] = res

    S_base = pd.DataFrame(oof_cols, index=X_feat.index)

    # 3) meta (only XGB)
    _print_meta_header("XGB")
    meta_res = _train_meta_xgb_multi_seeds(
        S_base, y_train, root=root, seed0=seed + 2000
    )

    print(f"\n[STACK] Final meta OOF AUC = {meta_res['auc']:.4f}")

    return {
        "stack_base": S_base,
        "per_base": base_info,
        "meta_oof": meta_res["avg_oof"],
        "meta_auc": meta_res["auc"],
        "meta_artifact": meta_res["artifact"],
    }

# The `infer()` Function

In [None]:
# ─────────────────────────────────────────────────────────────────────
# Loaders
# ─────────────────────────────────────────────────────────────────────


def _best_seeds(base_dir: Path) -> list[int]:
    """Parse best_top*.txt and return the listed seeds (sorted)."""
    txt = sorted(base_dir.glob("best_top*.txt"))[-1]
    seeds = []
    for line in txt.read_text().splitlines():
        if line.startswith("seed="):
            seeds.append(int(line.split("seed=")[1].split()[0]))
    return sorted(seeds)


def _load_fold_pipelines(seed_root: Path) -> list:
    """Load all fold_{k}.joblib under a seed path, sorted by k."""
    fold_dir = seed_root / "fold_models"
    folds = []
    fold_re = re.compile(r"fold_(\d+)\.joblib$")
    for p in sorted(
        fold_dir.glob("fold_*.joblib"),
        key=lambda q: int(fold_re.search(q.name).group(1)),
    ):
        folds.append(load(p))
    return folds


def _load_base_fold_models(
    model_dir: Path, base_learners: Iterable[str]
) -> dict[str, dict[int, list]]:
    """
    Returns:
      { model_type: { seed: [fold_pipeline0, fold_pipeline1, ...] } }
    """
    out: dict[str, dict[int, list]] = {}
    base_root = model_dir / "base"
    for m in base_learners:
        m_root = base_root / m
        seed_map: dict[int, list] = {}
        for s in _best_seeds(m_root):
            folds = _load_fold_pipelines(m_root / f"seed_{s}")
            seed_map[s] = folds
        out[m] = seed_map
    return out


def _load_base_full_models(
    model_dir: Path, base_learners: Iterable[str]
) -> dict[str, dict[int, object]]:
    """
    Returns:
      { model_type: { seed: full_refit_pipeline } }
    Expects a file named {model_type}_full.joblib under each seed directory.
    """
    out: dict[str, dict[int, object]] = {}
    base_root = model_dir / "base"
    for m in base_learners:
        m_root = base_root / m
        seed_map: dict[int, object] = {}
        for s in _best_seeds(m_root):
            p = m_root / f"seed_{s}" / f"{m}_full.joblib"
            seed_map[s] = load(p)
        out[m] = seed_map
    return out


def _load_meta_model_paths(model_dir: Path) -> list[str]:
    """Load meta_artifact.joblib and return list of XGB model paths."""
    art = load(model_dir / "meta_artifact.joblib")
    return list(art["model_paths"])


# ─────────────────────────────────────────────────────────────────────
# Prediction helpers
# ─────────────────────────────────────────────────────────────────────


def _predict_base_fold_ensemble(fold_pipes: list, X: pd.DataFrame) -> np.ndarray:
    """Average predictions across all folds for ONE seed."""
    preds = [
        pipe.predict_proba(X)[:, 1].astype(np.float32, copy=False)
        for pipe in fold_pipes
    ]
    return np.mean(np.vstack(preds), axis=0).astype(np.float32, copy=False)


def _predict_base_across_seeds_fold(
    models_per_seed: dict[int, list], X: pd.DataFrame
) -> np.ndarray:
    """For one base (fold mode): average across folds per seed, then average across seeds."""
    per_seed = [
        _predict_base_fold_ensemble(folds, X) for folds in models_per_seed.values()
    ]
    return np.mean(np.vstack(per_seed), axis=0).astype(np.float32, copy=False)


def _predict_base_across_seeds_full(
    models_per_seed: dict[int, object], X: pd.DataFrame
) -> np.ndarray:
    """For one base (full-refit mode): average full-refit pipelines across seeds."""
    per_seed = [
        pipe.predict_proba(X)[:, 1].astype(np.float32, copy=False)
        for pipe in models_per_seed.values()
    ]
    return np.mean(np.vstack(per_seed), axis=0).astype(np.float32, copy=False)


def _predict_meta_xgb(model_paths: list[str], F_star: pd.DataFrame) -> np.ndarray:
    """Average predictions across the meta top seeds (all XGBoost)."""
    dtest = xgb.DMatrix(F_star.values, feature_names=list(F_star.columns))
    preds = []
    for mp in model_paths:
        booster = xgb.Booster()
        booster.load_model(mp)
        preds.append(booster.predict(dtest).astype(np.float32, copy=False))
    return np.mean(np.vstack(preds), axis=0).astype(np.float32, copy=False)


# ─────────────────────────────────────────────────────────────────────
# iinfer(): choose between fold-ensemble or full-refit bases
# ─────────────────────────────────────────────────────────────────────


def infer(
    X_test: Iterable[pd.DataFrame],
    model_dir: Path = MODEL_DIR,
    inference_mode: str = INFERENCE_MODE,  # "fold" | "full"
    base_learners: list = BASE_LEARNERS,
):
    """
    Inference pipeline:

    inference_mode = "fold" (default):
        • Load per-seed fold models for each base.
        • For each shard (one id):
            - build per-id features (inference=True),
            - base preds: average across folds per seed, then across seeds,
            - stack → meta features → XGB meta (avg across meta top seeds),
            - yield np.ndarray([prob]).

    inference_mode = "full":
        • Load per-seed FULL-REFIT pipelines for each base ({model}_full.joblib).
        • Same as above but base preds are averaged across seeds only (no folds).
    """
    if inference_mode == "fold":
        base_models = _load_base_fold_models(model_dir, base_learners)
        base_predict = _predict_base_across_seeds_fold
    elif inference_mode == "full":
        base_models = _load_base_full_models(model_dir, base_learners)
        base_predict = _predict_base_across_seeds_full
    else:
        raise ValueError("inference_mode must be 'fold' or 'full'")

    meta_paths = _load_meta_model_paths(model_dir)

    # signal readiness
    yield

    # process each dataset exactly once (each shard = one id's long df)
    for df_raw in X_test:
        # one id per shard
        id1 = df_raw.index.get_level_values("id").unique()[0]

        # per-id features (no cache I/O)
        X_feat = build_features(df_raw, inference=True).reindex([id1])

        # base predictions → stacked frame
        base_stack = {
            f"{m}_oof": base_predict(models_per_seed, X_feat)
            for m, models_per_seed in base_models.items()
        }
        S_star = pd.DataFrame(base_stack, index=[id1])

        # meta features + XGB meta prediction
        F_star = _stack_features(S_star)
        pred = _predict_meta_xgb(meta_paths, F_star)

        yield pred.astype(np.float32, copy=False)

# Local testing

To make sure your `train()` and `infer()` function are working properly, you can call the `crunch.test()` function that will reproduce the cloud environment locally. <br />
Even if it is not perfect, it should give you a quick idea if your model is working properly.

In [None]:
crunch.test(
    # Uncomment to disable the train
    # force_first_train=False,

    # Uncomment to disable the determinism check
    # no_determinism_check=True,
)

23:17:30 
23:17:31 started
23:17:31 running local test
23:17:31 internet access isn't restricted, no check will be done
23:17:31 
23:17:32 starting unstructured loop...
23:17:32 executing - command=train


data/X_train.parquet: download from https:crunchdao--competition--production.s3-accelerate.amazonaws.com/data-releases/146/X_train.parquet (204327238 bytes)
data/X_train.parquet: already exists, file length match
data/X_test.reduced.parquet: download from https:crunchdao--competition--production.s3-accelerate.amazonaws.com/data-releases/146/X_test.reduced.parquet (2380918 bytes)
data/X_test.reduced.parquet: already exists, file length match
data/y_train.parquet: download from https:crunchdao--competition--production.s3-accelerate.amazonaws.com/data-releases/146/y_train.parquet (61003 bytes)
data/y_train.parquet: already exists, file length match
data/y_test.reduced.parquet: download from https:crunchdao--competition--production.s3-accelerate.amazonaws.com/data-releases/146/y_test.reduced.parquet (2655 bytes)
data/y_test.reduced.parquet: already exists, file length match


train: skip param with default value: base_models=('xgb_main', 'xgb_lite', 'lgb_main', 'cat_main')
train: skip param with default value: seed=69
train: skip param with default value: model_dir=resources/model


Loading cached data from resources/features/all_0930_1346.parquet

════════════════════════════════════════════════════════════════════════
═══════════════════════ TRAINING BASE — xgb_main ═══════════════════════
════════════════════════════════════════════════════════════════════════

────────────────────────────────────────────────────────────────────────
[xgb_main] SEED 169 (1/20) — START
[xgb_main] Fold 0: TRAIN 0.7387 | VAL 0.6898
[xgb_main] Fold 1: TRAIN 0.7407 | VAL 0.6989
[xgb_main] Fold 2: TRAIN 0.7221 | VAL 0.6787
[xgb_main] Fold 3: TRAIN 0.7453 | VAL 0.6730
[xgb_main] Fold 4: TRAIN 0.7544 | VAL 0.6584
[xgb_main] OOF AUC = 0.6781
[xgb_main] Full-data TRAIN AUC = 0.7373

────────────────────────────────────────────────────────────────────────
[xgb_main] SEED 186 (2/20) — START
[xgb_main] Fold 0: TRAIN 0.7396 | VAL 0.6801
[xgb_main] Fold 1: TRAIN 0.7549 | VAL 0.6928
[xgb_main] Fold 2: TRAIN 0.7495 | VAL 0.6908
[xgb_main] Fold 3: TRAIN 0.7369 | VAL 0.6752
[xgb_main] Fold 4: TRAI

23:44:46 executing - command=infer


[META-XGB] Full-data TRAIN AUC = 0.7170

████████████████████████████████████████████████████████████████████████
                 [META-XGB] — TOP2 avg OOF AUC = 0.7022                 
████████████████████████████████████████████████████████████████████████

[STACK] Final meta OOF AUC = 0.7022


infer: skip param with default value: inference_mode=full
infer: skip param with default value: base_learners=('xgb_main', 'xgb_lite', 'lgb_main', 'cat_main')
23:45:08 checking determinism by executing the inference again with 30% of the data (tolerance: 1e-08)
23:45:08 executing - command=infer
23:45:13 determinism check: passed
23:45:13 save prediction - path=data/prediction.parquet
23:45:13 ended
23:45:13 duration - time=00:27:42
23:45:13 memory - before="268.93 MB" after="266.75 MB" consumed="-2179072 bytes"


# Results

Once the local tester is done, you can preview the result stored in `data/prediction.parquet`.

In [None]:
prediction = pd.read_parquet("data/prediction.parquet")
prediction

Unnamed: 0_level_0,prediction
id,Unnamed: 1_level_1
10001,0.160384
10002,0.200757
10003,0.131639
10004,0.168218
10005,0.197960
...,...
10097,0.208788
10098,0.130918
10099,0.104289
10100,0.244442


### Local scoring

You can call the function that the system uses to estimate your score locally.

In [None]:
# Load the targets
target = pd.read_parquet("data/y_test.reduced.parquet")["structural_breakpoint"]

# Call the scoring function
roc_auc_score(
    target,
    prediction,
)

0.711737089201878

# Submit your Notebook

To submit your work, you must:
1. Download your Notebook from Colab
2. Upload it to the platform
3. Create a run to validate it

### >> https://hub.crunchdao.com/competitions/structural-break/submit/notebook