In [1]:
import numpy as np
import pandas as pd
import scipy.stats as stats
import plotly.express as px
import plotly.graph_objects as go
from plotly import figure_factory as ff
from ddm_model import DDM, Simulation, Threshold

# Base Model

In [2]:
th1 = Threshold(15, type="collapsing", collapse_arg=(15,0.005,8))

In [3]:
ddm_one = DDM(drift_rate=0.017348, noise_mag=0.5, threshold=th1,initial_condition=0.03) #drift rate max is around 0.07.

In [4]:
simulation_one = Simulation(ddm_one, 1, 1000,15)

In [5]:
simulation_one.plot_trajectories()

In [6]:
simulation_one.reaction_times.mean()

np.float64(403.1333333333333)

In [9]:
fig = ff.create_distplot([simulation_one.reaction_times.flatten()], ['Response Times'], show_hist=False, show_rug=False)
fig.show()

In [10]:
simulation_one.error_rate

0.26666666666666666

Trajectory simulation data:

In [12]:
pd.DataFrame(simulation_one.simulated_trajectories).T

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14
0,-0.007864,1.256519,0.549253,0.001516,-0.606516,-0.923377,-0.553480,0.271592,0.196788,-0.223483,-0.269776,-0.606365,0.034640,0.225830,0.771803
1,-0.424772,0.410853,1.087343,0.174871,-0.336102,-1.053599,-0.298959,-0.469600,0.947787,-0.567670,0.199303,-0.463040,-0.949986,-0.731368,1.110385
2,-0.345776,0.011337,0.978367,-0.402458,-1.087463,-1.622253,0.158844,-0.196091,0.351697,-0.184749,0.844316,0.140805,-0.888204,-0.934242,0.564039
3,-0.032709,-0.100171,1.483486,-1.099218,-1.242602,-1.636808,0.727294,-0.715497,0.828449,-0.711665,1.020651,0.473590,-1.287601,-0.785840,-0.960855
4,0.233730,-0.038552,1.732746,-1.435824,-2.299390,-1.289724,0.455793,-0.428983,0.716344,-0.757325,0.470312,-0.379876,-1.172933,-1.114958,-1.997283
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
765,,,,,,,,,,,,,6.686211,,
766,,,,,,,,,,,,,7.523215,,
767,,,,,,,,,,,,,7.322017,,
768,,,,,,,,,,,,,7.077627,,


# Fitting real data with model.fit() using Optuna

Empirical data from Myers et al. 2022 -> https://www.frontiersin.org/journals/psychology/articles/10.3389/fpsyg.2022.1039172/full

In [13]:
empdata = pd.read_csv("empdata1.csv") #Loading empirical data

In [15]:
empdata_ms = empdata.rt.values*1000 #rescale from seconds to milliseconds

In [50]:
study = ddm_one.fit(trials_data=empdata_ms)

[I 2025-02-23 00:27:15,562] A new study created in memory with name: no-name-2e3abe1f-2925-4f03-90f6-f153f9524498

suggest_uniform has been deprecated in v3.0.0. This feature will be removed in v6.0.0. See https://github.com/optuna/optuna/releases/tag/v3.0.0. Use suggest_float instead.


suggest_uniform has been deprecated in v3.0.0. This feature will be removed in v6.0.0. See https://github.com/optuna/optuna/releases/tag/v3.0.0. Use suggest_float instead.


suggest_uniform has been deprecated in v3.0.0. This feature will be removed in v6.0.0. See https://github.com/optuna/optuna/releases/tag/v3.0.0. Use suggest_float instead.

[I 2025-02-23 00:27:15,839] Trial 0 finished with value: 141840.83576282882 and parameters: {'drift_rate': 0.08716765620193072, 'noise_mag': 0.7739385133036925, 'threshold': 10.71862425095039}. Best is trial 0 with value: 141840.83576282882.
[I 2025-02-23 00:27:15,878] Trial 1 finished with value: 182675.07537085112 and parameters: {'drift_rate': 0.0218935979360

Best parameters: {'drift_rate': 0.042685399485695844, 'noise_mag': 0.04828106710937197, 'threshold': 18.792181079825358}


In [51]:
simulation_two = Simulation(ddm_one, 1, 1000,20) #simulate 20 trials for 1000 ms with the new fitted parameters.

In [52]:
simulation_two.plot_trajectories()

In [58]:
print(simulation_two.reaction_times.mean())
print(empdata_ms.mean())
print(simulation_two.reaction_times.std())
print(empdata_ms.std())

439.45
432.673686686
21.88258439947165
186.83221306387455


Visual inspection of RT distribution

In [54]:
kde1 = stats.gaussian_kde(empdata_ms, bw_method='scott')
kde2 = stats.gaussian_kde(simulation_two.reaction_times.flatten(), bw_method='scott')

x_range = np.linspace(min(min(empdata_ms), min(simulation_two.reaction_times.flatten())), max(max(empdata_ms), max(simulation_two.reaction_times.flatten())), 1000)

In [55]:
combined_fig = go.Figure()

combined_fig.add_trace(go.Scatter(x=x_range, y=kde1(x_range), mode='lines', name='Empirical Response Times', line=dict(color='blue', width=2)))
combined_fig.add_trace(go.Scatter(x=x_range, y=kde2(x_range), mode='lines', name='Simulated Response Times', line=dict(color='red', width=2)))

combined_fig.update_layout(title="Combined Response Time Distributions",xaxis=dict(title="Response Time (ms)"),yaxis=dict(title="Density"),showlegend=True)

combined_fig.show()