In [3]:
import numpy as np
import torch
from sbi.utils import BoxUniform
from sbi.inference import SNPE, simulate_for_sbi

from sbi_for_diffusion_models.choice_model import choice_model_simulator_torch

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Training device:", device)

low = torch.tensor([0, 0, -2, 0.2, 0.0], dtype=torch.float32)
high = torch.tensor([1, 5,  2, 3.0, 1.0], dtype=torch.float32)

# 1) Prior for simulation (CPU) so simulate_for_sbi samples theta on CPU.
prior_sim = BoxUniform(low=low, high=high)

rng = np.random.default_rng(0)

def simulator(th: torch.Tensor) -> torch.Tensor:
    # Keep simulation on CPU (fastest for your current simulator).
    # (simulate_for_sbi will pass CPU tensors because prior_sim is on CPU.)
    return choice_model_simulator_torch(th, rng=rng, resample_invalid=True)

theta, x = simulate_for_sbi(
    simulator,
    prior_sim,
    num_simulations=10_000,
    simulation_batch_size=2048,
    num_workers=1,  # Windows-friendly
)

# 2) Prior for training (GPU). sbi asserts prior.device == training device. :contentReference[oaicite:2]{index=2}
prior_train = BoxUniform(low=low.to(device), high=high.to(device))

# 3) Train on GPU, keep data stored on CPU (recommended). :contentReference[oaicite:3]{index=3}
inference = SNPE(prior=prior_train, device=str(device))
density_estimator = inference.append_simulations(theta, x, data_device="cuda").train(
    training_batch_size=4096,
)
posterior = inference.build_posterior(density_estimator)

print("Density estimator device:", next(density_estimator.parameters()).device)


Training device: cuda:0


Running 10000 simulations.:   0%|          | 0/10000 [00:00<?, ?it/s]



 Neural network successfully converged after 283 epochs.Density estimator device: cuda:0


In [None]:
from sbi.inference import SNLE

# prior_train already on CUDA
inference = SNLE(prior=prior_train, device=str(device))

density_estimator = (
    inference
    .append_simulations(theta, x, data_device="cpu")   # usually best
    .train(training_batch_size=4096)
)


  warn("In one-dimensional output space, this flow is limited to Gaussians")


 Neural network successfully converged after 519 epochs.

AttributeError: 'SNLE_A' object has no attribute 'build_likelihood'

In [11]:
post_device = next(density_estimator.parameters()).device
theta_d = theta.to(post_device)
x_d = x.to(post_device)

with torch.no_grad():
    ll = density_estimator.log_prob(x_d, context=theta_d)
print(ll.mean().item(), ll.std().item())


-0.49954721331596375 1.0380749702453613


In [None]:
import torch
import matplotlib.pyplot as plt

from sbi.inference import MNLE
from sbi.analysis import pairplot
from sbi.utils.get_nn_models import likelihood_nn

from sbi_for_diffusion_models.rt_choice_model import (
    rt_choice_model_simulator_torch,
    simulate_session_data_rt_choice,
)

# ------------- 1) Define prior -------------
# Example: adjust bounds to your problem.
# theta = [a0_frac, lam, v, B, t_nd]
low = torch.tensor([0.0, -2.0, 0.0, 0.2, 0.0])
high = torch.tensor([1.0,  2.0, 5.0, 5.0, 1.0])

prior = torch.distributions.Independent(
    torch.distributions.Uniform(low, high), 1
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# ------------- 2) Generate training simulations -------------
def simulate_training_set(num_simulations: int, batch_size: int, **sim_kwargs):
    theta = prior.sample((num_simulations,)).to(device=device, dtype=torch.float32)

    x_chunks = []
    for start in range(0, num_simulations, batch_size):
        th = theta[start : start + batch_size]
        x = rt_choice_model_simulator_torch(th, **sim_kwargs)  # (B,2)
        x_chunks.append(x.detach().cpu())

    x = torch.cat(x_chunks, dim=0).to(torch.float32)  # keep stored data on CPU
    theta_cpu = theta.detach().cpu()

    assert torch.isfinite(theta_cpu).all()
    assert torch.isfinite(x).all()
    return theta_cpu, x


num_simulations = 50_000
simulation_batch_size = 2048

theta_train, x_train = simulate_training_set(
    num_simulations,
    simulation_batch_size,
    mu_sensory=1.0,
    p_success=0.75,
)

# ------------- 3) Build MNLE likelihood estimator -------------
# MNLE expects mixed data; likelihood_nn(model="mnle") handles the discrete component.
# log_transform_x=True is typically good because RTs are positive and heavy-tailed.
estimator_builder = likelihood_nn(
    model="mnle",
    log_transform_x=True,
    z_score_theta="independent",
    z_score_x="independent",
)

trainer = MNLE(prior=prior, density_estimator=estimator_builder)
trainer.append_simulations(theta_train, x_train, exclude_invalid_x=False)
trainer.train(training_batch_size=4096)

# ------------- 4) Build posterior (MCMC over learned likelihood) -------------
posterior = trainer.build_posterior(
    prior=prior,
    mcmc_method="slice_np_vectorized",
    mcmc_parameters=dict(
        warmup_steps=200,
        thin=5,
        num_chains=20,
        init_strategy="proposal",
    ),
)

# ------------- 5) Test on synthetic "session" dataset (IID trials) -------------
theta_true = torch.tensor([0.55, 0.2, 1.5, 2.0, 0.25], dtype=torch.float32)
x_o = simulate_session_data_rt_choice(
    theta_true,
    num_trials=300,
    mu_sensory=1.0,
    p_success=0.75,
)

# Sample posterior
num_posterior_samples = 5_000
samples = posterior.sample(
    (num_posterior_samples,),
    x=x_o,
    mcmc_method="slice_np_vectorized",
    warmup_steps=200,
    thin=5,
    num_chains=20,
    init_strategy="proposal",
)

labels = [r"$a_0$", r"$\lambda$", r"$v$", r"$B$", r"$t_{nd}$"]
fig, ax = pairplot(
    [prior.sample((2000,)), samples],
    points=theta_true.unsqueeze(0),
    diag="kde",
    upper="kde",
    labels=labels,
)
plt.suptitle("MNLE posterior vs prior (RT+choice)", fontsize=14)
plt.show()


        continuous data in the first n-1 columns (e.g., reaction times) and
        categorical data in the last column (e.g., corresponding choices). If
        this is not the case for the passed `x` do not use this function.


 Neural network successfully converged after 290 epochs.

            distributed data X={x_1, ..., x_n}, i.e., data generated based on the
            same underlying (unknown) parameter. The resulting posterior will be with
            respect to entire batch, i.e,. p(theta | X).
  warn(


Running vectorized MCMC with 20 chains:   0%|          | 0/45000 [00:00<?, ?it/s]