In [None]:
%matplotlib inline
import os
import sys
module_path = os.path.abspath('..')
if module_path not in sys.path:
    sys.path.append(module_path)
import scipy.stats as st
import matplotlib.pyplot as plt
import numpy as np
from nireact import model, lba, task

data_dir = '/Users/morton/work/mistr'
lba_dir = os.path.join(data_dir, 'batch', 'lba2')

In [None]:
data = task.read_test_all_subj(data_dir)
test = data.test_type.values
subj_idx = data.subj_idx.values
n_subj = len(np.unique(subj_idx))

In [None]:
# prior parameters for variable drift model
fixed = {'s': 1, 'tau': 0, 'b': 8}
gen_param = {'A': lambda: st.uniform.rvs(0, 8),
             'v2': lambda: st.uniform.rvs(0, 10),
             'r': lambda: st.uniform.rvs(0, 1)}
gen_subj = {'v1': lambda: model.sample_hier_drift(4, 1.5, .75, n_subj)}
spec_var = {'name': 'var', 'model': lba.LBAVar(), 
            'fixed': fixed, 'param': gen_param, 'subj_param': gen_subj}

In [None]:
# prior parameters for full model
fixed = {'s': 1, 'tau': 0, 'b': 8, 'v4': -10}
gen_param = {'A': lambda: st.uniform.rvs(0, 8),
             'v2': lambda: st.uniform.rvs(0, 10),
             'r': lambda: st.uniform.rvs(0, 1)}
gen_subj = {'v1': lambda: model.sample_hier_drift(4, 1.5, .75, n_subj),
            'v3': lambda: model.sample_hier_drift(4, 1.5, .75, n_subj)}
spec_nav = {'name': 'nav', 'model': lba.LBANav(), 
            'fixed': fixed, 'param': gen_param, 'subj_param': gen_subj}

In [None]:
n_rep = 100
results = model.model_recovery([spec_var, spec_nav], test, subj_idx, n_rep, 
                               draws=10000, tune=5000, target_accept=.95)

In [None]:
res_file = os.path.join(lba_dir, f'model_comp{n_rep}.npz')
np.savez(res_file, **results)