#### Imports

In [11]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import torch
from torch import Tensor
from torch.distributions import Beta, Uniform, LogNormal, HalfNormal, TransformedDistribution, ExpTransform

import sbi
from sbi.utils import BoxUniform, MultipleIndependent
from sbi.inference import MNLE
import sbi.neural_nets as nn
from sbi.inference.posteriors.mcmc_posterior import MCMCPosterior
from sbi.analysis import pairplot

In [12]:
from __future__ import annotations
import sys
from pathlib import Path

SRC_DIR = Path.cwd().resolve().parents[0] 

sys.path.insert(0, str(SRC_DIR))

from sbi_for_diffusion_models.ddm_simulator import (
    simulate_pulse_ddm_single,
    pulse_ddm_simulator_torch,
    simulate_session_data,
)

#### Build NN for MNLE 

In [13]:
# Define Priors over parameters
prior = MultipleIndependent(
    [
        # a0 in [0,1]
        Beta(
            concentration1=torch.tensor([2.0], dtype=torch.float32),
            concentration0=torch.tensor([2.0], dtype=torch.float32),
        ),

        # lam >= 0
        HalfNormal(scale=torch.tensor([0.5], dtype=torch.float32)),

        # nu >= 0 (pulse strength)
        LogNormal(
            loc=torch.tensor([np.log(1.0)], dtype=torch.float32),
            scale=torch.tensor([0.5], dtype=torch.float32),
        ),

        # B > 0
        LogNormal(
            loc=torch.tensor([np.log(2.0)], dtype=torch.float32),
            scale=torch.tensor([0.5], dtype=torch.float32),
        ),

        # sigma_a >= 0 (accumulator diffusion noise)
        HalfNormal(scale=torch.tensor([0.3], dtype=torch.float32)),

        # t_nd in a sensible range (seconds)
        Uniform(
            low=torch.tensor([0.05], dtype=torch.float32),
            high=torch.tensor([0.6], dtype=torch.float32),
        ),

        # sigma_s >= 0 (sensory noise on eta)
        HalfNormal(scale=torch.tensor([0.5], dtype=torch.float32)),
    ],
    validate_args=False,
)  

#### Run code on rat data 

Inside the generate_session_data() function, for Real data, the following must be changed:
- correct_side replaced from the task,
- logic for correctness_idx is same: is_correct = (choice == correct_side) and correctness_idx logic

In [None]:
# load dataframe:

rat_df = pd.read_csv("rat_data_clean.csv")

# we must fit the model per rat:
# print(rat_df["name"].unique())

rat_individual_df = rat_df[rat_df["name"] == 1054].reset_index()

rat_individual_df.head()

In [None]:
def make_x_from_rat_df(rat_individual_df) -> torch.Tensor:
    df = rat_individual_df.copy()
    df = df.dropna(subset=["rt", "outcome"])

    rts = df["rt"].to_numpy(dtype="float32")
    correctness = df["outcome"].to_numpy(dtype="float32")  # 0/1

    rt_tensor = torch.from_numpy(rts)
    correctness_tensor = torch.from_numpy(correctness)

    x = torch.stack([rt_tensor, correctness_tensor], dim=1)  # (num_trials, 2)
    return x

In [None]:
"""
Use posterior and trainer from train_mnle

Select one individual: rat_individual_df, preprocess and get posterior samples.

"""
posterior, trainer = train_mnle(num_simulations=50_000)

x_rat = make_x_from_rat_df(rat_individual_df)

num_posterior_samples = 5_000

posterior_samples_rat = posterior.sample(
    sample_shape=(num_posterior_samples,),
    x=x_rat,
)
labels = [r"$\lambda$", r"$\nu$", r"$B$", r"$\sigma$"]

fig, ax = pairplot(
    [prior.sample((2_000,)), posterior_samples_rat],
    diag="kde",
    upper="kde",
    labels=labels,
)
plt.suptitle("Rat 1054: posterior vs prior", fontsize=14)
plt.show()