## Load the packages

In [None]:
# Import the packages
import hssm
import pytensor  # Graph-based tensor library
import bambi as bmb
import pandas as pd
from matplotlib import pyplot as plt
import arviz as az
import numpy as np
import pymc as pm

# Setting float precision in pytensor
pytensor.config.floatX = "float32"
pytensor.config.optimizer = 'None'
from jax.config import config
config.update("jax_enable_x64", False)

## Load the data

In [None]:
# Load a package-supplied dataset
data = hssm.load_data('cavanagh_theta')

# Plot the RTs
data['response'] = data['response'].replace(0, -1)
plt.hist(data['rt']*data['response'], bins=20)
plt.show()

In [None]:
data.head()

## Specify the model

In [4]:
# Specify the model
model = hssm.HSSM(
    model="ddm",
    loglik_kind="approx_differentiable",
    data=data,
    p_outlier={"name": "Uniform", "lower": 0.01, "upper": 0.05},
    lapse=bmb.Prior("Uniform", lower=0.0, upper=20.0),
    include=[
        {
            "name": "v",
            "prior": {
                "Intercept": {"name": "Normal", "mu": 0.0, "sigma": 1.0},
                "1|subj_idx": {"name": "Normal", "mu":0.0, "sigma":{"name": "Gamma",  "alpha": 2.0, "beta": 10.0}},
                
                "conf": {"name": "Normal", "mu": 0.0, "sigma": 0.5},
                "1|conf": {"name": "Normal", "mu":0.0, "sigma":{"name": "Gamma",  "alpha": 2.0, "beta": 10.0}},
            },
            "formula": "v ~ 1 + conf + (1 + conf|subj_idx)",
        },
        {
            "name": "a",
            "prior": {
                "Intercept": {"name": "Gamma", "alpha": 2.0, "beta": 10.0},
                "1|subj_idx": {"name": "Normal", "mu":0.0, "sigma":{"name": "Gamma",  "alpha": 2.0, "beta": 10.0}},
                
                "conf": {"name": "Normal", "mu": 0.0, "sigma": 0.2},
                "1|conf": {"name": "Normal", "mu":0.0, "sigma":{"name": "Gamma",  "alpha": 2.0, "beta": 10.0}},
            },
            "formula": "a ~ 1 + conf + (1 + conf|subj_idx)",
        },
        {
            "name": "z",
            "prior": {
                "Intercept": {"name": "Uniform", "lower": 0.1, "upper": 0.9},
                "1|subj_idx": {"name": "Normal", "mu":0.0, "sigma":{"name": "Gamma",  "alpha": 2.0, "beta": 10.0}},
                
                "conf": {"name": "Normal", "mu": 0.0, "sigma": 0.2},
                "1|conf": {"name": "Normal", "mu":0.0, "sigma":{"name": "Gamma",  "alpha": 2.0, "beta": 10.0}},
            },
            "formula": "z ~ 1 + conf + (1 + conf|subj_idx)",
        },
    ],
)

## Set the initial values for the sampler

In [5]:
# Specify the dictionary of initial values
n_subjects = len(data.subj_idx.unique())
n_subjects

# To check the shapes of all variables:
# model.pymc_model.eval_rv_shapes()

# To check the initial points (init method might changes these; 
#e.g. jitter might be added when you run the model; 
#we use init="adapt_diag" to avoid this):
# model.pymc_model.initial_point()

my_inits = {'t': 0.1,
                                  
            'v_Intercept': 0.0,
            'v_1|subj_idx_sigma': 0.01,
            'v_1|subj_idx_offset': np.zeros((n_subjects,)).astype(np.float32), # Watch out: shape=(n_subjects,)
            
            'v_conf': np.array([0.0]).astype(np.float32),
            'v_conf|subj_idx_sigma': np.array([0.1]).astype(np.float32),
            'v_conf|subj_idx_offset': np.zeros((n_subjects, 1)).astype(np.float32),
            
            'a_Intercept': 1.0,
            'a_1|subj_idx_sigma': 0.1,
            'a_1|subj_idx_offset': np.zeros((n_subjects,)).astype(np.float32), # Watch out: shape=(n_subjects,)
            
            'a_conf': np.array([0.0]).astype(np.float32),
            'a_conf|subj_idx_sigma': np.array([0.1]).astype(np.float32),
            'a_conf|subj_idx_offset': np.zeros((n_subjects, 1)).astype(np.float32),
            
            'z_Intercept': 0.5,
            'z_1|subj_idx_sigma': 0.1,
            'z_1|subj_idx_offset': np.zeros((n_subjects,)).astype(np.float32), # Watch out: shape=(n_subjects,)
            
            'z_conf': np.array([0.0]).astype(np.float32),
            'z_conf|subj_idx_sigma': np.array([0.1]).astype(np.float32),
            'z_conf|subj_idx_offset': np.zeros((n_subjects, 1)).astype(np.float32),
           }


## Fit the model

In [None]:
# Sample
modelObject = model.sample(
    sampler="nuts_numpyro", 
    initvals = my_inits, 
    chains=4, 
    cores=4, 
    draws=200, 
    tune=200,
    #target_accept=0.95
)

## Analyze the posterior

In [None]:
# PLot the traces
az.rcParams["plot.max_subplots"] = 20
az.plot_trace(modelObject)
plt.tight_layout()
plt.show()

In [None]:
# Parameter estimates
pd.set_option('display.max_rows', 500)
az.summary(modelObject, var_names=['~a','~t', '~z'])


In [None]:
# Plot the posteriors
az.rcParams["plot.max_subplots"] = 20
az.plot_posterior(modelObject, var_names=['~a','~t', '~z'])

In [None]:
# Posterior pair plot
az.plot_pair(modelObject, kind="kde")
