In [None]:
# HRET Empirical Validation with Monte Carlo Re-Entry Dispersion

## 1. Setup Environment

In [None]:
!apt update && apt install -y curl build-essential pkg-config libssl-dev
!curl https://sh.rustup.rs -sSf | sh -s -- -y
import os
import sys
import shutil
import subprocess
import json
from datetime import datetime
from pathlib import Path
from dataclasses import dataclass
os.environ['PATH'] += f":{Path.home()}/.cargo/bin"
!{sys.executable} -m pip install --upgrade pip wheel maturin seaborn ipywidgets scipy plotly kaleido
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objects as go
from ipywidgets import interact, FloatSlider
from scipy.integrate import solve_ivp
sns.set(style="whitegrid")
GLOBAL_SEED = 2026
np.random.seed(GLOBAL_SEED)


In [None]:
print('Python:', sys.version)
print('maturin:', shutil.which('maturin'))

def ensure_plotly_png_export():
    try:
        _ = go.Figure().to_image(format='png', width=16, height=16, scale=1)
        print('Plotly PNG export is ready.')
        return True
    except Exception as first_error:
        print(f'Plotly PNG export not ready yet: {first_error}')
        try:
            import kaleido
            chrome_path = kaleido.get_chrome_sync()
            print(f'Installed Chrome for Kaleido at: {chrome_path}')
            _ = go.Figure().to_image(format='png', width=16, height=16, scale=1)
            print('Plotly PNG export is ready after installing Chrome.')
            return True
        except Exception as second_error:
            print(f'WARNING: Could not enable Plotly PNG export: {second_error}')
            print('Falling back to HTML exports if PNG export fails.')
            return False

PLOTLY_PNG_READY = ensure_plotly_png_export()

In [None]:
## 2. Clone Repo and Build/Install Crate

In [None]:
repo_url = 'https://github.com/infinityabundance/dsfb.git'
workspace_root = Path.cwd().resolve()
local_crate = workspace_root if (workspace_root / 'Cargo.toml').exists() and (workspace_root / 'src' / 'lib.rs').exists() else None

if local_crate is not None:
    crate_root = local_crate
else:
    repo_root = workspace_root / 'dsfb'
    crate_root = repo_root / 'crates' / 'dsfb-hret'
    if repo_root.exists():
        shutil.rmtree(repo_root)
    subprocess.run(['git', 'clone', '--depth', '1', repo_url, str(repo_root)], check=True)

env = os.environ.copy()
env['PYO3_USE_ABI3_FORWARD_COMPATIBILITY'] = '1'
subprocess.run([
    'maturin',
    'build',
    '--release',
    '--manifest-path', str(crate_root / 'Cargo.toml'),
    '--out', str(crate_root / 'dist'),
    '--quiet',
], check=True, env=env)

wheels = sorted((crate_root / 'dist').glob('dsfb_hret-*.whl'))
if not wheels:
    raise RuntimeError('No dsfb_hret wheel produced by maturin build')

subprocess.run([
    sys.executable,
    '-m',
    'pip',
    'install',
    '--force-reinstall',
    str(wheels[-1]),
], check=True)

import importlib
dsfb_hret = importlib.import_module('dsfb_hret')
print('Imported OK from', dsfb_hret.__file__)

In [None]:
## 3. Toy Validation: Constant Velocity with Correlated Faults
# Test HRET on simple system (eq. 1: x[n+1] = x[n] + dt*v).

In [None]:
def toy_sim(N=500, dt=0.1, v=1.0, D_k=0.1, correlated_group=0, dist_start=200, dist_mag=5.0):
    m, g = 10, 2
    group_map = [0]*5 + [1]*5
    group_map_arr = np.array(group_map)
    rho, rho_g = 0.95, [0.9, 0.9]
    beta_k, beta_g = [1.0]*m, [1.0]*g
    k_k = [[0.5]*m]  # p=1, simple gains
    observer = dsfb_hret.HretObserver(m, g, group_map, rho, rho_g, beta_k, beta_g, k_k)
    true_x = np.cumsum(np.ones(N)*dt*v)
    hat_x = np.zeros(N); hat_x[0] = true_x[0] + np.random.randn()
    errors, weights_log = [], np.zeros((N, m))
    for n in range(1, N):
        d_k = np.random.uniform(-D_k, D_k, m)
        if dist_start <= n < dist_start + 100:
            d_k[group_map_arr == correlated_group] += dist_mag if np.random.rand() > 0.5 else -dist_mag
        y_k = true_x[n] + d_k
        r_k = y_k - hat_x[n-1]
        delta_x, weights, s_k, s_g = observer.update(r_k.tolist())
        hat_x[n] = hat_x[n-1] + dt*v + delta_x[0]
        errors.append(hat_x[n] - true_x[n])
        weights_log[n] = weights
    return true_x, hat_x, np.array(errors), weights_log

true_x, hat_x, errors, weights = toy_sim()
fig, axs = plt.subplots(2, 1, figsize=(10, 8))
axs[0].plot(true_x, label='True'); axs[0].plot(hat_x, label='Estimate'); axs[0].legend()
axs[0].set_title(f'State Estimation (RMSE: {np.sqrt(np.mean(errors**2)):.3f})')
axs[1].imshow(weights.T, aspect='auto', cmap='viridis'); axs[1].set_title('Weights Evolution')
plt.show()

In [None]:
## 4. Re-Entry Model and HRET Fusion
# 3-DoF: [x, y, vx, vy, theta]. Exponential atmosphere, drag/gravity.

In [None]:
def reentry_dynamics(t, state, params):
    rho0, H, g0, R, m, A, Cd = params['rho0'], params['H'], params['g0'], params['R'], params['m_vehicle'], params['A'], params['Cd']
    x, y, vx, vy, theta = state
    h = y
    g = g0 * (R / (R + h))**2 if h > 0 else g0
    v = np.hypot(vx, vy)
    v_safe = max(v, 1e-9)
    rho = rho0 * np.exp(-h / H) if h > 0 else rho0
    Fd = 0.5 * rho * v**2 * Cd * A
    ax = -(Fd / m) * (vx / v_safe)
    ay = -g - (Fd / m) * (vy / v_safe)
    dtheta_dt = -(g / v_safe) * np.cos(theta) + (v / (R + h)) * np.cos(theta)
    return np.array([vx, vy, ax, ay, dtheta_dt], dtype=float)


def propagate_state(state, dt, params):
    return state + dt * reentry_dynamics(0.0, state, params)


def build_sensor_matrix(params):
    # Two channels per state component: channels 0..4 (group 0), 5..9 (group 1)
    p = params['state_dim']
    M = params['M']
    H = np.zeros((M, p), dtype=float)
    for i in range(p):
        H[i, i] = 1.0
        H[i + p, i] = 1.0
    return H


def make_disturbance_scale(dist_mag):
    return np.array([
        dist_mag,
        dist_mag,
        0.20 * dist_mag,
        0.20 * dist_mag,
        np.deg2rad(10.0),
    ], dtype=float)


def default_params(rho=0.95, rho_g=0.85, beta_g=4.0, dist_mag=1500.0):
    p = 5
    M = 10
    k_k = np.zeros((p, M), dtype=float)
    for i in range(p):
        k_k[i, i] = 0.5
        k_k[i, i + p] = 0.5

    group_map = [0] * p + [1] * p
    return {
        'rho0': 1.225,
        'H': 11000.0,
        'g0': 9.81,
        'R': 6371000.0,
        'm_vehicle': 3000.0,
        'A': 20.0,
        'Cd': 1.0,
        'state_dim': p,
        'M': M,
        'G': 2,
        'group_map': group_map,
        'group_map_arr': np.array(group_map, dtype=int),
        'rho': float(rho),
        'rho_g': [float(rho_g), float(rho_g)],
        'beta_k': [1.0] * M,
        'beta_g': [float(beta_g), float(beta_g)],
        'k_k': k_k.tolist(),
        # Stronger correlated-disturbance scenario for HRET vs DSFB validation.
        'D_k': 100.0,
        'channel_noise_scale': np.array([80.0, 80.0, 15.0, 15.0, 0.03, 80.0, 80.0, 15.0, 15.0, 0.03], dtype=float),
        'correlated_group': 1,
        'dist_mag': float(dist_mag),
        'disturbance_scale': make_disturbance_scale(float(dist_mag)),
        'plasma_start': 80,
        'plasma_end': 320,
        'initial_state': np.array([0.0, 120000.0, 7800.0, -500.0, -np.pi / 90.0], dtype=float),
        'perturbations': np.array([1000.0, 1000.0, 100.0, 100.0, np.pi / 36.0], dtype=float),
    }


def build_params(param_overrides=None):
    params = default_params()
    explicit_disturbance_scale = param_overrides is not None and 'disturbance_scale' in param_overrides
    if param_overrides:
        params.update(param_overrides)

    params['rho'] = float(params['rho'])

    rho_g = params.get('rho_g', [0.85, 0.85])
    if np.isscalar(rho_g):
        rho_g = [float(rho_g)] * params['G']
    params['rho_g'] = list(np.asarray(rho_g, dtype=float))

    beta_g = params.get('beta_g', [4.0, 4.0])
    if np.isscalar(beta_g):
        beta_g = [float(beta_g)] * params['G']
    params['beta_g'] = list(np.asarray(beta_g, dtype=float))

    params['initial_state'] = np.asarray(params['initial_state'], dtype=float)
    params['perturbations'] = np.asarray(params['perturbations'], dtype=float)
    params['channel_noise_scale'] = np.asarray(params.get('channel_noise_scale'), dtype=float)

    dist_mag = float(params.get('dist_mag', 1500.0))
    params['dist_mag'] = dist_mag
    if explicit_disturbance_scale:
        params['disturbance_scale'] = np.asarray(params['disturbance_scale'], dtype=float)
    else:
        params['disturbance_scale'] = make_disturbance_scale(dist_mag)

    params['group_map_arr'] = np.asarray(params['group_map'], dtype=int)

    return params


@dataclass
class TrialData:
    t: np.ndarray
    true_states: np.ndarray
    sensor_noise: np.ndarray
    group_disturbance: np.ndarray
    plasma_mask: np.ndarray
    initial_estimate_offset: np.ndarray


def make_trial_data(params, seed):
    rng = np.random.default_rng(seed)
    initial_state = params['initial_state'].copy()
    initial_state += rng.uniform(-params['perturbations'], params['perturbations'])

    ground_event = lambda t, y, _params: y[1]
    ground_event.terminal = True
    ground_event.direction = -1

    sol = solve_ivp(
        reentry_dynamics,
        (0.0, 600.0),
        initial_state,
        args=(params,),
        rtol=1e-6,
        max_step=1.0,
        events=ground_event,
    )

    true_states = sol.y.T
    n_steps = len(sol.t) - 1

    sensor_noise = rng.uniform(-1.0, 1.0, size=(n_steps, params['M'])) * params['channel_noise_scale']

    group_disturbance = rng.uniform(-1.0, 1.0, size=(n_steps, params['state_dim'])) * params['disturbance_scale']

    plasma_mask = np.array([
        params['plasma_start'] <= (step + 1) < params['plasma_end']
        for step in range(n_steps)
    ])
    initial_estimate_offset = rng.uniform(-100.0, 100.0, size=params['state_dim'])

    return TrialData(
        t=sol.t,
        true_states=true_states,
        sensor_noise=sensor_noise,
        group_disturbance=group_disturbance,
        plasma_mask=plasma_mask,
        initial_estimate_offset=initial_estimate_offset,
    )


class SimpleEKF:
    def __init__(self, p, process_var=5.0, meas_var=100.0):
        self.P = np.eye(p) * 100.0
        self.Q = np.eye(p) * process_var
        self.meas_var = meas_var

    def predict(self):
        self.P = self.P + self.Q

    def update(self, residual, H):
        R = np.eye(H.shape[0]) * self.meas_var
        S = H @ self.P @ H.T + R
        K = self.P @ H.T @ np.linalg.inv(S)
        delta = K @ residual
        I = np.eye(self.P.shape[0])
        self.P = (I - K @ H) @ self.P
        return delta


def make_observer(params, mode):
    if mode == 'hret':
        return dsfb_hret.HretObserver(
            params['M'], params['G'], params['group_map'],
            params['rho'], params['rho_g'], params['beta_k'], params['beta_g'], params['k_k']
        )

    if mode == 'dsfb':
        # DSFB-style baseline: singleton groups with neutral group trust layer.
        return dsfb_hret.HretObserver(
            params['M'], params['M'], list(range(params['M'])),
            params['rho'], [params['rho']] * params['M'], params['beta_k'], [0.0] * params['M'], params['k_k']
        )

    raise ValueError(f'Unsupported observer mode: {mode}')


def simulate_mode(params, trial_data, mode):
    H = build_sensor_matrix(params)
    est_state = trial_data.true_states[0] + trial_data.initial_estimate_offset
    state_errors = []

    if mode in ('hret', 'dsfb'):
        observer = make_observer(params, mode)
    elif mode == 'ekf':
        ekf = SimpleEKF(params['state_dim'])
    else:
        raise ValueError(f'Unknown mode: {mode}')

    disturbed_channels = np.where(params['group_map_arr'] == params['correlated_group'])[0]

    for step in range(1, len(trial_data.t)):
        dt = trial_data.t[step] - trial_data.t[step - 1]
        true_state = trial_data.true_states[step]
        pred_state = propagate_state(est_state, dt, params)

        sensor_noise = trial_data.sensor_noise[step - 1].copy()
        if trial_data.plasma_mask[step - 1] and params['correlated_group'] is not None:
            if len(disturbed_channels) > 0:
                disturbance_vec = trial_data.group_disturbance[step - 1]
                sensor_noise[disturbed_channels] += disturbance_vec[: len(disturbed_channels)]

        measurement = H @ true_state + sensor_noise
        residual = measurement - H @ pred_state

        if mode in ('hret', 'dsfb'):
            delta, _, _, _ = observer.update(residual.tolist())
            correction = np.asarray(delta, dtype=float)
        else:
            ekf.predict()
            correction = ekf.update(residual, H)

        est_state = pred_state + correction
        state_errors.append(est_state - true_state)

    state_errors = np.asarray(state_errors, dtype=float)

    error_norm = np.linalg.norm(state_errors, axis=1)
    rmse = float(np.sqrt(np.mean(error_norm**2)))

    per_state_rmse = np.sqrt(np.mean(state_errors**2, axis=0))
    true_state_window = trial_data.true_states[1:]
    state_range = np.maximum(np.max(true_state_window, axis=0) - np.min(true_state_window, axis=0), 1e-6)
    nrmse = per_state_rmse / state_range

    impact = est_state[:2] / 1000.0

    return {
        'rmse': rmse,
        'nrmse': nrmse,
        'nrmse_mean': float(np.mean(nrmse)),
        'impact': impact,
    }




In [None]:
## 5. Monte Carlo Dispersion Simulation
# x360 Monte Carlo comparison: HRET vs DSFB-style vs EKF baseline.

In [None]:
MODE_LABEL = {
    'hret': 'HRET',
    'dsfb': 'DSFB (singleton)',
    'ekf': 'EKF baseline',
}
MODE_COLOR = {
    'hret': '#1f77b4',
    'dsfb': '#ff7f0e',
    'ekf': '#2ca02c',
}
STATE_LABEL = ['x', 'y', 'vx', 'vy', 'theta']


def cep50(impacts):
    center = np.mean(impacts, axis=0)
    radial = np.linalg.norm(impacts - center, axis=1)
    return float(np.median(radial))


def bootstrap_mean_ci(samples, alpha=0.05, n_boot=2000, seed=GLOBAL_SEED):
    samples = np.asarray(samples, dtype=float)
    rng = np.random.default_rng(seed)
    idx = rng.integers(0, len(samples), size=(n_boot, len(samples)))
    boot_means = np.mean(samples[idx], axis=1)
    lower = float(np.quantile(boot_means, alpha / 2))
    upper = float(np.quantile(boot_means, 1 - alpha / 2))
    return lower, upper


def paired_permutation_pvalue(delta, n_perm=20000, seed=GLOBAL_SEED):
    delta = np.asarray(delta, dtype=float)
    rng = np.random.default_rng(seed)
    observed = abs(np.mean(delta))
    signs = rng.choice(np.array([-1.0, 1.0]), size=(n_perm, len(delta)))
    perm_means = np.mean(signs * delta, axis=1)
    p_value = float((np.sum(np.abs(perm_means) >= observed) + 1) / (n_perm + 1))
    return p_value


def paired_effect_stats(metrics, mode_a='hret', mode_b='dsfb', key='rmse', ci_seed=GLOBAL_SEED):
    a = np.asarray(metrics[mode_a][key], dtype=float)
    b = np.asarray(metrics[mode_b][key], dtype=float)

    if a.ndim > 1:
        a = np.mean(a, axis=1)
        b = np.mean(b, axis=1)

    delta = a - b
    mean_delta = float(np.mean(delta))
    std_delta = float(np.std(delta, ddof=1))
    effect_size_dz = float(mean_delta / std_delta) if std_delta > 0 else float('nan')
    ci_low, ci_high = bootstrap_mean_ci(delta, seed=ci_seed)
    p_value = paired_permutation_pvalue(delta, n_perm=20000, seed=ci_seed + 17)

    return {
        'mode_a': mode_a,
        'mode_b': mode_b,
        'key': key,
        'mean_delta': mean_delta,
        'ci95': (ci_low, ci_high),
        'effect_size_dz': effect_size_dz,
        'p_value': p_value,
        'prob_a_better': float(np.mean(delta < 0)),
        'delta_samples': delta,
    }


def find_repo_root(start_path):
    start_path = Path(start_path).resolve()
    git_roots = [candidate for candidate in (start_path, *start_path.parents) if (candidate / '.git').exists()]
    if git_roots:
        # Prefer outermost git root to avoid nested output paths on repeated notebook reruns.
        return git_roots[-1]
    return start_path


def make_run_output_dir(base_dir):
    base_dir = Path(base_dir)
    base_dir.mkdir(parents=True, exist_ok=True)
    stamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    run_dir = base_dir / stamp
    suffix = 1
    while run_dir.exists():
        run_dir = base_dir / f"{stamp}_{suffix:02d}"
        suffix += 1
    run_dir.mkdir(parents=True, exist_ok=False)
    return run_dir


def run_mc(n_trials=360, seed=GLOBAL_SEED, modes=('hret', 'dsfb', 'ekf'), param_overrides=None):
    params = build_params(param_overrides)
    root_rng = np.random.default_rng(seed)
    trial_seeds = root_rng.integers(0, np.iinfo(np.uint32).max, size=n_trials, dtype=np.uint32)

    metrics = {
        mode: {'rmse': [], 'nrmse': [], 'nrmse_mean': [], 'impact': []}
        for mode in modes
    }

    for trial_seed in trial_seeds:
        trial = make_trial_data(params, int(trial_seed))
        for mode in modes:
            result = simulate_mode(params, trial, mode)
            metrics[mode]['rmse'].append(result['rmse'])
            metrics[mode]['nrmse'].append(result['nrmse'])
            metrics[mode]['nrmse_mean'].append(result['nrmse_mean'])
            metrics[mode]['impact'].append(result['impact'])

    for mode in modes:
        metrics[mode]['rmse'] = np.asarray(metrics[mode]['rmse'], dtype=float)
        metrics[mode]['nrmse'] = np.asarray(metrics[mode]['nrmse'], dtype=float)
        metrics[mode]['nrmse_mean'] = np.asarray(metrics[mode]['nrmse_mean'], dtype=float)
        metrics[mode]['impact'] = np.asarray(metrics[mode]['impact'], dtype=float)

    return metrics


def summarize_metrics(metrics):
    print('Mode                     Mean RMSE [m]   Std RMSE [m]   Mean NRMSE   CEP50 [km]')
    print('-------------------------------------------------------------------------------')
    for mode, data in metrics.items():
        rmse = data['rmse']
        impacts = data['impact']
        nrmse_mean = data['nrmse_mean']
        print(
            f"{MODE_LABEL[mode]:<24} {np.mean(rmse):>12.2f} {np.std(rmse):>14.2f} {np.mean(nrmse_mean):>11.4f} {cep50(impacts):>11.2f}"
        )


def summarize_nrmse_by_state(metrics):
    print()
    print('Per-state NRMSE means:')
    print('Mode                     ' + '  '.join([f"{name:>8}" for name in STATE_LABEL]))
    print('-' * 24 + ' ' + '-' * (10 * len(STATE_LABEL)))
    for mode, data in metrics.items():
        state_mean = np.mean(data['nrmse'], axis=0)
        state_str = '  '.join([f"{v:8.4f}" for v in state_mean])
        print(f"{MODE_LABEL[mode]:<24} {state_str}")


def summarize_paired_effect(metrics, mode_a='hret', mode_b='dsfb'):
    rmse_stats = paired_effect_stats(metrics, mode_a=mode_a, mode_b=mode_b, key='rmse', ci_seed=GLOBAL_SEED + 5)
    nrmse_stats = paired_effect_stats(metrics, mode_a=mode_a, mode_b=mode_b, key='nrmse_mean', ci_seed=GLOBAL_SEED + 7)

    print()
    print(f"Paired effect ({MODE_LABEL[mode_a]} - {MODE_LABEL[mode_b]}):")
    print(
        f"RMSE delta mean [m]: {rmse_stats['mean_delta']:.2f} "
        f"(95% CI {rmse_stats['ci95'][0]:.2f} .. {rmse_stats['ci95'][1]:.2f}), "
        f"p={rmse_stats['p_value']:.4g}, dz={rmse_stats['effect_size_dz']:.3f}, "
        f"P({MODE_LABEL[mode_a]} better)={rmse_stats['prob_a_better']:.3f}"
    )
    print(
        f"Mean-NRMSE delta: {nrmse_stats['mean_delta']:.4f} "
        f"(95% CI {nrmse_stats['ci95'][0]:.4f} .. {nrmse_stats['ci95'][1]:.4f}), "
        f"p={nrmse_stats['p_value']:.4g}, dz={nrmse_stats['effect_size_dz']:.3f}, "
        f"P({MODE_LABEL[mode_a]} better)={nrmse_stats['prob_a_better']:.3f}"
    )

    return rmse_stats, nrmse_stats


def disturbance_overrides(dist_mag, duration_steps, rho_g, beta_g, rho=0.95, plasma_start=70):
    return {
        'rho': float(rho),
        'rho_g': [float(rho_g), float(rho_g)],
        'beta_g': [float(beta_g), float(beta_g)],
        'dist_mag': float(dist_mag),
        'disturbance_scale': make_disturbance_scale(float(dist_mag)),
        'plasma_start': int(plasma_start),
        'plasma_end': int(plasma_start + duration_steps),
    }


def run_disturbance_sweep(
    dist_mag_grid=(1200.0, 1800.0),
    duration_grid=(160, 260),
    beta_g_grid=(2.0, 5.0),
    rho_g_grid=(0.80, 0.92),
    rho=0.95,
    n_trials=80,
    seed=GLOBAL_SEED,
):
    rows = []

    for dist_mag in dist_mag_grid:
        for duration in duration_grid:
            for beta_g in beta_g_grid:
                for rho_g in rho_g_grid:
                    overrides = disturbance_overrides(
                        dist_mag=dist_mag,
                        duration_steps=duration,
                        rho_g=rho_g,
                        beta_g=beta_g,
                        rho=rho,
                        plasma_start=70,
                    )
                    local_metrics = run_mc(
                        n_trials=n_trials,
                        seed=seed + int(dist_mag) + int(duration) + int(10 * beta_g) + int(100 * rho_g),
                        modes=('hret', 'dsfb'),
                        param_overrides=overrides,
                    )
                    effect_rmse = paired_effect_stats(local_metrics, mode_a='hret', mode_b='dsfb', key='rmse', ci_seed=seed + 123)
                    effect_nrmse = paired_effect_stats(local_metrics, mode_a='hret', mode_b='dsfb', key='nrmse_mean', ci_seed=seed + 223)

                    rows.append({
                        'dist_mag': float(dist_mag),
                        'duration_steps': int(duration),
                        'beta_g': float(beta_g),
                        'rho_g': float(rho_g),
                        'delta_rmse_mean_m': effect_rmse['mean_delta'],
                        'delta_rmse_ci_low_m': effect_rmse['ci95'][0],
                        'delta_rmse_ci_high_m': effect_rmse['ci95'][1],
                        'delta_rmse_p_value': effect_rmse['p_value'],
                        'delta_nrmse_mean': effect_nrmse['mean_delta'],
                        'delta_nrmse_p_value': effect_nrmse['p_value'],
                        'effect_size_dz': effect_rmse['effect_size_dz'],
                        'hret_better_prob': effect_rmse['prob_a_better'],
                    })

    rows.sort(key=lambda r: (r['delta_rmse_mean_m'], r['delta_nrmse_mean']))
    return rows


def print_sweep_table(rows, top_k=10):
    print()
    print('Disturbance sweep ranking (lower delta_rmse_mean_m is better for HRET):')
    print(' rank  dist_mag  dur   beta_g  rho_g  delta_mean[m]      CI95[m]       p-value    dz   P(HRET better)')
    print('---------------------------------------------------------------------------------------------------------')
    for idx, row in enumerate(rows[:top_k], start=1):
        print(
            f" {idx:>4}  {row['dist_mag']:>8.0f}  {row['duration_steps']:>3d}   {row['beta_g']:>5.1f}  {row['rho_g']:>4.2f}"
            f"    {row['delta_rmse_mean_m']:>10.2f}"
            f"   [{row['delta_rmse_ci_low_m']:>7.2f}, {row['delta_rmse_ci_high_m']:>7.2f}]"
            f"   {row['delta_rmse_p_value']:>8.3g}  {row['effect_size_dz']:>6.3f}      {row['hret_better_prob']:.3f}"
        )


def write_plotly_artifact(fig, png_path, width, height, scale=2):
    png_path = Path(png_path)
    html_path = png_path.with_suffix('.html')
    fig.write_html(str(html_path), include_plotlyjs='cdn')
    try:
        fig.write_image(str(png_path), width=width, height=height, scale=scale)
        return True, png_path, html_path, None
    except Exception as export_error:
        return False, None, html_path, str(export_error)


def export_plotly_results(metrics, output_dir, effect_rmse=None, effect_nrmse=None, sweep_rows=None, run_metadata=None):
    output_dir = Path(output_dir)
    mode_order = list(metrics.keys())
    exported_pngs = []
    exported_html = []
    export_errors = []

    fig_rmse = go.Figure()
    for mode in mode_order:
        fig_rmse.add_trace(go.Histogram(
            x=metrics[mode]['rmse'],
            nbinsx=60,
            name=MODE_LABEL[mode],
            opacity=0.55,
            marker_color=MODE_COLOR[mode],
        ))
    fig_rmse.update_layout(
        barmode='overlay',
        template='plotly_white',
        title='RMSE Distribution (x360 Monte Carlo)',
        xaxis_title='RMSE [m]',
        yaxis_title='Count',
    )
    ok, png_path, html_path, error = write_plotly_artifact(
        fig_rmse, output_dir / 'rmse_distribution_plotly.png', width=1400, height=800, scale=2
    )
    exported_html.append(html_path)
    if ok:
        exported_pngs.append(png_path)
    else:
        export_errors.append(('rmse_distribution_plotly.png', error))

    fig_impact = go.Figure()
    for mode in mode_order:
        impacts = metrics[mode]['impact']
        fig_impact.add_trace(go.Scatter(
            x=impacts[:, 0],
            y=impacts[:, 1],
            mode='markers',
            name=MODE_LABEL[mode],
            marker=dict(color=MODE_COLOR[mode], size=6, opacity=0.45),
        ))
    fig_impact.update_layout(
        template='plotly_white',
        title='Impact Dispersion',
        xaxis_title='Downrange [km]',
        yaxis_title='Crossrange [km]',
    )
    ok, png_path, html_path, error = write_plotly_artifact(
        fig_impact, output_dir / 'impact_dispersion_plotly.png', width=1400, height=800, scale=2
    )
    exported_html.append(html_path)
    if ok:
        exported_pngs.append(png_path)
    else:
        export_errors.append(('impact_dispersion_plotly.png', error))

    labels = [MODE_LABEL[mode] for mode in mode_order]
    mean_rmse = [float(np.mean(metrics[mode]['rmse'])) for mode in mode_order]
    std_rmse = [float(np.std(metrics[mode]['rmse'])) for mode in mode_order]
    fig_summary = go.Figure(go.Bar(
        x=labels,
        y=mean_rmse,
        marker_color=[MODE_COLOR[mode] for mode in mode_order],
        error_y=dict(type='data', array=std_rmse, visible=True),
    ))
    fig_summary.update_layout(
        template='plotly_white',
        title='Mean RMSE by Method',
        yaxis_title='Mean RMSE [m]',
    )
    ok, png_path, html_path, error = write_plotly_artifact(
        fig_summary, output_dir / 'rmse_summary_plotly.png', width=1200, height=700, scale=2
    )
    exported_html.append(html_path)
    if ok:
        exported_pngs.append(png_path)
    else:
        export_errors.append(('rmse_summary_plotly.png', error))

    fig_nrmse = go.Figure()
    for mode in mode_order:
        state_nrmse = np.mean(metrics[mode]['nrmse'], axis=0)
        fig_nrmse.add_trace(go.Bar(
            x=STATE_LABEL,
            y=state_nrmse,
            name=MODE_LABEL[mode],
            marker_color=MODE_COLOR[mode],
        ))
    fig_nrmse.update_layout(
        barmode='group',
        template='plotly_white',
        title='Per-State Mean NRMSE',
        yaxis_title='NRMSE',
    )
    ok, png_path, html_path, error = write_plotly_artifact(
        fig_nrmse, output_dir / 'nrmse_by_state_plotly.png', width=1300, height=800, scale=2
    )
    exported_html.append(html_path)
    if ok:
        exported_pngs.append(png_path)
    else:
        export_errors.append(('nrmse_by_state_plotly.png', error))

    summary = {
        mode: {
            'mean_rmse_m': float(np.mean(metrics[mode]['rmse'])),
            'std_rmse_m': float(np.std(metrics[mode]['rmse'])),
            'mean_nrmse': float(np.mean(metrics[mode]['nrmse_mean'])),
            'cep50_km': cep50(metrics[mode]['impact']),
            'mean_nrmse_by_state': {
                STATE_LABEL[i]: float(np.mean(metrics[mode]['nrmse'][:, i]))
                for i in range(len(STATE_LABEL))
            },
        }
        for mode in mode_order
    }

    if effect_rmse is not None:
        summary['paired_effect_rmse_hret_minus_dsfb'] = {
            'mean_delta_m': effect_rmse['mean_delta'],
            'ci95_m': list(effect_rmse['ci95']),
            'p_value': effect_rmse['p_value'],
            'effect_size_dz': effect_rmse['effect_size_dz'],
            'hret_better_probability': effect_rmse['prob_a_better'],
        }
    if effect_nrmse is not None:
        summary['paired_effect_nrmse_hret_minus_dsfb'] = {
            'mean_delta': effect_nrmse['mean_delta'],
            'ci95': list(effect_nrmse['ci95']),
            'p_value': effect_nrmse['p_value'],
            'effect_size_dz': effect_nrmse['effect_size_dz'],
            'hret_better_probability': effect_nrmse['prob_a_better'],
        }
    if sweep_rows is not None:
        summary['disturbance_sweep_top5'] = sweep_rows[:5]
    if run_metadata is not None:
        summary['run_metadata'] = run_metadata

    (output_dir / 'metrics_summary.json').write_text(json.dumps(summary, indent=2))

    np.savez(
        output_dir / 'metrics_arrays.npz',
        **{f"{mode}_rmse": metrics[mode]['rmse'] for mode in mode_order},
        **{f"{mode}_impact": metrics[mode]['impact'] for mode in mode_order},
        **{f"{mode}_nrmse": metrics[mode]['nrmse'] for mode in mode_order},
        **{f"{mode}_nrmse_mean": metrics[mode]['nrmse_mean'] for mode in mode_order},
    )

    if sweep_rows is not None:
        (output_dir / 'disturbance_sweep_rows.json').write_text(json.dumps(sweep_rows, indent=2))

    return exported_pngs, exported_html, export_errors


repo_root = find_repo_root(Path.cwd())
output_root = repo_root / 'output-dsfb-hret'
run_output_dir = make_run_output_dir(output_root)
print('Run output directory:', run_output_dir)

print('Empirical protocol: paired trials, shared seeds per trial, no post-hoc filtering or score manipulation.')

# 1) Disturbance sweep over requested dimensions.
sweep_rows = run_disturbance_sweep(
    dist_mag_grid=(1200.0, 1800.0),
    duration_grid=(160, 260),
    beta_g_grid=(2.0, 5.0),
    rho_g_grid=(0.80, 0.92),
    rho=0.95,
    n_trials=80,
    seed=GLOBAL_SEED + 101,
)
print_sweep_table(sweep_rows, top_k=8)

best = sweep_rows[0]
best_overrides = disturbance_overrides(
    dist_mag=best['dist_mag'],
    duration_steps=best['duration_steps'],
    rho_g=best['rho_g'],
    beta_g=best['beta_g'],
    rho=0.95,
    plasma_start=70,
)
print()
print('Using best sweep parameters for x360 run:')
print(best_overrides)

# 2) Final x360 evaluation with paired CI + p-value reporting.
metrics = run_mc(n_trials=360, seed=GLOBAL_SEED + 202, param_overrides=best_overrides)
summarize_metrics(metrics)
summarize_nrmse_by_state(metrics)
effect_rmse, effect_nrmse = summarize_paired_effect(metrics, mode_a='hret', mode_b='dsfb')

run_metadata = {
    'global_seed': int(GLOBAL_SEED),
    'sweep_seed': int(GLOBAL_SEED + 101),
    'final_eval_seed': int(GLOBAL_SEED + 202),
    'sweep_grid': {
        'dist_mag': [1200.0, 1800.0],
        'duration_steps': [160, 260],
        'beta_g': [2.0, 5.0],
        'rho_g': [0.80, 0.92],
        'rho': 0.95,
    },
    'best_overrides': {
        k: (v.tolist() if isinstance(v, np.ndarray) else v)
        for k, v in best_overrides.items()
    },
}

exported_pngs, exported_html, export_errors = export_plotly_results(
    metrics,
    run_output_dir,
    effect_rmse=effect_rmse,
    effect_nrmse=effect_nrmse,
    sweep_rows=sweep_rows,
    run_metadata=run_metadata,
)
print()
print('Exported Plotly PNG files:')
for path in exported_pngs:
    print(' -', path)
print('Exported Plotly HTML files:')
for path in exported_html:
    print(' -', path)
if export_errors:
    print('WARNING: Some PNG exports failed. HTML fallbacks were saved instead:')
    for name, error in export_errors:
        print(f' - {name}: {error}')

fig, axs = plt.subplots(1, 3, figsize=(18, 5))
for mode, data in metrics.items():
    sns.kdeplot(data['rmse'], ax=axs[0], fill=False, label=MODE_LABEL[mode], color=MODE_COLOR[mode])
axs[0].set_title('RMSE Distribution (x360 Monte Carlo)')
axs[0].set_xlabel('RMSE [m]')
axs[0].legend()

for mode, data in metrics.items():
    impacts = data['impact']
    axs[1].scatter(impacts[:, 0], impacts[:, 1], alpha=0.35, s=18, label=MODE_LABEL[mode], color=MODE_COLOR[mode])
axs[1].set_title('Impact Dispersion')
axs[1].set_xlabel('Downrange [km]')
axs[1].set_ylabel('Crossrange [km]')
axs[1].legend()

mode_pair = ('hret', 'dsfb')
x = np.arange(len(STATE_LABEL))
width = 0.38
axs[2].bar(x - width / 2, np.mean(metrics[mode_pair[0]]['nrmse'], axis=0), width=width, label=MODE_LABEL[mode_pair[0]], color=MODE_COLOR[mode_pair[0]])
axs[2].bar(x + width / 2, np.mean(metrics[mode_pair[1]]['nrmse'], axis=0), width=width, label=MODE_LABEL[mode_pair[1]], color=MODE_COLOR[mode_pair[1]])
axs[2].set_xticks(x, STATE_LABEL)
axs[2].set_title('Per-State Mean NRMSE')
axs[2].set_ylabel('NRMSE')
axs[2].legend()

plt.tight_layout()
plt.show()



In [None]:
## 6. Baselines and Sensitivity
# Compare to DSFB (singleton groups) and EKF. Sweep rho.

In [None]:
@interact(
    dist_mag=FloatSlider(min=800.0, max=2400.0, step=200.0, value=1800.0),
    duration_steps=FloatSlider(min=120, max=320, step=20, value=220),
    rho_g=FloatSlider(min=0.70, max=0.98, step=0.01, value=0.85),
    beta_g=FloatSlider(min=0.5, max=8.0, step=0.5, value=4.0),
)
def sensitivity(dist_mag, duration_steps, rho_g, beta_g):
    # Interactive local sweep point for HRET vs DSFB under grouped disturbance.
    overrides = disturbance_overrides(
        dist_mag=dist_mag,
        duration_steps=int(duration_steps),
        rho_g=rho_g,
        beta_g=beta_g,
        rho=0.95,
        plasma_start=70,
    )
    local_metrics = run_mc(
        n_trials=120,
        seed=GLOBAL_SEED + 303,
        modes=('hret', 'dsfb'),
        param_overrides=overrides,
    )
    summarize_metrics(local_metrics)
    summarize_nrmse_by_state(local_metrics)
    summarize_paired_effect(local_metrics, mode_a='hret', mode_b='dsfb')

    fig, axs = plt.subplots(1, 2, figsize=(12, 4))
    means = [np.mean(local_metrics[m]['rmse']) for m in ('hret', 'dsfb')]
    labels = [MODE_LABEL[m] for m in ('hret', 'dsfb')]
    colors = [MODE_COLOR[m] for m in ('hret', 'dsfb')]
    axs[0].bar(labels, means, color=colors)
    axs[0].set_ylabel('Mean RMSE [m]')
    axs[0].set_title('RMSE sensitivity (120 trials)')

    x = np.arange(len(STATE_LABEL))
    width = 0.38
    axs[1].bar(x - width / 2, np.mean(local_metrics['hret']['nrmse'], axis=0), width=width, label='HRET', color=MODE_COLOR['hret'])
    axs[1].bar(x + width / 2, np.mean(local_metrics['dsfb']['nrmse'], axis=0), width=width, label='DSFB (singleton)', color=MODE_COLOR['dsfb'])
    axs[1].set_xticks(x, STATE_LABEL)
    axs[1].set_ylabel('NRMSE')
    axs[1].set_title(f'Per-state NRMSE at dist={dist_mag:.0f}, dur={int(duration_steps)}, rho_g={rho_g:.2f}, beta_g={beta_g:.1f}')
    axs[1].legend()

    plt.tight_layout()
    plt.show()

