## Imports

In [1]:
import os
import jax
import numpyro
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import hssm
import arviz as az
import seaborn as sns

jax.config.update('jax_platform_name', 'cpu')
hssm.set_floatX("float32")
numpyro.set_host_device_count(14)



Setting PyTensor floatX type to float32.
Setting "jax_enable_x64" to False. If this is not intended, please set `jax` to False.


In [2]:
df_raw = pd.read_csv('data_wt_exp5.csv')

## Definitions

In [3]:
def subsitute_values_sequential(data,varname,new_values):
  unique_values = sorted(data[varname].unique())
  substitutions = {val: new_val for val, new_val in zip(unique_values,new_values)}
  return  data[varname].replace(substitutions)

In [4]:
def robust_z(x):
    x = np.asarray(x, float)
    med = np.median(x)
    mad = np.median(np.abs(x - med))
    return 0.6745 * (x - med) / (mad if mad>0 else np.finfo(float).eps)

### Data Cleanup

In [5]:
data = df_raw.copy()

## stim side l/r (-1/1)
data['stim_side'] =  np.sign(data['stim'])

## Desirability side l/r (-1/1)
data['des_side'] = np.sign(data['stim'])*data['desirability']

## Stim evidence: low vs high (0/1)
data['stim_easy'] = subsitute_values_sequential(data,'stim_strength',[0.25,.5,0.75,1])

## Incentive: low vs high (0/1)
data['incentive'] = subsitute_values_sequential(data,'incentive',[-.5,.5])

## direction and magnitude
data['stim_des'] = data['stim_easy']*data['desirability'] 

data['resp_des'] = data['resp']*data['des_side'] #responded desirable vs undesirable (+1/-1)

data = data[data['desirability']!=0] #keep only trials with desirability manipulation (for now)
#%% Data cleaning/exclusion

MIN_RT = 200 
MAX_RT = 35000 
MAD_THRESH_RT = 3

ACC_THRESH_LO = 0.525 
ACC_THRESH_HI = 0.975

MIN_TRIALS_PER_COND = 8
COND_VARS = ['incentive', 'desirability']

#count original number of trials before exclusion
n_trials_raw= len(data)

In [6]:
data = data[(~data['correct'].isna()) & (~data['rt'].isna())
             & (data['rt']> MIN_RT) & (data['rt'] < MAX_RT) & (~data['resp'].isna())]

In [7]:
data = data[data.groupby('participant')['rt'].transform(lambda x: abs(robust_z(x)))<MAD_THRESH_RT]


In [8]:
print('N trials excluded = %i / %i'%(n_trials_raw - len(data),n_trials_raw))

print('Mean RT = %i ms, median RT = %i ms, max RT = %i ms' % (data['rt'].mean(),data['rt'].median(),data['rt'].max()))

N trials excluded = 1066 / 19200
Mean RT = 1734 ms, median RT = 1502 ms, max RT = 29557 ms


## Participant Cleanup

In [9]:
participant_accuracy = data.groupby('participant')['correct'].mean()
valid_participants = participant_accuracy[(participant_accuracy > ACC_THRESH_LO) & (participant_accuracy < ACC_THRESH_HI)].index

print('N participants kept (accuracy) = %i / %i'%(len(valid_participants),len(data['participant'].unique())))
data = data[data['participant'].isin(valid_participants)]

#remove participants with extreme median RTs compared to sample
participant_rt = data.groupby('participant')['rt'].median()
valid_participants_rt = participant_rt[np.abs(robust_z(participant_rt))<3].index
print('N participants kept (RT) = %i / %i'%(len(valid_participants_rt),len(data['participant'].unique())))
data = data[data['participant'].isin(valid_participants_rt)]

# remove participants, who, after exclsion, have too few trials for one or more conditions
groupby_vars = ['participant']
groupby_vars.extend(COND_VARS)
sub_trials_per_cond = data.groupby(groupby_vars).size().reset_index(name='n_trials')

#get list of participants that have at least min_trials_per_cond trials per condition (combination of incentive and desirability values)
df_participants_with_min_trials = sub_trials_per_cond.groupby('participant').filter(lambda x: (x['n_trials'] >= MIN_TRIALS_PER_COND).all())
valid_participants = df_participants_with_min_trials['participant'].unique().tolist()
print('N participants kept after excluding participants with too few trials = %i / %i'%(len(valid_participants),len(data['participant'].unique())))
data = data[data['participant'].isin(valid_participants)]

N participants kept (accuracy) = 438 / 600
N participants kept (RT) = 427 / 438
N participants kept after excluding participants with too few trials = 425 / 427


## HSSM DataFrame

In [10]:
df_hssm = pd.DataFrame({
    'response': data['resp_des'].astype('int32'),
    'rt': (data['rt'] / 1000).astype('float32'),
    'participant_id': data['participant'].astype('int32')
})

df_hssm = df_hssm.astype({col: 'float32' for col in df_hssm.select_dtypes(include='float64').columns})

In [11]:
df_hssm

Unnamed: 0,response,rt,participant_id
0,-1,1.115,1
1,-1,2.170,1
2,-1,2.096,1
3,1,1.687,1
4,1,2.857,1
...,...,...,...
19195,-1,2.616,600
19196,1,2.698,600
19197,1,1.673,600
19198,1,1.415,600


In [12]:
random_ids = np.random.choice(
    df_hssm.participant_id.unique(),
    size=round(len(df_hssm.participant_id.unique())/10),
    replace = False
)

In [13]:
df_test = df_hssm[df_hssm['participant_id'].isin(random_ids)]

In [14]:
df_test

Unnamed: 0,response,rt,participant_id
1760,-1,0.8854,56
1761,-1,1.8173,56
1762,-1,2.2342,56
1763,-1,1.5838,56
1765,-1,1.5828,56
...,...,...,...
19131,-1,0.6140,598
19132,-1,1.2660,598
19133,-1,0.6180,598
19134,1,1.0000,598


## Fit model

In [15]:
results = []

In [20]:
for nsub, isub in enumerate(df_test['participant_id'].unique()):
    print(f"___Participant {isub}, {nsub+1}/{len(df_test['participant_id'].unique())}___")

    participant_folder = f"plots/Exp4/S{int(isub):04d}"   
    os.makedirs(participant_folder, exist_ok=True)

    df_sub = df_test[df_test['participant_id'] == isub].copy()
    df_sub = df_sub.drop('participant_id', axis = 1)

    print("Median RT =", np.median(df_sub['rt']))
    print("N trials =", len(df_sub))

    model = hssm.HSSM(
        model="ddm",
        data=df_sub,
        include=[
            {
                "name":"v",
                "formula":"v ~ 1",
                "prior":{"Intercept":{"name":"Normal","mu":0,"sigma":1}}
            },
            {
                "name":"a",
                "formula":"a ~ 1",
                "prior":{"Intercept":{"name":"Normal","mu":1.5,"sigma":0.5}}
            }
        ]
    )

    print("starting model")

    infer_data_sub = model.sample(
        cores=3,
        chains=3,
        draws=300,
        tune=1000,
        idata_kwargs=dict(log_likelihood=False),
        progressbar=True,
        target_accept=0.99,
    )

    print("model done")

    fit_dict = infer_data_sub.to_dict()

    summary_table = az.summary(infer_data_sub)
    print(summary_table.to_string())
    with open(os.path.join(participant_folder, "summary_table.txt"), "w") as file:
        file.write(summary_table.to_string())

        # Save plots
    az.plot_posterior(infer_data_sub)
    plt.savefig(os.path.join(participant_folder, "posterior_plot.png"))
    plt.close()

___Participant 56, 1/42___
Median RT = 1.601
N trials = 31


Only 300 samples in chain.


starting model


Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (3 chains in 3 jobs)
NUTS: [t, z, v_Intercept, a_Intercept]


Sampling 3 chains for 1_000 tune and 300 draw iterations (3_000 + 900 draws total) took 12 seconds.


ValueError: different number of dimensions on data and dims: 2 vs 3

In [25]:
# Debug: get raw trace (no automatic to_inference_data conversion) and print shapes
try:
    trace = model.sample(
        cores=3, chains=3, draws=300, tune=1000,
        return_inferencedata=False,  # avoid automatic pm.to_inference_data
        progressbar=True, target_accept=0.99,
        idata_kwargs=dict(log_likelihood=False),
    )
    print("Got raw trace; inspecting variable shapes...")
    # For PyMC MultiTrace
    varnames = [v for v in trace.varnames if not v.endswith("__")]
    for v in varnames:
        arr = trace.get_values(v, combine=False, squeeze=False)  # list per chain
        # convert to ndarray with chain dimension first
        arr = np.asarray(arr)  # shape: (n_chains, n_draws, *event_shape)
        print(v, "-> shape:", arr.shape, "ndim:", arr.ndim)
except Exception as e:
    print("Sampling returned error (or not PyMC backend). Exception:", e)
    raise

Only 300 samples in chain.
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (3 chains in 3 jobs)
NUTS: [t, z, v_Intercept, a_Intercept]


Sampling 3 chains for 1_000 tune and 300 draw iterations (3_000 + 900 draws total) took 11 seconds.


Sampling returned error (or not PyMC backend). Exception: different number of dimensions on data and dims: 2 vs 3


ValueError: different number of dimensions on data and dims: 2 vs 3

In [24]:
len(df_sub)

31

In [None]:
summary_table

In [None]:
fit_dict

In [None]:
import pandas as pd
import os
import glob

# Base folder containing participant subfolders
base_folder = "plots/Exp4"

all_summaries = []

# Find all summary_table.txt files recursively
txt_files = glob.glob(os.path.join(base_folder, "S*", "summary_table.txt"))

for txt_file in txt_files:
    # Extract participant id from folder name
    participant_id = int(os.path.basename(os.path.dirname(txt_file))[1:])  # 'S0001' -> 1
    
    # Read the txt table
    try:
        df = pd.read_csv(txt_file, delim_whitespace=True, index_col=0)
    except pd.errors.ParserError:
        # fallback if the txt is formatted as a print(table)
        df = pd.read_fwf(txt_file, index_col=0)
    
    df['participant_id'] = participant_id
    all_summaries.append(df.reset_index().rename(columns={'index':'param'}))

# Concatenate all participants into a single DataFrame
df_all = pd.concat(all_summaries, ignore_index=True)

print(df_all.head())

In [None]:
df_all

In [None]:
df_v = df_all[df_all['param']=='v_Intercept']

In [None]:
sns.kdeplot(df_v['mean'])

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

# Filter for drift rate (v_Intercept)
df_v = df_all[df_all['param'] == 'v_Intercept']

plt.figure(figsize=(8, 5))
sns.kdeplot(df_v['mean'], fill=True, bw_adjust=0.5)
sns.rugplot(df_v['mean'], color='k')
plt.xlabel("Drift rate (v_Intercept)")
plt.ylabel("Density")
plt.title("Distribution of fitted drift rates across participants")
plt.tight_layout()
plt.show()


In [None]:
all_summary = []


In [None]:
df_all = pd.concat(all_summary, ignore_index=True)