In [39]:
import pandas as pd
import matplotlib.pyplot as plt
import torch

# Experimental data

*Roitman, JD and Shadlen, MN (2002), “Response of neurons in the lateral intraparietal area during a combined visual discrimination reaction time task”, Journal of Neuroscience, Vol. 22(21), 9475-9489.*


#### additional information:
- 'coh'      coherence of trial (multiplied by 10 - ie. 32 is a coherence of 3.2%)
- 'correct'  whether the subject was correct (1 - correct, 0 - error)
- rt should be in ms (between 5 and 1762, but most in [200,1200])
- in Shinn et al. animal N (coded as 0 below) is used for all main figures
- per monkey and coherence level: trial number n>500 for animal 0, n>400 for animal 1.



In [40]:
ddm_data = pd.read_csv("../data/roitman_data_clean.csv")

In [41]:
ddm_data

Unnamed: 0,rt,coherence,decision,animal
0,464.0,256.0,1.0,0.0
1,318.0,64.0,1.0,0.0
2,531.0,128.0,1.0,0.0
3,567.0,0.0,1.0,0.0
4,398.0,0.0,1.0,0.0
...,...,...,...,...
6143,743.0,64.0,1.0,1.0
6144,704.0,0.0,1.0,1.0
6145,490.0,512.0,1.0,1.0
6146,558.0,256.0,1.0,1.0


# Generative model 
Use the *pyddm* toolbox:

https://pyddm.readthedocs.io/en/stable/

In [42]:
# load pyddm 

from pyddm import Model
from pyddm.models import DriftConstant, NoiseConstant, BoundConstant, OverlayNonDecision, ICPointSourceCenter
from pyddm.functions import fit_adjust_model, display_model

import pyddm.plot

from pyddm import Fittable
from pyddm.models import LossRobustBIC
from pyddm.functions import fit_adjust_model, fit_model

from pyddm.models import (
    BoundCollapsingExponential,
    DriftLinear,
    ICPointSourceCenter,
    NoiseConstant,
    OverlayNonDecision,
)


from roitman_utils import filter_roitman_data 
# this filters the data and puts it into the format we need for pyddm

In [43]:
data = filter_roitman_data(ddm_data, 
                            coherence=128, 
                            animal=0 , 
                            n_trial="all", 
                            attach_model_mask=False,
                            partition=None,
                            data_mode='pyddm')

In [44]:
model_fit = Model(name='Simple model (fitted)',
                        drift= DriftLinear(drift=Fittable(minval=0, maxval=5),t=0, x=Fittable(minval=-20, maxval=10)),
                        noise=NoiseConstant(noise=1),
                        bound=BoundCollapsingExponential(B=Fittable(minval=0.5, maxval=4), tau=Fittable(minval=0.1, maxval=4)),
                        overlay=OverlayNonDecision(nondectime=Fittable(minval=0.1, maxval=0.4)),
                        IC = ICPointSourceCenter(),
                        dx=.001, dt=.01, T_dur=2)

# fit model
fit_adjust_model(data, model_fit,
                fitting_method="differential_evolution",
                lossfunction=LossRobustBIC, verbose=False)

sol = model_fit.solve()


Info: Params [  2.56341775 -19.33219281   0.64081395   0.86808267   0.1500179 ] gave 109.49685410101176


In [45]:
# pyddm.plot.plot_fit_diagnostics(model=model_fit, sample=data)
# plt.show()


In [46]:
# solve model 
sol = model_fit.solve()


In [47]:
# generate data 
generated_data = sol.resample(k=1000)

In [48]:
generated_data = torch.tensor(generated_data.choice_upper, dtype=torch.float32).unsqueeze(-1)
real_data = torch.tensor(data.choice_upper, dtype=torch.float32).unsqueeze(-1)

In [49]:
torch.save(generated_data, "../data/ddm/generated_data.pt")
torch.save(real_data, "../data/ddm/real_data.pt")

In [50]:
# generated_data

tensor([[0.7097],
        [0.5650],
        [1.1352],
        [0.7177],
        [0.7140],
        [0.6700],
        [0.4412],
        [0.6106],
        [0.6330],
        [0.9345],
        [0.9981],
        [0.2461],
        [0.6374],
        [0.4970],
        [0.5197],
        [0.6032],
        [0.5877],
        [0.6987],
        [0.8002],
        [0.5168],
        [0.7588],
        [0.5272],
        [0.5055],
        [0.6410],
        [0.7639],
        [0.7163],
        [0.8408],
        [0.2743],
        [0.5751],
        [0.9398],
        [0.7643],
        [0.8414],
        [0.4087],
        [0.7382],
        [0.6843],
        [0.7866],
        [0.7429],
        [0.9640],
        [1.0074],
        [0.6836],
        [0.8594],
        [0.8915],
        [0.2571],
        [0.6818],
        [0.8716],
        [0.5505],
        [0.9229],
        [0.7973],
        [0.6618],
        [0.5870],
        [0.7654],
        [0.4451],
        [0.8201],
        [0.7431],
        [0.4237],
        [0