#### Imports

In [8]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import torch
from torch import Tensor
from torch.distributions import Beta, Uniform, LogNormal, HalfNormal, TransformedDistribution, ExpTransform

import sbi
from sbi.utils import BoxUniform, MultipleIndependent
from sbi.inference import MNLE
import sbi.neural_nets as nn
from sbi.inference.posteriors.mcmc_posterior import MCMCPosterior
from sbi.analysis import sbc_rank_plot, pairplot

#### Define DDM model

In [None]:
# ========= pulse-DDM model parameters =========

DT = 0.005   # time step
T_MAX = 8.0  # total trial window
T0 = 0.1     # fixed non-decision time
T_DEC = T_MAX - T0 # diffusion process window
N_STEPS = int(T_DEC / DT)

PULSE_INTERVAL = 0.1
STEPS_PER_PULSE = int(PULSE_INTERVAL / DT)
N_PULSES = N_STEPS // STEPS_PER_PULSE + 1

In [None]:
def simulate_pulse_ddm_single(theta: np.ndarray, rng: np.random.Generator,) -> tuple[float, int]:
    """
    Simulate ONE trial from the pulse-DDM.
    theta = [a0, lam, nu, B, sigma]

    inputs:
        theta: numpy array of the form: [a0, lam, nu, B, sigma]
        rng: random number object

    Returns:
        rt: reaction time in seconds (float)
          choice_idx = 0 (left/lower bound), 1 (right/upper bound)
    """
    a0, lam, nu, B, sigma = theta.astype(np.float32)
    # keep these params positive
    B = abs(B)
    lam = abs(lam)
    nu = abs(nu)
    sigma = abs(sigma)

    # starting point = a0 * B
    a0 = np.clip(a0, 0.0, 1.0)
    a_curr = float(np.clip(a0, 0.0, 1.0) * B)

    # Sample pulse directions randomly: (-1 or 1)
    s_seq = np.where(rng.random(size=N_PULSES) < 0.5, 1.0, -1.0).astype(np.float32)

    # Precompute information jump schedule over time steps
    pulse_indices = np.arange(0, N_STEPS, STEPS_PER_PULSE, dtype=int)
    jump_t = np.zeros(N_STEPS, dtype=np.float32)
    s_for_schedule = s_seq[: pulse_indices.shape[0]]
    jump_t[pulse_indices] = s_for_schedule * nu

    done = False
    first_hit_time = -1.0
    choice_idx = None

    # Euler–Maruyama integration of diffusion SDE
    for t in range(N_STEPS):
        if done:
            break

        drift = -lam * a_curr * DT
        jump = jump_t[t]
        noise = sigma * np.sqrt(DT) * rng.normal()

        a_next = a_curr + drift + jump + noise

        # Hitting bounds: lower = 0, upper = B
        if a_next >= B:
            done = True
            choice_idx = 1  # upper
            first_hit_time = (t + 1) * DT

        elif a_next <= 0.0:
            done = True
            choice_idx = 0  # lower
            first_hit_time = (t + 1) * DT

        a_curr = a_next

    if not done:
        first_hit_time = T_DEC
        choice_idx = 1 if a_curr >= (B / 2.0) else 0

    rt = float(np.clip(T0 + first_hit_time, 1e-3, T_MAX))

    return rt, int(choice_idx)

In [None]:
def pulse_ddm_simulator_torch(theta: Tensor) -> Tensor:
    """
    Input:
        theta: torch.Tensor of parameters: [lam, nu, B, sigma]

    Returns:
        x: torch.Tensor with columns: [rt, choice]
    """
    theta_np = theta.detach().cpu().numpy().astype(np.float32)
    batch_size = theta_np.shape[0]
    xs = np.zeros((batch_size, 2), dtype=np.float32)

    rng = np.random.default_rng()

    for i in range(batch_size):
        rt, choice_idx = simulate_pulse_ddm_single(theta_np[i], rng)
        xs[i, 0] = rt
        xs[i, 1] = choice_idx

    return torch.from_numpy(xs).to(torch.float32)

#### Build NN for MNLE 

In [None]:
# ========= Define Priors over parameters (provide range for MCMC)=========

"""
paramter a0:

Introduce the bias parameter a0: Beta(2,2), weighted towards 0.5,
parameter nu: log(nu) ~ normal(log(1), 0.5)
paramter B: normal(log(2), 0.5)
paramter sigma: half normal, or exponential -- first try half normal (to keep positive)
"""
prior = MultipleIndependent(
    [
        # a0 ~ Beta(2,2)
        Beta(
            concentration1=torch.tensor([2.0], dtype=torch.float32),
            concentration0=torch.tensor([2.0], dtype=torch.float32),
        ),

        # lam ~ HalfNormal(0.5)
        HalfNormal(
            scale=torch.tensor([0.5], dtype=torch.float32),
        ),

        # nu ~ LogNormal(log 1, 0.5)
        LogNormal(
            loc=torch.tensor([np.log(1.0)], dtype=torch.float32),
            scale=torch.tensor([0.5], dtype=torch.float32),
        ),

        # B ~ LogNormal(log 2, 0.5)
        LogNormal(
            loc=torch.tensor([np.log(2.0)], dtype=torch.float32),
            scale=torch.tensor([0.5], dtype=torch.float32),
        ),

        # sigma ~ HalfNormal(0.5)
        HalfNormal(
            scale=torch.tensor([0.5], dtype=torch.float32),
        ),
    ],
    validate_args=False,
)

In [None]:
def train_mnle(
    num_simulations: int = 10_000,
    simulation_batch_size: int = 512,
    mcmc_kwargs: dict | None = None,
):
    """
    Train MNLE for the pulse-DDM and return an MCMC-based posterior object.

    The learned single-trial likelihood is:
        x = (rt, correctness_idx), correctness_idx ∈ {0,1}.
    """

    if mcmc_kwargs is None:
        mcmc_kwargs = dict(
            num_chains=50,
            warmup_steps=200,
            thin=5,
            init_strategy="proposal",
        )

    # Sample parameters from the prior -- specified by Ryan
    theta_train = prior.sample((num_simulations,)).to(torch.float32)
    x_train_list = []

    # generate single trial training data with the simulator
    for start in range(0, num_simulations, simulation_batch_size):
        batch = theta_train[start:start+simulation_batch_size]
        x_batch = pulse_ddm_simulator_torch(batch)
        x_train_list.append(x_batch)

    x_train = torch.cat(x_train_list, dim=0).to(torch.float32)

    # nan logic
    assert torch.isfinite(theta_train).all(), "NaN/Inf in theta_train."
    assert torch.isfinite(x_train).all(), "NaN/Inf in x_train."

    # Define the MNLE likelihood network builder
    estimator_builder = likelihood_nn(
        model="mnle",
        x_num_conDim=1,
        log_transform_x=True,
        z_score_theta="independent",
        z_score_x="independent",
    )

    # Initialize MNLE trainer, train on (theta, x) pairs
    trainer = MNLE(prior, estimator_builder)
    _ = trainer.append_simulations(
        theta_train,
        x_train,
        exclude_invalid_x=False,
    ).train()

    # build MCMC-based posterior from the trained likelihood estimator
    posterior = trainer.build_posterior(
        prior=prior,
        posterior_parameters=MCMCPosteriorParameters(
            method="slice_np_vectorized",
            **mcmc_kwargs,
        ),
    )

    return posterior, trainer

#### Parameter recovery

In [None]:
def simulate_session_data(theta_true: Tensor, num_trials: int = 1000) -> Tensor:
    theta_np = theta_true.detach().cpu().numpy().astype(np.float32)
    xs = np.zeros((num_trials, 2), dtype=np.float32)
    rng = np.random.default_rng()

    for n in range(num_trials):
        rt, choice_idx = simulate_pulse_ddm_single(theta_np, rng)
        xs[n, 0] = rt
        xs[n, 1] = choice_idx

    return torch.from_numpy(xs).to(torch.float32)

def run_parameter_recovery(
    num_simulations=50_000,
    num_trials_session=200,
    num_posterior_samples=5_000,
    theta_true=None,
):
    """
    Train MNLE once, then test how well it can recover θ_true from a synthetic dataset
    """
    posterior, trainer = train_mnle(num_simulations)

    if theta_true is None:
        theta_true = torch.tensor([0.5, 0.4, 1.0, 2.0, 0.2], dtype=torch.float32)

    theta_true = theta_true.to(torch.float32)

    x_o = simulate_session_data(theta_true, num_trials_session)

    posterior_samples = posterior.sample(
        sample_shape=(num_posterior_samples,),
        x=x_o,
    )

    labels = [r"$a_0$", r"$\lambda$", r"$\nu$", r"$B$", r"$\sigma$"]

    fig, ax = pairplot(
        [prior.sample((2000,)), posterior_samples],
        points=theta_true.unsqueeze(0),
        diag="kde",
        upper="kde",
        labels=labels,
    )
    plt.suptitle("Parameter Recovery: MNLE posterior vs prior", fontsize=14)
    plt.show()

    return theta_true, x_o, posterior, posterior_samples


#### Posterior Predictive Checks

In [None]:
def posterior_predictive_checks(
    posterior,
    x_o: Tensor,
    num_sessions: int = 50,
    num_trials_per_session: int | None = None,
    num_posterior_samples: int = 1_000,
):
    """
    Post Hoc Analysis: Compare RT and choice statistics between observed and posterior predictive datasets.

    Inputs:
        posterior: sbi posterior (MCMC Posterior) from MNLE fitting
        x_o: observed data, shape (N_trials, 2) [rt, choice_idx]
        num_sessions: number of posterior predictive datasets to draw
        num_trials_per_session: optional; if None, use len(x_o)
    """
    if num_trials_per_session is None:
        num_trials_per_session = x_o.shape[0]

    # ---------- 1) Sample from posterior ONCE ----------
    with torch.no_grad():
        theta_samples = posterior.sample(
            sample_shape=(num_posterior_samples,),
            x=x_o,
        )  # shape: [S, 4]

    theta_samples_np = theta_samples.cpu().numpy()
    S = theta_samples_np.shape[0]

    # ---------- 2) Compute observed summary stats ----------
    x_o_np = x_o.cpu().numpy()
    rt_obs = x_o_np[:, 0]
    choice_idx_obs = x_o_np[:, 1]  # 0 = left, 1 = right (internal coding)

    obs_mean_rt = float(rt_obs.mean())
    obs_p_right = float(choice_idx_obs.mean())  # since this is 0/1, mean = P(right)

    # ---------- 3) Simulate posterior-predictive datasets ----------
    rng = np.random.default_rng()

    rep_mean_rt = []
    rep_p_right = []

    for i in range(num_sessions):
        # pick one θ sample from posterior
        idx = rng.integers(0, S)
        theta_i = torch.from_numpy(theta_samples_np[idx])

        # simulate a session at that θ
        x_rep = simulate_session_data(theta_i, num_trials=num_trials_per_session)
        x_rep_np = x_rep.numpy()

        rt_rep = x_rep_np[:, 0]
        choice_idx_rep = x_rep_np[:, 1]

        rep_mean_rt.append(rt_rep.mean())
        rep_p_right.append(choice_idx_rep.mean())

    rep_mean_rt = np.array(rep_mean_rt)
    rep_p_right = np.array(rep_p_right)

    # ---------- 4) Plot PPC summary histograms ----------
    fig, axes = plt.subplots(1, 2, figsize=(8, 3))

    # Mean RT across sessions
    axes[0].hist(rep_mean_rt, bins=20, alpha=0.8)
    axes[0].axvline(obs_mean_rt, color="red", linestyle="--", label="observed")
    axes[0].set_xlabel("Mean RT (s)")
    axes[0].set_ylabel("Count")
    axes[0].set_title("PPC: session mean RT")
    axes[0].legend(frameon=False)

    # Proportion right across sessions
    axes[1].hist(rep_p_right, bins=20, alpha=0.8)
    axes[1].axvline(obs_p_right, color="red", linestyle="--", label="observed")
    axes[1].set_xlabel("P(right)")
    axes[1].set_ylabel("Count")
    axes[1].set_title("PPC: session P(right)")
    axes[1].set_xlim(0, 1)
    axes[1].legend(frameon=False)

    plt.tight_layout()
    plt.show()

In [None]:
# ========= Simulation-Based Calibration (SBC) =========

def simulate_dataset_for_sbc(theta_batch: Tensor, num_trials: int) -> list[Tensor]:
    """
    For each θ in theta_batch, simulate a dataset of num_trials trials.
    Returns list of tensors, each of shape (num_trials, 2).
    """
    datasets = []
    for i in range(theta_batch.shape[0]):
        x_i = simulate_session_data(theta_batch[i], num_trials=num_trials)
        datasets.append(x_i)
    return datasets

from sbi.diagnostics import run_sbc
from sbi.analysis.plot import sbc_rank_plot

def run_sbc_for_mnle(
    posterior,
    num_sbc_samples: int = 20,
    num_trials_per_dataset: int = 20,
    num_posterior_samples: int = 200,
):
    thetas = prior.sample((num_sbc_samples,))

    # ---------- 2) Simulate one iid dataset x_i for each theta_i ----------
    # We'll store them as a 3D tensor: (num_sbc_samples, num_trials, 2)
    xs_list = []
    for i in range(num_sbc_samples):
        theta_i = thetas[i]                # shape: (4,)
        x_i = simulate_session_data(       # uses your NumPy-based simulator
            theta_i,
            num_trials=num_trials_per_dataset,
        )                                  # shape: (num_trials, 2)
        xs_list.append(x_i)

    xs = torch.stack(xs_list, dim=0)       # (num_sbc_samples, num_trials, 2)

    # ---------- 3) Run SBC ----------
    # IMPORTANT: pass thetas & xs POSITIONALLY, not by the old keyword names.
    ranks, dap_samples = run_sbc(
        thetas,
        xs,
        posterior,
        num_posterior_samples=num_posterior_samples,
        use_batched_sampling=False,   # safer on memory for MNLE+MCMC posteriors
        num_workers=1,                # bump if you want CPU parallelism
    )

    # ---------- 4) Plot SBC rank histograms ----------
    fig, ax = sbc_rank_plot(
        ranks,
        num_posterior_samples,
        num_bins=20,
        figsize=(5, 3),
    )

    return ranks, dap_samples, fig, ax


In [None]:
def plot_empirical_rt_choice(x_o: torch.Tensor, title: str = "Observed data"):
    """
    Quick visualization of RT distribution and choice proportion for a dataset x_o.
    x_o: (N_trials, 2) with columns [rt, choice_idx_mnle] (0=left, 1=right internally)
    """
    x_np = x_o.numpy()
    rt = x_np[:, 0]
    choice_idx = x_np[:, 1]
    p_right = choice_idx.mean()

    fig, axes = plt.subplots(1, 2, figsize=(8, 3))

    axes[0].hist(rt, bins=30, alpha=0.8)
    axes[0].set_xlabel("RT (s)")
    axes[0].set_ylabel("Count")
    axes[0].set_title(f"{title}: RT distribution")

    axes[1].bar(["Left", "Right"], [1.0 - p_right, p_right])
    axes[1].set_ylim(0, 1)
    axes[1].set_title(f"{title}: choice proportions")

    plt.tight_layout()
    plt.show()

def plot_posterior_marginals(posterior_samples: torch.Tensor):
    """
    Simple marginal histograms for each parameter from posterior samples.
    """
    labels = [r"$\lambda$", r"$\nu$", r"$B$", r"$\sigma$", r"a0"]
    samples_np = posterior_samples.numpy()
    fig, axes = plt.subplots(1, samples_np.shape[1], figsize=(12, 3))
    for i in range(samples_np.shape[1]):
        axes[i].hist(samples_np[:, i], bins=30, alpha=0.8)
        axes[i].set_title(labels[i])
    plt.tight_layout()
    plt.show()

#### MNLE pipeline 

In [None]:
print(">>> Training MNLE + parameter recovery...")
posterior, trainer = train_mnle(num_simulations=10_000)

In [None]:
# run parameter recovery
theta_true, x_o, posterior, posterior_samples = run_parameter_recovery(
    num_simulations=10_000,
    num_trials_session=20,
    num_posterior_samples=2_000,
)

# Plot empirical data summaries for the observed synthetic session
plot_empirical_rt_choice(x_o, title="Synthetic session (θ_true)")

# Plot posterior marginals for θ
plot_posterior_marginals(posterior_samples)

In [None]:
posterior_predictive_checks(
    posterior,
    x_o,
    num_sessions=50,              # 50 fake sessions
    num_trials_per_session=x_o.shape[0],
    num_posterior_samples=1_000,  # 1k θ samples for reuse
)

In [None]:
ranks, dap_samples, fig_sbc, ax_sbc = run_sbc_for_mnle(
    posterior,
    num_sbc_samples=5,
    num_trials_per_dataset=20,
    num_posterior_samples=100,
)

#### Run code on rat data 

Inside the generate_session_data() function, for Real data, the following must be changed:
- correct_side replaced from the task,
- logic for correctness_idx is same: is_correct = (choice == correct_side) and correctness_idx logic

In [None]:
# load dataframe:

rat_df = pd.read_csv("rat_data_clean.csv")

# we must fit the model per rat:
# print(rat_df["name"].unique())

rat_individual_df = rat_df[rat_df["name"] == 1054].reset_index()

rat_individual_df.head()

In [None]:
def make_x_from_rat_df(rat_individual_df) -> torch.Tensor:
    df = rat_individual_df.copy()
    df = df.dropna(subset=["rt", "outcome"])

    rts = df["rt"].to_numpy(dtype="float32")
    correctness = df["outcome"].to_numpy(dtype="float32")  # 0/1

    rt_tensor = torch.from_numpy(rts)
    correctness_tensor = torch.from_numpy(correctness)

    x = torch.stack([rt_tensor, correctness_tensor], dim=1)  # (num_trials, 2)
    return x

In [None]:
"""
Use posterior and trainer from train_mnle

Select one individual: rat_individual_df, preprocess and get posterior samples.

"""
posterior, trainer = train_mnle(num_simulations=50_000)

x_rat = make_x_from_rat_df(rat_individual_df)

num_posterior_samples = 5_000

posterior_samples_rat = posterior.sample(
    sample_shape=(num_posterior_samples,),
    x=x_rat,
)
labels = [r"$\lambda$", r"$\nu$", r"$B$", r"$\sigma$"]

fig, ax = pairplot(
    [prior.sample((2_000,)), posterior_samples_rat],
    diag="kde",
    upper="kde",
    labels=labels,
)
plt.suptitle("Rat 1054: posterior vs prior", fontsize=14)
plt.show()