In [None]:
import matplotlib.dates as mdates
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn.objects as so
import utils
from ssms.basic_simulators.simulator import simulator
from statsmodels.tsa.stattools import acf

from drift_diffusion.model import pdf
from drift_diffusion.model import DriftDiffusionModel

plt.rcParams.update({"font.size": 12, "font.family": "Helvetica"})
so.Plot.config.theme.update(plt.rcParams)

Experimental Design

In [None]:
df195 = (
    utils.mat_to_pd("datasets/Rat195Vectors_241025.mat")
    .query("Valid == 1 and RT == RT")
    .assign(trialDate=lambda x: pd.to_datetime(x["trialDate"] - 719529, unit="D"))
    .set_index("trialDate")
    .sort_index()
    .assign(
        RT=lambda x: x["RT"] - x["RT"].min() + 1e-2,
        y=lambda x: x["RT"] * x["correct"].map({1: 1, 0: -1}),
        coh_bins=lambda x: pd.cut(x["coherence"], bins=5),
        day=lambda x: ((x.index - pd.Timedelta(hours=14)).floor("D") + pd.Timedelta(hours=14)),  # 2pm - 2pm
        hour=lambda x: x.index.hour + 1,
    )
)

df195_heatmap = df195.loc["2008-12-03":"2009-03-12"].groupby(["day", "hour"], as_index=False)["RT"].count()

min, max = df195_heatmap["RT"].min(), df195_heatmap["RT"].max()

cmap = "YlOrRd"
fig, ax = plt.subplots(figsize=(9, 2))
sc = ax.scatter(
    df195_heatmap["day"], df195_heatmap["hour"], c=df195_heatmap["RT"], marker="s", cmap=cmap, s=10, vmin=-300
)
ax.set_ylabel("Hour")
ax.set_yticks([1, 6, 12, 18, 24])
ax.set_xlabel("Day")
ax.xaxis.set_major_formatter(mdates.DateFormatter("%b-%d"))

cbar = fig.colorbar(sc, ax=ax, location="top", ticks=(min, max), anchor=(1, 1), fraction=0.06, aspect=15, pad=0.05)
cbar.set_label("# Trials")

Decision-Making Process

In [None]:
vs = np.repeat(np.linspace(0, 1, 21) * 2, 1000)
sim_params = {"a": 1.5, "t": 0, "v": vs, "z": 0.5}

sims = simulator(model="ddm", theta=sim_params, n_samples=1, random_state=1)
dfsim = pd.DataFrame(
    {
        "coherence": vs / 2,  # still 0–1 if you need it elsewhere
        "coherence_pct": (vs / 2) * 100,  # 0–100 for plotting
        "rts": sims["rts"].squeeze(),
        "choices": np.where(sims["choices"].squeeze() > 0, 1, 0),
    }
)

(
    so.Plot(dfsim, x="coherence_pct", y="rts")
    .add(so.Dot(color="k"), so.Agg(func="mean"))
    .add(so.Range(color="k"), so.Est(errorbar=("ci", 95)))
    .label(x="% Rightward Motion", y="Proportion\nRightward Choices")
    .theme({"font.size": 18})
    .layout(size=(5, 4))
)

Drift Diffusion Model

In [None]:
params = {"a": 0.5, "t0": 0.01, "v": 0, "z": 0}

num = 1000
rts = params["t0"] + np.linspace(1e-5, 1, num)
y = np.r_[-rts[::-1], rts]

fig, ax = plt.subplots(figsize=(11, 2))

for v, ls in [(0, "-"), (1, "--")]:
    params["v"] = v
    ps = pdf(y, **params)
    ps /= ps.sum()
    label = rf"$\theta = (a={params['a']}, t_0={params['t0']}, \mathbf{{v={v}}}, z={params['z']})$"
    ax.plot(y[num:], ps[num:], color="r", ls=ls, label=label)
    ax.plot(y[:num], ps[:num], color="b", ls=ls)

ax.legend()
ax.ticklabel_format(style="sci", axis="y", scilimits=(0, 1))
ax.set_xlabel(r"$RT$")
ax.set_ylabel(r"$f\;(RT\mid\theta)$")

Decisions as Time Series

In [None]:
fig, ax = plt.subplots(figsize=(4, 3))

n_lags = 500
x = -np.arange(n_lags)
ax.plot(x, acf(df195["RT"], adjusted=True, nlags=n_lags)[1:], c="k", lw=0.5)
ax.set_ylim(0, 0.16)
ax.set_xlabel("Trial")
ax.set_ylabel("Corr$(|RT|)$")

Estimation and Inference

In [None]:
df195_fit_by_day = utils.fit_ddm(df195.loc["2008-12-03":"2009-03-12"], "day")
df195_fit_by_day = df195_fit_by_day.assign(
    v=lambda df: df["beta_v"] * 0.85, **{"v-": lambda df: df["beta_v-"] * 0.85, "v+": lambda df: df["beta_v+"] * 0.85}
)

In [None]:
ddm_config = {
    "var1": "a",
    "var2": "v",
    "var1_lim": [1, 2.05],
    "var2_lim": [0.25, 2],
    "var1_label": r"$\hat{a}$",
    "var2_label": r"$\hat{v}$",
    "corr_label": r"$\text{{corr}}(\hat{{a}},\hat{{v}})={corr:.2f}$",
    "acf_label1": r"$\text{corr}(\hat{a})$",
    "acf_label2": r"$\text{corr}(\hat{v})$",
}
fig = utils.plot_fits(df195_heatmap, df195_fit_by_day, x="day", y="hour", config=ddm_config)

Maximum Likelihood Estimator

In [None]:
sim_params = {"a": 0.5, "t": 0.1, "v": 0.5, "z": 0.5}

sims = simulator(model="ddm", theta=sim_params, n_samples=10_000, random_state=1)
y = sims["rts"].squeeze() * sims["choices"].squeeze()

fig, ax = plt.subplots(figsize=(8, 4))
ax.hist(y, bins=200, color="k", alpha=0.25, density=True, label=r"$\hat{f}\;(RT)$")

num = 1000
rts = np.linspace(1e-5, 1, num)
y = np.r_[-rts[::-1], rts]
ps = pdf(y, 0.6, 0.01, 0, 0)
ax.plot(y[num:], ps[num:], color="r", ls="--", label=r"$f\;(RT \mid \theta^{(0)})$")
ax.plot(y[:num], ps[:num], color="b", ls="--")
ps = pdf(y, 0.5, 0.1, 0.5, 0)
ax.plot(y[num:], ps[num:], color="r", label=r"$f\;(RT\mid\hat{\theta}_\text{MLE})$")
ax.plot(y[:num], ps[:num], color="b")
ax.legend()
ax.set_xlim((-1.01, 1.01))
ax.set_xlabel(r"$RT$")
ax.set_ylabel(r"$f$")

In [None]:
log_likelihood = lambda y, params: np.sum(np.log(pdf(y, *params)))
n_samples = 1000
X = pd.DataFrame(np.ones(n_samples))

sim_params = [
    {"a": 2, "t": 0.1, "v": 0.5, "z": 0.5},
    {"a": 4, "t": 0.1, "v": 0.5, "z": 0.5},
]


levels = 20
vs = np.linspace(-2, 2.5, levels)
colors = ["gray", "k"]

fig, ax = plt.subplots(figsize=(8, 4))

for idx in range(2):
    sims = simulator(model="ddm", theta=sim_params[idx], n_samples=n_samples, random_state=2)
    y = sims["rts"].squeeze() * sims["choices"].squeeze()
    log_likelihoods = [log_likelihood(y, [sim_params[idx]["a"], sim_params[idx]["t"], v, 0]) for v in vs]
    ddm = DriftDiffusionModel(a=sim_params[idx]["a"], t0=sim_params[idx]["t"], v="+1", z=0)
    stderr = ddm.fit(X, y).covariance_.squeeze()
    ax.plot(vs, log_likelihoods, color=colors[idx], label=f"$\\widehat{{se}}(\\hat{{v}}_\\text{{MLE}})$ = {stderr:.1e}")

ax.legend()
ax.ticklabel_format(style="sci", axis="y", scilimits=(0, 1))
ax.set_xlabel(r"$v$")
ax.set_ylabel(r"$\ell(v \mid RT_t)$")