In [None]:
import os
# Force CPU backend on Apple Silicon to avoid Metal issues
os.environ['JAX_PLATFORMS'] = 'cpu'

import matplotlib.pyplot as plt

# Disable LaTeX rendering in matplotlib
plt.rcParams["text.usetex"] = False
plt.rcParams["font.family"] = "sans-serif"

from jax import random
from jax import numpy as jnp
from sbijax import plot_loss_profile
import matplotlib.pyplot as plt
import numpy as np
import matplotlib as mpl

from aind_behavior_vrforaging_analysis.sbi_ddm_analysis.snle.snle_utils_jax import plot_real_synth_hist, extract_samples
# from aind_behavior_vrforaging_analysis.sbi_ddm_analysis.enhanced_stats_37 import FEATURE_NAMES

In [None]:
import pickle
from aind_behavior_vrforaging_analysis.sbi_ddm_analysis.simulator import PatchForagingDDM_JAX, create_prior
from aind_behavior_vrforaging_analysis.sbi_ddm_analysis.snle.snle_inference_jax import infer_parameters_snle
from sbijax import NLE
from sbijax.nn import make_maf

# Load saved parameters 
model_name = 'snle_2M_h128_l8_b256_37feat.pkl'# 'snle_2M_h128_l8_b256_1feat' #''snle_2M_h128_l8_b1024_300feat' # 'snle_2M_h128_l8_b2048_300feat' #'snle_5M_h128_l8_b256_37feat' #'snle_2M_h64_l5_b256_37feat'
model_name = 'snle_2M_lr0.0001_ts2000_h128_l8_b256_37feat/model.pkl'
model_name = 'snle_2M_lr0.0005_ts2000_h128_l8_b256_37feat_0/model'
with open(f'/Users/laura.driscoll/Documents/code/sbi_results/{model_name}.pkl', 'rb') as f:
    saved_data = pickle.load(f)

# Reconstruct the model
simulator = PatchForagingDDM_JAX(
    max_sites_per_window=100,
    interval_normalization=saved_data['config']['interval_normalization']
)
prior_fn = create_prior(
    prior_low=jnp.array(saved_data['config']['prior_low']),
    prior_high=jnp.array(saved_data['config']['prior_high'])
)

rng_key = random.PRNGKey(saved_data['config']['seed']+1)
rng_key, test_key = random.split(rng_key)
test_theta = prior_fn().sample(seed=test_key)
test_x = simulator.simulator_fn(seed=test_key, theta=test_theta)
n_features = test_x.shape[-1]

# Rebuild the flow architecture (must match training)
flow = make_maf(
    n_dimension=n_features,  # obs data dimension
    n_layers=saved_data['config']['num_layers'],
    hidden_sizes=(saved_data['config']['hidden_dim'], saved_data['config']['hidden_dim'])
)

# Create SNLE model
fns = prior_fn, simulator.simulator_fn
snle = NLE(fns, flow)

print("Model reconstructed! Ready for inference.")

In [None]:
# Disable LaTeX rendering in matplotlib
plt.rcParams["text.usetex"] = False
plt.rcParams["font.family"] = "sans-serif"

_, axes = plt.subplots(figsize=(6, 3))
plot_loss_profile(saved_data['losses'], axes)
axes.set_ylim(bottom=-200,top=0)
plt.show()

In [None]:
param_names = ["drift_rate", "reward_bump", "failure_bump", "noise_std"]
param_labels = [
    "drift_rate: evidence accumulation rate",
    "reward_bump: evidence boost from receiving reward",
    "failure_bump: evidence boost from not receiving reward",
    "noise_std: std of noise in evidence accumulation"
]

number_of_samples = 5
cmap = mpl.colormaps['rainbow']

# 3. Sample the colormap at regular intervals from 0 to 1
# This returns an array of RGBA tuples
gradient = np.linspace(0, 1, number_of_samples)
colors_rgba = cmap(gradient)
drift_values = np.linspace(0, 1, number_of_samples)*2

fig, axes = plt.subplots(1, 4, figsize=(10, 2))
axes = axes.flatten()

# --- Simulate observed data ---
print("\n2. Simulating observed data...")
rng_key = random.PRNGKey(88)
rng_key, subkey = random.split(rng_key)
true_theta = prior_fn().sample(seed=subkey)['theta']

for drift_i, drift in enumerate(drift_values):
    true_theta = true_theta.at[0].set(drift)
    rng_key, subkey = random.split(rng_key)
    _, observed_stats = simulator.simulate_one_window(true_theta, subkey)
    print(f"   True theta: {true_theta}")

    # --- Run inference ---
    print("\n3. Testing inference...")
    rng_key, subkey = random.split(rng_key)
    posterior_samples, diagnostics = infer_parameters_snle(
    snle,
    saved_data['snle_params'],
    observed_stats, 
    saved_data['y_mean'], saved_data['y_std'],
    num_samples=5_000,
    num_warmup=50,
    num_chains=2,
    rng_key=subkey
    )

    # --- Plot posterior distributions ---
    for i in range(4):

        # Compute histogram
        counts, bins, _ =axes[i].hist(posterior_samples[:, i], bins=30, color=colors_rgba[drift_i], edgecolor=None, alpha=0.3)

        # Posterior mode (bin center with max count)
        mode_index = jnp.argmax(counts)
        posterior_mode = (bins[mode_index] + bins[mode_index + 1]) / 2

        axes[i].axvline(true_theta[i], color=colors_rgba[drift_i], linestyle='-', label="true value",alpha=0.5)
        axes[i].axvline(posterior_mode, color=colors_rgba[drift_i], linestyle=':', label='MAP estimate',alpha=0.5)

        axes[i].set_xlabel(param_names[i])
        axes[i].set_ylabel("Frequency")

        axes[i].set_xlim(saved_data['config']['prior_low'][i], saved_data['config']['prior_high'][i])

    axes[i].legend(loc='center left', bbox_to_anchor=(1, 0.5))

plt.tight_layout()
plt.show()


In [None]:
vary_theta = 1
param_names = ["drift_rate", "reward_bump", "failure_bump", "noise_std"]
param_labels = [
    "drift_rate: evidence accumulation rate",
    "reward_bump: evidence boost from receiving reward",
    "failure_bump: evidence boost from not receiving reward",
    "noise_std: std of noise in evidence accumulation"
]

number_of_samples = 5
cmap = mpl.colormaps['rainbow']

# 3. Sample the colormap at regular intervals from 0 to 1
# This returns an array of RGBA tuples
gradient = np.linspace(0, 1, number_of_samples)
colors_rgba = cmap(gradient)
theta_values = np.linspace(saved_data['config']['prior_low'][vary_theta], 
                           saved_data['config']['prior_high'][vary_theta], 
                           number_of_samples)

fig, axes = plt.subplots(1, 4, figsize=(10, 2))
axes = axes.flatten()

# --- Simulate observed data ---
print("\n2. Simulating observed data...")
rng_key = random.PRNGKey(111)
rng_key, subkey = random.split(rng_key)
true_theta = prior_fn().sample(seed=subkey)['theta']

for theta_i, theta in enumerate(theta_values):
    true_theta = true_theta.at[vary_theta].set(theta)
    rng_key, subkey = random.split(rng_key)
    _, observed_stats = simulator.simulate_one_window(true_theta, subkey)
    print(f"   True theta: {true_theta}")

    # --- Run inference ---
    print("\n3. Testing inference...")
    rng_key, subkey = random.split(rng_key)
    posterior_samples, diagnostics = infer_parameters_snle(
    snle,
    saved_data['snle_params'],
    observed_stats, 
    saved_data['y_mean'], saved_data['y_std'],
    num_samples=1000,
    num_warmup=50,
    num_chains=2,
    rng_key=subkey
    )

    # --- Plot posterior distributions ---
    for i in range(4):

        # Compute histogram
        counts, bins, _ =axes[i].hist(posterior_samples[:, i], bins=30, color=colors_rgba[theta_i], edgecolor=None, alpha=0.3)

        # Posterior mode (bin center with max count)
        mode_index = jnp.argmax(counts)
        posterior_mode = (bins[mode_index] + bins[mode_index + 1]) / 2

        axes[i].axvline(true_theta[i], color=colors_rgba[theta_i], linestyle='-', label="true value",alpha=0.5)
        axes[i].axvline(posterior_mode, color=colors_rgba[theta_i], linestyle=':', label='MAP estimate',alpha=0.5)

        axes[i].set_xlabel(param_names[i])
        axes[i].set_ylabel("Frequency")

        axes[i].set_xlim(saved_data['config']['prior_low'][i], saved_data['config']['prior_high'][i])

    axes[i].legend(loc='center left', bbox_to_anchor=(1, 0.5))

plt.tight_layout()
plt.show()


In [None]:
import seaborn as sns
import pandas as pd
import numpy as np
from scipy.stats import gaussian_kde

def pairplot(posterior_samples, true_params=None, param_names=None, 
             figsize_per_param=2.5, grid_points=100, fig=None, axes=None,
             color_i = 1):
    """
    Lower-triangle corner plot with:
    - 2D filled KDEs (off-diagonal)
    - 1D KDEs (diagonal)
    - Red 'X' for true parameters
    """

    cmaps = ['Purples', 'Blues', 'Greens', 'Oranges', 'Reds']
    color = plt.cm.get_cmap(cmaps[color_i])(1.)

    if isinstance(posterior_samples, jnp.ndarray):
        posterior_samples = np.array(posterior_samples)
    
    n_params = posterior_samples.shape[1]
    if param_names is None:
        param_names = [f"param{i}" for i in range(n_params)]
    
    if fig is None or axes is None:
        fig, axes = plt.subplots(n_params, n_params, figsize=(figsize_per_param*n_params, figsize_per_param*n_params))
    
    for i in range(n_params):
        for j in range(n_params):
            ax = axes[i, j]
            
            # Only fill lower triangle
            if i < j:
                ax.axis('off')
                continue
            
            # Diagonal: 1D KDE
            if i == j:
                data = posterior_samples[:, i]
                kde = gaussian_kde(data)
                x_grid = np.linspace(data.min(), data.max(), grid_points)
                ax.fill_between(x_grid, kde(x_grid), color=color, alpha=0.5)
                
                if true_params is not None:
                    ax.axvline(true_params[i], color=color, linestyle='-', lw=1)
            
            # Off-diagonal: 2D KDE
            else:
                x = posterior_samples[:, j]
                y = posterior_samples[:, i]
                xy = np.vstack([x, y])
                kde = gaussian_kde(xy)
                x_grid = np.linspace(x.min(), x.max(), grid_points)
                y_grid = np.linspace(y.min(), y.max(), grid_points)
                X, Y = np.meshgrid(x_grid, y_grid)
                Z = kde(np.vstack([X.ravel(), Y.ravel()])).reshape(X.shape)

                # Define the threshold
                threshold = .5

                # Mask values below the threshold
                Z_masked = np.ma.masked_where(Z < threshold, Z)

                ax.contourf(X, Y, Z_masked, levels=20, cmap=cmaps[color_i], alpha=0.7)
                
                if true_params is not None:
                    ax.scatter(true_params[j], true_params[i], c=color, s=50, marker='X', label='True')
            
            # Only label left and bottom axes
            if i < n_params - 1:
                ax.set_xticklabels([])
            else:
                ax.set_xlabel(param_names[j])
            if j > 0:
                ax.set_yticklabels([])
            else:
                ax.set_ylabel(param_names[i])
    
    # Add a legend in the top-left subplot
    handles = []
    if true_params is not None:
        handles.append(plt.Line2D([0], [0], marker='X', color='w', markerfacecolor=color, markersize=8, label='True'))
    axes[0, 1].legend(handles=handles, loc='upper left')
    

In [None]:
figsize_per_param=2.0
n_params = posterior_samples.shape[1]
fig, axes = plt.subplots(n_params, n_params, figsize=(figsize_per_param*n_params, figsize_per_param*n_params))
pairplot(posterior_samples, true_theta, param_names, figsize_per_param=figsize_per_param, fig=fig, axes=axes)

plt.tight_layout()
plt.show()

In [None]:
vary_theta = 0
number_of_samples = 5
param_names = ["drift_rate", "reward_bump", "failure_bump", "noise_std"]
param_labels = [
    "drift_rate: evidence accumulation rate",
    "reward_bump: evidence boost from receiving reward",
    "failure_bump: evidence boost from not receiving reward",
    "noise_std: std of noise in evidence accumulation"
]

cmap = mpl.colormaps['rainbow']
# 3. Sample the colormap at regular intervals from 0 to 1
# This returns an array of RGBA tuples
gradient = np.linspace(0, 1, number_of_samples+2)[1:-1]
colors_rgba = cmap(gradient)
theta_values = np.linspace(saved_data['config']['prior_low'][vary_theta], 
                           saved_data['config']['prior_high'][vary_theta], 
                           number_of_samples+2)[1:-1]

figsize_per_param=2.0
n_params = posterior_samples.shape[1]
fig, axes = plt.subplots(n_params, n_params, figsize=(figsize_per_param*n_params, figsize_per_param*n_params))

# --- Simulate observed data ---
print("\n2. Simulating observed data...")
rng_key = random.PRNGKey(666)
rng_key, subkey = random.split(rng_key)
true_theta = prior_fn().sample(seed=subkey)['theta']

for theta_i, theta in enumerate(theta_values):
    true_theta = true_theta.at[vary_theta].set(theta)
    rng_key, subkey = random.split(rng_key)
    _, observed_stats = simulator.simulate_one_window(true_theta, subkey)
    print(f"   True theta: {true_theta}")

    # --- Run inference ---
    print("\n3. Testing inference...")
    rng_key, subkey = random.split(rng_key)
    posterior_samples, diagnostics = infer_parameters_snle(
    snle,
    saved_data['snle_params'],
    observed_stats, 
    saved_data['y_mean'], saved_data['y_std'],
    num_samples=1000,
    num_warmup=50,
    num_chains=2,
    rng_key=subkey
    )

    # --- Plot posterior distributions ---
    pairplot(posterior_samples, true_theta, param_names, 
             figsize_per_param=figsize_per_param, fig=fig, axes=axes, 
             color_i=theta_i)


plt.tight_layout()
plt.show()