In [4]:
import numpy as np
import pandas as pd
import plotly.express as px

In [69]:
class DDM():
    def __init__(self,dt,input_a,input_b,noise_mag,threshold):
        self.dt = dt
        self.input_a = input_a
        self.input_b = input_b
        self.noise_mag = noise_mag
        self.threshold = threshold

        self.simulated_trajectories = None #you have to use them with .T
        self.simulated_results = None
    
    def simulate(self,trials, timesteps):
        results = []
        trajectories = []
        for trial in range(trials):
            x = 0
            traj = [x]
            for y in range(timesteps):
                noise_term = np.random.randn()
                dxdt = (self.input_a - self.input_b) + (self.noise_mag * noise_term)
                x += dxdt
                traj.append(x)

                if x >= self.threshold:
                    results.append("A")
                    break
                elif x <= -self.threshold:
                    results.append("B")
                    break
            else:
                results.append("None")
            trajectories.append(traj)
        
        self.simulated_trajectories = trajectories
        self.simulated_results = results
    
    def plot_trajectories(self):
        s = pd.DataFrame(self.simulated_trajectories).T
        fig = px.line(s, title="Drift-Diffusion Model Simulations", labels={"index": "Time Step", "value": "Decision Variable x(t)"})
        fig.add_hline(y=self.threshold, line_dash="dash", line_color="red", annotation_text="Decision A")
        fig.add_hline(y=-self.threshold, line_dash="dash", line_color="blue", annotation_text="Decision B")
        fig.show()


In [70]:
ddm_one = DDM(dt=0.1, input_a =0.95, input_b =1, noise_mag=7, threshold=20)

In [73]:
ddm_one.simulate(trials=10,timesteps=10000)

In [74]:
ddm_one.plot_trajectories()

In [77]:
pd.DataFrame(ddm_one.simulated_trajectories).T

Unnamed: 0,0,1,2,3,4,5,6,7,8,9
0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,-14.436292,10.408394,-5.502496,-3.441903,2.938223,0.960068,1.311577,-4.045355,-5.802765,-1.413008
2,-8.373956,20.244675,-4.310542,-8.709091,-1.104822,0.134632,3.298087,0.311574,-13.373633,3.197481
3,-9.135445,,10.271141,-1.272477,-11.089575,4.844482,7.501756,15.066547,-5.641817,8.675726
4,-8.084335,,9.960715,6.690855,-11.211853,-2.557679,13.613355,9.532958,-11.250574,7.06695
5,-10.601674,,9.330708,6.255353,0.525998,-1.145591,9.152937,7.232426,-15.870941,11.04001
6,-0.654704,,7.457781,0.147157,1.78643,-7.749782,2.788509,9.205368,-17.928212,11.970052
7,10.347633,,8.401298,8.586389,-6.566251,-1.44044,-22.631994,2.220663,-19.447886,15.048649
8,15.024409,,8.065857,17.302075,-22.341337,-12.291211,,-1.210307,-24.182511,19.153153
9,9.518559,,-3.479402,18.472395,,-16.520501,,-2.597077,,18.905938


In [99]:
ddm_2 = DDM(dt=0.1, input_a =0.95, input_b =1, noise_mag=7, threshold=20)

In [100]:
ddm_2.simulate(trials=10,timesteps=10000)

In [101]:
ddm_2.plot_trajectories()