In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import helper
import os
from matplotlib.lines import Line2D
import matplotlib.ticker as ticker

In [None]:
exp_measures = pd.read_csv("../data/measures.csv")
exp_data = pd.read_csv("../data/processed_data.csv", index_col=["subj_id", "session", "route", "intersection_no"])
model_params = pd.read_csv("../model_fit_results/model_1/best_fit_parameters/full_data_parameters_fitted.csv")
simulation_results_path = "../model_fit_results/model_1/simulation_results/"

# Main paper figures

### Model diagram

In [None]:
def get_trace(t, dt, tta, d, model_params):  
    noise = np.random.randn(len(t))
    drift = model_params["alpha"]*(tta+ model_params["beta"]*d - model_params["theta"])
    dx = drift * simulation_params["dt"] +  model_params["noise"] * noise * np.sqrt(simulation_params["dt"])
    x = np.append([0], np.cumsum(dx)[:-1])    
    boundary = model_params["b_0"]/(1+np.exp(-model_params["k"]*(tta-model_params["tta_crit"])))
    response_time_idx = np.where(abs(x)>boundary)
    if len(response_time_idx[0])>0:
        return(x[:response_time_idx[0][0]+1])
    else:
        return(x)
    
def generate_model_traces(t, dt, tta, d, model_params):
    is_two_correct = False
    is_rt_ok = False
    while not(is_two_correct) or not(is_rt_ok):
        trials = [get_trace(t, dt, tta, d, model_params) for i in range(3)]
        is_two_correct = (np.sum([1 if x[-1]>0 else 0 for x in trials]) == 2)
        rts = [model_params["ndt_location"]+len(trial)*dt for trial in trials]
        is_rt_ok = all([(rt>0.4) & (rt<0.7) for rt in rts])
    pd.DataFrame(trials).to_csv("ddm_traces.csv")

In [None]:
def plot_model_traces(model_params, simulation_params, generate_new=False):
    colors = {"go": "#4052ac", "stay": "#e6263d", "equation": "#50596A"}
    condition = {"tta": 6, "d": 120}
    dt = simulation_params["dt"]
    t = np.arange(0, simulation_params["duration"], dt)
    d = condition["d"] - (condition["d"] / condition["tta"])*t
    tta = condition["tta"] - t
    ndt = model_params["ndt_location"]

    if generate_new:
        generate_model_traces(t, dt, tta, d, model_params)        
    trials = [x[1][~x[1].isna()].values[1:] for x in pd.read_csv("ddm_traces.csv").iterrows()]
    
    fig, (ax_rate, ax) = plt.subplots(2, 1, figsize=(8,6), gridspec_kw={"height_ratios": [1, 3]}, dpi=150, sharex=True)
    
    for x in trials:
        color = colors["go"] if x[-1]>0 else colors["stay"]
        ax.plot(t[:len(x)]+ndt, x, lw=1.0, alpha=0.4, color=color)
        ax.plot(t[len(x)]+ndt, x[-1], ls="", alpha=0.8, marker="o", ms=4, color=color)

    for tta_condition in [4, 6]:
        boundary = model_params["b_0"]/(1+np.exp(-model_params["k"]*(tta_condition-t-model_params["tta_crit"])))
        ax.plot(t+ndt, boundary, ls="--", color="gray", alpha=0.6)
        ax.plot(t+ndt, -boundary, ls="--", color="gray", alpha=0.6)
        ax.text(0.78, 0.3+0.15*(tta_condition-4), r"TTA$|_{t=0}$=%is" % (tta_condition),
                color=colors["equation"], alpha=0.6, fontsize=16)
        
        # Add extra markers to illustrate what would happen if the boundary was lower (TTA condition is 4s)
        if tta_condition==4:
            for x in trials:
                color = colors["go"] if x[-1]>0 else colors["stay"]
                response_time_idx = np.where(abs(x)>boundary[:len(x)])[0][0]
                ax.plot(t[response_time_idx]+ndt, x[response_time_idx], ls="", alpha=0.5,
                        marker="o", ms=4, color=color)

    ax.text(0.01, 0.04, "Non-decision time", color=colors["equation"], fontsize=13)
    ax.plot([0, ndt], [0,0], lw=1, color=colors["equation"])
    ax.vlines(x=[0, ndt-0.002], ymin=-0.02, ymax=0.09, ls="--", color=colors["equation"], lw=1)

    ax.annotate("", xy=(ndt,0), xytext=(ndt+0.2, 0.15),
                 arrowprops=dict(arrowstyle="<-", connectionstyle="arc3", color=colors["equation"], linewidth=2))
    ax.text(0.43, 0.02, r"$dx = \alpha(TTA + \beta d - \theta_{crit}) dt + dW$", color=colors["equation"], fontsize=16)
    
    ax.text(0.47, 0.6, "Go", color=colors["go"], alpha=0.7, fontsize=16)
    ax.text(0.47, -0.65, "Stay", color=colors["stay"], alpha=0.7, fontsize=16)

    ax.set_xlabel(r"Time $t$", fontsize=14)
    ax.set_ylabel(r"Accumulated evidence $x$", fontsize=14)
    ax.set_yticks([-0.6, -0.3, 0, 0.3, 0.6])
    ax.set_ylim((-0.7, 0.7))
   
    ax_rate.plot(t+ndt, tta+model_params["beta"]*d, color="#50596A", alpha=0.8)
    ax_rate.axhline(model_params["theta"], color="grey", alpha=0.6, ls="--", xmin=0.01, xmax=0.99)
    ax_rate.set_ylabel(r"$TTA + \beta d$", fontsize=16)
    ax_rate.set_xlim((0,1))
    ax_rate.set_ylim((8,12))
    ax_rate.text(0.05, 9.5, r"$\theta_{crit}$", color="grey", alpha=0.7, fontsize=16)

    sns.despine(offset=2, trim=True)
    plt.tight_layout()

In [None]:
all_subj_params = model_params[model_params.subj_id=="all"].to_dict("records")[0]
all_subj_params["ndt_location"] -= 0.1
all_subj_params["theta"] += 0.5
simulation_params = {"dt": 0.0001, "duration": 1}

plot_model_traces(all_subj_params, simulation_params, generate_new=False)
plt.savefig("../output/fig_model_diagram.png", bbox_inches="tight")

### Trial timeline

In [None]:
traj_id = (129, 1, 2, 5)
traj = exp_data.loc[traj_id].reset_index()

In [None]:
def plot_traj(traj, ax=None):
    color = "#50596A"
    font_size = 14
    ms = 12
    lw = 3
    
    idx_bot_spawn = int(traj.iloc[0].idx_bot_spawn)
    idx_response = int(traj.iloc[0].idx_response)
    idx_min_distance = int(traj.iloc[0].idx_min_distance)
    tta = int(traj.iloc[0].tta_condition)
    d = int(traj.iloc[0].d_condition)
    
    traj.loc[:, "t"] -= traj.t.values[idx_bot_spawn]
    
    fig, axes = plt.subplots(4, 1, figsize=(8,6), dpi=150, sharex=True)
    axes[0].plot(traj.t, traj.ego_v, color=color, lw=lw)
    axes[0].plot(traj.t[idx_bot_spawn], traj.ego_v[idx_bot_spawn], color=color, ls="", marker="o", ms=ms)
    axes[0].plot(traj.t[idx_response], traj.ego_v[idx_response], color=color, ls="", marker="x", ms=ms)
    axes[0].set_ylabel("Velocity, m/s", fontsize=font_size)
    axes[0].set_ylim((-1, 7))
    
    for x in [traj.t[idx_bot_spawn], traj.t[idx_response]]:
        axes[0].axvline(x=x, ymin=-0.3, ymax=0.5, c=color, ls="--", clip_on=False)
        axes[1].axvline(x=x, ymin=-0.2, ymax=1, c=color, ls="--", clip_on=False)
    
    axes[1].annotate(text="", xy=(traj.t[idx_bot_spawn], 70), xytext=(traj.t[idx_response], 70),
                     arrowprops=dict(arrowstyle="<->", color="0.5"))
    axes[1].text(.21, .79, "Response time (RT)", transform=axes[1].transAxes, fontsize=font_size)
    
#     throttle_brake = traj.throttle-traj.brake
    axes[1].plot(traj.t, 100*traj.throttle, color=color, lw=lw)
    axes[1].plot(traj.t[idx_bot_spawn], 100*traj.throttle[idx_bot_spawn], color=color, ls="", marker="o", ms=ms)
    axes[1].plot(traj.t[idx_response], 100*traj.throttle[idx_response], color=color, ls="", marker="x", ms=ms)
    axes[1].set_ylabel("Gas pedal, %", fontsize=font_size)
    axes[1].set_ylim((-20, 120))
    axes[1].set_yticks([0, 100])

    axes[2].plot(traj.t[idx_bot_spawn:], traj.d_ego_bot[idx_bot_spawn:], color=color, lw=lw)
    axes[2].plot(traj.t[idx_bot_spawn], traj.d_ego_bot[idx_bot_spawn], color=color, ls="", marker="o", ms=ms)
    axes[2].plot(traj.t[idx_response], traj.d_ego_bot[idx_response], color=color, ls="", marker="x", ms=ms)
    axes[2].set_ylabel("Distance, m", fontsize=font_size)
    axes[2].set_ylim((80, 122))

    axes[3].plot(traj.t[idx_bot_spawn:], traj.tta[idx_bot_spawn:], color=color, lw=lw)
    axes[3].plot(traj.t[idx_bot_spawn], traj.tta[idx_bot_spawn], color=color, ls="", marker="o",
                 label="Velocity$=0$: oncoming car appears", ms=ms)
    axes[3].plot(traj.t[idx_response], traj.tta[idx_response], color=color, ls="", marker="x",
                 label="Gas pedal$>0$: decision is made", ms=ms)
    axes[3].set_ylabel("TTA, s", fontsize=font_size)
    axes[3].set_ylim((3.5, 6.5))
    axes[3].set_yticks([4, 5, 6])

    axes[3].set_xlim((-0.3, 1.5))
    axes[3].set_xticks(np.arange(0, 1.4, 0.3))
    
    axes[3].set_xlabel("Time $t$, s", fontsize=font_size)

    legend = fig.legend(loc="upper left", fontsize=font_size, bbox_to_anchor=(0.4, 1.05), frameon=False)
#     legend.get_title().set_fontsize("16")
    plt.tight_layout()

In [None]:
plot_traj(traj)
plt.savefig("../output/fig_trial_timeline.png", bbox_inches="tight")

### Model fit against data

In [None]:
def plot_all_subj_p_go(ax, exp_data, d_condition, marker, color, marker_offset=0):
    between_subj_mean = exp_data[(exp_data.d_condition==d_condition)].groupby(["subj_id", "tta_condition"]).mean()
    data_subj_d_measures = helper.get_mean_sem(between_subj_mean.reset_index(), var="is_go_decision", n_cutoff=2)
    ax.errorbar(data_subj_d_measures.index+marker_offset, data_subj_d_measures["mean"], yerr=data_subj_d_measures["sem"],
                    ls="", marker=marker, ms=9, color=color)
    
def plot_subj_p_go(ax, exp_data, d_condition, subj_id, marker, color):
    data_subj_d_measures = exp_data[(exp_data.subj_id==subj_id) & (exp_data.d_condition==d_condition)]
    psf_ci = helper.get_psf_ci(data_subj_d_measures)
    ax.plot(psf_ci.tta_condition, psf_ci.p_go, ls="", marker=marker, ms=9, color=color, zorder=10)
    ax.vlines(x=psf_ci.tta_condition, ymin=psf_ci.ci_l, ymax=psf_ci.ci_r, color=color, zorder=10)

def plot_subj_rt(ax, exp_data, d_condition, subj_id, marker, color, marker_offset=0):
    if subj_id=="all":
        between_subj_mean = exp_data[(exp_data.d_condition==d_condition) & (exp_data.is_go_decision)].groupby(["subj_id", "tta_condition"]).mean()
        measures = between_subj_mean.reset_index()
    else:
        measures = exp_data[(exp_data.subj_id==subj_id) & (exp_data.d_condition==d_condition) & (exp_data.is_go_decision)]

    if len(measures)>0:
        measures_mean_sem = helper.get_mean_sem(measures, var="RT", n_cutoff=2)
        ax.errorbar(measures_mean_sem.index+marker_offset, measures_mean_sem["mean"], yerr=measures_mean_sem["sem"],
                        ls="", marker=marker, ms=9, color=color)
    
def plot_compare_model_exp(var, exp_data, model_measures, ylabel):  
    #because of the bug in matplotlib/pandas, we need to convert subj_id to str to avoid warnings
    exp_data = exp_data.astype({"subj_id": str})
    # group_rt_means = pd.read_csv("rt_group_means.csv")
    model_measures = model_measures[(model_measures.tta_condition>=4.0) & (model_measures.tta_condition<=6.0)]

    d_conditions = [90, 120, 150]
    markers=["o", "s", "^"]
    colors = [plt.cm.viridis(r) for r in np.linspace(0.1,0.7,len(d_conditions))]
    
    subjects = model_measures.subj_id.unique()
    
    fig, axes = plt.subplots(4, 5, figsize=(14,12), sharex=True, sharey=True)

    axes_to_plot = np.concatenate([axes[:,:4].flatten(), [axes[0,4]]])
    for subj_id, subj_idx, ax in zip(subjects, range(len(subjects)), axes_to_plot):
        ax.set_title("All participants" if subj_id=="all" else "P%s" % (subj_idx+1), fontsize=16)
        for d_condition, color, marker in zip(d_conditions, colors, markers):
            model_subj_d_measures = model_measures[(model_measures.subj_id==subj_id) 
                                                   & (model_measures.d_condition==d_condition)]         
            
            if var=="is_go_decision":
                # Model
                ax.plot(model_subj_d_measures.tta_condition, model_subj_d_measures[var], 
                    color=color, label=d_condition)
                
                # Data
                # for all subjects, use scipy sem estimates, since we average over continuous measures (p_go)
                if subj_id=="all":
                    plot_all_subj_p_go(ax, exp_data, d_condition, marker, color)
                # for individual subjects, use binomial proportion sem estimates, since we average over binary measures (p_go)
                else:
                    plot_subj_p_go(ax, exp_data, d_condition, subj_id, marker, color)

            elif var=="RT":
                # Model
                ax.plot(model_subj_d_measures.tta_condition, model_subj_d_measures[var], color=color, label=d_condition)

                # Data
                plot_subj_rt(ax, exp_data, d_condition, subj_id, marker, color)
                
                # ax.set_xticklabels([4, 5, 6])
                ax.set_yticks([0.3, 0.6, 0.9, 1.2])
            else:
                raise(ValueError)

        ax.legend().remove()
        ax.set_xlabel("")
        ax.set_ylabel("")

    sns.despine(offset=5, trim=True)

    for ax in axes[1:,4]:
        fig.delaxes(ax)

    plt.tight_layout()

    legend_elements = ([Line2D([0], [0], color=color, marker=marker, ms=9, lw=0, label="Data, d=%im" % (d_condition))
                       for d_condition, color, marker in zip(d_conditions, colors, markers)] 
                       + [Line2D([0], [0], color="grey", label="Model")])

    fig.legend(handles=legend_elements, loc="center", bbox_to_anchor=(0.9, 0.62), fontsize=16, frameon=False)

    fig.text(0.35, -0.02, "Time-to-arrival (TTA), s", fontsize=18)
    fig.text(-0.02, 0.39, ylabel, fontsize=18, rotation=90)

In [None]:
model_measures_all_conditions = pd.read_csv(os.path.join(simulation_results_path, "full_data_measures.csv"))

plot_compare_model_exp("is_go_decision", exp_measures, model_measures_all_conditions, "Probability of go decision")
plt.savefig("../output/fig_p_go.png", bbox_inches="tight")

plot_compare_model_exp("RT", exp_measures, model_measures_all_conditions, "Response time, s")
plt.savefig("../output/fig_RT.png", bbox_inches="tight")

### RT distributions

In [None]:
def plot_condition_vincentized_dist(ax, condition, condition_data, kind="cdf"):
    # colors = dict(zip([90,120,150], [plt.cm.viridis(r) for r in np.linspace(0.1,0.7,3)]))
    # markers={90: "o", 120: "s", 150: "^"}
#     q = [0.1, 0.3, 0.5, 0.7, 0.9]
    q = np.linspace(0.01, 0.99, 15)
    condition_quantiles = condition_data.groupby("subj_id").apply(lambda d: np.quantile(a=d.RT, q=q)).mean()

    rt_range = np.linspace(condition_quantiles.min(), condition_quantiles.max(), len(q))
    step = rt_range[1] - rt_range[0]
    rt_grid = np.concatenate([rt_range[:3]-3*step, rt_range, rt_range[-3:]+step*3])
    vincentized_cdf = np.interp(rt_grid, condition_quantiles, q, left=0, right=1)
    vincentized_pdf = helper.differentiate(rt_grid, vincentized_cdf)

    ax.plot(rt_grid, vincentized_cdf, label="Data", color="grey", ls="", ms=9, marker="*")
    ax.set_ylim([-0.05, 1.1])
    ax.set_yticks([0.0, 0.5, 1.0])

def decorate_axis(ax, condition):
    if (((condition["d"] == 90) & (condition["TTA"] == 6))
        | ((condition["d"] == 90) & (condition["TTA"] == 5))
        | ((condition["d"] == 120) & (condition["TTA"] == 4))):
        ax.text(0.5, 1.02, "TTA=%is" % condition["TTA"], fontsize=16, transform=ax.transAxes,
                    horizontalalignment="center", verticalalignment="center")

    if condition["TTA"] == 6:
        ax.text(1.0, 0.5, "d=%im" % condition["d"], fontsize=16, transform=ax.transAxes, rotation=-90,
            horizontalalignment="center", verticalalignment="center")
    
def plot_vincentized_dist(fig, axes, exp_data, model_rts, model_no):
    plot_data = True if model_no == 1 else False
    conditions = [{"d": d, "TTA": TTA}
                  for d in sorted(exp_data.d_condition.unique()) 
                  for TTA in sorted(exp_data.tta_condition.unique())]
    
    for (ax, condition) in zip(axes.flatten(), conditions):
        condition_data = exp_data[(exp_data.is_go_decision)
                            & (exp_data.d_condition==condition["d"])
                            & (exp_data.tta_condition==condition["TTA"])]
        if len(condition_data) >= 25:       
            # Group-averaged data
            if plot_data:
                plot_condition_vincentized_dist(ax, condition, condition_data)
            
            # Model
            if not model_rts is None:
                condition_rts = model_rts[(model_rts.subj_id=="all")
                                        & (model_rts.d_condition==condition["d"])
                                        & (model_rts.tta_condition==condition["TTA"])]
                ax.plot(condition_rts.t, condition_rts.rt_corr_distr, color="black", alpha=0.8, lw=2)#, color="C%i" % (model_no-1))
        else:
            ax.set_axis_off()

        if plot_data:
            decorate_axis(ax, condition)

            ax.set_xlabel("")
            ax.set_xlim((0, 1.5))
            sns.despine(offset=5, trim=True)

    # if plot_data:
    #     legend_elements = ([Line2D([0], [0], color="grey", marker="*", ms=9, lw=0, label="Group-averaged data")]
    #     + [Line2D([0], [0], color="C%i" % (model_no-1), label="Model %i" % model_no) for model_no in [1,2,3]])
    #     fig.legend(handles=legend_elements, loc="center", bbox_to_anchor=(0.25, 0.75), fontsize=16, frameon=True)
    legend_elements = ([Line2D([0], [0], color='grey', marker='*', ms=9, lw=0, label='Group-averaged data'),
                        Line2D([0], [0], color='grey', label='Model')])
    fig.legend(handles=legend_elements, loc='center', bbox_to_anchor=(0.2, 0.75), fontsize=16, frameon=True)

    fig.text(0.43, 0.04, "Response time, s", fontsize=16)
    fig.text(0.04, 0.15, "Cumulative distribution function", fontsize=16, rotation=90)
    
    return fig, axes

In [None]:
fig, axes = plt.subplots(3, 3, figsize=(10,8), sharex=True, sharey=True)

model_rts = pd.read_csv(os.path.join(simulation_results_path, "full_data_rt_cdf.csv"))
plot_vincentized_dist(fig, axes, exp_measures, model_rts, model_no=1)
# fig.suptitle("Model %i" % model_no)
plt.savefig("../output/fig_RT_dist.png", bbox_inches="tight")

### Model cross-validation

In [None]:
model_measures_cross_validation_8 = pd.read_csv(os.path.join(simulation_results_path, "cross_validation_8_measures.csv"))

In [None]:
def plot_cross_validation_per_subj(exp_data, model_measures, subj_id):
    model_measures = model_measures[(model_measures.tta_condition>=4.0) & (model_measures.tta_condition<=6.0)]

    d_conditions = [90, 120, 150]
    colors = [plt.cm.viridis(r) for r in np.linspace(0.1,0.7,len(d_conditions))]
    markers=["o", "s", "^"]
    marker_size=9
    marker_offset = 0.05
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(6,3))
    
    if not subj_id=="all":
        fig.suptitle("P%s" % (subj_id), fontsize=16)
    
    for d_condition, marker, color in zip(d_conditions, markers, colors):
        model_subj_d_measures = model_measures[(model_measures.subj_id==subj_id) 
                                               & (model_measures.d_condition==d_condition)]
        # Model
        ax1.plot(model_subj_d_measures.tta_condition+marker_offset, model_subj_d_measures["is_go_decision"],
                    color=color, label=d_condition, ls="--", lw=1, marker=marker, ms=marker_size, fillstyle="none")
        ax2.plot(model_subj_d_measures.tta_condition+marker_offset, model_subj_d_measures["RT"],
                color=color, label=d_condition, ls="--", lw=1, marker=marker, ms=marker_size, fillstyle="none")

        # Data
        plot_all_subj_p_go(ax1, exp_data, d_condition, marker, color, -marker_offset)
        plot_subj_rt(ax2, exp_data, d_condition, subj_id, marker, color, -marker_offset)

    fig.text(0.35, -0.05, "Time-to-arrival (TTA), s", fontsize=16)
    
    ax1.set_xticks([4, 5, 6])
    ax2.set_xticks([4, 5, 6])
    
    ax1.legend().remove()
    ax2.legend().remove()
    
    ax1.set_ylabel("Probability of go", fontsize=16)
    ax2.set_ylabel("Response time", fontsize=16)

    ax1.set_ylim((0.0, 1.0))
    ax2.set_ylim((0.3, 0.8))
    
    sns.despine(offset=5, trim=True)
    plt.tight_layout()

    legend_elements = ([Line2D([0], [0], color=color, marker=marker, ms=marker_size, lw=1, ls="--", fillstyle="none", label="Model predictions,")
                           for d_condition, marker, color in zip(d_conditions, markers, colors)] 
                       + [Line2D([0], [0], color=color, marker=marker, ms=marker_size, lw=0, label="data, d=%im" % (d_condition))
                           for d_condition, marker, color in zip(d_conditions, markers, colors)])

    fig.legend(handles=legend_elements, loc="center left", bbox_to_anchor=(1.0, 0.5), fontsize=16, handlelength=1.5, columnspacing=0.2,
               frameon=False, ncol=2)

In [None]:
plot_cross_validation_per_subj(exp_measures, model_measures_cross_validation_8, subj_id="all")
plt.savefig("../output/fig_cross_validation_8_vincent.png", bbox_inches="tight")

# Supplementary figures


### ROC of p_go prediction for individual participants

In [None]:
def get_RT_means_with_cutoff(data, groupby_cols, n_cutoff=2):
    mean = data.groupby(groupby_cols)['RT'].mean()
    # sem = data.groupby('tta_condition')[var].apply(lambda x: scipy.stats.sem(x, axis=None, ddof=0))
    n = data.groupby(groupby_cols).size()
    means = pd.DataFrame({'RT': mean, 'n': n}, index=mean.index)
    return means[means.n>n_cutoff].RT

In [None]:
best_fit_params_path = '../model_fit_results/model_%i/best_fit_parameters/full_data_parameters_fitted.csv'
fit_results_path = '../model_fit_results/model_%i/simulation_results/full_data_measures.csv'

groupby_cols = ['subj_id', 'tta_condition', 'd_condition']

exp_measures = pd.read_csv('../data/measures.csv')
p_turn_by_condition = exp_measures.groupby(groupby_cols).mean().is_turn_decision
# rt_by_condition = exp_measures[exp_measures.is_turn_decision].groupby(groupby_cols).mean().RT
rt_by_condition = get_RT_means_with_cutoff(exp_measures[exp_measures.is_turn_decision], groupby_cols, n_cutoff=2)
data_means = pd.DataFrame([p_turn_by_condition, rt_by_condition]).T
print (len(data_means), len(p_turn_by_condition), len(rt_by_condition))

In [None]:
def get_data_vs_model_means(model_no, data_means):
    model_measures = pd.read_csv(fit_results_path % model_no)
    model_measures = model_measures[model_measures.tta_condition.isin([4, 5, 6])
                                    & model_measures.d_condition.isin([90, 120, 150])
                                    & ~(model_measures.subj_id=='all')]
    model_measures['subj_id'] = pd.to_numeric(model_measures.subj_id)
    model_measures.set_index(groupby_cols, inplace=True)

    data_vs_model = data_means.join(model_measures, lsuffix='_data', rsuffix='_model').reset_index()
    data_vs_model = data_vs_model.rename(columns={'is_turn_decision_data': 'p_turn_data',
                                                  'is_turn_decision_model': 'p_turn_model'})
    data_vs_model['subj_id'] =  data_vs_model['subj_id'].astype(int)
    data_vs_model[['tta_condition', 'd_condition']] = data_vs_model[['tta_condition', 'd_condition']].astype(int)
    data_vs_model['condition'] = data_vs_model[['tta_condition', 'd_condition']].astype(str).agg('_'.join, axis=1).astype(str)

    return data_vs_model

In [None]:
def plot_data_vs_model(var, data_vs_model, model_no):
    fig = plt.figure(dpi=300)
    ax = fig.add_subplot(111)
    ax.axes.set_aspect('equal')
    g = sns.scatterplot(ax=ax, data=data_vs_model, x='%s_data' % var, y='%s_model' % var,
                        hue='subj_id', style='condition', legend='full', palette='tab20')
    g.legend(loc='center left', bbox_to_anchor=(1.25, 0.5), ncol=2)
    control_x = np.linspace(data_vs_model.RT_data.min()-0.05, data_vs_model.RT_data.max()+0.05, 2) \
        if var=='RT' else np.linspace(0,1,2)
    ax.plot(control_x, control_x, color='grey')
    ax.xaxis.set_major_locator(ticker.MultipleLocator(0.2))
    ax.yaxis.set_major_locator(ticker.MultipleLocator(0.2))
    fit_metric = (1 - ((data_vs_model['%s_model' % var]-data_vs_model['%s_data' % var]).pow(2).sum())
                  /(data_vs_model['%s_data' % var]-data_vs_model['%s_data' % var].mean()).pow(2).sum())

    fig.suptitle(r'Model %i: not really $R^2=%.2f$' % (model_no, fit_metric))
    fig.savefig('../output/%s_%i.png' % (var, model_no), bbox_inches='tight')

In [None]:
for model_no in [1,2,3]:
    data_vs_model = get_data_vs_model_means(model_no=model_no, data_means=data_means)
    plot_data_vs_model('p_turn', data_vs_model, model_no=model_no)
    plot_data_vs_model('RT', data_vs_model, model_no=model_no)

The best theoretically possible ROC can be obtained from the estimated probabilities. By the way, the nudge-based manipulation of p_turn in the follow-up work could be linked to pushing the limits of theoretically possible prediction - if we cannot really be sure what the human will do and it might affect safety, we better sacrifice the current comfort to increase our prediction accuracy

In [None]:
def get_positive_rates(df, gamma):
    df["model_is_turn_decision"] = df["p_turn_model"]>gamma
    fpr = (len(df[~(df.is_turn_decision==df.model_is_turn_decision) & (df.model_is_turn_decision)])
                   /len(df[~df.is_turn_decision]))
    tpr = (len(df[(df.is_turn_decision==df.model_is_turn_decision) & (df.model_is_turn_decision)])
                   /len(df[df.is_turn_decision]))

    return (fpr, tpr)

def get_roc_curve(data_means, exp_measures, subj_id, model_no=1):
    subj_data_vs_model = exp_measures.loc[exp_measures.subj_id == subj_id,
                                          ["subj_id", "tta_condition", "d_condition", "is_turn_decision"]]
    if model_no == "best_case":
        model_predictions = data_means.reset_index().rename(columns={"is_turn_decision": "p_turn_model"})
    else:
        data_vs_model = get_data_vs_model_means(model_no=model_no, data_means=data_means)
        model_predictions = data_vs_model.loc[data_vs_model.subj_id == subj_id,
                                              ["subj_id", "tta_condition", "d_condition", "p_turn_model"]]
    subj_data_vs_model = (subj_data_vs_model.merge(model_predictions, on=["subj_id", "tta_condition", "d_condition"])
                 .reset_index(drop=True))

    return np.array([get_positive_rates(subj_data_vs_model, gamma=gamma) for gamma in np.linspace(1.001, -0.001, 21)]).T

In [None]:
fig, ax = plt.subplots()
for subj_id in exp_measures.subj_id.unique():
    roc = get_roc_curve(data_means, exp_measures, subj_id, model_no="best_case")
    ax.plot(roc[0], roc[1])
ax.set_aspect("equal")

In [None]:
subjects = exp_measures.subj_id.unique()

fig, axes = plt.subplots(4, 4, figsize=(12,14), sharex=True, sharey=True)

for subj_id, subj_idx, ax in zip(subjects, range(len(subjects)), axes.flatten()):
    ax.set_aspect("equal")
    for model_no in [1, 2, 3]:
        model_roc = get_roc_curve(data_means, exp_measures, subj_id, model_no=model_no)
        ax.plot(model_roc[0], model_roc[1], label="Model %i" % model_no)

    best_case_roc = get_roc_curve(data_means, exp_measures, subj_id, model_no="best_case")
    ax.plot(best_case_roc[0], best_case_roc[1], label="Best-case", ls="--", color="grey")

    ax.set_title("P%i" % (subj_idx+1))
    if subj_idx == 3:
        ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))

## Pseudo-R^2 of RT for individual participants

## ROC and R^2 for cross-validation on vincentized data

# Comparison with simpler models

### p_go and RT figures per participant for three models

In [None]:
for model_no in [1,2,3]:
    model_measures_all_conditions = pd.read_csv(os.path.join(fit_results_path % model_no, "full_data_measures.csv"))

    plot_compare_model_exp("is_go_decision", exp_measures, model_measures_all_conditions, "Probability of go")
    plt.savefig("../output/fig_p_go_model_%i.png" % model_no, bbox_inches="tight")

    plot_compare_model_exp("RT", exp_measures, model_measures_all_conditions, "Response time, s")
    plt.savefig("../output/fig_RT_model_%i.png" % model_no, bbox_inches="tight")

### Full RT distributions

In [None]:
fig, axes = plt.subplots(3, 3, figsize=(10,8), sharex=True, sharey=True)
for model_no in [1, 2, 3]:
    model_rts = pd.read_csv(os.path.join(fit_results_path % model_no, "full_data_rt_cdf.csv"))
    plot_vincentized_dist(fig, axes, exp_measures, model_rts, model_no)
    # fig.suptitle("Model %i" % model_no)
    plt.savefig("../output/fig_RT_dist_all_models.png", bbox_inches="tight")

- ROC and R^2 for individual participants for three models
- ROC and R^2 for cross-validation on vincentized data for three models

### Fitted parameter values
- swarmplot of parameter values per participant
- table with fitted parameters

### Robustness of fit
- swarmplot of repeated fits per two participants
