<a href="https://colab.research.google.com/github/gergogomori/brain_wide_pes/blob/main/pes_behavior.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install ONE-api


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import pickle
import uuid

from sklearn.preprocessing   import PowerTransformer, StandardScaler
from sklearn.linear_model    import LogisticRegression, LinearRegression
from sklearn.model_selection import RepeatedStratifiedKFold, cross_val_score
from scipy.stats             import linregress, permutation_test, ttest_1samp
from matplotlib.ticker       import MaxNLocator

from one.api import ONE

one  = ONE(base_url='https://openalyx.internationalbrainlab.org', password='international', silent=True)
eids = one.search(datasets='spikes.times.npy')

print('Number of sessions available:', len(eids))

rng = np.random.default_rng()

contrasts = np.array([0.0, 0.0625, 0.125, 0.25, 1.0])


In [None]:
import matplotlib as mpl

mpl.rcParams.update({
    'axes.labelsize'  : 15,
    'xtick.labelsize' : 14,
    'ytick.labelsize' : 14,
    'legend.fontsize' : 14})


# Parameters

In [None]:
min_RT     = 125.0  # ms
max_RT     = 500.0  # ms
n_end_disc = 100    # discarded trials from the end of the session
z_lim      = 2.0    # RT outlier threshold
alphas     = np.linspace(0.01, 0.99, 99) # range of learning rates

min_n_trials_behav = 10


# Setting the RT limits and extracting the priors

In [None]:
RT_lims_priors = {}

avg_RT_contr   = {contr : [] for contr in contrasts}
ex_plot_RTs    = {}

RT_pre_trans   = {}
RT_z_all       = {}

for i, EID in enumerate(eids):

    print('Done with', np.around(i / len(eids) * 100, 1), '% of the data.')

    try:
        trials = one.load_object(EID, 'trials')

        if (len(trials['rewardVolume']) - n_end_disc) >= 10:

            rts        = {contr : []   for contr in contrasts}
            priors     = {alpha : []   for alpha in alphas}
            curr_o_bar = {}

            rts_all    = {contr : []   for contr in contrasts}

            for trialID in range(len(trials['rewardVolume']) - n_end_disc):

                assert(np.logical_xor(np.isfinite(trials['contrastLeft'][trialID]), np.isfinite(trials['contrastRight'][trialID])))
                assert(np.isfinite(trials['choice'][trialID]))

                currRT    = (trials['firstMovement_times'][trialID] - trials['goCue_times'][trialID]) * 1000 # ms
                currRTVal = (trials['firstMovement_times'][trialID] > trials['goCue_times'][trialID]) and np.isfinite(currRT)
                currSimp  = np.nansum([trials['contrastLeft'][trialID], trials['contrastRight'][trialID]])

                if currRTVal:
                    rts_all[currSimp].append(currRT)

                # Extracting correct RT-s
                if currRT > min_RT and currRT < max_RT and currRTVal:
                    rts[currSimp].append(currRT)

                # Computing the priors
                for alpha in alphas:
                    if trialID == 0:
                        curr_o_bar[alpha] = 0.0
                        priors[alpha].append(0.5)
                    else:
                        prevChoice = trials['choice'][trialID - 1]

                        curr_o_bar[alpha] = curr_o_bar[alpha] + alpha * (1 - curr_o_bar[alpha])

                        if prevChoice == -1.0:
                            priors[alpha].append((1 - (alpha / curr_o_bar[alpha])) * priors[alpha][-1] + (alpha / curr_o_bar[alpha]) * 1.0)
                        elif prevChoice == 1.0:
                            priors[alpha].append((1 - (alpha / curr_o_bar[alpha])) * priors[alpha][-1] + (alpha / curr_o_bar[alpha]) * 0.0)
                        else:
                            priors[alpha].append(priors[alpha][-1])

            for alpha in alphas:
                assert(len(priors[alpha]) == (len(trials['rewardVolume']) - n_end_disc))

            ex_plot_RTs[EID]  = rts
            RT_pre_trans[EID] = rts_all

            # Computing the RT limits

            curr_lims     = {}
            RT_z_all[EID] = {}

            for contr in contrasts:

                if len(rts[contr]) > 1:

                    pt = PowerTransformer()

                    RT_z_all[EID][contr] = pt.fit_transform(np.array(rts[contr]).reshape(-1, 1))

                    curr_RT_limits = pt.inverse_transform(np.array([-z_lim, z_lim]).reshape(-1, 1)).flatten()

                    if curr_RT_limits[0] > 0 and np.isfinite(curr_RT_limits[0]):
                        curr_min_RT = np.max([curr_RT_limits[0], min_RT])
                    else:
                        curr_min_RT = min_RT

                    if curr_RT_limits[1] > 0 and np.isfinite(curr_RT_limits[1]):
                        curr_max_RT = np.min([curr_RT_limits[1], max_RT])
                    else:
                        curr_max_RT = max_RT

                    curr_lims[contr] = (curr_min_RT, curr_max_RT)


                    currAvgRT = pt.inverse_transform(np.array([0.0]).reshape(-1, 1)).flatten()

                    if currAvgRT > 0.0 and np.isfinite(currAvgRT):
                        avg_RT_contr[contr].append(currAvgRT)

            if sorted(list(curr_lims.keys())) == contrasts.tolist():

                # Collecting data for the choice model

                curr_data    = {alpha : [] for alpha in alphas}
                curr_choices = []

                for trialID in range(len(trials['rewardVolume']) - n_end_disc):

                    currRT     = (trials['firstMovement_times'][trialID] - trials['goCue_times'][trialID]) * 1000 # ms
                    currRTVal  = (trials['firstMovement_times'][trialID] > trials['goCue_times'][trialID]) and np.isfinite(currRT)
                    currSimp   = np.nansum([trials['contrastLeft'][trialID], trials['contrastRight'][trialID]])
                    currChoice = trials['choice'][trialID]

                    if (currChoice == -1.0 or currChoice == 1.0) and currRT > curr_lims[currSimp][0] and currRT < curr_lims[currSimp][1] and currRTVal:

                        # Extracting the signed contrasts
                        if trials['contrastLeft'][trialID] == 0.0 or trials['contrastRight'][trialID] == 0.0:
                            currSignContr = 0.0
                        else:
                            if np.isfinite(trials['contrastLeft'][trialID]):
                                currSignContr = -1 * trials['contrastLeft'][trialID]

                            elif np.isfinite(trials['contrastRight'][trialID]):
                                currSignContr = trials['contrastRight'][trialID]

                        # Extracting the priors
                        for alpha in alphas:
                            curr_data[alpha].append([currSignContr, priors[alpha][trialID] - 0.5])

                        # Extracting the choices
                        if currChoice == -1.0:
                            curr_choices.append(1.0)
                        elif currChoice == 1.0:
                            curr_choices.append(0.0)

                curr_choices = np.array(curr_choices)

                for alpha in alphas:
                    curr_data[alpha] = np.array(curr_data[alpha])
                    assert(len(curr_choices) == len(curr_data[alpha]))

                # Finding the right learning rate (alpha)
                if (np.sum(curr_choices == 0.0) >= 10 and np.sum(curr_choices == 1.0) >= 10):

                    alpha_results = {}
                    alpha_models  = {}

                    for alpha in alphas:

                        curr_model  = LogisticRegression(max_iter=1000, penalty=None, fit_intercept=False)
                        curr_cv     = RepeatedStratifiedKFold(n_splits=10, n_repeats=5)
                        curr_scores = cross_val_score(curr_model, X=curr_data[alpha], y=curr_choices, scoring='roc_auc', cv=curr_cv, n_jobs=-1)

                        alpha_results[alpha] = np.mean(curr_scores)
                        alpha_models[alpha]  = curr_model.fit(curr_data[alpha], curr_choices)

                    best_alpha = max(alpha_results, key=alpha_results.get)

                    RT_lims_priors[EID] = {}
                    RT_lims_priors[EID]['RT_lims']    = curr_lims
                    RT_lims_priors[EID]['alpha']      = best_alpha
                    RT_lims_priors[EID]['perf']       = alpha_results[best_alpha]
                    RT_lims_priors[EID]['priors']     = priors[best_alpha]
                    RT_lims_priors[EID]['choice_mod'] = alpha_models[best_alpha]

                    assert(len(priors[best_alpha]) == (len(trials['rewardVolume']) - n_end_disc))

    except Exception as e:

        print(f'Skipping {EID} due to: {e}')
        continue


In [None]:
ex_sess  = uuid.UUID('a7763417-e0d6-4f2a-aa55-e382fd9b5fb8')
ex_contr = 0.125

info_ex = one.get_details(ex_sess)

print('EID:', ex_sess)
print('Contrast:', ex_contr)
print('Lab:', info_ex['lab'])
print('Animal:', info_ex['subject'])
print('Starting time:', info_ex['start_time'])


fig, axs = plt.subplots(2, 2, figsize=(12, 10))


fig.delaxes(axs[0, 0])
fig.delaxes(axs[0, 1])


ax_top = fig.add_subplot(2, 1, 1)

ax_top.hist(RT_pre_trans[ex_sess][ex_contr], bins='auto', color='#5D7FA3', edgecolor='black', linewidth=1.0)
ax_top.yaxis.set_major_locator(MaxNLocator(integer=True))

y_max_top = ax_top.get_ylim()[1]

ax_top.vlines(min_RT, 0, y_max_top, linestyle='--', color='#FF80AB')
ax_top.vlines(max_RT, 0, y_max_top, linestyle='--', color='#FF80AB')
ax_top.set_xlabel('Reaction time (ms)')
ax_top.set_ylabel('Number of trials')

axs[1, 0].hist(RT_z_all[ex_sess][ex_contr], bins='auto', color='#5D7FA3', edgecolor='black', linewidth=1.0)
axs[1, 0].yaxis.set_major_locator(MaxNLocator(integer=True))

curr_y_max_bottom = axs[1, 0].get_ylim()[1]

axs[1, 0].vlines(-z_lim, 0, curr_y_max_bottom, linestyle='--', color='#FF80AB')
axs[1, 0].vlines(z_lim,  0, curr_y_max_bottom, linestyle='--', color='#FF80AB')
axs[1, 0].set_xlabel('Transformed values of the reaction time')
axs[1, 0].set_ylabel('Number of trials')


pt = PowerTransformer()
curr_rts = np.array(RT_pre_trans[ex_sess][ex_contr])
pt.fit(curr_rts[np.logical_and(curr_rts > min_RT, curr_rts < max_RT)].reshape(-1, 1))
typ_RT_ex = pt.inverse_transform(np.array([0.0]).reshape(-1, 1))


axs[1, 1].hist(ex_plot_RTs[ex_sess][ex_contr], bins='auto', color='#78B8E6', edgecolor='black', linewidth=1.0)
axs[1, 1].yaxis.set_major_locator(MaxNLocator(integer=True))
axs[1, 1].scatter([typ_RT_ex[0][0]], [axs[1, 1].get_ylim()[1] + 2], marker='d', s=70, edgecolor='black', linewidth=0.7, color='#FF80AB')
axs[1, 1].set_xlabel('Reaction time (ms)')
axs[1, 1].set_ylabel('Number of trials')

plt.savefig('', format='png', dpi=300, bbox_inches='tight')
plt.show()


# Reaction time, post-error, incoherent stimulus side

In [None]:
pes_RT_all   = {}
incoh_RT_all = {}
curr_eids    = sorted(list(RT_lims_priors.keys()))
rt_weights   = {'trial_id' : [],
                'contrast' : [],
                'surprise' : [],
                'prev_out' : []}

for ind_eid, EID in enumerate(curr_eids):

    pes_RT_all[EID]   = {contr : {'perr'  : [], 'pcorr' : []} for contr in [0.0625, 0.125, 0.25, 1.0]}
    incoh_RT_all[EID] = {contr : {'incoh' : [], 'coh'   : []} for contr in [0.0625, 0.125, 0.25, 1.0]}

    rt_weights[EID]   = {}

    curr_data   = []
    curr_y_vals = []

    trials = one.load_object(EID, 'trials')

    print('Done with', np.around(ind_eid / len(curr_eids) * 100, 1), '% of the data.')

    for trialID in range(len(trials['rewardVolume']) - n_end_disc):

        assert(np.logical_xor(np.isfinite(trials['contrastLeft'][trialID]), np.isfinite(trials['contrastRight'][trialID])))

        currRT    = (trials['firstMovement_times'][trialID] - trials['goCue_times'][trialID]) * 1000 # ms
        currRTVal = (trials['firstMovement_times'][trialID] > trials['goCue_times'][trialID]) and np.isfinite(currRT)
        currSimp  = np.nansum([trials['contrastLeft'][trialID], trials['contrastRight'][trialID]])
        currRes   = trials['feedbackType'][trialID]
        currPrior = RT_lims_priors[EID]['priors'][trialID]

        if (currSimp > 0.0 and
            currRes == 1.0 and
            currRT > RT_lims_priors[EID]['RT_lims'][currSimp][0] and
            currRT < RT_lims_priors[EID]['RT_lims'][currSimp][1] and
            currRTVal and
            trialID > 0 and
            ((prevRes == 1.0) or (prevRes == -1.0))):

            if prevRes == -1.0:
                pes_RT_all[EID][currSimp]['perr'].append(currRT)
                trans_prevRes = 1.0
            elif prevRes == 1.0:
                pes_RT_all[EID][currSimp]['pcorr'].append(currRT)
                trans_prevRes = 0.0

            if currPrior < 0.5:
                expLeft = True
            else:
                expLeft = False

            if np.isfinite(trials['contrastLeft'][trialID]):
                sideLeft = True
                actSide  = 0.0
            else:
                sideLeft = False
                actSide  = 1.0

            if (expLeft and sideLeft) or (not(expLeft) and not(sideLeft)):
                incoh_RT_all[EID][currSimp]['coh'].append(currRT)
            else:
                incoh_RT_all[EID][currSimp]['incoh'].append(currRT)

            curr_data.append([trialID, currSimp, np.abs(currPrior - actSide), trans_prevRes])
            curr_y_vals.append(currRT)

        prevRes = currRes


    curr_data = np.array(curr_data)

    curr_data[:, 0] = curr_data[:, 0] / (len(trials['rewardVolume']) - n_end_disc)

    assert(len(curr_data) == len(curr_y_vals))

    curr_model = LinearRegression(fit_intercept=True)
    curr_res   = curr_model.fit(curr_data, curr_y_vals)

    rt_weights['trial_id'].append(curr_res.coef_[0])
    rt_weights['contrast'].append(curr_res.coef_[1])
    rt_weights['surprise'].append(curr_res.coef_[2])
    rt_weights['prev_out'].append(curr_res.coef_[3])


In [None]:
_, ax = plt.subplots(2, 2, figsize=(12, 12))


sns.histplot(np.abs(rt_weights['trial_id']), ax=ax[0, 0], binwidth=6.0, binrange=(0.0, 120.0), color='#5D7FA3')

ax[0, 0].yaxis.set_major_locator(MaxNLocator(integer=True))
ax[0, 0].set_xlabel('Absolute value of \n"trial ID" coefficients')
ax[0, 0].set_ylabel('Number of sessions')
ax[0, 0].set_xlim(-5.0, 120.0)

sns.histplot(np.abs(rt_weights['contrast']), ax=ax[0, 1], binwidth=6.0, binrange=(0.0, 120.0), color='#5D7FA3')

ax[0, 1].yaxis.set_major_locator(MaxNLocator(integer=True))
ax[0, 1].set_xlabel('Absolute value of \n"contrast" coefficients')
ax[0, 1].set_ylabel('Number of sessions')
ax[0, 1].set_xlim(-5.0, 120.0)


sns.histplot(np.abs(rt_weights['surprise']), ax=ax[1, 0], binwidth=6.0, binrange=(0.0, 120.0), color='#5D7FA3')

ax[1, 0].yaxis.set_major_locator(MaxNLocator(integer=True))
ax[1, 0].set_xlabel('Absolute value of \n"surprise" coefficients')
ax[1, 0].set_ylabel('Number of sessions')
ax[1, 0].set_xlim(-5.0, 120.0)


sns.histplot(np.abs(rt_weights['prev_out']), ax=ax[1, 1], binwidth=6.0, binrange=(0.0, 120.0), color='#5D7FA3')

ax[1, 1].yaxis.set_major_locator(MaxNLocator(integer=True))
ax[1, 1].set_xlabel('Absolute value of \n"previous outcome" coefficients')
ax[1, 1].set_ylabel('Number of sessions')

ax[1, 1].set_xlim(-5.0, 120.0)


plt.savefig('*** Update the folder path here to match your directory structure ***', format='png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
delta_medians_pes   = {contr : [] for contr in [0.0625, 0.125, 0.25, 1.0]}
delta_medians_incoh = {contr : [] for contr in [0.0625, 0.125, 0.25, 1.0]}

top_sess_pes   = {}
top_sess_incoh = {}
top_sess_gen   = {}

for EID in curr_eids:

    delta_med_pes   = []
    delta_med_incoh = []
    med_rt_gen      = []

    for contr in [0.0625, 0.125, 0.25, 1.0]:

        if (len(pes_RT_all[EID][contr]['perr']) >= min_n_trials_behav) and (len(pes_RT_all[EID][contr]['pcorr']) >= min_n_trials_behav):

            curr_delta_med_pes = np.median(pes_RT_all[EID][contr]['perr']) - np.median(pes_RT_all[EID][contr]['pcorr'])

            delta_medians_pes[contr].append(curr_delta_med_pes)

            delta_med_pes.append(curr_delta_med_pes)

        if (len(incoh_RT_all[EID][contr]['incoh']) >= min_n_trials_behav) and (len(incoh_RT_all[EID][contr]['coh']) >= min_n_trials_behav):

            curr_delta_med_incoh = np.median(incoh_RT_all[EID][contr]['incoh']) - np.median(incoh_RT_all[EID][contr]['coh'])

            delta_medians_incoh[contr].append(curr_delta_med_incoh)

            delta_med_incoh.append(curr_delta_med_incoh)

        if len(ex_plot_RTs[EID][contr]) >= min_n_trials_behav:

            med_rt_gen.append(np.median(ex_plot_RTs[EID][contr]))

    if (len(delta_med_pes) == 4) and (np.all(np.array(delta_med_pes) > 0.0)):

        top_sess_pes[EID] = np.mean(delta_med_pes)

    if (len(delta_med_incoh) == 4) and (np.all(np.array(delta_med_incoh) > 0.0)):

        top_sess_incoh[EID] = np.mean(delta_med_incoh)

    if (len(med_rt_gen) == 4) and (np.all(np.diff(med_rt_gen) < 0.0)):

        top_sess_gen[EID] = np.mean(np.abs(np.diff(med_rt_gen)))


In [None]:
# Choosing which sessions to plot

chosen_sess_gen   = sorted(top_sess_gen,   key=top_sess_gen.get)[-100]
chosen_sess_pes   = sorted(top_sess_pes,   key=top_sess_pes.get)[-2]
chosen_sess_incoh = sorted(top_sess_incoh, key=top_sess_incoh.get)[-1]

print('Details of RT variability session:')

info_ex_gen = one.get_details(chosen_sess_gen)

print('EID:', chosen_sess_gen)
print('Lab:', info_ex_gen['lab'])
print('Animal:', info_ex_gen['subject'])
print('Starting time:', info_ex_gen['start_time'])

print()
print('Details of PES session:')

info_ex_pes = one.get_details(chosen_sess_pes)

print('EID:', chosen_sess_pes)
print('Lab:', info_ex_pes['lab'])
print('Animal:', info_ex_pes['subject'])
print('Starting time:', info_ex_pes['start_time'])

print()
print('Details of INCOH session:')

info_ex_incoh = one.get_details(chosen_sess_incoh)

print('EID:', chosen_sess_incoh)
print('Lab:', info_ex_incoh['lab'])
print('Animal:', info_ex_incoh['subject'])
print('Starting time:', info_ex_incoh['start_time'])


In [None]:
rt_ex_sess = pd.DataFrame([(contr, curr_RT)           for contr, curr_rts in ex_plot_RTs[chosen_sess_gen].items() if contr != 0.0 for curr_RT in curr_rts], columns=['Contrast', 'RT'])
rt_typ_all = pd.DataFrame([(contr, float(typ_RT[0]))  for contr, typ_rts  in avg_RT_contr.items()                 if contr != 0.0 for typ_RT  in typ_rts],  columns=['Contrast', 'typ_RT'])

fig, axs = plt.subplots(1, 2, figsize=(14, 6))

sns.violinplot(x='Contrast', y='RT', data=rt_ex_sess, orient='v', ax=axs[0], color='#A0D8F1', zorder=2, alpha=0.8)

slope_ex, intercept_ex, r_ex, p_val_ex, _ = linregress(np.log2(rt_ex_sess['Contrast']), rt_ex_sess['RT'])

axs[0].plot(np.array([0, 1, 2, 2.5]), slope_ex * np.log2([0.0625, 0.125, 0.25, np.power(2, -1.5)]) + intercept_ex, color='k', linestyle='--', linewidth=0.8, zorder=1)
axs[0].plot(np.array([2.5, 3.0]),     slope_ex * np.log2([np.power(2, -0.5), 1.0])                 + intercept_ex, color='k', linestyle='--', linewidth=0.8, zorder=1)

if p_val_ex < 0.001:
    textstr_ex = f'$r = {np.round(r_ex, 3)}$\n$p < 0.001$'
    axs[0].text(0.95, 0.95, textstr_ex,
                    transform=axs[0].transAxes,
                    fontsize=14,
                    verticalalignment='top',
                    horizontalalignment='right',
                    multialignment='left',
                    bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

axs[0].set_xlabel('Stimulus contrast')
axs[0].set_ylabel('Reaction time (ms)')
axs[0].grid(color='silver', axis='y', linestyle='--', zorder=0)


sns.violinplot(x='Contrast', y='typ_RT', data=rt_typ_all, orient='v', ax=axs[1], color='#FFB84D', zorder=2, alpha=0.8)

slope_gen, intercept_gen, r_gen, p_val_gen, _ = linregress(np.log2(rt_typ_all['Contrast']), rt_typ_all['typ_RT'])

axs[1].plot(np.array([0, 1, 2, 2.5]), slope_gen * np.log2([0.0625, 0.125, 0.25, np.power(2, -1.5)]) + intercept_gen, color='k', linestyle='--', linewidth=0.8, zorder=1)
axs[1].plot(np.array([2.5, 3.0]),     slope_gen * np.log2([np.power(2, -0.5), 1.0])                 + intercept_gen, color='k', linestyle='--', linewidth=0.8, zorder=1)


if p_val_gen < 0.001:
    textstr_gen = f'$r = {np.round(r_gen, 3)}$\n$p  < 0.001$'
    axs[1].text(0.95, 0.95, textstr_gen,
                    transform=axs[1].transAxes,
                    fontsize=14,
                    verticalalignment='top',
                    horizontalalignment='right',
                    multialignment='left',
                    bbox=dict(boxstyle='round', facecolor='white', alpha=0.9))

axs[1].set_xlabel('Stimulus contrast')
axs[1].set_ylabel('Typical reaction time (ms)')
axs[1].grid(color='silver', axis='y', linestyle='--', zorder=0)

plt.savefig('*** Update the folder path here to match your directory structure ***', format='png', bbox_inches='tight', dpi=300)
plt.show()


In [None]:
def perm_stat(x, y, axis):
    return np.mean(x, axis=axis) - np.mean(y, axis=axis)


In [None]:
print('0.0625 larger than 0.125', permutation_test((avg_RT_contr[0.0625], avg_RT_contr[0.125]), perm_stat, vectorized=True, alternative='greater', n_resamples=1e6).pvalue)
print('0.0625 larger than 0.25',  permutation_test((avg_RT_contr[0.0625], avg_RT_contr[0.25]),  perm_stat, vectorized=True, alternative='greater', n_resamples=1e6).pvalue)
print('0.0625 larger than 1.0',   permutation_test((avg_RT_contr[0.0625], avg_RT_contr[1.0]),   perm_stat, vectorized=True, alternative='greater', n_resamples=1e6).pvalue)
print('0.125 larger than 0.25',   permutation_test((avg_RT_contr[0.125], avg_RT_contr[0.25]),   perm_stat, vectorized=True, alternative='greater', n_resamples=1e6).pvalue)
print('0.125 larger than 1.0',    permutation_test((avg_RT_contr[0.125], avg_RT_contr[1.0]),    perm_stat, vectorized=True, alternative='greater', n_resamples=1e6).pvalue)
print('0.25 larger than 1.0',     permutation_test((avg_RT_contr[0.25], avg_RT_contr[1.0]),     perm_stat, vectorized=True, alternative='greater', n_resamples=1e6).pvalue)


In [None]:
pes_ex_sess = pd.DataFrame([
        (contr, curr_case == 'perr', curr_RT)
        for contr, inner_dict   in pes_RT_all[chosen_sess_pes].items()
        for curr_case, curr_rts in inner_dict.items()
        for curr_RT             in curr_rts], columns=['Contrast', 'Previous outcome', 'RT'])

pes_ex_sess['Previous outcome'] = pes_ex_sess['Previous outcome'].map({True: 'Error', False: 'Correct'})

pes_overview_pd = pd.DataFrame([(contr, delta_med) for contr, delta_meds in delta_medians_pes.items() for delta_med in delta_meds], columns=['Contrast', 'Delta median'])

fig, axs = plt.subplots(1, 2, figsize=(14, 6))

sns.violinplot(x='Contrast', y='RT', hue='Previous outcome', data=pes_ex_sess, split=True, hue_order=['Correct', 'Error'], ax=axs[0], zorder=2, inner='quart', gap=0.1, palette={'Correct' : '#90BE6D', 'Error' : '#FF6F61'})
axs[0].set_xlabel('Stimulus contrast')
axs[0].set_ylabel('Reaction time (ms)')
axs[0].legend(title='Previous outcome', title_fontsize=14)
axs[0].set_ylim(50, 550)
axs[0].grid(color='silver', axis='y', linestyle='--', zorder=0)


sns.violinplot(x='Contrast', y='Delta median', data=pes_overview_pd, orient='v', ax=axs[1], color='#FFB84D', zorder=2)

axs[1].set_ylim(-75, 75)
axs[1].grid(color='silver', axis='y', linestyle='--', zorder=0)

for line in axs[1].get_ygridlines():

    if line.get_ydata()[0] == 100.0 and len(line.get_ydata()) == 1:
        line.set_color('tomato')
        line.set_linewidth(2.0)
        line.set_linestyle('--')

axs[1].set_xlabel('Stimulus contrast')
axs[1].set_ylabel('Change in reaction time (ms)')

plt.savefig('*** Update the folder path here to match your directory structure ***', format='png', dpi=300, bbox_inches='tight')
plt.show()


In [None]:
print('Mean, std, t-test for 0.0625:', np.mean(delta_medians_pes[0.0625]), np.std(delta_medians_pes[0.0625]), ttest_1samp(delta_medians_pes[0.0625], popmean=0.0, alternative='greater').pvalue)
print('Mean, std, t-test for 0.125:',  np.mean(delta_medians_pes[0.125]),  np.std(delta_medians_pes[0.125]),  ttest_1samp(delta_medians_pes[0.125],  popmean=0.0, alternative='greater').pvalue)
print('Mean, std, t-test for 0.25:',   np.mean(delta_medians_pes[0.25]),   np.std(delta_medians_pes[0.25]),   ttest_1samp(delta_medians_pes[0.25],   popmean=0.0, alternative='greater').pvalue)
print('Mean, std, t-test for 1.0:',    np.mean(delta_medians_pes[1.0]),    np.std(delta_medians_pes[1.0]),    ttest_1samp(delta_medians_pes[1.0],    popmean=0.0, alternative='greater').pvalue)


In [None]:
incoh_ex_sess = pd.DataFrame([
        (contr, curr_case == 'incoh', curr_RT)
        for contr, inner_dict   in incoh_RT_all[chosen_sess_incoh].items()
        for curr_case, curr_rts in inner_dict.items()
        for curr_RT             in curr_rts], columns=['Contrast', 'Incoherent side?', 'RT'])

incoh_ex_sess['Incoherent side?'] = incoh_ex_sess['Incoherent side?'].map({True: 'Yes', False: 'No'})

incoh_overview_pd = pd.DataFrame([(contr, delta_med) for contr, delta_meds in delta_medians_incoh.items() for delta_med in delta_meds], columns=['Contrast', 'Delta median'])

fig, axs = plt.subplots(1, 2, figsize=(14, 6))

sns.violinplot(x='Contrast', y='RT', hue='Incoherent side?', data=incoh_ex_sess, split=True, hue_order=['No', 'Yes'], ax=axs[0], zorder=2, inner='quart', gap=0.1, palette={'Yes' : '#A1C9F4', 'No' : '#C6A1F4'})
axs[0].set_xlabel('Stimulus contrast')
axs[0].set_ylabel('Reaction time (ms)')
axs[0].legend(title='Incoherent?', title_fontsize=14)
axs[0].set_ylim(50, 550)
axs[0].grid(color='silver', axis='y', linestyle='--', zorder=0)


sns.violinplot(x='Contrast', y='Delta median', data=incoh_overview_pd, orient='v', ax=axs[1], color='#D2B48C', zorder=2)

axs[1].set_ylim(-75, 75)
axs[1].grid(color='silver', axis='y', linestyle='--', zorder=0)

for line in axs[1].get_ygridlines():

    if line.get_ydata()[0] == -50.0 and len(line.get_ydata()) == 1:
        line.set_color('tomato')
        line.set_linewidth(2.0)
        line.set_linestyle('--')

axs[1].set_xlabel('Stimulus contrast')
axs[1].set_ylabel('Change in reaction time (ms)')

plt.savefig('*** Update the folder path here to match your directory structure ***', format='png', dpi=300, bbox_inches='tight')
plt.show()


# Extracting trial information

In [None]:
trials_all = {}
rts_all    = {}

for EID in curr_eids:

    trials = one.load_object(EID, 'trials')

    curr_trials_all = {}

    n_left_pc  = {contr : 0 for contr in [0.0625, 0.125, 0.25, 1.0]}
    n_right_pc = {contr : 0 for contr in [0.0625, 0.125, 0.25, 1.0]}
    n_left_pe  = {contr : 0 for contr in [0.0625, 0.125, 0.25, 1.0]}
    n_right_pe = {contr : 0 for contr in [0.0625, 0.125, 0.25, 1.0]}

    for trialID in range(len(trials['rewardVolume']) - n_end_disc):

        assert(np.logical_xor(np.isfinite(trials['contrastLeft'][trialID]), np.isfinite(trials['contrastRight'][trialID])))

        currRT    = (trials['firstMovement_times'][trialID] - trials['goCue_times'][trialID]) * 1000 # ms
        currRTVal = (trials['firstMovement_times'][trialID] > trials['goCue_times'][trialID]) and np.isfinite(currRT)
        currSimp  = np.nansum([trials['contrastLeft'][trialID], trials['contrastRight'][trialID]])
        currRes   = trials['feedbackType'][trialID]
        currPrior = RT_lims_priors[EID]['priors'][trialID]

        if (trialID > 0 and
            currSimp > 0.0 and
            currRes == 1.0 and
            currRT > RT_lims_priors[EID]['RT_lims'][currSimp][0] and
            currRT < RT_lims_priors[EID]['RT_lims'][currSimp][1] and
            currRTVal):

            if (EID, currSimp) not in rts_all.keys():
                rts_all[(EID, currSimp)] = {'perr' : [], 'pcorr' : []}

            if np.isfinite(trials['contrastLeft'][trialID]):
                currChoice = -1
            else:
                currChoice = 1

            if prevRes == -1.0:
                prevOutcome = 1
                rts_all[(EID, currSimp)]['perr'].append(currRT)
            else:
                prevOutcome = -1
                rts_all[(EID, currSimp)]['pcorr'].append(currRT)

            if currChoice == -1 and prevOutcome == -1:
                n_left_pc[currSimp] += 1
            elif currChoice == 1 and prevOutcome == -1:
                n_right_pc[currSimp] += 1
            elif currChoice == -1 and prevOutcome == 1:
                n_left_pe[currSimp] += 1
            elif currChoice == 1 and prevOutcome == 1:
                n_right_pe[currSimp] += 1

            if currSimp not in curr_trials_all.keys():
                curr_trials_all[currSimp] = {}

            curr_trials_all[currSimp][trialID] = [currRT, currChoice, currPrior, prevOutcome]

        prevRes = currRes

    for contr in sorted(list(curr_trials_all.keys())):

        if (n_left_pc[contr]  >= min_n_trials_behav and
            n_right_pc[contr] >= min_n_trials_behav and
            n_left_pe[contr]  >= min_n_trials_behav and
            n_right_pe[contr] >= min_n_trials_behav):

            if EID not in trials_all.keys():
                trials_all[EID] = {}

            trials_all[EID][contr] = curr_trials_all[contr]

with open('*** Update the folder path here to match your directory structure ***', 'wb') as f:
    pickle.dump(trials_all, f, pickle.HIGHEST_PROTOCOL)


In [None]:
pes_rt_perc = 55
n_pes_sess  = {}

for curr_sess_contr in sorted(list(rts_all.keys())):

    if len(rts_all[curr_sess_contr]['pcorr']) > 1:

        curr_pes_thres = np.percentile(rts_all[curr_sess_contr]['pcorr'], pes_rt_perc)

        curr_n_pes = 0

        for RT_pes in rts_all[curr_sess_contr]['perr']:

            if RT_pes >= curr_pes_thres:
                curr_n_pes += 1

        n_pes_sess[curr_sess_contr] = curr_n_pes


In [None]:
chosen_pair_npes = sorted(n_pes_sess, key=n_pes_sess.get)[-71]

info_ex_pestrials = one.get_details(chosen_pair_npes[0])

print('EID:', chosen_pair_npes[0])
print('Contrast:', chosen_pair_npes[1])
print('Lab:', info_ex_pestrials['lab'])
print('Animal:', info_ex_pestrials['subject'])
print('Starting time:', info_ex_pestrials['start_time'])


_, ax = plt.subplots(1, 2, figsize=(18, 6))

sns.histplot(rts_all[chosen_pair_npes]['pcorr'], ax=ax[0], bins='doane', color='#90BE6D', label='Post-correct trials')

ax[0].yaxis.set_major_locator(MaxNLocator(integer=True))

curr_rt_hist_lim = ax[0].get_ylim()[1]

curr_pes_trials  = []
curr_npes_trials = []

for RT_pes in rts_all[chosen_pair_npes]['perr']:

    if RT_pes >= np.percentile(rts_all[chosen_pair_npes]['pcorr'], pes_rt_perc):
        curr_pes_trials.append(RT_pes)
    else:
        curr_npes_trials.append(RT_pes)

ax[0].scatter(curr_pes_trials,  np.full(len(curr_pes_trials),  curr_rt_hist_lim), marker='v', s=100, edgecolors='black', color='#FF6F61', label='Slow post-error trials')
ax[0].scatter(curr_npes_trials, np.full(len(curr_npes_trials), curr_rt_hist_lim), marker='v', s=100, edgecolors='black', color='#FFB84D', label='Non-slow post-error trials')

ax[0].legend()
ax[0].set_xlabel('Reaction time (ms)')
ax[0].set_ylabel('Number of trials')


sns.histplot(list(n_pes_sess.values()), ax=ax[1], bins='rice', color='#5D7FA3')
ax[1].set_xlabel('Number of slow post-error trials')
ax[1].set_ylabel('Number of session and contrast pairs')


plt.savefig('*** Update the folder path here to match your directory structure ***', format='png', dpi=300, bbox_inches='tight')
plt.show()
