In [None]:
# ============================================================================
# SIMULATION VALIDATION NOTEBOOK - BY ODOR TYPE
# Purpose: Verify simulator matches real environment and produces realistic behavior
#          Separately for each odor type (Methyl Butyrate and Alpha-pinene)
# ============================================================================

# %% Section 1: Setup and Imports

import os

os.environ["JAX_PLATFORMS"] = "cpu"  # Force CPU to avoid Metal issues

import json
from pathlib import Path

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# Import simulator
from aind_behavior_vrforaging_analysis.sbi_ddm_analysis.simulator import PatchForagingDDM_JAX, create_prior
from jax import random

# Set plotting style
plt.style.use("seaborn-v0_8-darkgrid")
%matplotlib inline

print(f"JAX backend: {jax.default_backend()}")
print(f"JAX version: {jax.__version__}")

# ============================================================================
# Load Real Data Statistics BY ODOR TYPE
# ============================================================================

# Path to processed data
base_path = Path("/Users/laura.driscoll/Documents/data/VR foraging/vr_foraging_data")

# Load batch processing results
results_df = pd.read_csv(base_path / "batch_processing_by_odor_results.csv")
successful_sessions = results_df[results_df["status"] == "success"]

print(f"\n{'=' * 70}")
print("DATA SUMMARY")
print(f"{'=' * 70}")
print(f"Successfully processed sessions: {len(successful_sessions)}")

# Define odor types to analyze
odor_types = ["Methyl_Butyrate", "Alpha_pinene"]
odor_display_names = {"Methyl_Butyrate": "Methyl Butyrate", "Alpha_pinene": "Alpha-pinene"}

# Collect statistics separately for each odor type
data_by_odor = {}

for odor_type in odor_types:
    print(f"\n{'=' * 70}")
    print(f"LOADING DATA FOR: {odor_display_names[odor_type]}")
    print(f"{'=' * 70}")

    all_within_patch_distances = []
    all_exit_positions = []
    all_stop_rates = []
    all_reward_rates = []
    all_sites_per_patch = []
    all_rewards_per_patch = []

    n_sessions_with_odor = 0

    for idx, row in successful_sessions.iterrows():
        session_dir = Path(row["session_dir"])
        odor_dir = session_dir / "window_data_by_odor" / odor_type

        # Check if this odor exists in this session
        if not odor_dir.exists():
            continue

        n_sessions_with_odor += 1

        # Load metadata
        with open(odor_dir / "metadata.json", "r") as f:
            metadata = json.load(f)

        # Load all windows for this odor
        n_windows = metadata["n_windows"]
        session_windows = []
        for i in range(n_windows):
            window = np.load(odor_dir / f"window_{i:03d}.npy")
            session_windows.append(window)

        session_windows = np.array(session_windows)

        # Extract data
        positions = session_windows[:, :, 0]
        rewards = session_windows[:, :, 1]
        stopped = session_windows[:, :, 2]

        # Within-patch inter-site distances
        for w in range(session_windows.shape[0]):
            for i in range(session_windows.shape[1] - 1):
                if stopped[w, i] == 1 and stopped[w, i + 1] == 1:
                    distance = positions[w, i + 1] - positions[w, i]
                    if distance > 0:
                        all_within_patch_distances.append(distance)

        # Exit positions
        exit_mask = stopped == 0
        all_exit_positions.extend(positions[exit_mask].flatten().tolist())

        # Behavioral stats
        all_stop_rates.append(stopped.mean())
        all_reward_rates.append(rewards.mean())

        # Sites per patch
        for w in range(session_windows.shape[0]):
            current_patch_length = 0
            cumul_reward = 0
            for i in range(session_windows.shape[1]):
                if stopped[w, i] == 1:
                    current_patch_length += 1
                    cumul_reward += rewards[w, i]
                else:  # Left patch
                    if current_patch_length > 0:
                        all_sites_per_patch.append(current_patch_length)
                        all_rewards_per_patch.append(cumul_reward)
                    current_patch_length = 0
                    cumul_reward = 0

    # Convert to arrays
    all_within_patch_distances = np.array(all_within_patch_distances)
    all_exit_positions = np.array(all_exit_positions)
    all_stop_rates = np.array(all_stop_rates)
    all_reward_rates = np.array(all_reward_rates)
    all_sites_per_patch = np.array(all_sites_per_patch)
    all_rewards_per_patch = np.array(all_rewards_per_patch)

    # Store statistics
    data_stats = {
        "odor_type": odor_type,
        "odor_display_name": odor_display_names[odor_type],
        "n_sessions": n_sessions_with_odor,
        "interval_mean": all_within_patch_distances.mean(),
        "interval_std": all_within_patch_distances.std(),
        "interval_min": all_within_patch_distances.min(),
        "interval_max": all_within_patch_distances.max(),
        "exit_position_mean": all_exit_positions.mean(),
        "exit_position_std": all_exit_positions.std(),
        "stop_rate": all_stop_rates.mean(),
        "reward_rate": all_reward_rates.mean(),
        "sites_per_patch": all_sites_per_patch.mean(),
        "rewards_per_patch": all_rewards_per_patch.mean(),
        "all_within_patch_distances": all_within_patch_distances,
        "all_exit_positions": all_exit_positions,
        "all_stop_rates": all_stop_rates,
        "all_reward_rates": all_reward_rates,
        "all_sites_per_patch": all_sites_per_patch,
        "all_rewards_per_patch": all_rewards_per_patch,
    }

    data_by_odor[odor_type] = data_stats

    print(f"\nStatistics for {odor_display_names[odor_type]} (from {n_sessions_with_odor} sessions):")
    print(f"  Inter-site interval: {data_stats['interval_mean']:.2f} ± {data_stats['interval_std']:.2f} cm")
    print(f"  Range: [{data_stats['interval_min']:.1f}, {data_stats['interval_max']:.1f}] cm")
    print("\nBehavioral Statistics:")
    print(f"  Exit position: {data_stats['exit_position_mean']:.1f} ± {data_stats['exit_position_std']:.1f} cm")
    print(f"  Stop rate: {data_stats['stop_rate']:.3f}")
    print(f"  Reward rate: {data_stats['reward_rate']:.3f}")
    print(f"  Sites per patch: {data_stats['sites_per_patch']:.2f}")
    print(f"  Rewards per patch: {data_stats['rewards_per_patch']:.2f}")

# ============================================================================
# Compare statistics across odor types
# ============================================================================

print(f"\n{'=' * 70}")
print("COMPARISON ACROSS ODOR TYPES")
print(f"{'=' * 70}")

for stat_name in [
    "interval_mean",
    "exit_position_mean",
    "stop_rate",
    "reward_rate",
    "sites_per_patch",
    "rewards_per_patch",
]:
    mb_val = data_by_odor["Methyl_Butyrate"][stat_name]
    ap_val = data_by_odor["Alpha_pinene"][stat_name]
    print(f"\n{stat_name}:")
    print(f"  Methyl Butyrate: {mb_val:.3f}")
    print(f"  Alpha-pinene: {ap_val:.3f}")
    print(f"  Difference: {abs(mb_val - ap_val):.3f} ({abs(mb_val - ap_val) / mb_val * 100:.1f}%)")

# ============================================================================
# Initialize Simulators (one per odor type if needed, or shared)
# ============================================================================

print(f"\n{'=' * 70}")
print("SIMULATOR INITIALIZATION")
print(f"{'=' * 70}")

# For now, use shared InterSite parameters (can be odor-specific later)
# Use average across both odors or fit separately
interval_min_cm = 20.0
interval_scale_cm = 19.0
odor_site_length = 50.0

# Average interval mean for normalization
avg_interval_mean = np.mean([data_by_odor[o]["interval_mean"] for o in odor_types])
interval_normalization = avg_interval_mean

print("Simulator parameters (shared across odors):")
print(f"  interval_min: {interval_min_cm:.2f} cm")
print(f"  interval_scale: {interval_scale_cm:.2f} cm")
print(f"  odor_site_length: {odor_site_length:.2f} cm")
print(f"  interval_normalization: {interval_normalization:.2f} cm")

simulator = PatchForagingDDM_JAX(
    initial_prob=0.8,
    decay_rate=-0.1,
    threshold=1.0,
    start_point=0.0,
    interval_min=interval_min_cm,
    interval_scale=interval_scale_cm,
    interval_normalization=interval_normalization,
    odor_site_length=odor_site_length,
    max_sites_per_window=100,
)

# Initialize Prior
prior_fn = create_prior()

print(f"\n{'=' * 70}")
print("PRIOR RANGES")
print(f"{'=' * 70}")
print("  drift_rate: [0., 2.]")
print("  reward_bump: [0., 2.]")
print("  failure_bump: [0., 2.]")
print("  noise_std: [0.05, 0.5]")

print(f"\n{'=' * 70}")
print("Setup complete! Ready for validation by odor type.")
print(f"{'=' * 70}")

In [None]:
# %% Section 2: Verify Environment Generation (Intervals)

print(f"\n{'=' * 70}")
print("SECTION 2: VERIFY ENVIRONMENT GENERATION")
print("Note: Interval distribution should be same for both odors")
print(f"{'=' * 70}")

# ============================================================================
# Combine interval data from both odors for comparison
# ============================================================================

all_intervals_combined = np.concatenate(
    [
        data_by_odor["Methyl_Butyrate"]["all_within_patch_distances"],
        data_by_odor["Alpha_pinene"]["all_within_patch_distances"],
    ]
)

interval_scale = interval_normalization  # 88.73 cm
all_intervals_normalized = all_intervals_combined / interval_scale

print("\nCombined data from both odors:")
print(f"  Total intervals: {len(all_intervals_combined):,}")
print(f"  Normalization scale: {interval_scale:.2f} cm")
print(f"  Normalized range: [{all_intervals_normalized.min():.3f}, {all_intervals_normalized.max():.3f}]")

# ============================================================================
# Generate interval samples from simulator
# ============================================================================

print("\nGenerating interval samples from simulator...")

rng_key = random.PRNGKey(42)
rng_key, subkey = random.split(rng_key)
theta = jnp.array([0.5, 0.5, 0.5, 0.2])  # mid-range parameters

n_simulations = 100
simulated_intervals = []

for i in range(n_simulations):
    rng_key, subkey = random.split(rng_key)
    window_data, _ = simulator.simulate_one_window(theta, subkey)
    window_data = np.array(window_data)

    positions = window_data[:, 0]
    stopped = window_data[:, 2]

    # Extract intervals between consecutive sites
    for j in range(len(positions) - 1):
        if stopped[j] == 1 and stopped[j + 1] == 1:  # Both within same patch
            interval = positions[j + 1] - positions[j]
            if interval > 0:
                simulated_intervals.append(interval)

simulated_intervals = np.array(simulated_intervals)

# ============================================================================
# Compare normalized data to simulator output
# ============================================================================

print(f"\n{'=' * 70}")
print("INTERVAL DISTRIBUTION COMPARISON (NORMALIZED)")
print(f"{'=' * 70}")

print(f"\nReal Data - Combined (N={len(all_intervals_normalized):,}):")
print(f"  Mean:   {all_intervals_normalized.mean():.3f}")
print(f"  Std:    {all_intervals_normalized.std():.3f}")
print(f"  Min:    {all_intervals_normalized.min():.3f}")
print(f"  Max:    {all_intervals_normalized.max():.3f}")
print(f"  Median: {np.median(all_intervals_normalized):.3f}")

print(f"\nSimulation (N={len(simulated_intervals):,}):")
print(f"  Mean:   {simulated_intervals.mean():.3f}")
print(f"  Std:    {simulated_intervals.std():.3f}")
print(f"  Min:    {simulated_intervals.min():.3f}")
print(f"  Max:    {simulated_intervals.max():.3f}")
print(f"  Median: {np.median(simulated_intervals):.3f}")

mean_diff = abs(simulated_intervals.mean() - all_intervals_normalized.mean())
std_diff = abs(simulated_intervals.std() - all_intervals_normalized.std())

print("\nDifferences:")
print(f"  Mean difference: {mean_diff:.3f} ({mean_diff / all_intervals_normalized.mean() * 100:.1f}%)")
print(f"  Std difference:  {std_diff:.3f} ({std_diff / all_intervals_normalized.std() * 100:.1f}%)")

# ============================================================================
# Visualize - show both odors separately
# ============================================================================

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Top row: Compare both odors to simulation
for idx, odor_type in enumerate(odor_types):
    ax = axes[0, idx]

    odor_data = data_by_odor[odor_type]["all_within_patch_distances"] / interval_scale

    ax.hist(
        odor_data,
        bins=50,
        alpha=0.6,
        density=True,
        label=f"{odor_display_names[odor_type]} Data",
        color="blue",
        edgecolor="black",
    )
    ax.hist(simulated_intervals, bins=50, alpha=0.6, density=True, label="Simulation", color="red", edgecolor="black")

    ax.axvline(odor_data.mean(), color="blue", linestyle="--", linewidth=2)
    ax.axvline(simulated_intervals.mean(), color="red", linestyle="--", linewidth=2)

    ax.set_xlabel("Inter-site interval (normalized)")
    ax.set_ylabel("Density")
    ax.set_title(f"{odor_display_names[odor_type]}: Intervals")
    ax.legend()
    ax.grid(True, alpha=0.3)

# Bottom left: Combined comparison
ax = axes[1, 0]
ax.hist(
    all_intervals_normalized, bins=50, alpha=0.6, density=True, label="Combined Data", color="blue", edgecolor="black"
)
ax.hist(simulated_intervals, bins=50, alpha=0.6, density=True, label="Simulation", color="red", edgecolor="black")
ax.axvline(all_intervals_normalized.mean(), color="blue", linestyle="--", linewidth=2)
ax.axvline(simulated_intervals.mean(), color="red", linestyle="--", linewidth=2)
ax.set_xlabel("Inter-site interval (normalized)")
ax.set_ylabel("Density")
ax.set_title("Combined: Simulation vs Data")
ax.legend()
ax.grid(True, alpha=0.3)

# Bottom right: Q-Q plot
ax = axes[1, 1]
data_sorted = np.sort(all_intervals_normalized)
sim_sorted = np.sort(simulated_intervals)

data_quantiles = np.linspace(0, 1, len(data_sorted))
sim_quantiles = np.linspace(0, 1, len(sim_sorted))
data_interp = np.interp(np.linspace(0, 1, 100), data_quantiles, data_sorted)
sim_interp = np.interp(np.linspace(0, 1, 100), sim_quantiles, sim_sorted)

ax.scatter(data_interp, sim_interp, alpha=0.5, s=20)
min_val = min(data_interp.min(), sim_interp.min())
max_val = max(data_interp.max(), sim_interp.max())
ax.plot([min_val, max_val], [min_val, max_val], "r--", linewidth=2, label="Perfect match")
ax.set_xlabel("Data quantiles (normalized)")
ax.set_ylabel("Simulation quantiles (normalized)")
ax.set_title("Q-Q Plot: Intervals")
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# ============================================================================
# Validation check
# ============================================================================

print(f"\n{'=' * 70}")
print("SECTION 2 VALIDATION")
print(f"{'=' * 70}")

tolerance_pct = 10
checks_passed = []

if mean_diff / all_intervals_normalized.mean() * 100 < tolerance_pct:
    print(f"✓ Mean interval matches within {tolerance_pct}%")
    checks_passed.append(True)
else:
    print(f"✗ Mean interval does NOT match (>{tolerance_pct}% difference)")
    checks_passed.append(False)

if std_diff / all_intervals_normalized.std() * 100 < tolerance_pct:
    print(f"✓ Std interval matches within {tolerance_pct}%")
    checks_passed.append(True)
else:
    print(f"✗ Std interval does NOT match (>{tolerance_pct}% difference)")
    checks_passed.append(False)

if all(checks_passed):
    print("\n✅ Section 2 PASSED: Environment generation matches real data")
else:
    print("\n⚠️  Section 2 FAILED: Environment generation needs adjustment")

print(f"{'=' * 70}")

In [None]:
# %% Section 3: Verify Position Outputs and Prior Coverage BY ODOR TYPE

print(f"\n{'=' * 70}")
print("SECTION 3 & 4: POSITION OUTPUTS AND PRIOR COVERAGE BY ODOR TYPE")
print(f"{'=' * 70}")

# We'll analyze each odor type separately to check if:
# 1. Exit positions match
# 2. Priors cover the data range

# Initialize prior (using Option C from previous analysis)
prior_fn = create_prior(prior_low=jnp.array([0.0, 0.0, 0.0, 0.05]), prior_high=jnp.array([2.0, 2.0, 2.0, 0.5]))

# Sample from prior to generate predictions
n_prior_samples = 3000
rng_key = random.PRNGKey(789)

prior_exits = []

print("\nSampling from priors...")
for i in range(n_prior_samples):
    rng_key, subkey1, subkey2 = random.split(rng_key, 3)
    theta = prior_fn().sample(seed=subkey1)["theta"]

    window_data, _ = simulator.simulate_one_window(theta, subkey2)
    window_data = np.array(window_data)

    positions = window_data[:, 0]
    stopped = window_data[:, 2]

    exit_mask = stopped == 0
    if exit_mask.sum() > 0:
        prior_exits.append(positions[exit_mask].mean())
    else:
        prior_exits.append(positions.max())

prior_exits = np.array(prior_exits)

# ============================================================================
# Analyze each odor type separately
# ============================================================================

fig, axes = plt.subplots(2, 3, figsize=(18, 10))

for odor_idx, odor_type in enumerate(odor_types):
    odor_name = odor_display_names[odor_type]
    odor_data = data_by_odor[odor_type]

    print(f"\n{'=' * 70}")
    print(f"ANALYZING: {odor_name}")
    print(f"{'=' * 70}")

    # Normalize exit positions by InterSite gap
    exit_positions_norm = odor_data["all_exit_positions"] / interval_normalization

    print("\nExit positions (normalized by InterSite gap):")
    print(f"  Mean:   {exit_positions_norm.mean():.2f}")
    print(f"  Median: {np.median(exit_positions_norm):.2f}")
    print(f"  1-99%:  [{np.percentile(exit_positions_norm, 1):.2f}, {np.percentile(exit_positions_norm, 99):.2f}]")

    print("\nPrior predictive (same for both odors):")
    print(f"  Mean:   {prior_exits.mean():.2f}")
    print(f"  1-99%:  [{np.percentile(prior_exits, 1):.2f}, {np.percentile(prior_exits, 99):.2f}]")

    # Check coverage
    data_range = (np.percentile(exit_positions_norm, 1), np.percentile(exit_positions_norm, 99))
    sim_range = (np.percentile(prior_exits, 1), np.percentile(prior_exits, 99))

    covers_low = sim_range[0] <= data_range[0] * 1.5
    covers_high = sim_range[1] >= data_range[1] * 0.67

    print("\nCoverage check:")
    print(f"  Data range: [{data_range[0]:.2f}, {data_range[1]:.2f}]")
    print(f"  Sim range:  [{sim_range[0]:.2f}, {sim_range[1]:.2f}]")
    print(f"  Covers low: {'✓' if covers_low else '✗'}")
    print(f"  Covers high: {'✓' if covers_high else '✗'}")

    if covers_low and covers_high:
        print(f"  → ✅ Adequate coverage for {odor_name}")
    else:
        print(f"  → ⚠️  May need wider priors for {odor_name}")

    # Plot 1: Exit position distribution
    ax = axes[odor_idx, 0]
    ax.hist(
        exit_positions_norm,
        bins=50,
        alpha=0.5,
        density=True,
        label=f"{odor_name} Data",
        color="blue",
        edgecolor="black",
        range=(0, 50),
    )
    ax.hist(
        prior_exits,
        bins=50,
        alpha=0.5,
        density=True,
        label="Prior Predictive",
        color="red",
        edgecolor="black",
        range=(0, 50),
    )
    ax.axvline(exit_positions_norm.mean(), color="blue", linestyle="--", linewidth=2)
    ax.axvline(prior_exits.mean(), color="red", linestyle="--", linewidth=2)
    ax.set_xlabel("Exit Position (InterSite gaps)")
    ax.set_ylabel("Density")
    ax.set_title(f"{odor_name}: Exit Positions")
    ax.legend()
    ax.grid(True, alpha=0.3)

    # Plot 2: CDF comparison
    ax = axes[odor_idx, 1]
    data_sorted = np.sort(exit_positions_norm)
    sim_sorted = np.sort(prior_exits)
    data_cdf = np.arange(1, len(data_sorted) + 1) / len(data_sorted)
    sim_cdf = np.arange(1, len(sim_sorted) + 1) / len(sim_sorted)

    ax.plot(data_sorted, data_cdf, "b-", linewidth=2, label=f"{odor_name} Data", alpha=0.7)
    ax.plot(sim_sorted, sim_cdf, "r-", linewidth=2, label="Prior Predictive", alpha=0.7)
    ax.set_xlabel("Exit Position (InterSite gaps)")
    ax.set_ylabel("Cumulative Probability")
    ax.set_title(f"{odor_name}: CDF")
    ax.set_xlim(0, 50)
    ax.legend()
    ax.grid(True, alpha=0.3)

    # Plot 3: Reward statistics
    ax = axes[odor_idx, 2]

    # Reward rate per patch
    rewards_per_patch = odor_data["all_rewards_per_patch"]
    sites_per_patch = odor_data["all_sites_per_patch"]

    reward_rates = rewards_per_patch / sites_per_patch
    reward_rates = reward_rates[~np.isnan(reward_rates)]

    ax.hist(reward_rates, bins=30, alpha=0.7, edgecolor="black", color="green")
    ax.axvline(reward_rates.mean(), color="red", linestyle="--", linewidth=2, label=f"Mean: {reward_rates.mean():.2f}")
    ax.set_xlabel("Reward Rate (rewards/sites)")
    ax.set_ylabel("Count")
    ax.set_title(f"{odor_name}: Reward Rate Distribution")
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# ============================================================================
# Summary comparison
# ============================================================================

print(f"\n{'=' * 70}")
print("SUMMARY: PRIOR COVERAGE ACROSS ODOR TYPES")
print(f"{'=' * 70}")

print("\nCurrent priors:")
print("  drift_rate: [0.005, 1.0]")
print("  reward_bump: [0.0, 3.0]")
print("  failure_bump: [0.0, 3.0]")
print("  noise_std: [0.05, 0.5]")

for odor_type in odor_types:
    odor_name = odor_display_names[odor_type]
    exit_positions_norm = data_by_odor[odor_type]["all_exit_positions"] / interval_normalization

    data_range = (np.percentile(exit_positions_norm, 1), np.percentile(exit_positions_norm, 99))
    sim_range = (np.percentile(prior_exits, 1), np.percentile(prior_exits, 99))

    print(f"\n{odor_name}:")
    print(f"  Data exit range (1-99%): [{data_range[0]:.2f}, {data_range[1]:.2f}]")
    print(f"  Sim exit range (1-99%):  [{sim_range[0]:.2f}, {sim_range[1]:.2f}]")

    covers = sim_range[0] <= data_range[0] * 1.5 and sim_range[1] >= data_range[1] * 0.67
    print(f"  Coverage: {'✅ Adequate' if covers else '⚠️  Needs adjustment'}")

print(f"\n{'=' * 70}")
print("KEY OBSERVATIONS:")
print(f"{'=' * 70}")

mb_reward_rate = data_by_odor["Methyl_Butyrate"]["reward_rate"]
ap_reward_rate = data_by_odor["Alpha_pinene"]["reward_rate"]

print("\n1. Reward rates differ significantly:")
print(f"   Methyl Butyrate: {mb_reward_rate:.1%}")
print(f"   Alpha-pinene: {ap_reward_rate:.1%}")
print("   → Suggests different reward probabilities or decay rates")

mb_exit_mean = (data_by_odor["Methyl_Butyrate"]["all_exit_positions"] / interval_normalization).mean()
ap_exit_mean = (data_by_odor["Alpha_pinene"]["all_exit_positions"] / interval_normalization).mean()

print("\n2. Exit positions differ:")
print(f"   Methyl Butyrate: {mb_exit_mean:.2f} InterSite gaps")
print(f"   Alpha-pinene: {ap_exit_mean:.2f} InterSite gaps")
print("   → May require odor-specific parameters for best fit")

print("\n3. RECOMMENDATION:")
print("   Fit separate parameters for each odor type during SBI")
print("   This allows model to capture different reward structures")

print(f"{'=' * 70}")

In [None]:
# Check current prior settings

print(f"\n{'=' * 70}")
print("CURRENT PRIOR SETTINGS")
print(f"{'=' * 70}")

# Sample a few parameter sets to verify the ranges
rng_key = random.PRNGKey(999)
samples = []

for i in range(1000):
    rng_key, subkey = random.split(rng_key)
    theta = prior_fn().sample(seed=subkey)["theta"]
    samples.append(np.array(theta))

samples = np.array(samples)

param_names = ["drift_rate", "reward_bump", "failure_bump", "noise_std"]

print("\nPrior ranges (from 1000 samples):")
for i, name in enumerate(param_names):
    print(f"  {name:15s}: [{samples[:, i].min():.4f}, {samples[:, i].max():.4f}]")
    print(f"                  Mean: {samples[:, i].mean():.4f}, Std: {samples[:, i].std():.4f}")

# Also show what's in the prior_fn definition
print(f"\n{'=' * 70}")
print("To see the exact prior definition, check where create_prior() was called.")
print("In Section 3, it was set as:")
print("  prior_fn = create_prior(")
print("      prior_low=jnp.array([0.005, 0.0, 0.0, 0.05]),")
print("      prior_high=jnp.array([1.0, 3.0, 3.0, 0.5])")
print("  )")
print(f"{'=' * 70}")