In [5]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import ddm
import os 
import csv
from model_definitions import ModelTtaBounds

Read parameters from the file, simulate model for a given parameter set, if the simulation time is short, plot p_turn and RTs, if not, save them to csv files

In [22]:
ndt = 'gaussian'

path = '../model_fit_results/%s_ndt' % (ndt)
file_name = 'all_conditions_merged.csv'
file_path = os.path.join(path, file_name)
parameters = pd.read_csv(file_path)

# idx_best_fit = parameters['loss'] == parameters.groupby(['subj_id'])['loss'].transform(max)
# parameters = parameters[idx_best_fit]

In [37]:
exp_data = pd.read_csv('../data/measures.csv', usecols=['subj_id', 'RT', 'is_turn_decision', 
                                                    'tta_condition', 'd_condition'])
subjects = exp_data.subj_id.unique()
conditions = [{'tta_condition': tta, 'd_condition': d} 
               for tta in exp_data.tta_condition.unique() 
               for d in exp_data.d_condition.unique()]

In [23]:
parameters

Unnamed: 0,subj_id,i,loss,alpha,beta,theta,noise,b_0,k,tta_crit,nondectime,halfwidth
0,129,0,0.43797,1.502183,0.019671,6.834053,1,1.25753,1.26381,3.732779,0.367761,0.032294
1,389,0,0.313379,1.54173,0.052285,11.587208,1,0.906105,0.100136,5.39933,0.48489,0.105365
2,525,0,0.221842,1.234224,0.026751,7.621752,1,0.8681,0.103047,3.26481,0.22051,0.087622
3,616,0,0.201307,1.868214,0.039785,10.511767,1,0.88713,0.297751,4.599213,0.5,0.088796
4,618,0,0.534275,1.131752,0.032144,8.225332,1,1.896079,0.330785,4.76416,0.182894,0.055941
5,642,0,0.516943,0.471747,0.111939,18.569484,1,1.081553,0.402499,3.34549,0.345195,0.158649
6,755,0,0.56208,0.552989,0.124447,19.330604,1,1.152863,0.47118,5.762067,0.384221,0.092433


In [25]:
param = parameters.loc[0]

In [21]:
modelTtaBounds = ModelTtaBounds()

In [30]:
overlay = (ddm.OverlayNonDecisionUniform(nondectime=param.nondectime,
                                        halfwidth=param.halfwidth) 
            if ndt=='uniform' else 
           ModelTtaBounds.OverlayNonDecisionGaussian(nondectime=param.nondectime,
                                                       ndsigma=param.halfwidth))

modelTtaBounds.model = ddm.Model(name='5 TTA- and d-dependent drift and bounds and random nondecision time',
                         drift=ModelTtaBounds.DriftTtaDistance(alpha=param.alpha, beta=param.beta, theta=param.theta),
                         noise=ddm.NoiseConstant(noise=param.noise),
                         bound=ModelTtaBounds.BoundCollapsingTta(b_0=param.b_0, k=param.k, tta_crit=param.tta_crit),
                         overlay=overlay,
                         T_dur=ModelTtaBounds.T_dur)

In [31]:
ddm.functions.display_model(modelTtaBounds.model)

Model 5 TTA- and d-dependent drift and bounds and random nondecision time information:
Drift component DriftTtaDistance:
    Drift depends on TTA and distance
    Fixed parameters:
    - alpha: 1.502183
    - beta: 0.019671
    - theta: 6.834053
Noise component NoiseConstant:
    constant
    Fixed parameters:
    - noise: 1.000000
Bound component BoundCollapsingTta:
    Bounds collapsing with TTA
    Fixed parameters:
    - b_0: 1.257530
    - k: 1.263810
    - tta_crit: 3.732779
IC component ICPointSourceCenter:
    point_source_center
    (No parameters)
Overlay component OverlayNonDecisionGaussian:
    Add a Gaussian-distributed non-decision time
    Fixed parameters:
    - nondectime: 0.367761
    - ndsigma: 0.032294



In [39]:
condition = conditions[0]

In [40]:
sol = modelTtaBounds.model.solve(conditions=condition)

In [44]:
def get_model_measures(model, condition):
    sol = model.solve(condition)
    return condition['tta_condition'], condition['d_condition'], sol.prob_correct(), sol.mean_decision_time()

In [45]:
model_measures = [get_model_measures(modelTtaBounds.model, condition) for condition in conditions]

In [46]:
model_measures

[(4.0, 90.0, 0.07262584077125499, 0.629672383524236),
 (4.0, 150.0, 0.358718203711926, 0.6707796359425608),
 (4.0, 120.0, 0.17682073955261116, 0.6560525454655103),
 (5.0, 90.0, 0.1663167292306485, 0.8395620104084046),
 (5.0, 150.0, 0.693725750132772, 0.842780136946839),
 (5.0, 120.0, 0.39996688779197814, 0.8613087477611082),
 (6.0, 90.0, 0.5560637835154901, 0.9578353525621883),
 (6.0, 150.0, 0.9713896319018631, 0.7962208158296877),
 (6.0, 120.0, 0.8419404998318335, 0.9007292521843867)]