# Hierarchical Bayesian Linear Regression using PyStan
**Florian Ott, 2021**

Here we fit response times against conflict, also checking for a context interaction. Further description of the analysis and visualization of the results can be found in the main manuscript

In [1]:
import numpy as np
import pandas as pd 
import matplotlib.pyplot as plt
import pystan
import glob as glob 
import time as time
import arviz as az
plt.style.use('ara')

# Load participant data
filename = glob.glob('../data/behaviour/data_all_participants_20220215120148.csv') 
dat = pd.read_csv(filename[0],index_col = 0) 

## Specifying the model

In [2]:
model = '''
data {
  int<lower=1> N;
  vector[N] ishard;
  vector[N] logrt;
  vector[N] conflict;
  int<lower=1> N_subjects;
  int<lower = 1, upper = N> vpn[N];
  int<lower=1> N_rep;
  vector[N_rep] conflict_rep;
}

parameters {
  // hyperparameters
  real mu_intercept;
  real mu_beta_conflict;
  real mu_beta_ishard;
  real mu_beta_interaction;
  real<lower=0> sigma_intercept;
  real<lower=0> sigma_beta_conflict;
  real<lower=0> sigma_ishard;
  real<lower=0> sigma_interaction;
  
  // paramters 
  vector[N_subjects] intercept;
  vector[N_subjects] beta_conflict;
  vector[N_subjects] beta_ishard;
  vector[N_subjects] beta_interaction; 
  real<lower=0>  sigma; //same data level noise for all subjects 
  
}

model {

  // hyper priors 
  mu_intercept ~ normal(0,10); 
  mu_beta_conflict ~ normal(0,10); 
  mu_beta_ishard ~ normal(0,10); 
  mu_beta_interaction ~ normal(0,10); 
  sigma_intercept ~ normal(0,10); 
  sigma_beta_conflict ~ normal(0,10); 
  sigma_ishard ~ normal(0,10); 
  sigma_interaction ~ normal(0,10); 

  
  //priors 
  intercept ~ normal(mu_intercept,sigma_intercept); 
  beta_conflict ~ normal(mu_beta_conflict,sigma_beta_conflict); 
  beta_ishard ~ normal(mu_beta_ishard,sigma_ishard); 
  beta_interaction ~ normal(mu_beta_interaction,sigma_interaction);
  sigma ~ normal(0, 10);

  logrt ~ normal(intercept[vpn] + beta_conflict[vpn].*conflict + beta_ishard[vpn].*ishard + beta_interaction[vpn] .* conflict .* ishard,sigma);
}

generated quantities {

  vector[N_subjects] intercept_rep;
  vector[N_subjects] beta_conflict_rep;
  vector[N_subjects] beta_ishard_rep;
  vector[N_subjects] beta_interaction_rep;
  vector[N_rep] rt_new_easy;
  vector[N_rep] rt_new_hard;
  
  for (n in 1:N_subjects){
    intercept_rep[n] = normal_rng(intercept[n], sigma_intercept);
    beta_conflict_rep[n] = normal_rng(beta_conflict[n], sigma_beta_conflict);
    beta_ishard_rep[n] = normal_rng(beta_ishard[n], sigma_ishard);
    beta_interaction_rep[n] = normal_rng(beta_interaction[n], sigma_interaction);
    }

  for (n in 1:N_rep){
    rt_new_easy[n] = lognormal_rng(intercept_rep[vpn[n]] + conflict_rep[n] * beta_conflict_rep[vpn[n]] + beta_ishard_rep[vpn[n]]*0 + beta_interaction_rep[vpn[n]] * conflict_rep[n] * 0 ,sigma);
    rt_new_hard[n] = lognormal_rng(intercept_rep[vpn[n]] + conflict_rep[n] * beta_conflict_rep[vpn[n]] + beta_ishard_rep[vpn[n]]*1 + beta_interaction_rep[vpn[n]] * conflict_rep[n] * 1 ,sigma);
    } 
}
'''

## Compiling

In [3]:
sm = pystan.StanModel(model_code=model,verbose = False)

INFO:pystan:COMPILING THE C++ CODE FOR MODEL anon_model_1989a5f201a87f58e7e2e78ff7aa273a NOW.


## Specifying the data

In [4]:
idx = (dat['timeout'] == 0)
logrt =  np.log(dat.loc[idx,['reaction_time']].to_numpy().squeeze())
conflict = dat.loc[idx,['conflict_planning']].to_numpy().squeeze()
ishard = dat.loc[idx,['is_23']].to_numpy(dtype='int').squeeze()
N = len(logrt)
vpn = dat.loc[idx,['vpn']].to_numpy().squeeze() - 100
N_subjects = len(np.unique(vpn))
conflict_rep = conflict
N_rep = len(conflict_rep)

dat_dict = {'N':N,         
            'logrt':logrt,
            'conflict':conflict, 
            'ishard':ishard ,
            'N_subjects':N_subjects,
            'vpn':vpn,
            'N_rep':N_rep,
            'conflict_rep':conflict_rep           
            } 

## Sampling posterior 

In [5]:
res = sm.sampling(data=dat_dict, iter=2000,  warmup=1000, thin=1, chains=4,control=dict(adapt_delta=0.97),seed=101, verbose = False);

To run all diagnostics call pystan.check_hmc_diagnostics(fit)
