# Multi-Step Evaluation Framework

Evaluates time-varying Markov transition model on **k-step distributional predictions**.

**Key idea:** The model produces time-varying transition matrices $A_t$ of shape $(n, n)$. For k-step prediction:

$$\pi_{t+k} = \pi_t \cdot A_t \cdot A_{t+1} \cdots A_{t+k-1}$$

We compare these predicted distributions against empirical observations using KL, JSD, TV, and severity metrics.

In [None]:
import json
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from scipy import stats

sns.set(style="whitegrid")
%matplotlib inline

# ============================================
# CONFIGURATION
# ============================================
K_VALUES = [1, 2, 3, 5, 10]         # Prediction horizons to evaluate
EXPERIMENT = "hard_labels"           # Change to compare different experiments
BOOTSTRAP_N = 1000                   # Bootstrap resamples for CIs
EPS = 1e-10                          # Smoothing for KL divergence

print(f"Evaluating experiment: {EXPERIMENT}")
print(f"K-step horizons: {K_VALUES}")

## 1. Load Data and Model Outputs

In [None]:
# ============================================
# LOAD DATASET
# ============================================
train_df  = pd.read_csv("dataset/train_diagnostic.csv")
labels_df = pd.read_csv("dataset/label_diagnostic.csv")

# Drop last row (forward return undefined) and Opinion column
train_df["Percent_change_forward"] = (
    train_df["Price"].shift(-1) / train_df["Price"] - 1
) * 100.0
train_df = train_df.iloc[:-1].copy()
labels_df = labels_df.iloc[:-1].copy()
train_df = train_df.drop(columns=["Opinion"], errors="ignore")

# States: 0-based
s_curr_all = (train_df["Backward_Bin"].values.astype(np.int64) - 1)
y_all      = (labels_df["Forward_Bin"].values.astype(np.int64) - 1)

n_samples = len(s_curr_all)
n_states = int(max(s_curr_all.max(), y_all.max()) + 1)

# Temporal split: 70 / 15 / 15 (must match training notebook)
train_end = int(0.7 * n_samples)
val_end   = int(0.85 * n_samples)

idx_train = np.arange(0, train_end)
idx_val   = np.arange(train_end, val_end)
idx_test  = np.arange(val_end, n_samples)

s_train = s_curr_all[idx_train]
s_test  = s_curr_all[idx_test]
y_train = y_all[idx_train]
y_test  = y_all[idx_test]

print(f"n_samples: {n_samples}, n_states: {n_states}")
print(f"Train: {len(idx_train)}, Val: {len(idx_val)}, Test: {len(idx_test)}")
print(f"Test indices: [{idx_test[0]}, {idx_test[-1]}]")

# ============================================
# LOAD TRANSITION MATRICES
# ============================================
results_dir = Path("results")

# Load primary experiment
A_all = torch.load(results_dir / EXPERIMENT / "models" / "A_all_model.pt",
                   map_location="cpu", weights_only=True)
A_all_np = A_all.numpy()

print(f"\nA_all shape: {A_all.shape}  (expected: ({n_samples}, {n_states}, {n_states}))")
print(f"Row sums check (sample): {A_all_np[0].sum(axis=1)[:5]}  (should be ~1.0)")

# Load all available experiments for comparison
all_experiments = {}
for exp_path in results_dir.iterdir():
    a_file = exp_path / "models" / "A_all_model.pt"
    if a_file.exists():
        all_experiments[exp_path.name] = torch.load(
            a_file, map_location="cpu", weights_only=True
        ).numpy()
        print(f"  Loaded: {exp_path.name} -> shape {all_experiments[exp_path.name].shape}")

print(f"\nLoaded {len(all_experiments)} experiments total")

## 2. k-Step Distribution Propagation

In [None]:
def predict_k_step_distribution(A_all, k_start, s_0, k, n_states):
    """
    Propagate distribution k steps forward using time-varying matrices.
    
    pi_{t+k} = pi_t * A_t * A_{t+1} * ... * A_{t+k-1}
    
    Args:
        A_all: (T, n, n) array of transition matrices
        k_start: starting time index
        s_0: initial state (integer)
        k: number of steps to propagate
        n_states: number of states
    
    Returns:
        pi: (n_states,) probability distribution over states at t+k
    """
    pi = np.zeros(n_states, dtype=np.float64)
    pi[s_0] = 1.0
    
    for step in range(k):
        t = k_start + step
        pi = pi @ A_all[t]  # (n,) @ (n, n) -> (n,)
    
    return pi


def compute_k_step_matrix(A_all, k_start, k):
    """
    Compute the k-step transition matrix: A_t * A_{t+1} * ... * A_{t+k-1}
    
    Returns: (n_states, n_states) matrix
    """
    M = A_all[k_start].copy().astype(np.float64)
    for step in range(1, k):
        t = k_start + step
        M = M @ A_all[t]
    return M


# Sanity check: k=1 should give same row as A_all[t]
t_check = idx_test[0]
s_check = s_test[0]
pi_1 = predict_k_step_distribution(A_all_np, t_check, s_check, k=1, n_states=n_states)
direct = A_all_np[t_check, s_check, :]
print(f"k=1 sanity check: max diff = {np.abs(pi_1 - direct).max():.2e}")
print(f"Row sum: {pi_1.sum():.6f}")

# k=2 check: matrix product rows sum to 1
M2 = compute_k_step_matrix(A_all_np, t_check, k=2)
print(f"k=2 matrix row sums: min={M2.sum(axis=1).min():.6f}, max={M2.sum(axis=1).max():.6f}")

## 3. Empirical k-Step Transition Distributions

In [None]:
def compute_empirical_k_step(s_curr_all, y_all, idx_set, k, n_states):
    """
    Compute empirical k-step transition counts from data.
    
    For each sample i in idx_set, the current state is s_curr_all[i]
    and the state k steps later is y_all[i + k - 1] (since y_all[i] is already 1 step ahead).
    
    For k=1: state at i -> y_all[i] (next state, which is already stored)
    For k=2: state at i -> y_all[i+1] (state 2 steps later)
    For k=j: state at i -> y_all[i+j-1]
    
    But we need to be careful: y_all[i] = state at time i+1.
    So state k steps after time i = s_curr_all[i + k] if it exists,
    but equivalently = y_all[i + k - 1].
    
    We use s_curr_all for the "arrival" state to be consistent.
    
    Returns:
        P_emp: (n_states, n_states) empirical transition matrix
        counts: (n_states,) number of samples per starting state
    """
    transition_counts = np.zeros((n_states, n_states), dtype=np.float64)
    state_counts = np.zeros(n_states, dtype=np.int64)
    
    for i in range(len(idx_set)):
        t = idx_set[i]
        s_start = s_curr_all[t]
        
        # The state k steps later
        # y_all[t] is the state at t+1, so state at t+k = y_all[t+k-1]
        target_idx = t + k - 1
        if target_idx >= len(y_all):
            continue
        
        # But we also need target_idx to be in test range for fair comparison
        # Actually, y_all[target_idx] gives us the forward bin at time target_idx,
        # which is the state at time target_idx + 1 = t + k.
        # For k=1: y_all[t] = state at t+1. Correct.
        # For k=2: y_all[t+1] = state at t+2. Correct.
        s_end = y_all[target_idx]
        
        transition_counts[s_start, s_end] += 1
        state_counts[s_start] += 1
    
    # Normalize rows
    P_emp = np.zeros_like(transition_counts)
    for s in range(n_states):
        if state_counts[s] > 0:
            P_emp[s] = transition_counts[s] / state_counts[s]
    
    return P_emp, state_counts


# Compute empirical matrices for all k values
empirical_matrices = {}
empirical_counts = {}

for k in K_VALUES:
    P_emp, counts = compute_empirical_k_step(s_curr_all, y_all, idx_test, k, n_states)
    empirical_matrices[k] = P_emp
    empirical_counts[k] = counts
    
    valid_states = (counts >= 5).sum()
    total_samples = counts.sum()
    print(f"k={k:2d}: {total_samples} valid transitions, "
          f"{valid_states}/{n_states} states with >=5 samples, "
          f"sparse states: {(counts > 0).sum() - valid_states}")

## 4. Distributional Metrics

In [None]:
def kl_divergence(p, q, eps=EPS):
    """KL(p || q) with smoothing."""
    p = np.asarray(p, dtype=np.float64) + eps
    q = np.asarray(q, dtype=np.float64) + eps
    p = p / p.sum()
    q = q / q.sum()
    return np.sum(p * np.log(p / q))


def jensen_shannon_divergence(p, q, eps=EPS):
    """JSD(p, q) = 0.5 * KL(p||m) + 0.5 * KL(q||m), m = (p+q)/2."""
    p = np.asarray(p, dtype=np.float64) + eps
    q = np.asarray(q, dtype=np.float64) + eps
    p = p / p.sum()
    q = q / q.sum()
    m = 0.5 * (p + q)
    return 0.5 * np.sum(p * np.log(p / m)) + 0.5 * np.sum(q * np.log(q / m))


def total_variation(p, q):
    """TV(p, q) = 0.5 * ||p - q||_1."""
    p = np.asarray(p, dtype=np.float64)
    q = np.asarray(q, dtype=np.float64)
    p_sum = p.sum()
    q_sum = q.sum()
    if p_sum > 0:
        p = p / p_sum
    if q_sum > 0:
        q = q / q_sum
    return 0.5 * np.abs(p - q).sum()


def mean_bin_distance(p_model, q_empirical, n_states):
    """
    Expected absolute bin error under the empirical distribution,
    using the model's expected bin as the prediction.
    
    severity = sum_j |E_model[bin] - j| * q(j)
    """
    bins = np.arange(n_states, dtype=np.float64)
    p = np.asarray(p_model, dtype=np.float64)
    q = np.asarray(q_empirical, dtype=np.float64)
    
    p_sum = p.sum()
    q_sum = q.sum()
    if p_sum > 0:
        p = p / p_sum
    if q_sum > 0:
        q = q / q_sum
    
    expected_bin_model = (p * bins).sum()
    severity = np.sum(np.abs(expected_bin_model - bins) * q)
    return severity


def compute_all_metrics(p_model, q_empirical, n_states):
    """Compute all four metrics between model and empirical distributions."""
    return {
        "KL": kl_divergence(q_empirical, p_model),  # KL(empirical || model)
        "JSD": jensen_shannon_divergence(p_model, q_empirical),
        "TV": total_variation(p_model, q_empirical),
        "Severity": mean_bin_distance(p_model, q_empirical, n_states),
    }


# Quick test
p_test = np.zeros(n_states); p_test[27] = 1.0
q_test = np.zeros(n_states); q_test[27] = 0.5; q_test[28] = 0.5
metrics_test = compute_all_metrics(p_test, q_test, n_states)
print("Metric test (delta_27 vs 0.5*delta_27 + 0.5*delta_28):")
for name, val in metrics_test.items():
    print(f"  {name}: {val:.4f}")

## 5. Evaluate Multi-Step Performance

In [None]:
def evaluate_multistep(A_all_np, s_curr_all, y_all, idx_test, k_values, n_states,
                       empirical_matrices, empirical_counts, label="Model"):
    """
    Evaluate k-step predictions for a given set of transition matrices.
    
    Returns:
        results: dict of k -> {metric_name -> {per_state: array, weighted_avg: float, mean: float, std: float}}
    """
    results = {}
    
    for k in k_values:
        P_emp = empirical_matrices[k]
        counts = empirical_counts[k]
        
        # Compute model k-step matrix averaged over test times
        # For each test starting point, compute the k-step prediction
        # Then average the predicted distributions per starting state
        model_k_step = np.zeros((n_states, n_states), dtype=np.float64)
        model_counts = np.zeros(n_states, dtype=np.int64)
        
        for i in range(len(idx_test)):
            t = idx_test[i]
            s_0 = s_curr_all[t]
            
            # Check we have enough future matrices
            if t + k > len(A_all_np):
                continue
            # Check the target exists in y_all
            if t + k - 1 >= len(y_all):
                continue
            
            pi_k = predict_k_step_distribution(A_all_np, t, s_0, k, n_states)
            model_k_step[s_0] += pi_k
            model_counts[s_0] += 1
        
        # Normalize to get average predicted distribution per state
        for s in range(n_states):
            if model_counts[s] > 0:
                model_k_step[s] /= model_counts[s]
        
        # Compute per-state metrics (only for states with sufficient data)
        metric_names = ["KL", "JSD", "TV", "Severity"]
        per_state = {m: np.full(n_states, np.nan) for m in metric_names}
        valid_mask = np.zeros(n_states, dtype=bool)
        
        for s in range(n_states):
            if counts[s] >= 5 and model_counts[s] > 0:
                valid_mask[s] = True
                m = compute_all_metrics(model_k_step[s], P_emp[s], n_states)
                for name in metric_names:
                    per_state[name][s] = m[name]
        
        # Aggregate
        total_weight = counts[valid_mask].sum()
        weights = counts[valid_mask].astype(np.float64) / total_weight if total_weight > 0 else None
        
        k_results = {}
        for name in metric_names:
            vals = per_state[name][valid_mask]
            k_results[name] = {
                "per_state": per_state[name],
                "weighted_avg": float(np.average(vals, weights=weights)) if weights is not None and len(vals) > 0 else np.nan,
                "mean": float(np.nanmean(vals)) if len(vals) > 0 else np.nan,
                "std": float(np.nanstd(vals)) if len(vals) > 0 else np.nan,
            }
        
        k_results["n_valid_states"] = int(valid_mask.sum())
        k_results["model_matrix"] = model_k_step
        results[k] = k_results
    
    return results


# Evaluate the primary experiment
results_primary = evaluate_multistep(
    A_all_np, s_curr_all, y_all, idx_test, K_VALUES, n_states,
    empirical_matrices, empirical_counts, label=EXPERIMENT
)

# Print summary
print(f"\n{'='*80}")
print(f"MULTI-STEP EVALUATION: {EXPERIMENT}")
print(f"{'='*80}")
print(f"{'k':>3} | {'KL':>8} | {'JSD':>8} | {'TV':>8} | {'Severity':>10} | {'Valid States':>12}")
print(f"{'-'*3}-+-{'-'*8}-+-{'-'*8}-+-{'-'*8}-+-{'-'*10}-+-{'-'*12}")
for k in K_VALUES:
    r = results_primary[k]
    print(f"{k:3d} | {r['KL']['weighted_avg']:8.4f} | {r['JSD']['weighted_avg']:8.4f} | "
          f"{r['TV']['weighted_avg']:8.4f} | {r['Severity']['weighted_avg']:10.4f} | {r['n_valid_states']:12d}")
print(f"{'='*80}")

## 6. Baseline Comparisons

In [None]:
# ============================================
# BASELINE 1: Stationary Empirical Matrix^k
# ============================================
# Compute empirical 1-step transition matrix from TRAINING data only
A_emp_train = np.zeros((n_states, n_states), dtype=np.float64)
for t in idx_train:
    s_from = s_curr_all[t]
    s_to = y_all[t]
    A_emp_train[s_from, s_to] += 1

# Normalize rows (with smoothing for empty rows)
row_sums = A_emp_train.sum(axis=1, keepdims=True)
A_emp_train = np.where(row_sums > 0, A_emp_train / row_sums, 1.0 / n_states)

# For k-step: raise to power k
def stationary_k_step(A_emp, k, n_samples_needed):
    """Create A_all-shaped array where every time step uses A_emp^k."""
    A_k = np.linalg.matrix_power(A_emp, 1)  # Start with A_emp
    # We'll build the product directly per-step to match the evaluate_multistep interface
    # Just repeat the same matrix for all time steps
    return np.tile(A_emp[np.newaxis, :, :], (n_samples_needed, 1, 1))

A_stationary = stationary_k_step(A_emp_train, 1, n_samples)

results_stationary = evaluate_multistep(
    A_stationary, s_curr_all, y_all, idx_test, K_VALUES, n_states,
    empirical_matrices, empirical_counts, label="Stationary"
)

# ============================================
# BASELINE 2: Marginal (ignore current state)
# ============================================
# Marginal distribution of states in training data
marginal_dist = np.bincount(y_train, minlength=n_states).astype(np.float64)
marginal_dist = marginal_dist / marginal_dist.sum()

# Create A_all where every row is the marginal distribution (state-independent)
A_marginal = np.tile(marginal_dist[np.newaxis, np.newaxis, :], (n_samples, n_states, 1))

results_marginal = evaluate_multistep(
    A_marginal, s_curr_all, y_all, idx_test, K_VALUES, n_states,
    empirical_matrices, empirical_counts, label="Marginal"
)

# ============================================
# EVALUATE ALL SAVED EXPERIMENTS
# ============================================
results_all_experiments = {}
for exp_name, A_exp in all_experiments.items():
    results_all_experiments[exp_name] = evaluate_multistep(
        A_exp, s_curr_all, y_all, idx_test, K_VALUES, n_states,
        empirical_matrices, empirical_counts, label=exp_name
    )

# Add baselines
results_all_experiments["Stationary (train)"] = results_stationary
results_all_experiments["Marginal"] = results_marginal

# ============================================
# COMPARISON TABLE
# ============================================
print(f"\n{'='*100}")
print("MULTI-STEP COMPARISON: ALL MODELS")
print(f"{'='*100}")

for metric_name in ["JSD", "TV", "Severity"]:
    print(f"\n--- {metric_name} (weighted avg) ---")
    header = f"{'Model':<30}"
    for k in K_VALUES:
        header += f" | k={k:>2}"
    print(header)
    print("-" * len(header))
    
    for model_name, model_results in results_all_experiments.items():
        row = f"{model_name:<30}"
        for k in K_VALUES:
            val = model_results[k][metric_name]["weighted_avg"]
            row += f" | {val:5.4f}"
        print(row)

print(f"\n{'='*100}")

## 7. Chapman-Kolmogorov Consistency

In [None]:
def chapman_kolmogorov_test(A_all_np, s_curr_all, y_all, idx_test, n_states, k_values=[2, 3, 5]):
    """
    Chapman-Kolmogorov consistency test.
    
    For each test time t and each k:
    - Model k-step matrix: A_t * A_{t+1} * ... * A_{t+k-1}
    - Empirical k-step: direct observed transitions starting at t
    
    Since we can't get a full empirical matrix at a single time t,
    we instead compare the model's k-step prediction for the actual
    starting state against the observed outcome.
    
    We measure:
    1. Frobenius norm between model 2-step matrix and product of model 1-step matrices (self-consistency)
    2. Log-likelihood of observed k-step transitions under model predictions
    """
    results = {}
    
    for k in k_values:
        log_likelihoods = []
        predicted_probs = []
        true_states = []
        
        for i in range(len(idx_test)):
            t = idx_test[i]
            s_0 = s_curr_all[t]
            
            # Check boundaries
            if t + k > len(A_all_np) or t + k - 1 >= len(y_all):
                continue
            
            # Model's k-step prediction
            pi_k = predict_k_step_distribution(A_all_np, t, s_0, k, n_states)
            
            # True state k steps later
            s_true = y_all[t + k - 1]
            
            # Log-likelihood of the true state under model prediction
            ll = np.log(pi_k[s_true] + EPS)
            log_likelihoods.append(ll)
            predicted_probs.append(pi_k)
            true_states.append(s_true)
        
        log_likelihoods = np.array(log_likelihoods)
        true_states = np.array(true_states)
        predicted_probs = np.array(predicted_probs)
        
        # Accuracy: most likely state == true state
        pred_argmax = predicted_probs.argmax(axis=1)
        accuracy = (pred_argmax == true_states).mean()
        
        # Expected bin error (severity)
        bins = np.arange(n_states, dtype=np.float64)
        expected_bins = (predicted_probs * bins[np.newaxis, :]).sum(axis=1)
        severity = np.abs(expected_bins - true_states.astype(np.float64)).mean()
        
        results[k] = {
            "mean_ll": float(log_likelihoods.mean()),
            "std_ll": float(log_likelihoods.std()),
            "accuracy": float(accuracy),
            "severity": float(severity),
            "n_samples": len(log_likelihoods),
            "log_likelihoods": log_likelihoods,
        }
    
    return results


# Self-consistency: compare model's multi-step product vs direct product
# (These should be identical by construction, but this verifies numerical stability)
print("=" * 60)
print("CHAPMAN-KOLMOGOROV CONSISTENCY")
print("=" * 60)

# Frobenius norm between A_t*A_{t+1} (2-step product) computed two ways
frobenius_deviations = {k: [] for k in [2, 3, 5]}

for i in range(0, min(len(idx_test) - 10, 200)):
    t = idx_test[i]
    for k in [2, 3, 5]:
        if t + k > len(A_all_np):
            continue
        
        # Method 1: Sequential matrix product
        M_seq = compute_k_step_matrix(A_all_np, t, k)
        
        # Method 2: For CK test, compare k-step model matrix against
        # empirical k-step transitions from test data around time t
        # Since empirical at single time t isn't available, we compare
        # the row for the actual starting state
        s_0 = s_curr_all[t]
        pi_model = M_seq[s_0]
        
        # Row sum deviation from 1
        row_sum_dev = np.abs(M_seq.sum(axis=1) - 1.0).max()
        frobenius_deviations[k].append(row_sum_dev)

print("\nNumerical stability (max row-sum deviation from 1.0):")
for k in [2, 3, 5]:
    devs = np.array(frobenius_deviations[k])
    print(f"  k={k}: mean={devs.mean():.2e}, max={devs.max():.2e}")

# Per-sample CK test: log-likelihood of true k-step outcomes
print("\n" + "-" * 60)
print("Per-sample k-step prediction quality:")
print("-" * 60)

ck_results_all = {}
for exp_name, A_exp in all_experiments.items():
    if exp_name in ["Stationary (train)", "Marginal"]:
        continue
    ck_results_all[exp_name] = chapman_kolmogorov_test(
        A_exp, s_curr_all, y_all, idx_test, n_states, k_values=[1, 2, 3, 5, 10]
    )

# Add baselines
ck_results_all["Stationary (train)"] = chapman_kolmogorov_test(
    A_stationary, s_curr_all, y_all, idx_test, n_states, k_values=[1, 2, 3, 5, 10]
)
ck_results_all["Marginal"] = chapman_kolmogorov_test(
    A_marginal, s_curr_all, y_all, idx_test, n_states, k_values=[1, 2, 3, 5, 10]
)

print(f"\n{'Model':<30} | {'k':>3} | {'Mean LL':>8} | {'Accuracy':>8} | {'Severity':>8} | {'N':>5}")
print("-" * 80)
for exp_name, ck_res in ck_results_all.items():
    for k in [1, 2, 3, 5, 10]:
        if k in ck_res:
            r = ck_res[k]
            print(f"{exp_name:<30} | {k:3d} | {r['mean_ll']:8.4f} | {r['accuracy']:8.4f} | "
                  f"{r['severity']:8.2f} | {r['n_samples']:5d}")

## 8. Visualization

In [None]:
# ============================================
# PLOT 1: Divergence vs k (all models)
# ============================================
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

for ax_idx, metric_name in enumerate(["JSD", "TV", "Severity"]):
    ax = axes[ax_idx]
    
    for model_name, model_results in results_all_experiments.items():
        means = [model_results[k][metric_name]["weighted_avg"] for k in K_VALUES]
        stds = [model_results[k][metric_name]["std"] for k in K_VALUES]
        
        style = "--" if model_name in ["Stationary (train)", "Marginal"] else "-"
        alpha = 0.6 if model_name in ["Stationary (train)", "Marginal"] else 0.9
        
        label = model_name.replace("_", " ").title()
        ax.plot(K_VALUES, means, style, marker="o", label=label, alpha=alpha, linewidth=2)
        ax.fill_between(K_VALUES,
                       [m - s for m, s in zip(means, stds)],
                       [m + s for m, s in zip(means, stds)],
                       alpha=0.15)
    
    ax.set_xlabel("Prediction Horizon k", fontsize=12)
    ax.set_ylabel(metric_name, fontsize=12)
    ax.set_title(f"{metric_name} vs Prediction Horizon", fontsize=13, fontweight="bold")
    ax.legend(fontsize=8, loc="best")
    ax.grid(alpha=0.3)
    ax.set_xticks(K_VALUES)

plt.tight_layout()
plt.savefig(f"results/multistep_divergence_vs_k.png", dpi=300, bbox_inches="tight")
plt.show()

# ============================================
# PLOT 2: Heatmaps - Model vs Empirical k-step matrices
# ============================================
k_plot = [1, 3, 5, 10]
fig, axes = plt.subplots(2, len(k_plot), figsize=(5 * len(k_plot), 10))

for col, k in enumerate(k_plot):
    if k not in results_primary:
        continue
    
    model_mat = results_primary[k]["model_matrix"]
    emp_mat = empirical_matrices[k]
    
    vmax = max(model_mat.max(), emp_mat.max())
    
    sns.heatmap(model_mat, ax=axes[0, col], cmap="viridis", vmin=0, vmax=vmax, cbar=True)
    axes[0, col].set_title(f"Model k={k}", fontsize=11, fontweight="bold")
    axes[0, col].set_xlabel("Next State")
    if col == 0:
        axes[0, col].set_ylabel("Current State")
    
    sns.heatmap(emp_mat, ax=axes[1, col], cmap="viridis", vmin=0, vmax=vmax, cbar=True)
    axes[1, col].set_title(f"Empirical k={k}", fontsize=11, fontweight="bold")
    axes[1, col].set_xlabel("Next State")
    if col == 0:
        axes[1, col].set_ylabel("Current State")

plt.suptitle(f"Model vs Empirical k-Step Transition Matrices ({EXPERIMENT})",
             fontsize=14, fontweight="bold", y=1.02)
plt.tight_layout()
plt.savefig(f"results/multistep_heatmaps.png", dpi=300, bbox_inches="tight")
plt.show()

# ============================================
# PLOT 3: Per-sample log-likelihood vs k
# ============================================
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Severity vs k
for exp_name, ck_res in ck_results_all.items():
    ks = sorted(ck_res.keys())
    sevs = [ck_res[k]["severity"] for k in ks]
    style = "--" if exp_name in ["Stationary (train)", "Marginal"] else "-"
    label = exp_name.replace("_", " ").title()
    axes[0].plot(ks, sevs, style, marker="o", label=label, linewidth=2)

axes[0].set_xlabel("Prediction Horizon k", fontsize=12)
axes[0].set_ylabel("Mean Severity (bins)", fontsize=12)
axes[0].set_title("Per-Sample Severity vs Horizon", fontsize=13, fontweight="bold")
axes[0].legend(fontsize=8)
axes[0].grid(alpha=0.3)

# Mean log-likelihood vs k
for exp_name, ck_res in ck_results_all.items():
    ks = sorted(ck_res.keys())
    lls = [ck_res[k]["mean_ll"] for k in ks]
    style = "--" if exp_name in ["Stationary (train)", "Marginal"] else "-"
    label = exp_name.replace("_", " ").title()
    axes[1].plot(ks, lls, style, marker="o", label=label, linewidth=2)

axes[1].set_xlabel("Prediction Horizon k", fontsize=12)
axes[1].set_ylabel("Mean Log-Likelihood", fontsize=12)
axes[1].set_title("Per-Sample Log-Likelihood vs Horizon", fontsize=13, fontweight="bold")
axes[1].legend(fontsize=8)
axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.savefig(f"results/multistep_persample_metrics.png", dpi=300, bbox_inches="tight")
plt.show()

# ============================================
# PLOT 4: Per-state JSD for the primary experiment
# ============================================
fig, ax = plt.subplots(figsize=(12, 5))

for k in K_VALUES:
    jsd_per_state = results_primary[k]["JSD"]["per_state"]
    valid = ~np.isnan(jsd_per_state)
    ax.plot(np.where(valid)[0], jsd_per_state[valid], "o-", label=f"k={k}", alpha=0.7, markersize=4)

ax.set_xlabel("State", fontsize=12)
ax.set_ylabel("Jensen-Shannon Divergence", fontsize=12)
ax.set_title(f"Per-State JSD by Prediction Horizon ({EXPERIMENT})", fontsize=13, fontweight="bold")
ax.legend()
ax.grid(alpha=0.3)

plt.tight_layout()
plt.savefig(f"results/multistep_perstate_jsd.png", dpi=300, bbox_inches="tight")
plt.show()

## 9. Statistical Validation

In [None]:
def bootstrap_ci(data, statistic=np.mean, n_boot=BOOTSTRAP_N, ci=0.95):
    """Compute bootstrap confidence interval."""
    boot_stats = []
    for _ in range(n_boot):
        sample = np.random.choice(data, size=len(data), replace=True)
        boot_stats.append(statistic(sample))
    boot_stats = np.sort(boot_stats)
    lower = boot_stats[int((1 - ci) / 2 * n_boot)]
    upper = boot_stats[int((1 + ci) / 2 * n_boot)]
    return lower, upper


np.random.seed(42)

print("=" * 90)
print("STATISTICAL VALIDATION")
print("=" * 90)

# For each k, compare models using per-sample log-likelihoods
for k in K_VALUES:
    print(f"\n--- k = {k} ---")
    print(f"{'Model':<30} | {'Mean LL':>8} | {'95% CI':>18} | {'Severity':>8} | {'95% CI':>18}")
    print("-" * 95)
    
    for exp_name, ck_res in ck_results_all.items():
        if k not in ck_res:
            continue
        r = ck_res[k]
        ll_data = r["log_likelihoods"]
        
        ll_ci = bootstrap_ci(ll_data)
        
        # Also bootstrap severity
        # Need to recompute per-sample severity
        # We already have mean severity; for bootstrap we need individual values
        # Use mean LL CI as primary
        print(f"{exp_name:<30} | {r['mean_ll']:8.4f} | [{ll_ci[0]:8.4f}, {ll_ci[1]:8.4f}] | "
              f"{r['severity']:8.2f} |                   ")

# ============================================
# PAIRED COMPARISON: Model vs Baselines
# ============================================
print(f"\n{'='*90}")
print("PAIRED WILCOXON TESTS (vs Stationary Baseline)")
print(f"{'='*90}")

baseline_name = "Stationary (train)"
if baseline_name in ck_results_all:
    for k in K_VALUES:
        if k not in ck_results_all[baseline_name]:
            continue
        
        baseline_ll = ck_results_all[baseline_name][k]["log_likelihoods"]
        
        print(f"\nk={k}:")
        for exp_name, ck_res in ck_results_all.items():
            if exp_name == baseline_name or k not in ck_res:
                continue
            
            model_ll = ck_res[k]["log_likelihoods"]
            
            # Ensure same length
            n_min = min(len(model_ll), len(baseline_ll))
            if n_min < 10:
                print(f"  {exp_name}: insufficient samples ({n_min})")
                continue
            
            diff = model_ll[:n_min] - baseline_ll[:n_min]
            
            try:
                stat, p_value = stats.wilcoxon(diff)
                mean_diff = diff.mean()
                direction = "better" if mean_diff > 0 else "worse"
                sig = "***" if p_value < 0.001 else "**" if p_value < 0.01 else "*" if p_value < 0.05 else "n.s."
                print(f"  {exp_name:<30} diff={mean_diff:+.4f} ({direction}) p={p_value:.4f} {sig}")
            except Exception as e:
                print(f"  {exp_name:<30} test failed: {e}")

print(f"\n{'='*90}")

## 10. Save Results

In [None]:
# ============================================
# SAVE ALL RESULTS
# ============================================
import os

# Prepare serializable results
def make_serializable(results):
    """Convert numpy arrays to lists for JSON serialization."""
    out = {}
    for k, v in results.items():
        if isinstance(v, dict):
            out[str(k)] = {}
            for kk, vv in v.items():
                if kk == "model_matrix":
                    continue  # Skip large matrices
                if isinstance(vv, dict):
                    out[str(k)][kk] = {}
                    for kkk, vvv in vv.items():
                        if isinstance(vvv, np.ndarray):
                            out[str(k)][kk][kkk] = vvv.tolist()
                        elif isinstance(vvv, (np.floating, np.integer)):
                            out[str(k)][kk][kkk] = float(vvv)
                        else:
                            out[str(k)][kk][kkk] = vvv
                elif isinstance(vv, (np.floating, np.integer)):
                    out[str(k)][kk] = float(vv)
                elif isinstance(vv, np.ndarray):
                    out[str(k)][kk] = vv.tolist()
                else:
                    out[str(k)][kk] = vv
        else:
            out[str(k)] = v
    return out


# Save per-experiment results
for exp_name in all_experiments:
    if exp_name in ["Stationary (train)", "Marginal"]:
        continue
    
    exp_dir = results_dir / exp_name / "metrics"
    os.makedirs(exp_dir, exist_ok=True)
    
    save_data = {
        "k_values": K_VALUES,
        "distributional_metrics": make_serializable(results_all_experiments[exp_name]),
    }
    
    if exp_name in ck_results_all:
        ck_save = {}
        for k, v in ck_results_all[exp_name].items():
            ck_save[str(k)] = {
                "mean_ll": v["mean_ll"],
                "std_ll": v["std_ll"],
                "accuracy": v["accuracy"],
                "severity": v["severity"],
                "n_samples": v["n_samples"],
            }
        save_data["chapman_kolmogorov"] = ck_save
    
    with open(exp_dir / "multistep_evaluation.json", "w") as f:
        json.dump(save_data, f, indent=2)
    print(f"Saved: {exp_dir / 'multistep_evaluation.json'}")

# Save comparison summary
summary_rows = []
for exp_name, ck_res in ck_results_all.items():
    for k in sorted(ck_res.keys()):
        r = ck_res[k]
        summary_rows.append({
            "Model": exp_name,
            "k": k,
            "Mean_LL": r["mean_ll"],
            "Accuracy": r["accuracy"],
            "Severity": r["severity"],
            "N_Samples": r["n_samples"],
        })

summary_df = pd.DataFrame(summary_rows)
summary_df.to_csv("results/multistep_summary.csv", index=False)
print(f"\nSaved: results/multistep_summary.csv")
print(f"\nAll figures saved to results/multistep_*.png")
print("\nDone!")