# Enhanced Feature Visualizations (37 Features)

This notebook visualizes how all 37 summary statistics respond to changes in DDM parameters.

**Feature Groups:**
- Basic (7): Max, mean, std of times, stops, rewards
- Reward History (4): Behavior after reward vs failure
- Temporal (5): Early/late session dynamics
- Distribution (4): Percentiles and shape
- Sequential (3): Autocorrelation and persistence
- Reward Stats (3): Detailed reward behavior
- Patch Stats (3): Exit patterns

In [None]:
import os
# Force CPU backend on Apple Silicon to avoid Metal issues
os.environ['JAX_PLATFORMS'] = 'cpu'

# === Setup ===
import numpy as np
import matplotlib.pyplot as plt
import jax.numpy as jnp
from jax import random
import seaborn as sns

from aind_behavior_vrforaging_analysis.sbi_ddm_analysis.simulator import PatchForagingDDM_JAX, create_prior

# Reproducibility
np.random.seed(1)

# Initialize simulator
simulator = PatchForagingDDM_JAX(max_sites_per_window=100)
rng_key = random.PRNGKey(42)

FEATURE_NAMES = [
    # Basic (7)
    "max_time", "mean_time", "std_time", 
    "mean_stops", "std_stops", "mean_rewards", "std_rewards",
    
    # Reward history (4) - KEY FEATURES
    "mean_time_after_reward",    
    "mean_time_after_failure",   
    "std_time_after_reward",
    "std_time_after_failure",
    
    # Temporal (5)
    "early_mean", "late_mean", "temporal_trend",
    "late_minus_early", "middle_mean",
    
    # Distribution (4)
    "p25", "median", "p75", "iqr",
    
    # Sequential (3)
    "autocorr_lag1", "diff_std", "mean_abs_change",
    
    # Reward stats (3)
    "reward_rate", "mean_reward_trial", "prop_patches_with_reward",
    
    # Patch stats (3)
    "n_patches", "mean_sites_per_patch", "stop_rate",
    
    # Consistency stats (6) - NEW: Distinguish systematic effects from noise
    "after_reward_cv",              # Coefficient of variation after rewards
    "after_failure_cv",             # Coefficient of variation after failures
    "transition_reliability",       # Consistency of reward→failure transitions
    "reward_effect_predictability", # How well rewards predict next duration
    "mean_local_std",              # Average local variability
    "signal_to_noise",             # Between-context / within-context variance
    "failure_effect",               #Time change from average after a failure
    "reward_effect"                 #Time change from average after a reward
]

# Feature groups for organized visualization
FEATURE_GROUPS = {
    "Basic": list(range(0, 7)),
    "Reward History": list(range(7, 11)),
    "Temporal": list(range(11, 16)),
    "Distribution": list(range(16, 20)),
    "Sequential": list(range(20, 23)),
    "Reward Stats": list(range(23, 26)),
    "Patch Stats": list(range(26, 37)),
}

def simulate_and_extract(theta, rng_key, window_sites=100):
    """
    Runs one simulation for given theta and extracts features.
    Returns (window_data, summary_stats, new_rng_key)
    """
    rng_key, subkey = random.split(rng_key)
    window_data, summary_stats = simulator.simulate_one_window(theta, subkey)
    return window_data, summary_stats, rng_key

print(f"✓ Loaded simulator with {len(FEATURE_NAMES)} features")
print(f"✓ Feature groups: {list(FEATURE_GROUPS.keys())}")

In [None]:
# === Sweep Setup ===
theta_labels = ["drift_rate", "reward_bump", "failure_bump", "noise_std"]
theta_base = jnp.array([0.4, 0.3, 0.1, 0.1])  # Base parameters

n_repeats = 20  # Runs per parameter combination
gradient_values = np.linspace(0.01, 1.0, 5)  # 5 levels for gradient
x_values = np.linspace(0.01, 1.0, 10)  # 10 points for x-axis

# Storage for results
results_mean = {label: {} for label in theta_labels}
results_sem = {label: {} for label in theta_labels}

print(f"Will run {len(theta_labels)} parameter sweeps")
print(f"Each sweep: {len(gradient_values)} gradients × {len(x_values)} x-values × {n_repeats} repeats")
print(f"Total simulations: {len(theta_labels) * len(gradient_values) * len(x_values) * n_repeats}")

In [None]:
# === Run Parameter Sweeps ===
import time

start_time = time.time()

for param_idx, param_x in enumerate(theta_labels):
    print(f"\nSweeping {param_x} ({param_idx+1}/{len(theta_labels)})...")
    
    # Get other parameters (for gradient and fixed)
    other_params = [p for j, p in enumerate(theta_labels) if j != param_idx]
    gradient_param_idx = theta_labels.index(other_params[0])  # First other param varies
    
    for grad_val in gradient_values:
        mean_list, sem_list = [], []
        
        for x_val in x_values:
            # Build theta for this combination
            theta = theta_base.copy()
            theta = theta.at[param_idx].set(x_val)  # X-axis parameter
            theta = theta.at[gradient_param_idx].set(grad_val)  # Gradient parameter
            # Other parameters stay at base values
            
            # Run multiple simulations
            runs = []
            for _ in range(n_repeats):
                _, summary, rng_key = simulate_and_extract(theta, rng_key)
                runs.append(np.array(summary))
            
            runs = np.vstack(runs)
            mean_list.append(runs.mean(axis=0))
            sem_list.append(runs.std(axis=0, ddof=1) / np.sqrt(n_repeats))
        
        # Store results
        key = f"{grad_val:.4f}"
        results_mean[param_x][key] = np.vstack(mean_list)
        results_sem[param_x][key] = np.vstack(sem_list)
    
    elapsed = time.time() - start_time
    print(f"  Completed in {elapsed:.1f}s")

total_time = time.time() - start_time
print(f"\n✓ All sweeps completed in {total_time/60:.1f} minutes")

## Visualization 1: Key Features by Parameter

Focus on the most important features for each parameter:
- **drift_rate**: Basic time statistics, temporal trends
- **reward_bump**: Reward history effects (mean_time_after_reward)
- **failure_bump**: Reward history effects (mean_time_after_failure)
- **noise_std**: Distribution shape, sequential dependencies

In [None]:
# === Visualization 1: Most Important Features ===

# Define key features to visualize for each parameter
KEY_FEATURES = {
    "drift_rate": [1, 2, 11, 12, 13],  # mean_time, std_time, early_mean, late_mean, trend
    "reward_bump": [7, 8, 9, 14, 23],  # After reward/failure, late-early, reward_rate
    "failure_bump": [19, 8, 9, 10, 14], # After reward/failure stats
    "noise_std": [2, 16, 18, 19, 20],  # std_time, p25, p75, iqr, autocorr
}

for param_x in theta_labels:
    other_params = [p for j, p in enumerate(theta_labels) if p != param_x]
    gradient_param = other_params[0]
    
    # Get key features for this parameter
    feature_indices = KEY_FEATURES[param_x]
    
    fig, axes = plt.subplots(1, len(feature_indices), figsize=(4*len(feature_indices), 4))
    if len(feature_indices) == 1:
        axes = [axes]
    
    fig.suptitle(f"Effect of {param_x} (gradient: {gradient_param})", fontsize=14, fontweight='bold')
    
    cmap = plt.cm.viridis(np.linspace(0, 1, len(gradient_values)))
    
    for ax_idx, feature_idx in enumerate(feature_indices):
        ax = axes[ax_idx]
        feature_name = FEATURE_NAMES[feature_idx]
        
        for color_idx, grad_val in enumerate(gradient_values):
            key = f"{grad_val:.4f}"
            mean_vals = results_mean[param_x][key][:, feature_idx]
            sem_vals = results_sem[param_x][key][:, feature_idx]
            
            ax.plot(x_values, mean_vals, color=cmap[color_idx], 
                   marker='o', markersize=4, linewidth=2,
                   label=f"{grad_val:.2f}")
            ax.fill_between(x_values,
                          mean_vals - 1.96 * sem_vals,
                          mean_vals + 1.96 * sem_vals,
                          color=cmap[color_idx], alpha=0.2)
        
        ax.set_title(feature_name, fontsize=11, fontweight='bold')
        ax.set_xlabel(param_x, fontsize=10)
        ax.set_ylabel('Value', fontsize=10)
        ax.grid(True, linestyle='--', alpha=0.3)
    
    # Add legend
    handles, labels = axes[0].get_legend_handles_labels()
    fig.legend(handles, labels, loc='center left', bbox_to_anchor=(1, 0.5),
              title=gradient_param, fontsize=9)
    
    plt.tight_layout(rect=[0, 0, 0.95, 0.96])
    plt.show()

## Visualization 2: Feature Groups

Visualize all features organized by functional groups.

In [None]:
# === Visualization 2: All Features by Group ===

for param_x in theta_labels:
    other_params = [p for j, p in enumerate(theta_labels) if p != param_x]
    gradient_param = other_params[0]
    
    # Create one large figure with subplots for each feature group
    fig = plt.figure(figsize=(20, 12))
    fig.suptitle(f"All Features: {param_x} (gradient: {gradient_param})", 
                fontsize=16, fontweight='bold', y=0.995)
    
    cmap = plt.cm.viridis(np.linspace(0, 1, len(gradient_values)))
    
    # Create subplots for each group
    group_row = 0
    for group_name, feature_indices in FEATURE_GROUPS.items():
        n_features = len(feature_indices)
        
        # Add group title
        ax_title = plt.subplot(8, 7, group_row * 7 + 1)
        ax_title.text(0.5, 0.5, group_name, fontsize=12, fontweight='bold',
                     ha='center', va='center')
        ax_title.axis('off')
        
        # Plot features in this group
        for i, feature_idx in enumerate(feature_indices):
            ax = plt.subplot(8, 7, group_row * 7 + i + 2)
            feature_name = FEATURE_NAMES[feature_idx]
            
            for color_idx, grad_val in enumerate(gradient_values):
                key = f"{grad_val:.4f}"
                mean_vals = results_mean[param_x][key][:, feature_idx]
                sem_vals = results_sem[param_x][key][:, feature_idx]
                
                ax.plot(x_values, mean_vals, color=cmap[color_idx], 
                       marker='o', markersize=3, linewidth=1.5, alpha=0.8)
                ax.fill_between(x_values,
                              mean_vals - 1.96 * sem_vals,
                              mean_vals + 1.96 * sem_vals,
                              color=cmap[color_idx], alpha=0.15)
            
            ax.set_title(feature_name, fontsize=8)
            ax.tick_params(labelsize=7)
            ax.grid(True, linestyle='--', alpha=0.3)
        
        group_row += 1
    
    # Add legend
    legend_ax = plt.subplot(7, 7, 49)
    for color_idx, grad_val in enumerate(gradient_values):
        legend_ax.plot([], [], color=cmap[color_idx], linewidth=3, 
                      label=f"{grad_val:.2f}")
    legend_ax.legend(title=gradient_param, loc='center', fontsize=8)
    legend_ax.axis('off')
    
    plt.tight_layout(rect=[0, 0, 1, 0.99])
    plt.show()

## Visualization 3: Reward History Effects

Deep dive into the critical reward history features that distinguish reward_bump from failure_bump.

In [None]:
# === Visualization 3: Reward History Deep Dive ===

fig, axes = plt.subplots(2, 2, figsize=(14, 10))
fig.suptitle('Reward History Effects: Critical Features for Bump Parameters', 
            fontsize=14, fontweight='bold')

# Feature indices for reward history
reward_features = [7, 8, 9, 10]  # After reward mean/std, after failure mean/std
params_to_show = ['reward_bump', 'failure_bump']

for param_idx, param_x in enumerate(params_to_show):
    other_params = [p for j, p in enumerate(theta_labels) if p != param_x]
    gradient_param = other_params[0]
    
    cmap = plt.cm.viridis(np.linspace(0, 1, len(gradient_values)))
    
    for feat_idx, feature_idx in enumerate(reward_features):
        ax = axes[param_idx, feat_idx]
        feature_name = FEATURE_NAMES[feature_idx]
        
        for color_idx, grad_val in enumerate(gradient_values):
            key = f"{grad_val:.4f}"
            mean_vals = results_mean[param_x][key][:, feature_idx]
            sem_vals = results_sem[param_x][key][:, feature_idx]
            
            ax.plot(x_values, mean_vals, color=cmap[color_idx], 
                   marker='o', markersize=5, linewidth=2.5,
                   label=f"{gradient_param}={grad_val:.2f}")
            ax.fill_between(x_values,
                          mean_vals - 1.96 * sem_vals,
                          mean_vals + 1.96 * sem_vals,
                          color=cmap[color_idx], alpha=0.2)
        
        ax.set_title(f"{param_x} → {feature_name}", fontsize=11, fontweight='bold')
        ax.set_xlabel(param_x, fontsize=10)
        ax.set_ylabel('Time (s)', fontsize=10)
        ax.grid(True, linestyle='--', alpha=0.3)
        
        if feat_idx == 0:
            ax.legend(fontsize=8, loc='best')

plt.tight_layout(rect=[0, 0, 1, 0.97])
plt.show()

print("\nKey Insight:")
print("  • reward_bump should primarily affect 'mean_time_after_reward' (feature 7)")
print("  • failure_bump should primarily affect 'mean_time_after_failure' (feature 8)")
print("  • These features enable parameter identifiability!")

## Visualization 4: Feature Sensitivity Heatmap

Show which features are most sensitive to each parameter.

In [None]:
# === Visualization 4: Sensitivity Heatmap ===

# Compute sensitivity as the range (max - min) of each feature
# across parameter values (using middle gradient value)
sensitivity_matrix = np.zeros((len(theta_labels), len(FEATURE_NAMES)))

middle_gradient_idx = len(gradient_values) // 2
middle_gradient_val = gradient_values[middle_gradient_idx]
key = f"{middle_gradient_val:.4f}"

for param_idx, param_x in enumerate(theta_labels):
    means = results_mean[param_x][key]
    for feat_idx in range(len(FEATURE_NAMES)):
        # Compute normalized range
        feat_values = means[:, feat_idx]
        feat_range = np.max(feat_values) - np.min(feat_values)
        feat_mean = np.mean(feat_values)
        # Normalize by mean to get relative sensitivity
        sensitivity_matrix[param_idx, feat_idx] = feat_range / (np.abs(feat_mean) + 1e-6)

# Plot heatmap
fig, ax = plt.subplots(figsize=(16, 6))

im = ax.imshow(sensitivity_matrix, aspect='auto', cmap='YlOrRd', interpolation='nearest')

# Set ticks and labels
ax.set_xticks(np.arange(len(FEATURE_NAMES)))
ax.set_yticks(np.arange(len(theta_labels)))
ax.set_xticklabels(FEATURE_NAMES, rotation=90, fontsize=8)
ax.set_yticklabels(theta_labels, fontsize=10)

# Add colorbar
cbar = plt.colorbar(im, ax=ax)
cbar.set_label('Relative Sensitivity (Range/Mean)', rotation=270, labelpad=20)

# Add title
ax.set_title('Feature Sensitivity to Parameters (Higher = More Informative)', 
            fontsize=12, fontweight='bold', pad=20)

# Add grid
ax.set_xticks(np.arange(len(FEATURE_NAMES)) - 0.5, minor=True)
ax.set_yticks(np.arange(len(theta_labels)) - 0.5, minor=True)
ax.grid(which='minor', color='white', linestyle='-', linewidth=1)

plt.tight_layout()
plt.show()

# Print top features for each parameter
print("\nTop 5 Most Sensitive Features per Parameter:")
print("="*70)
for param_idx, param_x in enumerate(theta_labels):
    top_indices = np.argsort(sensitivity_matrix[param_idx])[-5:][::-1]
    print(f"\n{param_x}:")
    for rank, idx in enumerate(top_indices, 1):
        sens = sensitivity_matrix[param_idx, idx]
        print(f"  {rank}. {FEATURE_NAMES[idx]:30s} (sensitivity: {sens:.3f})")

## Visualization 5: Pairwise Feature Relationships

Examine correlations between key features to understand redundancy.

In [None]:
# === Visualization 5: Feature Correlation Analysis ===

# Collect all feature values across all simulations
all_features = []

for param_x in theta_labels:
    for key in results_mean[param_x].keys():
        # Get all x-values for this gradient level
        all_features.append(results_mean[param_x][key])

all_features = np.vstack(all_features)

# Compute correlation matrix
corr_matrix = np.corrcoef(all_features.T)

# Plot correlation heatmap
fig, ax = plt.subplots(figsize=(18, 16))

im = ax.imshow(corr_matrix, cmap='coolwarm', vmin=-1, vmax=1, aspect='auto')

# Set ticks
ax.set_xticks(np.arange(len(FEATURE_NAMES)))
ax.set_yticks(np.arange(len(FEATURE_NAMES)))
ax.set_xticklabels(FEATURE_NAMES, rotation=90, fontsize=9)
ax.set_yticklabels(FEATURE_NAMES, fontsize=9)

# Add colorbar
cbar = plt.colorbar(im, ax=ax)
cbar.set_label('Correlation', rotation=270, labelpad=20)

# Add title
ax.set_title('Feature Correlation Matrix (All Parameters)', 
            fontsize=14, fontweight='bold', pad=20)

# Add grid lines at group boundaries
group_boundaries = [0]
for group_indices in FEATURE_GROUPS.values():
    group_boundaries.append(group_boundaries[-1] + len(group_indices))

for boundary in group_boundaries[1:-1]:
    ax.axhline(boundary - 0.5, color='black', linewidth=2)
    ax.axvline(boundary - 0.5, color='black', linewidth=2)

plt.tight_layout()
plt.show()

# Find highly correlated feature pairs (|corr| > 0.9)
high_corr_pairs = []
for i in range(len(FEATURE_NAMES)):
    for j in range(i+1, len(FEATURE_NAMES)):
        if abs(corr_matrix[i, j]) > 0.9:
            high_corr_pairs.append((FEATURE_NAMES[i], FEATURE_NAMES[j], corr_matrix[i, j]))

if high_corr_pairs:
    print("\nHighly Correlated Features (|r| > 0.9):")
    print("="*70)
    for feat1, feat2, corr in sorted(high_corr_pairs, key=lambda x: abs(x[2]), reverse=True):
        print(f"  {feat1:30s} <-> {feat2:30s}: r={corr:+.3f}")
    print("\nNote: Highly correlated features may be redundant for inference.")
else:
    print("\n✓ No highly correlated features (|r| > 0.9)")
    print("  This is good - features are relatively independent!")