# Go/No-Go Diagnostic Experiment

**Goal:** Determine whether there is *learnable conditional signal* in our current setup,
and whether "marginal collapse" is due to discretization + imbalance + objective rather than a true lack of signal.

- **Part 1:** Sanity diagnostics (MI + smoothed baselines, NO neural nets)
- **Part 2:** Bin-count ablation (`n_bins` in {25, 35, 40, 55})
- **Part 3:** Minimal neural check (only if Part 2 shows promise)
- **Part 4:** Final conclusion + save results

In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from datetime import datetime
import os, json, warnings

warnings.filterwarnings('ignore')
sns.set(style="whitegrid")
%matplotlib inline

SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)

EPS = 1e-10

print("Imports ready.")

Imports ready.


## Data Loading

Copied from `TransitionProbMatrix_NEWDATA.ipynb` (cells `8b3d1378`, `414278a2`, `3c86ff39`).
Same dataset, same split, same preprocessing.

In [2]:
# ====================================================================
# DATA LOADING — from TransitionProbMatrix_NEWDATA.ipynb
# ====================================================================
train_df  = pd.read_csv("dataset/train_diagnostic.csv")
labels_df = pd.read_csv("dataset/label_diagnostic.csv")

# Compute forward percent change from Price
train_df["Percent_change_forward"] = (
    train_df["Price"].shift(-1) / train_df["Price"] - 1
) * 100.0

# Drop last row (forward return undefined)
train_df = train_df.iloc[:-1].copy()
labels_df = labels_df.iloc[:-1].copy()

# Drop Opinion column (NaN)
train_df = train_df.drop(columns=["Opinion"], errors="ignore")

# Feature columns
drop_cols = ["index", "Percent_change_forward", "Backward_Bin"]
feature_cols = [c for c in train_df.columns if c not in drop_cols]
X_all = train_df[feature_cols].values.astype(np.float32)

# States: 0-based
s_curr_all = (train_df["Backward_Bin"].values.astype(np.int64) - 1)
y_all      = (labels_df["Forward_Bin"].values.astype(np.int64) - 1)

# Raw percent changes (for rebinning in Part 2)
pct_backward_all = train_df["Percent_change_backward"].values.astype(np.float64)
pct_forward_all  = train_df["Percent_change_forward"].values.astype(np.float64)

n_samples, n_features = X_all.shape
n_states = int(max(s_curr_all.max(), y_all.max()) + 1)

# ====================================================================
# TEMPORAL SPLIT: 70 / 15 / 15 (same as TransitionProbMatrix_NEWDATA)
# ====================================================================
T = n_samples
train_end = int(0.7 * T)
val_end   = int(0.85 * T)

idx_train = np.arange(0,         train_end)
idx_val   = np.arange(train_end, val_end)
idx_test  = np.arange(val_end,   T)

s_train, s_val, s_test = s_curr_all[idx_train], s_curr_all[idx_val], s_curr_all[idx_test]
y_train, y_val, y_test = y_all[idx_train], y_all[idx_val], y_all[idx_test]

X_train = X_all[idx_train]
X_val   = X_all[idx_val]
X_test  = X_all[idx_test]

# Standardize features (train stats only)
mean_feat = X_train.mean(axis=0, keepdims=True)
std_feat  = X_train.std(axis=0, keepdims=True) + 1e-8
X_train_std = (X_train - mean_feat) / std_feat
X_val_std   = (X_val   - mean_feat) / std_feat
X_test_std  = (X_test  - mean_feat) / std_feat

print(f"n_samples: {n_samples}, n_features: {n_features}, n_states: {n_states}")
print(f"Train: {len(idx_train)}, Val: {len(idx_val)}, Test: {len(idx_test)}")
print(f"Train indices: [{idx_train[0]}, {idx_train[-1]}]")
print(f"Val   indices: [{idx_val[0]}, {idx_val[-1]}]")
print(f"Test  indices: [{idx_test[0]}, {idx_test[-1]}]")

n_samples: 2368, n_features: 195, n_states: 55
Train: 1657, Val: 355, Test: 356
Train indices: [0, 1656]
Val   indices: [1657, 2011]
Test  indices: [2012, 2367]


## Part 1 — Sanity Diagnostics (NO neural nets)

In [3]:
# ====================================================================
# HELPER FUNCTIONS
# ====================================================================

def compute_marginal(y, n_classes):
    """Compute marginal distribution from integer labels."""
    counts = np.bincount(y, minlength=n_classes).astype(np.float64)
    return counts / counts.sum()


def compute_conditional(s, y, n_x, n_y, alpha=1.0):
    """
    Compute smoothed conditional P(y | x) via Laplace smoothing.
    C_alpha[x, y] = C[x, y] + alpha
    P(y|x) = C_alpha[x, y] / sum_y' C_alpha[x, y']
    
    Returns: P (n_x, n_y), raw_counts C (n_x, n_y)
    """
    C = np.zeros((n_x, n_y), dtype=np.float64)
    for si, yi in zip(s, y):
        C[si, yi] += 1
    C_alpha = C + alpha
    P = C_alpha / C_alpha.sum(axis=1, keepdims=True)
    return P, C


def evaluate_distribution(pred_dist, y_true, n_classes, label=""):
    """
    Evaluate predicted distributions vs true labels.
    pred_dist: (N, n_classes)  y_true: (N,)
    Returns dict with: mean_ll, accuracy, severity.
    
    Severity definition matches TransitionProbMatrix_NEWDATA.ipynb:
      severity = mean(|E_p[bin] - y_true|)
    where E_p[bin] = sum_j p[j] * j
    """
    N = len(y_true)
    probs_true = pred_dist[np.arange(N), y_true]
    mean_ll = np.log(probs_true + EPS).mean()
    accuracy = (pred_dist.argmax(axis=1) == y_true).mean()
    bins = np.arange(n_classes, dtype=np.float64)
    expected_bins = (pred_dist * bins[np.newaxis, :]).sum(axis=1)
    severity = np.abs(expected_bins - y_true.astype(np.float64)).mean()
    return {"label": label, "mean_ll": mean_ll, "accuracy": accuracy,
            "severity": severity, "n": N}


def compute_mi(s, y, n_x, n_y, eps=1e-10):
    """Compute MI(X; Y) from integer arrays with epsilon smoothing."""
    C = np.zeros((n_x, n_y), dtype=np.float64)
    for si, yi in zip(s, y):
        C[si, yi] += 1
    P_joint = (C + eps)
    P_joint = P_joint / P_joint.sum()
    P_x = P_joint.sum(axis=1)
    P_y = P_joint.sum(axis=0)
    mi = 0.0
    for i in range(n_x):
        for j in range(n_y):
            if P_joint[i, j] > 0:
                mi += P_joint[i, j] * np.log(P_joint[i, j] / (P_x[i] * P_y[j]))
    return mi


print("Helpers defined.")

Helpers defined.


In [4]:
# ====================================================================
# PART 1: COMPUTE & EVALUATE BASELINES
# ====================================================================

# 1. Marginal baseline (train only)
marginal_dist = compute_marginal(y_train, n_states)
entropy = -np.sum(marginal_dist * np.log(marginal_dist + EPS))
print(f"Marginal: mode=bin {marginal_dist.argmax()}, "
      f"max_prob={marginal_dist.max():.4f}, entropy={entropy:.4f} nats")

# 2. Smoothed conditional baselines
P_cond_01, C_counts = compute_conditional(s_train, y_train, n_states, n_states, alpha=0.1)
P_cond_10, _        = compute_conditional(s_train, y_train, n_states, n_states, alpha=1.0)

# State coverage diagnostics
state_counts_train = np.bincount(s_train, minlength=n_states)
print(f"\nState coverage (train): {(state_counts_train > 0).sum()}/{n_states} occupied, "
      f"{(state_counts_train < 5).sum()} with <5 samples")
print(f"  Count distribution: min={state_counts_train[state_counts_train>0].min()}, "
      f"median={np.median(state_counts_train[state_counts_train>0]):.0f}, "
      f"max={state_counts_train.max()}")

# 3. Evaluate on VAL and TEST
results_part1 = []

for split_name, s_sp, y_sp in [("val", s_val, y_val), ("test", s_test, y_test)]:
    N = len(y_sp)
    pred_m = np.tile(marginal_dist[np.newaxis, :], (N, 1))
    r = evaluate_distribution(pred_m, y_sp, n_states, f"Marginal")
    r["alpha"] = "-"; r["split"] = split_name
    results_part1.append(r)

    for alpha_val, P_cond in [(0.1, P_cond_01), (1.0, P_cond_10)]:
        pred_c = P_cond[s_sp]
        r = evaluate_distribution(pred_c, y_sp, n_states, f"Conditional")
        r["alpha"] = alpha_val; r["split"] = split_name
        results_part1.append(r)

df1 = pd.DataFrame(results_part1)[["label", "alpha", "split", "mean_ll", "accuracy", "severity", "n"]]

print("\n" + "="*90)
print("PART 1: BASELINE RESULTS (n_states=55)")
print("="*90)
print(df1.to_string(index=False, float_format=lambda x: f"{x:.6f}"))
print("="*90)

# Flag conditional vs marginal
for sp in ["val", "test"]:
    m_ll = df1[(df1["split"]==sp) & (df1["label"]=="Marginal")]["mean_ll"].values[0]
    c_ll = df1[(df1["split"]==sp) & (df1["label"]=="Conditional")]["mean_ll"].max()
    delta = c_ll - m_ll
    print(f"\n{sp.upper()}: Best conditional - Marginal = {delta:+.6f} nats"
          f"  {'BEATS marginal' if delta > 0 else '*** DOES NOT beat marginal ***'}")

Marginal: mode=bin 26, max_prob=0.0416, entropy=3.7083 nats

State coverage (train): 53/55 occupied, 11 with <5 samples
  Count distribution: min=1, median=35, max=69

PART 1: BASELINE RESULTS (n_states=55)
      label    alpha split   mean_ll  accuracy  severity   n
   Marginal        -   val -3.682208  0.047887  8.101643 355
Conditional 0.100000   val -4.125278  0.025352  8.185760 355
Conditional 1.000000   val -3.882289  0.025352  8.144085 355
   Marginal        -  test -3.692749  0.042135  8.266910 356
Conditional 0.100000  test -4.184170  0.042135  8.294956 356
Conditional 1.000000  test -3.891563  0.042135  8.317989 356

VAL: Best conditional - Marginal = -0.200081 nats  *** DOES NOT beat marginal ***

TEST: Best conditional - Marginal = -0.198814 nats  *** DOES NOT beat marginal ***


In [5]:
# ====================================================================
# PART 1: MUTUAL INFORMATION DIAGNOSTIC (TRAIN ONLY)
# ====================================================================

mi_real = compute_mi(s_train, y_train, n_states, n_states)
mi_bits = mi_real / np.log(2)

print(f"MI(Backward_Bin, Forward_Bin) on TRAIN:")
print(f"  MI = {mi_real:.6f} nats = {mi_bits:.6f} bits")

# Permutation test
mi_shuffled = []
for i in range(5):
    y_perm = np.random.permutation(y_train)
    mi_s = compute_mi(s_train, y_perm, n_states, n_states)
    mi_shuffled.append(mi_s)
    print(f"  Shuffled [{i}] = {mi_s:.6f} nats")

mi_shuffled = np.array(mi_shuffled)
print(f"\n  Shuffled mean = {mi_shuffled.mean():.6f}, std = {mi_shuffled.std():.6f}")
print(f"  Real MI / Shuffled mean = {mi_real / mi_shuffled.mean():.1f}x")

if mi_real > 3 * mi_shuffled.mean():
    print("  MI is significantly above noise floor.")
else:
    print("  *** WARNING: MI is NOT significantly above noise floor! ***")

MI(Backward_Bin, Forward_Bin) on TRAIN:
  MI = 0.659427 nats = 0.951352 bits
  Shuffled [0] = 0.617008 nats
  Shuffled [1] = 0.637256 nats
  Shuffled [2] = 0.628616 nats
  Shuffled [3] = 0.638155 nats
  Shuffled [4] = 0.630453 nats

  Shuffled mean = 0.630298, std = 0.007611
  Real MI / Shuffled mean = 1.0x


## Part 2 — Bin-Count Ablation

**Original discretization:** Fixed-width bins hardcoded in `dataset.Rmd` (55 bins, variable widths: 0.2% center, 0.5% mid, 1% tails).

**Ablation approach:**
- For `n_bins=55`: use original pre-computed bins from CSV.
- For `n_bins` in {25, 35, 40}: use quantile-based binning fit on **TRAIN forward returns only**, applied to all splits.

The purpose is to test whether fewer bins increase learnable conditional structure. Quantile bins ensure roughly equal frequency per bin, reducing sparsity.

In [6]:
# ====================================================================
# PART 2: BIN-COUNT ABLATION
# ====================================================================

def rebin_quantile(pct_values, pct_train, n_bins):
    """Rebin percent changes using quantile edges fit on training data."""
    edges = np.quantile(pct_train, np.linspace(0, 1, n_bins + 1))
    edges[0] = -np.inf
    edges[-1] = np.inf
    # Handle tied quantiles
    edges = np.unique(edges)
    actual_n = len(edges) - 1
    bins = np.clip(np.digitize(pct_values, edges) - 1, 0, actual_n - 1)
    return bins, edges, actual_n


N_BINS_LIST = [25, 35, 40, 55]
ablation_results = []
ablation_mi = {}  # store MI per n_bins

pct_fwd_train = pct_forward_all[idx_train]

for n_bins in N_BINS_LIST:
    print(f"\n{'='*60}")
    print(f"n_bins = {n_bins}")
    print(f"{'='*60}")

    if n_bins == 55:
        # Use original pre-computed bins from CSV
        s_new = s_curr_all.copy()
        y_new = y_all.copy()
        actual_n = 55
        method = "original_fixed"
        print("  Using original fixed-width bins from dataset.Rmd")
    else:
        # Quantile-based, fit on TRAIN forward returns only
        y_new, edges, actual_n = rebin_quantile(pct_forward_all, pct_fwd_train, n_bins)
        s_new, _, _ = rebin_quantile(pct_backward_all, pct_fwd_train, n_bins)
        method = f"quantile"
        print(f"  Quantile bins: {actual_n} actual (requested {n_bins})")

    # Split
    s_tr = s_new[idx_train]; s_va = s_new[idx_val]; s_te = s_new[idx_test]
    y_tr = y_new[idx_train]; y_va = y_new[idx_val]; y_te = y_new[idx_test]

    # State coverage
    sc = np.bincount(s_tr, minlength=actual_n)
    print(f"  Avg samples/state (train): {sc[sc>0].mean():.1f}, "
          f"min={sc[sc>0].min()}, states<5: {(sc<5).sum()}")

    # MI
    mi = compute_mi(s_tr, y_tr, actual_n, actual_n)
    ablation_mi[actual_n] = mi
    print(f"  MI = {mi:.6f} nats ({mi/np.log(2):.6f} bits)")

    # Marginal
    marginal = compute_marginal(y_tr, actual_n)

    # Conditional with both alphas
    for alpha in [0.1, 1.0]:
        P_cond, _ = compute_conditional(s_tr, y_tr, actual_n, actual_n, alpha=alpha)

        for sp_name, s_sp, y_sp in [("val", s_va, y_va), ("test", s_te, y_te)]:
            N = len(y_sp)
            r_m = evaluate_distribution(
                np.tile(marginal[np.newaxis, :], (N, 1)), y_sp, actual_n)
            r_c = evaluate_distribution(P_cond[s_sp], y_sp, actual_n)
            delta = r_c["mean_ll"] - r_m["mean_ll"]

            ablation_results.append({
                "n_bins": actual_n, "method": method, "alpha": alpha,
                "split": sp_name,
                "MI_nats": mi, "MI_bits": mi / np.log(2),
                "marginal_LL": r_m["mean_ll"],
                "conditional_LL": r_c["mean_ll"],
                "delta_LL": delta,
                "marginal_acc": r_m["accuracy"],
                "conditional_acc": r_c["accuracy"],
                "marginal_sev": r_m["severity"],
                "conditional_sev": r_c["severity"],
            })

            if sp_name == "test":
                print(f"  alpha={alpha}, test: marg_LL={r_m['mean_ll']:.4f}, "
                      f"cond_LL={r_c['mean_ll']:.4f}, delta={delta:+.4f}")

print("\nAblation complete.")


n_bins = 25
  Quantile bins: 25 actual (requested 25)
  Avg samples/state (train): 66.3, min=66, states<5: 0
  MI = 0.256094 nats (0.369466 bits)
  alpha=0.1, test: marg_LL=-3.2188, cond_LL=-3.5423, delta=-0.3236
  alpha=1.0, test: marg_LL=-3.2188, cond_LL=-3.3173, delta=-0.0985

n_bins = 35
  Quantile bins: 35 actual (requested 35)
  Avg samples/state (train): 47.3, min=47, states<5: 0
  MI = 0.471888 nats (0.680791 bits)
  alpha=0.1, test: marg_LL=-3.5547, cond_LL=-4.1986, delta=-0.6439
  alpha=1.0, test: marg_LL=-3.5547, cond_LL=-3.7045, delta=-0.1498

n_bins = 40
  Quantile bins: 40 actual (requested 40)
  Avg samples/state (train): 41.4, min=41, states<5: 0
  MI = 0.573988 nats (0.828090 bits)
  alpha=0.1, test: marg_LL=-3.6886, cond_LL=-4.4660, delta=-0.7774
  alpha=1.0, test: marg_LL=-3.6886, cond_LL=-3.8372, delta=-0.1486

n_bins = 55
  Using original fixed-width bins from dataset.Rmd
  Avg samples/state (train): 31.3, min=1, states<5: 11
  MI = 0.659427 nats (0.951352 bits)
 

In [7]:
# ====================================================================
# PART 2 SUMMARY
# ====================================================================
df2 = pd.DataFrame(ablation_results)

# Best alpha per (n_bins, split)
summary2 = []
for nb in sorted(df2["n_bins"].unique()):
    for sp in ["val", "test"]:
        sub = df2[(df2["n_bins"] == nb) & (df2["split"] == sp)]
        best = sub.loc[sub["delta_LL"].idxmax()]
        summary2.append({
            "n_bins": int(best["n_bins"]), "split": sp,
            "best_alpha": best["alpha"],
            "MI_nats": best["MI_nats"],
            "marginal_LL": best["marginal_LL"],
            "conditional_LL": best["conditional_LL"],
            "delta_LL": best["delta_LL"],
            "cond_acc": best["conditional_acc"],
            "cond_sev": best["conditional_sev"],
        })

df2_summary = pd.DataFrame(summary2)

print("="*100)
print("PART 2: BIN-COUNT ABLATION SUMMARY (best alpha per n_bins)")
print("="*100)
print(df2_summary.to_string(index=False, float_format=lambda x: f"{x:.6f}"))
print("="*100)

# Identify best n_bins for Part 3
best_row_val = df2_summary[df2_summary["split"] == "val"].sort_values(
    "delta_LL", ascending=False).iloc[0]
best_n_bins = int(best_row_val["n_bins"])
best_delta_val = best_row_val["delta_LL"]
best_alpha_val = best_row_val["best_alpha"]

proceed_to_part3 = best_delta_val > 0.01

print(f"\nBest n_bins on VAL: {best_n_bins} "
      f"(delta_LL = {best_delta_val:+.6f}, alpha = {best_alpha_val})")
if proceed_to_part3:
    print(f"Delta > 0.01 nats: PROCEEDING to Part 3 with n_bins={best_n_bins}")
else:
    print("Delta <= 0.01 nats: Conditional advantage is negligible.")
    print("Part 3 will be SKIPPED (no evidence of learnable conditional signal).")

PART 2: BIN-COUNT ABLATION SUMMARY (best alpha per n_bins)
 n_bins split  best_alpha  MI_nats  marginal_LL  conditional_LL  delta_LL  cond_acc  cond_sev
     25   val    1.000000 0.256094    -3.218704       -3.348023 -0.129319  0.042254  5.953382
     25  test    1.000000 0.256094    -3.218758       -3.317298 -0.098540  0.050562  5.970979
     35   val    1.000000 0.471888    -3.554669       -3.683649 -0.128980  0.050704  8.313137
     35  test    1.000000 0.471888    -3.554692       -3.704510 -0.149818  0.022472  8.376967
     40   val    1.000000 0.573988    -3.688670       -3.796610 -0.107940  0.016901  9.573489
     40  test    1.000000 0.573988    -3.688565       -3.837156 -0.148592  0.022472  9.537670
     55   val    1.000000 0.659427    -3.682208       -3.882289 -0.200081  0.025352  8.144085
     55  test    1.000000 0.659427    -3.692749       -3.891563 -0.198814  0.042135  8.317989

Best n_bins on VAL: 40 (delta_LL = -0.107940, alpha = 1.0)
Delta <= 0.01 nats: Conditional adv

## Part 3 — Minimal Neural Check

Only runs if Part 2 showed a non-trivial conditional advantage (delta_LL > 0.01 on VAL).

Trains two variants:
- **(A)** Default sampling (current)
- **(B)** State-balanced sampling (WeightedRandomSampler inversely proportional to current-state frequency)

Uses `soft_gaussian_sigma1.0` (best-performing label strategy from prior experiments).

Model: `TransitionNetStateHead` — copied from `TransitionProbMatrix_NEWDATA.ipynb` cell `6978c078`.

In [8]:
# ====================================================================
# PART 3: MINIMAL NEURAL CHECK
# ====================================================================

if not proceed_to_part3:
    print("SKIPPING Part 3: no conditional advantage found in Part 2.")
    print("Final summary will be based on Parts 1-2 only.")
    r_default = r_balanced = r_marginal_p3 = r_cond_p3 = None
    df3 = pd.DataFrame()  # empty
    n_states_p3 = None
else:
    DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    N_EPOCHS_P3 = 30
    PATIENCE = 8

    # --- Re-discretize with best n_bins ---
    if best_n_bins == 55:
        s_p3 = s_curr_all.copy()
        y_p3 = y_all.copy()
        n_states_p3 = 55
    else:
        y_p3, _, n_states_p3 = rebin_quantile(pct_forward_all, pct_fwd_train, best_n_bins)
        s_p3, _, _ = rebin_quantile(pct_backward_all, pct_fwd_train, best_n_bins)

    s_tr3 = s_p3[idx_train]; s_va3 = s_p3[idx_val]; s_te3 = s_p3[idx_test]
    y_tr3 = y_p3[idx_train]; y_va3 = y_p3[idx_val]; y_te3 = y_p3[idx_test]

    print(f"Part 3: n_bins={n_states_p3}, device={DEVICE}")

    # --- Model (from TransitionProbMatrix_NEWDATA.ipynb, cell 6978c078) ---
    class TransitionNetStateHead(nn.Module):
        def __init__(self, n_features, n_states, trunk_dims=(64, 128, 256, 128),
                     trunk_out=64, dropout=0.2):
            super().__init__()
            self.n_features = n_features
            self.n_states = n_states
            self.trunk_out = trunk_out
            layers = []
            in_dim = n_features
            for h in trunk_dims:
                layers += [nn.Linear(in_dim, h), nn.GELU(), nn.Dropout(dropout)]
                in_dim = h
            layers += [nn.Linear(in_dim, trunk_out), nn.GELU()]
            self.trunk = nn.Sequential(*layers)
            self.head_W = nn.Parameter(torch.randn(n_states, trunk_out, n_states) * 0.01)
            self.head_b = nn.Parameter(torch.zeros(n_states, n_states))

        def forward(self, x, s_curr):
            h = self.trunk(x)
            W_s = self.head_W[s_curr]
            b_s = self.head_b[s_curr]
            return torch.bmm(h.unsqueeze(1), W_s).squeeze(1) + b_s

    # --- Soft labels (from cell kq2127g0bk) ---
    def create_soft_labels_batch(y_hard, n_st, sigma=1.0):
        B = y_hard.shape[0]
        dev = y_hard.device
        j = torch.arange(n_st, device=dev, dtype=torch.float32).unsqueeze(0).expand(B, -1)
        d2 = (j - y_hard.unsqueeze(1).float()) ** 2
        u = torch.exp(-d2 / (2 * sigma ** 2))
        return u / (u.sum(dim=1, keepdim=True) + 1e-8)

    # --- Dataset ---
    class TransitionDataset(Dataset):
        def __init__(self, X, s, y):
            self.X = torch.tensor(X, dtype=torch.float32)
            self.s = torch.tensor(s, dtype=torch.long)
            self.y = torch.tensor(y, dtype=torch.long)
        def __len__(self): return len(self.X)
        def __getitem__(self, i): return self.X[i], self.s[i], self.y[i]

    # --- Train + evaluate function ---
    def train_and_evaluate(s_tr, y_tr, s_va, y_va, s_te, y_te,
                           n_st, sampler=None, label="default"):
        torch.manual_seed(SEED)
        model = TransitionNetStateHead(n_features, n_st).to(DEVICE)
        opt = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)

        tr_ds = TransitionDataset(X_train_std, s_tr, y_tr)
        va_ds = TransitionDataset(X_val_std, s_va, y_va)
        te_ds = TransitionDataset(X_test_std, s_te, y_te)

        tr_ld = DataLoader(tr_ds, batch_size=256,
                           sampler=sampler, shuffle=(sampler is None))
        va_ld = DataLoader(va_ds, batch_size=512, shuffle=False)
        te_ld = DataLoader(te_ds, batch_size=512, shuffle=False)

        best_vl, best_st, patience_ctr = float("inf"), None, 0

        for ep in range(1, N_EPOCHS_P3 + 1):
            model.train()
            for xb, sb, yb in tr_ld:
                xb, sb, yb = xb.to(DEVICE), sb.to(DEVICE), yb.to(DEVICE)
                opt.zero_grad()
                logits = model(xb, sb)
                soft = create_soft_labels_batch(yb, n_st, sigma=1.0)
                loss = F.kl_div(F.log_softmax(logits, 1), soft, reduction='batchmean')
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                opt.step()

            model.eval()
            vl, vn = 0.0, 0
            with torch.no_grad():
                for xb, sb, yb in va_ld:
                    xb, sb, yb = xb.to(DEVICE), sb.to(DEVICE), yb.to(DEVICE)
                    logits = model(xb, sb)
                    soft = create_soft_labels_batch(yb, n_st, sigma=1.0)
                    loss = F.kl_div(F.log_softmax(logits, 1), soft, reduction='batchmean')
                    vl += loss.item() * len(xb); vn += len(xb)
            vl /= vn

            if vl < best_vl:
                best_vl = vl
                best_st = {k: v.cpu().clone() for k, v in model.state_dict().items()}
                patience_ctr = 0
            else:
                patience_ctr += 1

            if ep % 10 == 0 or ep == 1:
                print(f"  [{label}] ep {ep:02d}: val_loss={vl:.4f} (best={best_vl:.4f})")
            if patience_ctr >= PATIENCE:
                print(f"  [{label}] Early stop at epoch {ep}")
                break

        model.load_state_dict(best_st); model.to(DEVICE); model.eval()

        all_p, all_y = [], []
        with torch.no_grad():
            for xb, sb, yb in te_ld:
                xb, sb = xb.to(DEVICE), sb.to(DEVICE)
                probs = F.softmax(model(xb, sb), dim=1).cpu().numpy()
                all_p.append(probs); all_y.append(yb.numpy())
        return evaluate_distribution(np.vstack(all_p), np.concatenate(all_y), n_st, label)

    # --- (A) Default sampling ---
    print("\n--- Training: DEFAULT sampling ---")
    r_default = train_and_evaluate(
        s_tr3, y_tr3, s_va3, y_va3, s_te3, y_te3, n_states_p3, label="neural_default")

    # --- (B) Balanced sampling ---
    print("\n--- Training: BALANCED sampling ---")
    sc3 = np.bincount(s_tr3, minlength=n_states_p3)
    w = 1.0 / np.maximum(sc3, 1).astype(np.float64)
    sw = w[s_tr3]; sw = sw / sw.sum()
    sampler_b = WeightedRandomSampler(torch.DoubleTensor(sw), len(s_tr3), replacement=True)
    r_balanced = train_and_evaluate(
        s_tr3, y_tr3, s_va3, y_va3, s_te3, y_te3, n_states_p3,
        sampler=sampler_b, label="neural_balanced")

    # --- Baselines at this n_bins ---
    marg_p3 = compute_marginal(y_tr3, n_states_p3)
    r_marginal_p3 = evaluate_distribution(
        np.tile(marg_p3[np.newaxis, :], (len(y_te3), 1)),
        y_te3, n_states_p3, "marginal")

    P_cond_p3, _ = compute_conditional(
        s_tr3, y_tr3, n_states_p3, n_states_p3, alpha=float(best_alpha_val))
    r_cond_p3 = evaluate_distribution(
        P_cond_p3[s_te3], y_te3, n_states_p3,
        f"conditional (a={best_alpha_val})")

    # --- Part 3 table ---
    p3_rows = [r_marginal_p3, r_cond_p3, r_default, r_balanced]
    df3 = pd.DataFrame(p3_rows)[["label", "mean_ll", "accuracy", "severity"]]

    print("\n" + "="*90)
    print(f"PART 3: NEURAL CHECK (n_bins={n_states_p3}, TEST)")
    print("="*90)
    print(df3.to_string(index=False, float_format=lambda x: f"{x:.6f}"))
    print("="*90)

    for r in [r_default, r_balanced]:
        d = r["mean_ll"] - r_marginal_p3["mean_ll"]
        print(f"  {r['label']} vs marginal: delta = {d:+.6f} -> "
              f"{'BEATS marginal' if d > 0 else 'does NOT beat marginal'}")

SKIPPING Part 3: no conditional advantage found in Part 2.
Final summary will be based on Parts 1-2 only.


## Part 4 — Final Conclusion

In [9]:
# ====================================================================
# PART 4: FINAL CONCLUSION + SAVE
# ====================================================================
os.makedirs("results/diagnostics", exist_ok=True)
ts = datetime.now().strftime("%Y%m%d_%H%M%S")

df1.to_csv(f"results/diagnostics/part1_baselines_{ts}.csv", index=False)
df2.to_csv(f"results/diagnostics/part2_ablation_{ts}.csv", index=False)
df2_summary.to_csv(f"results/diagnostics/part2_summary_{ts}.csv", index=False)
if len(df3) > 0:
    df3.to_csv(f"results/diagnostics/part3_neural_{ts}.csv", index=False)

# ====================================================================
# FINAL SUMMARY
# ====================================================================
print("\n" + "="*90)
print("FINAL CONCLUSION")
print("="*90)

# 1. Conditional vs marginal
best_test = df2_summary[df2_summary["split"] == "test"].sort_values(
    "delta_LL", ascending=False).iloc[0]
bt_delta = best_test["delta_LL"]
bt_nb = int(best_test["n_bins"])

if bt_delta > 0:
    print(f"  [+] Conditional baseline BEATS marginal on TEST "
          f"(delta={bt_delta:+.6f} nats, n_bins={bt_nb})")
else:
    print(f"  [-] Conditional baseline does NOT beat marginal on TEST "
          f"(best delta={bt_delta:+.6f})")

for nb in sorted(df2_summary["n_bins"].unique()):
    d = df2_summary[(df2_summary["n_bins"]==nb) & (df2_summary["split"]=="test")]["delta_LL"].values[0]
    print(f"      n_bins={nb}: delta_LL={d:+.6f} {'[+]' if d > 0 else '[-]'}")

# 2. MI
print(f"\n  MI(X_t, Y_t) = {mi_real:.6f} nats ({mi_bits:.6f} bits)")
print(f"  Shuffled MI mean = {mi_shuffled.mean():.6f} nats")
ratio = mi_real / mi_shuffled.mean() if mi_shuffled.mean() > 0 else float('inf')
if ratio > 3:
    print(f"  [+] MI is {ratio:.1f}x noise floor -> statistically significant")
else:
    print(f"  [-] MI is only {ratio:.1f}x noise floor -> weak")

# 3. Neural model
if r_default is not None:
    dd = r_default["mean_ll"] - r_marginal_p3["mean_ll"]
    db = r_balanced["mean_ll"] - r_marginal_p3["mean_ll"]
    if dd > 0:
        print(f"\n  [+] Neural (default) BEATS marginal (delta={dd:+.6f})")
    else:
        print(f"\n  [-] Neural (default) does NOT beat marginal (delta={dd:+.6f})")
    if db > 0:
        print(f"  [+] Neural (balanced) BEATS marginal (delta={db:+.6f})")
    else:
        print(f"  [-] Neural (balanced) does NOT beat marginal (delta={db:+.6f})")
    if db > dd:
        print(f"  Balanced sampling helps: +{db - dd:.6f} nats")
    else:
        print(f"  Balanced sampling does NOT help: {db - dd:+.6f} nats")
else:
    print(f"\n  Neural model: SKIPPED (no conditional signal found in Part 2)")

# 4. Recommendation
print(f"\n{'='*90}")
print("RECOMMENDATION")
print("="*90)

neural_beats = (r_default is not None and
                max(r_default["mean_ll"], r_balanced["mean_ll"]) > r_marginal_p3["mean_ll"])
cond_beats = bt_delta > 0

if neural_beats:
    print("  The neural model BEATS marginal.")
    print("  RECOMMENDATION: PROCEED with further architecture improvements.")
    if r_balanced["mean_ll"] > r_default["mean_ll"]:
        print("  Use balanced sampling as default going forward.")
elif cond_beats and not neural_beats:
    print("  Conditional structure EXISTS (count-based baseline beats marginal),")
    print("  but the neural model fails to exploit it.")
    print("  RECOMMENDATION: Focus on improving the neural model (objective, capacity,")
    print("  regularization) rather than changing architecture or state design.")
else:
    print("  No conditional advantage found at any discretization tested.")
    print("  RECOMMENDATION: REVISE the research claim / state design.")
    print("  Consider: different state definitions, continuous targets, or")
    print("  multi-variate state spaces.")

print("="*90)
print(f"\nAll tables saved to results/diagnostics/ (timestamp: {ts})")


FINAL CONCLUSION
  [-] Conditional baseline does NOT beat marginal on TEST (best delta=-0.098540)
      n_bins=25: delta_LL=-0.098540 [-]
      n_bins=35: delta_LL=-0.149818 [-]
      n_bins=40: delta_LL=-0.148592 [-]
      n_bins=55: delta_LL=-0.198814 [-]

  MI(X_t, Y_t) = 0.659427 nats (0.951352 bits)
  Shuffled MI mean = 0.630298 nats
  [-] MI is only 1.0x noise floor -> weak

  Neural model: SKIPPED (no conditional signal found in Part 2)

RECOMMENDATION
  No conditional advantage found at any discretization tested.
  RECOMMENDATION: REVISE the research claim / state design.
  Consider: different state definitions, continuous targets, or
  multi-variate state spaces.

All tables saved to results/diagnostics/ (timestamp: 20260213_155510)
