# Enhanced Feature Visualizations (31 Features)

This notebook visualizes how all 31 summary statistics respond to changes in DDM parameters.

In [None]:
import os
# Force CPU backend on Apple Silicon to avoid Metal issues
os.environ['JAX_PLATFORMS'] = 'cpu'

# === Setup ===
import numpy as np
import matplotlib.pyplot as plt
import jax.numpy as jnp
from jax import random
import seaborn as sns
from scipy.stats import spearmanr, pearsonr
from sklearn.ensemble import RandomForestRegressor
from sklearn.inspection import permutation_importance

from aind_behavior_vrforaging_analysis.sbi_ddm_analysis.simulator import PatchForagingDDM_JAX, create_prior

# Reproducibility
np.random.seed(1)
rng_key = random.PRNGKey(42)

In [None]:
def comprehensive_feature_importance_analysis(simulator, prior_fn, feature_names, n_samples=2000):
    """
    Comprehensive analysis to identify most informative features for each parameter
    
    Returns rankings, plots, and recommendations
    """
    
    print("="*70)
    print("COMPREHENSIVE FEATURE IMPORTANCE ANALYSIS")
    print("="*70)
    
    # =========================================================================
    # 1. GENERATE DATA
    # =========================================================================
    print(f"\nGenerating {n_samples} samples from prior...")
    
    rng_key = random.PRNGKey(42)
    thetas = []
    stats = []
    
    for i in range(n_samples):
        if i % 200 == 0:
            print(f"  {i}/{n_samples}...")
        
        rng_key, subkey1, subkey2 = random.split(rng_key, 3)
        theta = prior_fn().sample(seed=subkey1)['theta']
        _, summary_stats = simulator.simulate_one_window(theta, subkey2)
        
        thetas.append(theta)
        stats.append(summary_stats)
    
    thetas = jnp.array(thetas)
    stats = jnp.array(stats)
    
    param_names = ['drift_rate', 'reward_bump', 'failure_bump', 'noise_std']
    n_features = len(feature_names)
    n_params = len(param_names)
    
    print(f"\n✓ Generated data shape: thetas={thetas.shape}, stats={stats.shape}")
    
    # =========================================================================
    # 2. CORRELATION ANALYSIS (LINEAR RELATIONSHIPS)
    # =========================================================================
    print("\n" + "="*70)
    print("METHOD 1: CORRELATION ANALYSIS")
    print("="*70)
    
    correlation_results = {}
    
    for i, param in enumerate(param_names):
        correlations = []
        for j in range(n_features):
            # Pearson correlation (linear)
            r_pearson, _ = pearsonr(np.array(thetas[:, i]), np.array(stats[:, j]))
            # Spearman correlation (monotonic)
            r_spearman, _ = spearmanr(np.array(thetas[:, i]), np.array(stats[:, j]))
            
            correlations.append({
                'feature': feature_names[j],
                'feature_idx': j,
                'pearson': r_pearson,
                'spearman': r_spearman,
                'abs_pearson': abs(r_pearson),
                'abs_spearman': abs(r_spearman),
            })
        
        # Sort by absolute correlation
        correlations.sort(key=lambda x: x['abs_spearman'], reverse=True)
        correlation_results[param] = correlations
        
        print(f"\n{param} - Top 10 correlated features:")
        print(f"  {'Rank':<6} {'Feature':<30} {'Pearson':>10} {'Spearman':>10}")
        print("  " + "-"*60)
        for rank, c in enumerate(correlations[:10], 1):
            print(f"  {rank:<6} {c['feature']:<30} {c['pearson']:>10.3f} {c['spearman']:>10.3f}")
    
    # =========================================================================
    # 3. MUTUAL INFORMATION (NONLINEAR RELATIONSHIPS)
    # =========================================================================
    print("\n" + "="*70)
    print("METHOD 2: MUTUAL INFORMATION")
    print("="*70)
    
    from sklearn.feature_selection import mutual_info_regression
    
    mi_results = {}
    
    for i, param in enumerate(param_names):
        # Compute MI for all features
        mi_scores = mutual_info_regression(
            np.array(stats), 
            np.array(thetas[:, i]),
            random_state=42
        )
        
        mi_list = [
            {'feature': feature_names[j], 'feature_idx': j, 'mi': mi_scores[j]}
            for j in range(n_features)
        ]
        mi_list.sort(key=lambda x: x['mi'], reverse=True)
        mi_results[param] = mi_list
        
        print(f"\n{param} - Top 10 by Mutual Information:")
        print(f"  {'Rank':<6} {'Feature':<30} {'MI Score':>12}")
        print("  " + "-"*50)
        for rank, m in enumerate(mi_list[:10], 1):
            print(f"  {rank:<6} {m['feature']:<30} {m['mi']:>12.4f}")
    
    # =========================================================================
    # 4. RANDOM FOREST FEATURE IMPORTANCE
    # =========================================================================
    print("\n" + "="*70)
    print("METHOD 3: RANDOM FOREST IMPORTANCE")
    print("="*70)
    
    rf_results = {}
    
    for i, param in enumerate(param_names):
        print(f"\nTraining Random Forest for {param}...")
        
        # Train Random Forest
        rf = RandomForestRegressor(
            n_estimators=100,
            max_depth=10,
            random_state=42,
            n_jobs=-1
        )
        rf.fit(np.array(stats), np.array(thetas[:, i]))
        
        # Get feature importances
        importances = rf.feature_importances_
        
        # Permutation importance (more robust)
        perm_importance = permutation_importance(
            rf, np.array(stats), np.array(thetas[:, i]),
            n_repeats=10, random_state=42, n_jobs=-1
        )
        
        rf_list = [
            {
                'feature': feature_names[j],
                'feature_idx': j,
                'importance': importances[j],
                'perm_importance': perm_importance.importances_mean[j],
                'perm_std': perm_importance.importances_std[j],
            }
            for j in range(n_features)
        ]
        rf_list.sort(key=lambda x: x['perm_importance'], reverse=True)
        rf_results[param] = rf_list
        
        print(f"\n{param} - Top 10 by Random Forest:")
        print(f"  {'Rank':<6} {'Feature':<30} {'Importance':>12} {'Perm.Imp':>12}")
        print("  " + "-"*62)
        for rank, r in enumerate(rf_list[:10], 1):
            print(f"  {rank:<6} {r['feature']:<30} {r['importance']:>12.4f} "
                  f"{r['perm_importance']:>12.4f}")
    
    # =========================================================================
    # 5. LASSO REGRESSION (SPARSE FEATURE SELECTION)
    # =========================================================================
    print("\n" + "="*70)
    print("METHOD 4: LASSO SPARSE SELECTION")
    print("="*70)
    
    from sklearn.linear_model import LassoCV
    from sklearn.preprocessing import StandardScaler
    
    # Standardize features
    scaler = StandardScaler()
    stats_scaled = scaler.fit_transform(np.array(stats))
    
    lasso_results = {}
    
    for i, param in enumerate(param_names):
        print(f"\nRunning LASSO for {param}...")
        
        # LASSO with cross-validation
        lasso = LassoCV(cv=5, random_state=42, max_iter=5000)
        lasso.fit(stats_scaled, np.array(thetas[:, i]))
        
        # Get non-zero coefficients
        lasso_list = [
            {
                'feature': feature_names[j],
                'feature_idx': j,
                'coefficient': lasso.coef_[j],
                'abs_coefficient': abs(lasso.coef_[j]),
            }
            for j in range(n_features)
        ]
        lasso_list.sort(key=lambda x: x['abs_coefficient'], reverse=True)
        lasso_results[param] = lasso_list
        
        # Count selected features
        n_selected = sum(1 for l in lasso_list if abs(l['coefficient']) > 0.001)
        
        print(f"\n{param} - LASSO selected {n_selected}/{n_features} features")
        print(f"  Alpha: {lasso.alpha_:.6f}, R²: {lasso.score(stats_scaled, np.array(thetas[:, i])):.3f}")
        print(f"\n  {'Rank':<6} {'Feature':<30} {'Coefficient':>15}")
        print("  " + "-"*53)
        for rank, l in enumerate([x for x in lasso_list if abs(x['coefficient']) > 0.001][:10], 1):
            print(f"  {rank:<6} {l['feature']:<30} {l['coefficient']:>15.6f}")
    
    # =========================================================================
    # 6. ENSEMBLE RANKING (COMBINE ALL METHODS)
    # =========================================================================
    print("\n" + "="*70)
    print("METHOD 5: ENSEMBLE RANKING")
    print("="*70)
    
    ensemble_results = {}
    
    for param in param_names:
        # Collect rankings from each method
        feature_scores = {fname: {'ranks': [], 'scores': []} for fname in feature_names}
        
        # Add correlation ranks
        for rank, item in enumerate(correlation_results[param], 1):
            feature_scores[item['feature']]['ranks'].append(rank)
            feature_scores[item['feature']]['scores'].append(item['abs_spearman'])
        
        # Add MI ranks
        for rank, item in enumerate(mi_results[param], 1):
            feature_scores[item['feature']]['ranks'].append(rank)
            feature_scores[item['feature']]['scores'].append(item['mi'])
        
        # Add RF ranks
        for rank, item in enumerate(rf_results[param], 1):
            feature_scores[item['feature']]['ranks'].append(rank)
            feature_scores[item['feature']]['scores'].append(item['perm_importance'])
        
        # Add LASSO ranks
        for rank, item in enumerate(lasso_results[param], 1):
            feature_scores[item['feature']]['ranks'].append(rank)
            feature_scores[item['feature']]['scores'].append(item['abs_coefficient'])
        
        # Compute ensemble score (average rank, lower is better)
        ensemble_list = []
        for fname, data in feature_scores.items():
            avg_rank = np.mean(data['ranks'])
            min_rank = min(data['ranks'])
            max_rank = max(data['ranks'])
            
            ensemble_list.append({
                'feature': fname,
                'avg_rank': avg_rank,
                'min_rank': min_rank,
                'max_rank': max_rank,
                'rank_std': np.std(data['ranks']),
            })
        
        ensemble_list.sort(key=lambda x: x['avg_rank'])
        ensemble_results[param] = ensemble_list
        
        print(f"\n{param} - ENSEMBLE RANKING (lower rank = more important):")
        print(f"  {'Rank':<6} {'Feature':<30} {'Avg Rank':>10} {'Best':>6} {'Worst':>6} {'Std':>6}")
        print("  " + "-"*70)
        for rank, e in enumerate(ensemble_list[:15], 1):
            print(f"  {rank:<6} {e['feature']:<30} {e['avg_rank']:>10.1f} "
                  f"{e['min_rank']:>6.0f} {e['max_rank']:>6.0f} {e['rank_std']:>6.1f}")
    
    # =========================================================================
    # 7. VISUALIZATIONS
    # =========================================================================
    print("\n" + "="*70)
    print("GENERATING VISUALIZATIONS")
    print("="*70)
    
    # Plot 1: Heatmap of feature importance across methods
    fig, axes = plt.subplots(2, 2, figsize=(20, 16))
    
    for idx, (ax, param) in enumerate(zip(axes.flat, param_names)):
        # Get top 15 features from ensemble
        top_features = [e['feature'] for e in ensemble_results[param][:15]]
        
        # Create importance matrix
        methods = ['Correlation', 'Mutual Info', 'Random Forest', 'LASSO']
        importance_matrix = np.zeros((len(top_features), len(methods)))
        
        for i, fname in enumerate(top_features):
            # Correlation (normalized)
            corr_item = next(c for c in correlation_results[param] if c['feature'] == fname)
            importance_matrix[i, 0] = corr_item['abs_spearman']
            
            # MI (normalized)
            mi_item = next(m for m in mi_results[param] if m['feature'] == fname)
            importance_matrix[i, 1] = mi_item['mi'] / max(m['mi'] for m in mi_results[param])
            
            # RF (normalized)
            rf_item = next(r for r in rf_results[param] if r['feature'] == fname)
            importance_matrix[i, 2] = rf_item['perm_importance'] / max(r['perm_importance'] for r in rf_results[param])
            
            # LASSO (normalized)
            lasso_item = next(l for l in lasso_results[param] if l['feature'] == fname)
            importance_matrix[i, 3] = lasso_item['abs_coefficient'] / max(l['abs_coefficient'] for l in lasso_results[param])
        
        sns.heatmap(importance_matrix, ax=ax, cmap='YlOrRd', 
                   xticklabels=methods, yticklabels=top_features,
                   cbar_kws={'label': 'Normalized Importance'},
                   vmin=0, vmax=1, annot=True, fmt='.2f')
        ax.set_title(f'{param} - Top 15 Features', fontsize=14, fontweight='bold')
        ax.set_xlabel('Method', fontsize=12)
        ax.set_ylabel('Feature', fontsize=12)
    
    plt.tight_layout()
    plt.savefig('feature_importance_heatmap.png', dpi=150, bbox_inches='tight')
    plt.show()
    print("✓ Saved: feature_importance_heatmap.png")
    
    # Plot 2: Scatter plots of top features vs parameters
    fig, axes = plt.subplots(4, 4, figsize=(20, 16))
    
    for i, param in enumerate(param_names):
        top_4_features = [e['feature'] for e in ensemble_results[param][:4]]
        
        for j, fname in enumerate(top_4_features):
            ax = axes[i, j]
            feat_idx = feature_names.index(fname)
            
            ax.scatter(stats[:, feat_idx], thetas[:, i], alpha=0.3, s=1)
            ax.set_xlabel(fname, fontsize=10)
            ax.set_ylabel(param if j == 0 else '', fontsize=10)
            
            # Add correlation
            r = correlation_results[param][next(k for k, c in enumerate(correlation_results[param]) if c['feature'] == fname)]['spearman']
            ax.set_title(f'ρ={r:.3f}', fontsize=9)
            ax.grid(True, alpha=0.3)
    
    plt.suptitle('Top 4 Features for Each Parameter', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.savefig('top_features_scatter.png', dpi=150, bbox_inches='tight')
    plt.show()
    print("✓ Saved: top_features_scatter.png")
    
    # =========================================================================
    # 8. RECOMMENDATIONS
    # =========================================================================
    print("\n" + "="*70)
    print("FEATURE SELECTION RECOMMENDATIONS")
    print("="*70)
    
    # Find features that appear in top 10 for multiple parameters
    feature_frequency = {}
    for fname in feature_names:
        count = sum(
            1 for param in param_names 
            if fname in [e['feature'] for e in ensemble_results[param][:10]]
        )
        if count > 0:
            feature_frequency[fname] = count
    
    print("\nFeatures appearing in top 10 for multiple parameters:")
    for fname, count in sorted(feature_frequency.items(), key=lambda x: x[1], reverse=True):
        params = [p for p in param_names if fname in [e['feature'] for e in ensemble_results[p][:10]]]
        print(f"  {fname:<30} → {count} parameters: {', '.join(params)}")
    
    # Recommend minimal feature set
    print("\n" + "="*70)
    print("RECOMMENDED MINIMAL FEATURE SETS")
    print("="*70)
    
    for n_features_target in [10, 15, 20]:
        print(f"\nTop {n_features_target} features (union across all parameters):")
        
        # Collect top N from each parameter
        selected = set()
        for param in param_names:
            selected.update([e['feature'] for e in ensemble_results[param][:n_features_target//2]])
        
        # Sort by frequency
        selected_list = sorted(
            selected,
            key=lambda x: sum(1 for p in param_names if x in [e['feature'] for e in ensemble_results[p][:10]]),
            reverse=True
        )[:n_features_target]
        
        print(f"  Total: {len(selected_list)} features")
        for fname in selected_list:
            # Find which parameters this feature is top-10 for
            good_for = [p for p in param_names if fname in [e['feature'] for e in ensemble_results[p][:10]]]
            print(f"    {fname:<30} → {', '.join(good_for) if good_for else 'general'}")
    
    return {
        'correlation': correlation_results,
        'mutual_info': mi_results,
        'random_forest': rf_results,
        'lasso': lasso_results,
        'ensemble': ensemble_results,
        'data': (thetas, stats)
    }

In [None]:
# For 35-feature model
from aind_behavior_vrforaging_analysis.sbi_ddm_analysis.enhanced_stats_35 import FEATURE_NAMES  # 35 features


# Initialize simulator
simulator_35 = PatchForagingDDM_JAX(max_sites_per_window=100)
prior_fn = create_prior()

results_35 = comprehensive_feature_importance_analysis(
    simulator_35,  # Your 35-feature simulator
    prior_fn,
    FEATURE_NAMES,
    n_samples=2000
)

In [None]:
# For 37-feature model
from aind_behavior_vrforaging_analysis.sbi_ddm_analysis.enhanced_stats_37 import FEATURE_NAMES  # 37 features


# Initialize simulator
simulator_37 = PatchForagingDDM_JAX(max_sites_per_window=100)
prior_fn = create_prior()

results_37 = comprehensive_feature_importance_analysis(
    simulator_37,  # Your 37-feature simulator
    prior_fn,
    FEATURE_NAMES,
    n_samples=2000
)

In [None]:
# For 23-feature model
from aind_behavior_vrforaging_analysis.sbi_ddm_analysis.enhanced_stats_23 import FEATURE_NAMES  # 23 features

# Initialize simulator
simulator_23 = PatchForagingDDM_JAX(max_sites_per_window=100)
prior_fn = create_prior()

results_23 = comprehensive_feature_importance_analysis(
    simulator_23,  # Your 23-feature simulator
    prior_fn,
    FEATURE_NAMES,
    n_samples=2000
)

In [None]:
from aind_behavior_vrforaging_analysis.sbi_ddm_analysis.enhanced_stats import compute_summary_stats

# Test sensitivity of new features
def test_new_features():
    rng_key = random.PRNGKey(42)
    
    test_cases = {
        'High drift vs low drift': [
            jnp.array([1.5, 0.5, 0.5, 0.2]),
            jnp.array([0.3, 0.5, 0.5, 0.2]),
        ],
        'High reward_bump vs low': [
            jnp.array([0.75, 1.5, 0.5, 0.2]),
            jnp.array([0.75, 0.2, 0.5, 0.2]),
        ],
        'High failure_bump vs low': [
            jnp.array([0.75, 0.5, 1.5, 0.2]),
            jnp.array([0.75, 0.5, 0.2, 0.2]),
        ],
        'High noise vs low': [
            jnp.array([0.75, 0.5, 0.5, 0.4]),
            jnp.array([0.75, 0.5, 0.5, 0.05]),
        ],
    }
    
    for test_name, (theta1, theta2) in test_cases.items():
        print(f"\n{test_name}:")
        
        # Simulate from each
        stats1_list = []
        stats2_list = []
        
        for _ in range(50):
            rng_key, sub1, sub2 = random.split(rng_key, 3)
            
            window_data1, _, _ = simulate_and_extract(theta1, sub1)
            stats1 = compute_summary_stats(window_data1)
            stats1_list.append(stats1)
            
            window_data2, _, _ = simulate_and_extract(theta2, sub2)
            stats2 = compute_summary_stats(window_data2)
            stats2_list.append(stats2)
        
        stats1_mean = jnp.mean(jnp.array(stats1_list), axis=0)
        stats2_mean = jnp.mean(jnp.array(stats2_list), axis=0)
        
        # Find features with biggest difference
        diffs = jnp.abs(stats1_mean - stats2_mean)
        top_5 = jnp.argsort(diffs)[::-1][:5]
        
        print("  Top 5 discriminating features:")
        for rank, idx in enumerate(top_5, 1):
            print(f"    {rank}. {OPTIMIZED_FEATURE_NAMES[idx]:<30} "
                  f"({stats1_mean[idx]:.3f} vs {stats2_mean[idx]:.3f})")

test_new_features()

In [None]:
from scipy.cluster.hierarchy import dendrogram, linkage

def analyze_summary_stat_correlations(n_samples=1000):
    """
    Generate samples from prior, compute correlations, visualize redundancy
    """
    print("Generating samples from prior...")
    rng_key = random.PRNGKey(42)
    prior_fn = create_prior()
    
    # Sample from prior
    prior_dist = prior_fn()
    theta_samples = prior_dist.sample(n_samples, seed=rng_key)['theta']
    
    # Simulate and extract stats
    all_stats = []
    for i in range(n_samples):
        rng_key, subkey = random.split(rng_key)
        _, stats, _ = simulate_and_extract(theta_samples[i], subkey)
        all_stats.append(stats)
    
    all_stats = np.array(all_stats)  # (1000, 37)
    
    # Compute correlation matrix
    corr_matrix = np.corrcoef(all_stats.T)
    
    # === PLOT 1: Full correlation heatmap ===
    fig, ax = plt.subplots(figsize=(16, 14))
    im = ax.imshow(corr_matrix, cmap='RdBu_r', vmin=-1, vmax=1, aspect='auto')
    
    # Add feature names
    ax.set_xticks(range(len(FEATURE_NAMES)))
    ax.set_yticks(range(len(FEATURE_NAMES)))
    ax.set_xticklabels(FEATURE_NAMES, rotation=90, ha='right', fontsize=8)
    ax.set_yticklabels(FEATURE_NAMES, fontsize=8)
    
    # Add colorbar
    plt.colorbar(im, ax=ax, label='Correlation')
    
    # Add grid lines between groups
    group_boundaries = [0, 7, 11, 16, 20, 23, 26, 37]
    for boundary in group_boundaries:
        ax.axhline(boundary - 0.5, color='black', linewidth=1.5)
        ax.axvline(boundary - 0.5, color='black', linewidth=1.5)
    
    plt.title('Summary Statistics Correlation Matrix', fontsize=14, pad=20)
    plt.tight_layout()
    plt.savefig('correlation_matrix_full.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    # === PLOT 2: Hierarchical clustering to find redundant groups ===
    fig, ax = plt.subplots(figsize=(12, 8))
    
    # Convert correlation to distance
    distance_matrix = 1 - np.abs(corr_matrix)
    linkage_matrix = linkage(distance_matrix, method='ward')
    
    dendrogram(linkage_matrix, labels=FEATURE_NAMES, ax=ax, 
               orientation='right', leaf_font_size=8)
    ax.set_xlabel('Distance (1 - |correlation|)')
    ax.set_title('Hierarchical Clustering of Summary Statistics')
    plt.tight_layout()
    plt.savefig('stat_clustering.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    # === IDENTIFY HIGHLY CORRELATED PAIRS ===
    print("\n" + "="*70)
    print("HIGHLY CORRELATED PAIRS (|r| > 0.95)")
    print("="*70)
    
    high_corr_pairs = []
    for i in range(len(FEATURE_NAMES)):
        for j in range(i+1, len(FEATURE_NAMES)):
            if abs(corr_matrix[i, j]) > 0.95:
                high_corr_pairs.append((i, j, corr_matrix[i, j]))
    
    high_corr_pairs.sort(key=lambda x: abs(x[2]), reverse=True)
    
    for i, j, r in high_corr_pairs:
        print(f"{FEATURE_NAMES[i]:<30} <-> {FEATURE_NAMES[j]:<30}  r = {r:>6.3f}")
    
    print(f"\nTotal highly correlated pairs: {len(high_corr_pairs)}")
    
    # === RECOMMEND REMOVALS ===
    print("\n" + "="*70)
    print("RECOMMENDED REMOVALS (keep one from each pair)")
    print("="*70)
    
    to_remove = set()
    for i, j, r in high_corr_pairs:
        if i not in to_remove and j not in to_remove:
            # Keep the one that appears in fewer pairs
            i_count = sum(1 for x, y, _ in high_corr_pairs if i in (x, y))
            j_count = sum(1 for x, y, _ in high_corr_pairs if j in (x, y))
            
            if j_count > i_count:
                to_remove.add(j)
                print(f"Remove: {FEATURE_NAMES[j]:<30} (redundant with {FEATURE_NAMES[i]})")
            else:
                to_remove.add(i)
                print(f"Remove: {FEATURE_NAMES[i]:<30} (redundant with {FEATURE_NAMES[j]})")
    
    print(f"\nRecommended to remove: {len(to_remove)} stats")
    print(f"Would reduce from {len(FEATURE_NAMES)} -> {len(FEATURE_NAMES) - len(to_remove)} stats")
    
    return all_stats, corr_matrix, list(to_remove)

# Run the analysis
all_stats, corr_matrix, to_remove = analyze_summary_stat_correlations(n_samples=1000)

In [None]:
def parameter_sensitivity_analysis():
    """
    For each parameter, vary it while holding others constant.
    See which summary stats are most sensitive.
    """
    rng_key = random.PRNGKey(42)
    
    # Base parameter values (middle of prior)
    base_theta = jnp.array([1.0, 1.0, 1.0, 0.275])
    
    param_names = ["drift_rate", "reward_bump", "failure_bump", "noise_std"]
    param_ranges = {
        'drift_rate': jnp.linspace(0.0, 2.0, 15),
        'reward_bump': jnp.linspace(0.0, 2.0, 15),
        'failure_bump': jnp.linspace(0.0, 2.0, 15),
        'noise_std': jnp.linspace(0.05, 0.5, 15),
    }
    
    results = {}
    
    for param_idx, (param_name, param_values) in enumerate(param_ranges.items()):
        print(f"\nAnalyzing sensitivity to {param_name}...")
        
        stats_for_param = []
        
        for param_val in param_values:
            # Create theta with one parameter varied
            theta = base_theta.at[param_idx].set(param_val)
            
            # Average over 50 simulations to reduce noise
            stats_samples = []
            for _ in range(50):
                rng_key, subkey = random.split(rng_key)
                _, stats, _ = simulate_and_extract(theta, subkey)
                stats_samples.append(stats)
            
            mean_stats = np.mean(stats_samples, axis=0)
            stats_for_param.append(mean_stats)
        
        results[param_name] = {
            'values': param_values,
            'stats': np.array(stats_for_param)  # (15, 37)
        }
    
    # === COMPUTE SENSITIVITY SCORES ===
    from scipy.stats import spearmanr
    
    sensitivity_matrix = np.zeros((4, len(FEATURE_NAMES)))
    
    for param_idx, param_name in enumerate(param_names):
        param_values = results[param_name]['values']
        param_stats = results[param_name]['stats']
        
        for stat_idx in range(len(FEATURE_NAMES)):
            # Spearman correlation between parameter value and stat value
            corr, _ = spearmanr(param_values, param_stats[:, stat_idx])
            sensitivity_matrix[param_idx, stat_idx] = abs(corr)
    
    # === PLOT SENSITIVITY HEATMAP ===
    fig, ax = plt.subplots(figsize=(16, 6))
    im = ax.imshow(sensitivity_matrix, cmap='YlOrRd', aspect='auto', vmin=0, vmax=1)
    
    ax.set_yticks(range(4))
    ax.set_yticklabels(param_names, fontsize=10)
    ax.set_xticks(range(len(FEATURE_NAMES)))
    ax.set_xticklabels(FEATURE_NAMES, rotation=90, ha='right', fontsize=8)
    
    plt.colorbar(im, ax=ax, label='|Spearman correlation|')
    plt.title('Parameter Sensitivity: Which stats respond to which parameters?', fontsize=14)
    plt.tight_layout()
    plt.savefig('parameter_sensitivity.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    # === IDENTIFY MOST SENSITIVE STATS PER PARAMETER ===
    print("\n" + "="*70)
    print("TOP 5 MOST SENSITIVE STATS FOR EACH PARAMETER")
    print("="*70)
    
    for param_idx, param_name in enumerate(param_names):
        print(f"\n{param_name.upper()}:")
        
        # Get top 5 stats
        top_indices = np.argsort(sensitivity_matrix[param_idx])[::-1][:5]
        
        for rank, stat_idx in enumerate(top_indices, 1):
            corr = sensitivity_matrix[param_idx, stat_idx]
            print(f"  {rank}. {FEATURE_NAMES[stat_idx]:<30} (|r| = {corr:.3f})")
    
    # === IDENTIFY WEAK STATS (not sensitive to anything) ===
    print("\n" + "="*70)
    print("WEAK STATS (max sensitivity < 0.3 to any parameter)")
    print("="*70)
    
    max_sensitivity = sensitivity_matrix.max(axis=0)
    weak_stats = np.where(max_sensitivity < 0.3)[0]
    
    for stat_idx in weak_stats:
        print(f"{FEATURE_NAMES[stat_idx]:<30} (max |r| = {max_sensitivity[stat_idx]:.3f})")
    
    return results, sensitivity_matrix

# Run sensitivity analysis
sensitivity_results, sensitivity_matrix = parameter_sensitivity_analysis()

In [None]:
# === Sweep Setup ===
theta_labels = ["drift_rate", "reward_bump", "failure_bump", "noise_std"]
theta_base = jnp.array([0.4, 0.3, 0.1, 0.1])  # Base parameters

n_repeats = 20  # Runs per parameter combination
gradient_values = np.linspace(0.01, 1.0, 5)  # 5 levels for gradient
x_values = np.linspace(0.01, 1.0, 10)  # 10 points for x-axis

# Storage for results
results_mean = {label: {} for label in theta_labels}
results_sem = {label: {} for label in theta_labels}

print(f"Will run {len(theta_labels)} parameter sweeps")
print(f"Each sweep: {len(gradient_values)} gradients × {len(x_values)} x-values × {n_repeats} repeats")
print(f"Total simulations: {len(theta_labels) * len(gradient_values) * len(x_values) * n_repeats}")

In [None]:
# === Run Parameter Sweeps ===
import time

start_time = time.time()

for param_idx, param_x in enumerate(theta_labels):
    print(f"\nSweeping {param_x} ({param_idx+1}/{len(theta_labels)})...")
    
    # Get other parameters (for gradient and fixed)
    other_params = [p for j, p in enumerate(theta_labels) if j != param_idx]
    gradient_param_idx = theta_labels.index(other_params[0])  # First other param varies
    
    for grad_val in gradient_values:
        mean_list, sem_list = [], []
        
        for x_val in x_values:
            # Build theta for this combination
            theta = theta_base.copy()
            theta = theta.at[param_idx].set(x_val)  # X-axis parameter
            theta = theta.at[gradient_param_idx].set(grad_val)  # Gradient parameter
            # Other parameters stay at base values
            
            # Run multiple simulations
            runs = []
            for _ in range(n_repeats):
                _, summary, rng_key = simulate_and_extract(theta, rng_key)
                runs.append(np.array(summary))
            
            runs = np.vstack(runs)
            mean_list.append(runs.mean(axis=0))
            sem_list.append(runs.std(axis=0, ddof=1) / np.sqrt(n_repeats))
        
        # Store results
        key = f"{grad_val:.4f}"
        results_mean[param_x][key] = np.vstack(mean_list)
        results_sem[param_x][key] = np.vstack(sem_list)
    
    elapsed = time.time() - start_time
    print(f"  Completed in {elapsed:.1f}s")

total_time = time.time() - start_time
print(f"\n✓ All sweeps completed in {total_time/60:.1f} minutes")

## Visualization 1: Key Features by Parameter

Focus on the most important features for each parameter:
- **drift_rate**: Basic time statistics, temporal trends
- **reward_bump**: Reward history effects (mean_time_after_reward)
- **failure_bump**: Reward history effects (mean_time_after_failure)
- **noise_std**: Distribution shape, sequential dependencies

In [None]:
# === Visualization 1: Most Important Features ===

# Define key features to visualize for each parameter
KEY_FEATURES = {
    "drift_rate": [1, 2, 11, 12, 13],  # mean_time, std_time, early_mean, late_mean, trend
    "reward_bump": [7, 8, 9, 14, 23],  # After reward/failure, late-early, reward_rate
    "failure_bump": [7, 8, 9, 10, 14], # After reward/failure stats
    "noise_std": [2, 16, 18, 19, 20],  # std_time, p25, p75, iqr, autocorr
}

for param_x in theta_labels:
    other_params = [p for j, p in enumerate(theta_labels) if p != param_x]
    gradient_param = other_params[0]
    
    # Get key features for this parameter
    feature_indices = KEY_FEATURES[param_x]
    
    fig, axes = plt.subplots(1, len(feature_indices), figsize=(4*len(feature_indices), 4))
    if len(feature_indices) == 1:
        axes = [axes]
    
    fig.suptitle(f"Effect of {param_x} (gradient: {gradient_param})", fontsize=14, fontweight='bold')
    
    cmap = plt.cm.viridis(np.linspace(0, 1, len(gradient_values)))
    
    for ax_idx, feature_idx in enumerate(feature_indices):
        ax = axes[ax_idx]
        feature_name = FEATURE_NAMES[feature_idx]
        
        for color_idx, grad_val in enumerate(gradient_values):
            key = f"{grad_val:.4f}"
            mean_vals = results_mean[param_x][key][:, feature_idx]
            sem_vals = results_sem[param_x][key][:, feature_idx]
            
            ax.plot(x_values, mean_vals, color=cmap[color_idx], 
                   marker='o', markersize=4, linewidth=2,
                   label=f"{grad_val:.2f}")
            ax.fill_between(x_values,
                          mean_vals - 1.96 * sem_vals,
                          mean_vals + 1.96 * sem_vals,
                          color=cmap[color_idx], alpha=0.2)
        
        ax.set_title(feature_name, fontsize=11, fontweight='bold')
        ax.set_xlabel(param_x, fontsize=10)
        ax.set_ylabel('Value', fontsize=10)
        ax.grid(True, linestyle='--', alpha=0.3)
    
    # Add legend
    handles, labels = axes[0].get_legend_handles_labels()
    fig.legend(handles, labels, loc='center left', bbox_to_anchor=(1, 0.5),
              title=gradient_param, fontsize=9)
    
    plt.tight_layout(rect=[0, 0, 0.95, 0.96])
    plt.show()

## Visualization 2: Feature Groups

Visualize all features organized by functional groups.

In [None]:
# === Visualization 2: All Features by Group ===

for param_x in theta_labels:
    other_params = [p for j, p in enumerate(theta_labels) if p != param_x]
    gradient_param = other_params[0]
    
    # Create one large figure with subplots for each feature group
    fig = plt.figure(figsize=(20, 12))
    fig.suptitle(f"All Features: {param_x} (gradient: {gradient_param})", 
                fontsize=16, fontweight='bold', y=0.995)
    
    cmap = plt.cm.viridis(np.linspace(0, 1, len(gradient_values)))
    
    # Create subplots for each group
    group_row = 0
    for group_name, feature_indices in FEATURE_GROUPS.items():
        n_features = len(feature_indices)
        
        # Add group title
        ax_title = plt.subplot(8, 7, group_row * 7 + 1)
        ax_title.text(0.5, 0.5, group_name, fontsize=12, fontweight='bold',
                     ha='center', va='center')
        ax_title.axis('off')
        
        # Plot features in this group
        for i, feature_idx in enumerate(feature_indices):
            ax = plt.subplot(8, 7, group_row * 7 + i + 2)
            feature_name = FEATURE_NAMES[feature_idx]
            
            for color_idx, grad_val in enumerate(gradient_values):
                key = f"{grad_val:.4f}"
                mean_vals = results_mean[param_x][key][:, feature_idx]
                sem_vals = results_sem[param_x][key][:, feature_idx]
                
                ax.plot(x_values, mean_vals, color=cmap[color_idx], 
                       marker='o', markersize=3, linewidth=1.5, alpha=0.8)
                ax.fill_between(x_values,
                              mean_vals - 1.96 * sem_vals,
                              mean_vals + 1.96 * sem_vals,
                              color=cmap[color_idx], alpha=0.15)
            
            ax.set_title(feature_name, fontsize=8)
            ax.tick_params(labelsize=7)
            ax.grid(True, linestyle='--', alpha=0.3)
        
        group_row += 1
    
    # Add legend
    legend_ax = plt.subplot(7, 7, 49)
    for color_idx, grad_val in enumerate(gradient_values):
        legend_ax.plot([], [], color=cmap[color_idx], linewidth=3, 
                      label=f"{grad_val:.2f}")
    legend_ax.legend(title=gradient_param, loc='center', fontsize=8)
    legend_ax.axis('off')
    
    plt.tight_layout(rect=[0, 0, 1, 0.99])
    plt.show()

## Visualization 3: Reward History Effects

Deep dive into the critical reward history features that distinguish reward_bump from failure_bump.

In [None]:
# === Visualization 3: Reward History Deep Dive ===

fig, axes = plt.subplots(2, 2, figsize=(14, 10))
fig.suptitle('Reward History Effects: Critical Features for Bump Parameters', 
            fontsize=14, fontweight='bold')

# Feature indices for reward history
reward_features = [7, 8, 9, 10]  # After reward mean/std, after failure mean/std
params_to_show = ['reward_bump', 'failure_bump']

for param_idx, param_x in enumerate(params_to_show):
    other_params = [p for j, p in enumerate(theta_labels) if p != param_x]
    gradient_param = other_params[0]
    
    cmap = plt.cm.viridis(np.linspace(0, 1, len(gradient_values)))
    
    for feat_idx, feature_idx in enumerate(reward_features):
        ax = axes[param_idx, feat_idx]
        feature_name = FEATURE_NAMES[feature_idx]
        
        for color_idx, grad_val in enumerate(gradient_values):
            key = f"{grad_val:.4f}"
            mean_vals = results_mean[param_x][key][:, feature_idx]
            sem_vals = results_sem[param_x][key][:, feature_idx]
            
            ax.plot(x_values, mean_vals, color=cmap[color_idx], 
                   marker='o', markersize=5, linewidth=2.5,
                   label=f"{gradient_param}={grad_val:.2f}")
            ax.fill_between(x_values,
                          mean_vals - 1.96 * sem_vals,
                          mean_vals + 1.96 * sem_vals,
                          color=cmap[color_idx], alpha=0.2)
        
        ax.set_title(f"{param_x} → {feature_name}", fontsize=11, fontweight='bold')
        ax.set_xlabel(param_x, fontsize=10)
        ax.set_ylabel('Time (s)', fontsize=10)
        ax.grid(True, linestyle='--', alpha=0.3)
        
        if feat_idx == 0:
            ax.legend(fontsize=8, loc='best')

plt.tight_layout(rect=[0, 0, 1, 0.97])
plt.show()

print("\nKey Insight:")
print("  • reward_bump should primarily affect 'mean_time_after_reward' (feature 7)")
print("  • failure_bump should primarily affect 'mean_time_after_failure' (feature 8)")
print("  • These features enable parameter identifiability!")

## Visualization 4: Feature Sensitivity Heatmap

Show which features are most sensitive to each parameter.

In [None]:
# === Visualization 4: Sensitivity Heatmap ===

# Compute sensitivity as the range (max - min) of each feature
# across parameter values (using middle gradient value)
sensitivity_matrix = np.zeros((len(theta_labels), len(FEATURE_NAMES)))

middle_gradient_idx = len(gradient_values) // 2
middle_gradient_val = gradient_values[middle_gradient_idx]
key = f"{middle_gradient_val:.4f}"

for param_idx, param_x in enumerate(theta_labels):
    means = results_mean[param_x][key]
    for feat_idx in range(len(FEATURE_NAMES)):
        # Compute normalized range
        feat_values = means[:, feat_idx]
        feat_range = np.max(feat_values) - np.min(feat_values)
        feat_mean = np.mean(feat_values)
        # Normalize by mean to get relative sensitivity
        sensitivity_matrix[param_idx, feat_idx] = feat_range / (np.abs(feat_mean) + 1e-6)

# Plot heatmap
fig, ax = plt.subplots(figsize=(16, 6))

im = ax.imshow(sensitivity_matrix, aspect='auto', cmap='YlOrRd', interpolation='nearest')

# Set ticks and labels
ax.set_xticks(np.arange(len(FEATURE_NAMES)))
ax.set_yticks(np.arange(len(theta_labels)))
ax.set_xticklabels(FEATURE_NAMES, rotation=90, fontsize=8)
ax.set_yticklabels(theta_labels, fontsize=10)

# Add colorbar
cbar = plt.colorbar(im, ax=ax)
cbar.set_label('Relative Sensitivity (Range/Mean)', rotation=270, labelpad=20)

# Add title
ax.set_title('Feature Sensitivity to Parameters (Higher = More Informative)', 
            fontsize=12, fontweight='bold', pad=20)

# Add grid
ax.set_xticks(np.arange(len(FEATURE_NAMES)) - 0.5, minor=True)
ax.set_yticks(np.arange(len(theta_labels)) - 0.5, minor=True)
ax.grid(which='minor', color='white', linestyle='-', linewidth=1)

plt.tight_layout()
plt.show()

# Print top features for each parameter
print("\nTop 5 Most Sensitive Features per Parameter:")
print("="*70)
for param_idx, param_x in enumerate(theta_labels):
    top_indices = np.argsort(sensitivity_matrix[param_idx])[-5:][::-1]
    print(f"\n{param_x}:")
    for rank, idx in enumerate(top_indices, 1):
        sens = sensitivity_matrix[param_idx, idx]
        print(f"  {rank}. {FEATURE_NAMES[idx]:30s} (sensitivity: {sens:.3f})")

## Visualization 5: Pairwise Feature Relationships

Examine correlations between key features to understand redundancy.

In [None]:
# === Visualization 5: Feature Correlation Analysis ===

# Collect all feature values across all simulations
all_features = []

for param_x in theta_labels:
    for key in results_mean[param_x].keys():
        # Get all x-values for this gradient level
        all_features.append(results_mean[param_x][key])

all_features = np.vstack(all_features)

# Compute correlation matrix
corr_matrix = np.corrcoef(all_features.T)

# Plot correlation heatmap
fig, ax = plt.subplots(figsize=(18, 16))

im = ax.imshow(corr_matrix, cmap='coolwarm', vmin=-1, vmax=1, aspect='auto')

# Set ticks
ax.set_xticks(np.arange(len(FEATURE_NAMES)))
ax.set_yticks(np.arange(len(FEATURE_NAMES)))
ax.set_xticklabels(FEATURE_NAMES, rotation=90, fontsize=9)
ax.set_yticklabels(FEATURE_NAMES, fontsize=9)

# Add colorbar
cbar = plt.colorbar(im, ax=ax)
cbar.set_label('Correlation', rotation=270, labelpad=20)

# Add title
ax.set_title('Feature Correlation Matrix (All Parameters)', 
            fontsize=14, fontweight='bold', pad=20)

# Add grid lines at group boundaries
group_boundaries = [0]
for group_indices in FEATURE_GROUPS.values():
    group_boundaries.append(group_boundaries[-1] + len(group_indices))

for boundary in group_boundaries[1:-1]:
    ax.axhline(boundary - 0.5, color='black', linewidth=2)
    ax.axvline(boundary - 0.5, color='black', linewidth=2)

plt.tight_layout()
plt.show()

# Find highly correlated feature pairs (|corr| > 0.9)
high_corr_pairs = []
for i in range(len(FEATURE_NAMES)):
    for j in range(i+1, len(FEATURE_NAMES)):
        if abs(corr_matrix[i, j]) > 0.9:
            high_corr_pairs.append((FEATURE_NAMES[i], FEATURE_NAMES[j], corr_matrix[i, j]))

if high_corr_pairs:
    print("\nHighly Correlated Features (|r| > 0.9):")
    print("="*70)
    for feat1, feat2, corr in sorted(high_corr_pairs, key=lambda x: abs(x[2]), reverse=True):
        print(f"  {feat1:30s} <-> {feat2:30s}: r={corr:+.3f}")
    print("\nNote: Highly correlated features may be redundant for inference.")
else:
    print("\n✓ No highly correlated features (|r| > 0.9)")
    print("  This is good - features are relatively independent!")