In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import interpolate, optimize
import ddm
import os 
import csv

In [2]:
class LossWLS(ddm.LossFunction):
    name = 'Weighted least squares as described in Ratcliff & Tuerlinckx 2002'
    rt_quantiles = [0.1, 0.3, 0.5, 0.7, 0.9]
    rt_q_weights = [2, 2, 1, 1, 0.5]
#     exp_data = pd.read_csv('../data/measures.csv', 
#                             usecols=['subj_id', 'RT', 'is_turn_decision', 'tta_condition', 'd_condition'])
           
    def setup(self, dt, T_dur, **kwargs):
        self.dt = dt
        self.T_dur = T_dur
    
    def get_rt_quantiles(self, x, t_domain, exp=False):
        cdf = x.cdf_corr(T_dur=self.T_dur, dt=self.dt) if exp else x.cdf_corr()
        cdf_interp = interpolate.interp1d(t_domain, cdf/cdf[-1])
        # If the model produces very fast RTs, interpolated cdf(0) can be >0.1, then we cannot find root like usual
        # In this case, the corresponding rt quantile is half of the time step of cdf
        rt_quantile_values = [optimize.root_scalar(lambda x:cdf_interp(x)-quantile, bracket=(0, t_domain[-1])).root
                              if (cdf_interp(0)<quantile) else self.dt/2
                              for quantile in self.rt_quantiles]
        return np.array(rt_quantile_values)
    
    def loss(self, model):
        solultions = self.cache_by_conditions(model)
        WLS = 0
        for comb in self.sample.condition_combinations(required_conditions=self.required_conditions):
            c = frozenset(comb.items())
            comb_sample = self.sample.subset(**comb)
            WLS += 4*(solultions[c].prob_correct() - comb_sample.prob_correct())**2            
            # Sometimes model p_correct is very close to 0, then RT distribution is weird, in this case ignore RT error 
            if ((solultions[c].prob_correct()>0.001) & (comb_sample.prob_correct()>0)):
#                 self.condition_data = self.exp_data[(self.exp_data.d_condition==comb['d_condition']) 
#                                                     & (self.exp_data.tta_condition==comb['tta_condition'])]
#                 self.condition_data = list(comb_sample.items(correct=True))
                model_rt_q = self.get_rt_quantiles(solultions[c], model.t_domain(), exp=False)
                exp_rt_q = self.get_rt_quantiles(comb_sample, model.t_domain(), exp=True)
                WLS += np.dot((model_rt_q-exp_rt_q)**2, self.rt_q_weights)*comb_sample.prob_correct()
        return WLS

In [3]:
T_dur = 2.5

class DriftTtaDistance(ddm.models.Drift):
    name = 'Drift depends on TTA and distance'
    required_parameters = ['alpha', 'tta_crit', 'beta', 'd_crit']
    required_conditions = ['tta_condition', 'd_condition'] 
    
    def get_drift(self, t, conditions, **kwargs):
        v = conditions['d_condition']/conditions['tta_condition']
        return (self.alpha*(conditions['tta_condition'] - t - self.tta_crit + 
                            self.beta*(conditions['d_condition'] - v*t - self.d_crit)))

class BoundCollapsingTta(ddm.models.Bound):
    name = 'Bounds collapsing with TTA'
    required_parameters = ['b_0', 'k', 'tta_crit']
    required_conditions = ['tta_condition'] 
    def get_bound(self, t, conditions, **kwargs):
        tau = conditions['tta_condition'] - t
        return self.b_0/(1+np.exp(-self.k*(tau-self.tta_crit)))

param_names = ['alpha', 'tta_crit', 'beta', 'd_crit', 'noise', 'b_0', 'k', 'tta_crit', 'nondectime', 'halfwidth']
tta_crit = ddm.Fittable(minval=3, maxval=6)
model_TTA_bounds = ddm.Model(name='5 TTA- and d-dependent drift and bounds and uniformly distributed nondecision time',
                             drift=DriftTtaDistance(alpha=ddm.Fittable(minval=0.1, maxval=3),
                                                     tta_crit=tta_crit,
                                                     beta=ddm.Fittable(minval=0, maxval=1),
                                                     d_crit=ddm.Fittable(minval=90, maxval=150)),
                             noise=ddm.NoiseConstant(noise=1),
                             bound=BoundCollapsingTta(b_0=ddm.Fittable(minval=0.5, maxval=5), 
                                                      k=ddm.Fittable(minval=0.1, maxval=2),
                                                      tta_crit=tta_crit),
                             overlay=ddm.OverlayNonDecisionUniform(nondectime=ddm.Fittable(minval=0, maxval=0.5),
                                                                   halfwidth=ddm.Fittable(minval=0, maxval=0.3)),
                             T_dur=T_dur)

In [4]:
def fit_model(model, exp_data, subj_id):
    print(subj_id)
    training_data = exp_data[(exp_data.subj_id == subj_id)] 
#                              & ~((exp_data.d_condition==condition['d']) & (exp_data.tta_condition==condition['tta']))]
    training_sample = ddm.Sample.from_pandas_dataframe(df=training_data, 
                                                       rt_column_name='RT', correct_column_name='is_turn_decision')
    return(ddm.fit_adjust_model(sample=training_sample, model=model, lossfunction=LossWLS, suppress_output=True))

def write_to_csv(directory, filename, array):
    if not os.path.exists(directory):
        os.makedirs(directory)
    with open(os.path.join(directory, filename), 'a', newline='') as csvfile:
        writer = csv.writer(csvfile, delimiter=',')
        writer.writerow(array)

In [None]:
model = model_TTA_bounds

exp_data = pd.read_csv('../data/measures.csv', usecols=['subj_id', 'RT', 'is_turn_decision', 
                                                        'tta_condition', 'd_condition'])
subjects = exp_data.subj_id.unique()
# conditions = [{'tta': tta, 'd': d} 
#                   for tta in exp_data.tta_condition.unique() 
#                   for d in exp_data.d_condition.unique()]

directory = '../model_fit_results/'
write_to_csv(directory, 'model_%s_params_by_subject.csv' % (model.name[0]), 
             ['subj_id', 'i', 'loss'] + param_names)

for subj_id in subjects:
#     for condition in conditions:
    for i in range(5):
        fitted_model = fit_model(model, exp_data, subj_id)
        write_to_csv(directory, 'model_%s_params_by_subject.csv' % (model.name[0]), 
                     [subj_id, i, fitted_model.get_fit_result().value()] 
                     + fitted_model.get_model_parameters())

129
differential_evolution step 1: f(x)= 1.97525
differential_evolution step 2: f(x)= 1.97525
differential_evolution step 3: f(x)= 1.97525
differential_evolution step 4: f(x)= 1.3742
differential_evolution step 5: f(x)= 1.3742
differential_evolution step 6: f(x)= 1.3742
differential_evolution step 7: f(x)= 0.83524
differential_evolution step 8: f(x)= 0.657377
differential_evolution step 9: f(x)= 0.657377
differential_evolution step 10: f(x)= 0.657377
differential_evolution step 11: f(x)= 0.657377
differential_evolution step 12: f(x)= 0.657377
differential_evolution step 13: f(x)= 0.657377
differential_evolution step 14: f(x)= 0.657377
differential_evolution step 15: f(x)= 0.657377
differential_evolution step 16: f(x)= 0.652796
differential_evolution step 17: f(x)= 0.563638
differential_evolution step 18: f(x)= 0.563638
differential_evolution step 19: f(x)= 0.511279
differential_evolution step 20: f(x)= 0.502534
differential_evolution step 21: f(x)= 0.502534
differential_evolution step 

differential_evolution step 2: f(x)= 0.666733
differential_evolution step 3: f(x)= 0.666733
differential_evolution step 4: f(x)= 0.666733
differential_evolution step 5: f(x)= 0.666733
differential_evolution step 6: f(x)= 0.666733
differential_evolution step 7: f(x)= 0.640627
differential_evolution step 8: f(x)= 0.611465
differential_evolution step 9: f(x)= 0.611465
differential_evolution step 10: f(x)= 0.611465
differential_evolution step 11: f(x)= 0.611465
differential_evolution step 12: f(x)= 0.516055
differential_evolution step 13: f(x)= 0.516055
differential_evolution step 14: f(x)= 0.502999
differential_evolution step 15: f(x)= 0.494738
differential_evolution step 16: f(x)= 0.494738
differential_evolution step 17: f(x)= 0.490901
differential_evolution step 18: f(x)= 0.490901
differential_evolution step 19: f(x)= 0.490901
differential_evolution step 20: f(x)= 0.490901
differential_evolution step 21: f(x)= 0.490901
differential_evolution step 22: f(x)= 0.478061
differential_evolutio

differential_evolution step 41: f(x)= 0.452768
differential_evolution step 42: f(x)= 0.452768
differential_evolution step 43: f(x)= 0.452768
differential_evolution step 44: f(x)= 0.447669
differential_evolution step 45: f(x)= 0.444001
differential_evolution step 46: f(x)= 0.444001
differential_evolution step 47: f(x)= 0.444001
differential_evolution step 48: f(x)= 0.444001
differential_evolution step 49: f(x)= 0.443496
differential_evolution step 50: f(x)= 0.443496
differential_evolution step 51: f(x)= 0.443496
differential_evolution step 52: f(x)= 0.443496
differential_evolution step 53: f(x)= 0.443496
differential_evolution step 54: f(x)= 0.44125
differential_evolution step 55: f(x)= 0.44125
differential_evolution step 56: f(x)= 0.44125
differential_evolution step 57: f(x)= 0.439957
differential_evolution step 58: f(x)= 0.439957
differential_evolution step 59: f(x)= 0.439957
differential_evolution step 60: f(x)= 0.439957
differential_evolution step 61: f(x)= 0.439957
differential_evo

differential_evolution step 37: f(x)= 0.322435
differential_evolution step 38: f(x)= 0.322435
differential_evolution step 39: f(x)= 0.322435
differential_evolution step 40: f(x)= 0.322435
differential_evolution step 41: f(x)= 0.322435
differential_evolution step 42: f(x)= 0.322435
differential_evolution step 43: f(x)= 0.322435
differential_evolution step 44: f(x)= 0.322435
differential_evolution step 45: f(x)= 0.322435
differential_evolution step 46: f(x)= 0.322435
differential_evolution step 47: f(x)= 0.322435
differential_evolution step 48: f(x)= 0.318822
differential_evolution step 49: f(x)= 0.318822
differential_evolution step 50: f(x)= 0.318822
differential_evolution step 51: f(x)= 0.318822
differential_evolution step 52: f(x)= 0.318822
differential_evolution step 53: f(x)= 0.318822
differential_evolution step 54: f(x)= 0.318822
differential_evolution step 55: f(x)= 0.318822
differential_evolution step 56: f(x)= 0.318822
differential_evolution step 57: f(x)= 0.318822
differential_

differential_evolution step 35: f(x)= 0.328435
differential_evolution step 36: f(x)= 0.327797
differential_evolution step 37: f(x)= 0.327797
differential_evolution step 38: f(x)= 0.327797
differential_evolution step 39: f(x)= 0.327797
differential_evolution step 40: f(x)= 0.327797
differential_evolution step 41: f(x)= 0.32189
differential_evolution step 42: f(x)= 0.32189
differential_evolution step 43: f(x)= 0.319485
differential_evolution step 44: f(x)= 0.319485
differential_evolution step 45: f(x)= 0.318579
differential_evolution step 46: f(x)= 0.318579
differential_evolution step 47: f(x)= 0.318579
differential_evolution step 48: f(x)= 0.318579
differential_evolution step 49: f(x)= 0.318579
differential_evolution step 50: f(x)= 0.316338
differential_evolution step 51: f(x)= 0.316338
differential_evolution step 52: f(x)= 0.316338
differential_evolution step 53: f(x)= 0.316338
differential_evolution step 54: f(x)= 0.316338
differential_evolution step 55: f(x)= 0.316338
differential_ev

differential_evolution step 21: f(x)= 0.270213
differential_evolution step 22: f(x)= 0.270213
differential_evolution step 23: f(x)= 0.270213
differential_evolution step 24: f(x)= 0.263751
differential_evolution step 25: f(x)= 0.263751
differential_evolution step 26: f(x)= 0.263751
differential_evolution step 27: f(x)= 0.257857
differential_evolution step 28: f(x)= 0.257857
differential_evolution step 29: f(x)= 0.257857
differential_evolution step 30: f(x)= 0.257857
differential_evolution step 31: f(x)= 0.257857
differential_evolution step 32: f(x)= 0.257857
differential_evolution step 33: f(x)= 0.249823
differential_evolution step 34: f(x)= 0.243379
differential_evolution step 35: f(x)= 0.242398
differential_evolution step 36: f(x)= 0.242398
differential_evolution step 37: f(x)= 0.234347
differential_evolution step 38: f(x)= 0.234347
differential_evolution step 39: f(x)= 0.234347
differential_evolution step 40: f(x)= 0.234347
differential_evolution step 41: f(x)= 0.234347
differential_