<a href="https://colab.research.google.com/github/gergogomori/brain_wide_pes/blob/main/pes_shuffling.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install ONE-api


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pickle
import uuid

import pandas  as pd
import seaborn as sns

from scipy.stats import permutation_test, binomtest

from scipy.ndimage import gaussian_filter1d


In [None]:
import matplotlib as mpl

mpl.rcParams.update({
    'axes.labelsize'  : 15,
    'xtick.labelsize' : 14,
    'ytick.labelsize' : 14,
    'legend.fontsize' : 14})


# Loading the data

In [None]:
with open('*** Update the folder path here to match your directory structure ***', 'rb') as f:
    matched_trials = pickle.load(f)

with open('*** Update the folder path here to match your directory structure ***', 'rb') as f:
    spikes_all = pickle.load(f)

with open('*** Update the folder path here to match your directory structure ***', 'rb') as f:
    neuron_region = pickle.load(f)


In [None]:
matched_trials_mod = {}

for curr_eid_contr_side in sorted(list(matched_trials.keys())):

    matched_trials_mod[curr_eid_contr_side] = {}

    for curr_repr in sorted(list(matched_trials[curr_eid_contr_side].keys())):

        if curr_repr == 'error':
            matched_trials_mod[curr_eid_contr_side]['slowing'] = matched_trials[curr_eid_contr_side]['error']

        elif curr_repr == 'choice':
            matched_trials_mod[curr_eid_contr_side]['error'] = matched_trials[curr_eid_contr_side]['choice']


# Parameters

In [None]:
win_len   = 40 # ms
win_shift = 20 # ms
t_lims    = (-300, 300) # ms
sign_lvl  = 0.05


In [None]:
t_mids_pre = np.arange(t_lims[0] + (win_len / 2), 0, win_shift).astype(float)
t_wins_pre = [[float(midp - win_len / 2), float(midp + win_len / 2)] for midp in t_mids_pre]

t_mids_post = np.arange((win_len / 2), t_lims[1], win_shift).astype(float)
t_wins_post = [[float(midp - win_len / 2), float(midp + win_len / 2)] for midp in t_mids_post]

t_mids = np.concatenate((t_mids_pre, t_mids_post))
t_wins = t_wins_pre + t_wins_post

print('The middle of the time bins:')
print(t_mids, 'ms')
print('The time windows:')
print(t_wins, 'ms')


In [None]:
time_plot     = np.concatenate((t_mids_pre[1 : : 2], t_mids_post[ : : 2])).astype(int)
ind_time_plot = [i for i, t in enumerate(t_mids)  if t in time_plot]


# Shuffling

In [None]:
def perm_stat(x, y, axis):
    return np.mean(x, axis=axis) - np.mean(y, axis=axis)


In [None]:
pvals_all = {'slowing' : {}, 'error' : {}}

for ind_curr_eid_contr_side, curr_eid_contr_side in enumerate(sorted(list(matched_trials_mod.keys()))):

    print('Done with', np.around(ind_curr_eid_contr_side / len(matched_trials_mod.keys()) * 100, 1), '% of the data.')

    for curr_repr in sorted(list(matched_trials_mod[curr_eid_contr_side].keys())):

        pvals_all[curr_repr][curr_eid_contr_side] = {}

        for neuron in sorted(list(spikes_all[(curr_eid_contr_side[0], curr_eid_contr_side[1])])):

            curr_sp_pes  = []
            curr_sp_npes = []

            for trial in matched_trials_mod[curr_eid_contr_side][curr_repr]['pes']:

                trial_sp_counts = []
                curr_spiketrain = spikes_all[(curr_eid_contr_side[0], curr_eid_contr_side[1])][neuron][trial]

                for win in t_wins:
                    trial_sp_counts.append(len(curr_spiketrain[np.logical_and(curr_spiketrain > win[0], curr_spiketrain < win[1])]))
                curr_sp_pes.append(trial_sp_counts)

            for trial in matched_trials_mod[curr_eid_contr_side][curr_repr]['npes']:

                trial_sp_counts = []
                curr_spiketrain = spikes_all[(curr_eid_contr_side[0], curr_eid_contr_side[1])][neuron][trial]

                for win in t_wins:
                    trial_sp_counts.append(len(curr_spiketrain[np.logical_and(curr_spiketrain > win[0], curr_spiketrain < win[1])]))
                curr_sp_npes.append(trial_sp_counts)

            curr_sp_pes  = np.array(curr_sp_pes)
            curr_sp_npes = np.array(curr_sp_npes)

            assert(curr_sp_pes.shape[1]  == len(t_mids))
            assert(curr_sp_npes.shape[1] == len(t_mids))

            if np.all((np.sum(curr_sp_pes, axis=0) + np.sum(curr_sp_npes, axis=0)) > 0):

                pvals_all[curr_repr][curr_eid_contr_side][neuron] = permutation_test((curr_sp_pes, curr_sp_npes), perm_stat, axis=0, vectorized=True, alternative='two-sided', n_resamples=1e5).pvalue

                assert(len(pvals_all[curr_repr][curr_eid_contr_side][neuron]) == len(t_mids))

with open('*** Update the folder path here to match your directory structure ***', 'wb') as f:
    pickle.dump(pvals_all, f, pickle.HIGHEST_PROTOCOL)


# Selecting one neuron for each trial condition

In [None]:
sess_cond = {'error' : {}, 'both' : {}}

all_sess_contr_side = sorted(list(set(pvals_all['slowing'].keys()) | set(pvals_all['error'].keys())))

for curr_eid_contr_side in all_sess_contr_side:

    if curr_eid_contr_side in pvals_all['slowing'].keys():

        assert(curr_eid_contr_side in pvals_all['error'].keys())

    if (curr_eid_contr_side in pvals_all['slowing'].keys()) and (curr_eid_contr_side in pvals_all['error'].keys()):

        if curr_eid_contr_side[0] not in sess_cond['both'].keys():
            sess_cond['both'][curr_eid_contr_side[0]] = [curr_eid_contr_side]
        else:
            sess_cond['both'][curr_eid_contr_side[0]].append(curr_eid_contr_side)


for curr_eid_contr_side in all_sess_contr_side:

    if curr_eid_contr_side[0] not in sess_cond['both'].keys():

        assert(not((curr_eid_contr_side in pvals_all['slowing'].keys()) and (curr_eid_contr_side in pvals_all['error'].keys())))

        if curr_eid_contr_side in pvals_all['error'].keys():

            assert(curr_eid_contr_side not in pvals_all['slowing'].keys())

            if curr_eid_contr_side[0] not in sess_cond['error'].keys():
                sess_cond['error'][curr_eid_contr_side[0]] = [curr_eid_contr_side]
            else:
                sess_cond['error'][curr_eid_contr_side[0]].append(curr_eid_contr_side)


In [None]:
pvals_neurons = {'slowing' : {}, 'error' : {}}

neurons2sess_contr_side = {'slowing' : {}, 'error' : {}}

for EID in sess_cond['error'].keys():

    if len(sess_cond['error'][EID]) == 1:

        for neuron in pvals_all['error'][sess_cond['error'][EID][0]].keys():
            pvals_neurons['error'][neuron] = pvals_all['error'][sess_cond['error'][EID][0]][neuron]

            neurons2sess_contr_side['error'][neuron] = sess_cond['error'][EID][0]

    else:

        all_sess_res = {}

        for curr_eid_contr_side in sess_cond['error'][EID]:

            for neuron in sorted(list(pvals_all['error'][curr_eid_contr_side].keys())):

                if neuron not in all_sess_res.keys():
                    all_sess_res[neuron] = {}

                all_sess_res[neuron][curr_eid_contr_side] = pvals_all['error'][curr_eid_contr_side][neuron]

        for neuron in sorted(list(all_sess_res.keys())):

            curr_neuron_n_sign = {}

            for curr_eid_contr_side in sorted(list(all_sess_res[neuron].keys())):
                curr_neuron_n_sign[curr_eid_contr_side] = np.sum(all_sess_res[neuron][curr_eid_contr_side] < sign_lvl)

            max_n_sign  = max(curr_neuron_n_sign.values())

            chosen_sess = [k for k, v in curr_neuron_n_sign.items() if v == max_n_sign]

            if len(chosen_sess) == 1:
                pvals_neurons['error'][neuron] = pvals_all['error'][chosen_sess[0]][neuron]

                assert(neuron not in neurons2sess_contr_side['error'].keys())

                neurons2sess_contr_side['error'][neuron] = chosen_sess[0]

            else:

                curr_neuron_avg_pval = {}

                for curr_eid_contr_side in chosen_sess:

                    curr_neuron_avg_pval[curr_eid_contr_side] = np.mean(all_sess_res[neuron][curr_eid_contr_side])

                sess_smallest_avg_pval = min(curr_neuron_avg_pval, key=curr_neuron_avg_pval.get)

                pvals_neurons['error'][neuron] = pvals_all['error'][sess_smallest_avg_pval][neuron]

                assert(neuron not in neurons2sess_contr_side['error'].keys())

                neurons2sess_contr_side['error'][neuron] = sess_smallest_avg_pval


In [None]:
for EID in sess_cond['both'].keys():

    if len(sess_cond['both'][EID]) == 1:

        curr_neurons_slow  = sorted(list(pvals_all['slowing'][sess_cond['both'][EID][0]].keys()))
        curr_neurons_error = sorted(list(pvals_all['error'][sess_cond['both'][EID][0]].keys()))

        for neuron in curr_neurons_slow:
            pvals_neurons['slowing'][neuron] = pvals_all['slowing'][sess_cond['both'][EID][0]][neuron]

            neurons2sess_contr_side['slowing'][neuron] = sess_cond['both'][EID][0]

        for neuron in curr_neurons_error:
            pvals_neurons['error'][neuron] = pvals_all['error'][sess_cond['both'][EID][0]][neuron]

            assert(neuron not in neurons2sess_contr_side['error'].keys())

            neurons2sess_contr_side['error'][neuron] = sess_cond['both'][EID][0]


    else:

        curr_neurons_all = set()

        for curr_eid_contr_side in sess_cond['both'][EID]:
            curr_neurons_all |= set(pvals_all['slowing'][curr_eid_contr_side].keys())
            curr_neurons_all |= set(pvals_all['error'][curr_eid_contr_side].keys())

        curr_neurons_res = {}

        for neuron in sorted(list(curr_neurons_all)):

            curr_neurons_res[neuron] = {'slowing' : {}, 'error' : {}}

            for curr_eid_contr_side in sess_cond['both'][EID]:

                if neuron in pvals_all['slowing'][curr_eid_contr_side].keys():
                    curr_neurons_res[neuron]['slowing'][curr_eid_contr_side] = pvals_all['slowing'][curr_eid_contr_side][neuron]

                if neuron in pvals_all['error'][curr_eid_contr_side].keys():
                    curr_neurons_res[neuron]['error'][curr_eid_contr_side] = pvals_all['error'][curr_eid_contr_side][neuron]

            if len(curr_neurons_res[neuron]['slowing'].keys()) == 0:

                assert(len(curr_neurons_res[neuron]['error'].keys()) == 1)

                curr_neurons_res[neuron].pop('slowing')

            elif len(curr_neurons_res[neuron]['error'].keys()) == 0:

                assert(len(curr_neurons_res[neuron]['slowing'].keys()) == 1)

                curr_neurons_res[neuron].pop('error')

            elif (sorted(list(curr_neurons_res[neuron]['slowing'].keys())) != sorted(list(curr_neurons_res[neuron]['error'].keys()))):

                not_selected_sess = sorted(list(set(curr_neurons_res[neuron]['slowing'].keys()).symmetric_difference(set(curr_neurons_res[neuron]['error'].keys()))))

                if len(not_selected_sess) == 1:

                    assert(((len(curr_neurons_res[neuron]['slowing'].keys()) == 2) and (len(curr_neurons_res[neuron]['error'].keys()) == 1)) or
                           ((len(curr_neurons_res[neuron]['slowing'].keys()) == 1) and (len(curr_neurons_res[neuron]['error'].keys()) == 2)))

                    if not_selected_sess[0] in curr_neurons_res[neuron]['slowing'].keys():
                        curr_neurons_res[neuron]['slowing'].pop(not_selected_sess[0])

                    if not_selected_sess[0] in curr_neurons_res[neuron]['error'].keys():
                        curr_neurons_res[neuron]['error'].pop(not_selected_sess[0])

                    assert(len(curr_neurons_res[neuron]['slowing'].keys()) == 1)
                    assert(len(curr_neurons_res[neuron]['error'].keys())   == 1)

                else:

                    assert(len(not_selected_sess) == 2)

                    assert(len(curr_neurons_res[neuron]['slowing'].keys()) == 1)
                    assert(len(curr_neurons_res[neuron]['error'].keys())   == 1)

                    curr_slowing_sess = list(curr_neurons_res[neuron]['slowing'].keys())[0]
                    curr_error_sess   = list(curr_neurons_res[neuron]['error'].keys())[0]

                    assert(not(curr_slowing_sess == curr_error_sess))

                    curr_n_sign_win_slowing = np.sum(curr_neurons_res[neuron]['slowing'][curr_slowing_sess] < sign_lvl)
                    curr_n_sign_win_error   = np.sum(curr_neurons_res[neuron]['error'][curr_error_sess]     < sign_lvl)

                    if (curr_n_sign_win_slowing > 0) or (curr_n_sign_win_error > 0):

                        if (curr_n_sign_win_slowing > 0):

                            assert(not(np.any(curr_neurons_res[neuron]['error'][curr_error_sess] < sign_lvl)))

                            curr_neurons_res[neuron].pop('error')

                        elif (curr_n_sign_win_error > 0):

                            assert(not(np.any(curr_neurons_res[neuron]['slowing'][curr_slowing_sess] < sign_lvl)))

                            curr_neurons_res[neuron].pop('slowing')

                    else:

                        avg_slowing_pvals = np.mean(curr_neurons_res[neuron]['slowing'][curr_slowing_sess])
                        avg_error_pvals   = np.mean(curr_neurons_res[neuron]['error'][curr_error_sess])

                        if avg_slowing_pvals < avg_error_pvals:
                            curr_neurons_res[neuron].pop('error')
                        else:
                            curr_neurons_res[neuron].pop('slowing')

            elif (len(curr_neurons_res[neuron]['slowing'].keys()) == 2) and (len(curr_neurons_res[neuron]['error'].keys()) == 2):

                n_sign_win = {}

                for curr_sess_contr_side in sess_cond['both'][EID]:

                    n_sign_win[curr_sess_contr_side] = np.sum(curr_neurons_res[neuron]['slowing'][curr_sess_contr_side] < sign_lvl) + np.sum(curr_neurons_res[neuron]['error'][curr_sess_contr_side] < sign_lvl)

                assert(len(n_sign_win.keys()) == 2)

                if n_sign_win[sess_cond['both'][EID][0]] > n_sign_win[sess_cond['both'][EID][1]]:

                    curr_neurons_res[neuron]['slowing'].pop(sess_cond['both'][EID][1])
                    curr_neurons_res[neuron]['error'].pop(sess_cond['both'][EID][1])

                    assert(len(curr_neurons_res[neuron]['slowing'].keys()) == 1)
                    assert(len(curr_neurons_res[neuron]['error'].keys())   == 1)

                elif n_sign_win[sess_cond['both'][EID][0]] < n_sign_win[sess_cond['both'][EID][1]]:

                    curr_neurons_res[neuron]['slowing'].pop(sess_cond['both'][EID][0])
                    curr_neurons_res[neuron]['error'].pop(sess_cond['both'][EID][0])

                    assert(len(curr_neurons_res[neuron]['slowing'].keys()) == 1)
                    assert(len(curr_neurons_res[neuron]['error'].keys())   == 1)

                else:

                    assert(n_sign_win[sess_cond['both'][EID][0]] == n_sign_win[sess_cond['both'][EID][1]])

                    avg_pval = {}

                    for curr_sess_contr_side in sess_cond['both'][EID]:
                        avg_pval[curr_sess_contr_side] = np.mean(curr_neurons_res[neuron]['slowing'][curr_sess_contr_side]) + np.mean(curr_neurons_res[neuron]['error'][curr_sess_contr_side])

                    assert(len(avg_pval.keys()) == 2)

                    if avg_pval[sess_cond['both'][EID][0]] > avg_pval[sess_cond['both'][EID][1]]:

                        curr_neurons_res[neuron]['slowing'].pop(sess_cond['both'][EID][0])
                        curr_neurons_res[neuron]['error'].pop(sess_cond['both'][EID][0])

                        assert(len(curr_neurons_res[neuron]['slowing'].keys()) == 1)
                        assert(len(curr_neurons_res[neuron]['error'].keys())   == 1)

                    elif avg_pval[sess_cond['both'][EID][0]] < avg_pval[sess_cond['both'][EID][1]]:

                        curr_neurons_res[neuron]['slowing'].pop(sess_cond['both'][EID][1])
                        curr_neurons_res[neuron]['error'].pop(sess_cond['both'][EID][1])

                        assert(len(curr_neurons_res[neuron]['slowing'].keys()) == 1)
                        assert(len(curr_neurons_res[neuron]['error'].keys())   == 1)

        for neuron in sorted(list(curr_neurons_res.keys())):

            if 'slowing' in curr_neurons_res[neuron].keys():

                sess_contr_side_all = list(curr_neurons_res[neuron]['slowing'].keys())

                assert(len(sess_contr_side_all) == 1)

                pvals_neurons['slowing'][neuron] = curr_neurons_res[neuron]['slowing'][sess_contr_side_all[0]]

                assert(neuron not in neurons2sess_contr_side['slowing'].keys())

                neurons2sess_contr_side['slowing'][neuron] = sess_contr_side_all[0]


            if 'error' in curr_neurons_res[neuron].keys():

                sess_contr_side_all = list(curr_neurons_res[neuron]['error'].keys())

                assert(len(sess_contr_side_all) == 1)

                pvals_neurons['error'][neuron] = curr_neurons_res[neuron]['error'][sess_contr_side_all[0]]

                assert(neuron not in neurons2sess_contr_side['error'].keys())

                neurons2sess_contr_side['error'][neuron] = sess_contr_side_all[0]


In [None]:
for neuron in sorted(list(neurons2sess_contr_side['error'].keys())):

    assert(len(neurons2sess_contr_side['error'][neuron]) == 3)


for neuron in sorted(list(neurons2sess_contr_side['slowing'].keys())):

    assert(len(neurons2sess_contr_side['slowing'][neuron]) == 3)


# Prominent regions

In [None]:
regions_temp = {'slowing' : {}, 'error' : {}, 'both' : {}}
global_temp  = {'slowing' : {'sign' : np.full(len(t_mids), 0), 'all' : np.full(len(t_mids), 0)},
                'error'   : {'sign' : np.full(len(t_mids), 0), 'all' : np.full(len(t_mids), 0)},
                'both'    : {'sign' : np.full(len(t_mids), 0), 'all' : np.full(len(t_mids), 0)}}

final_neurons_all = sorted(list(set(pvals_neurons['slowing'].keys()) | set(pvals_neurons['error'].keys())))

for neuron in final_neurons_all:

    if (neuron in pvals_neurons['slowing'].keys()) and (neuron in pvals_neurons['error'].keys()):
        global_temp['both']['all']  = global_temp['both']['all'] + 1
        global_temp['both']['sign'] = global_temp['both']['sign'] + ((pvals_neurons['slowing'][neuron] < sign_lvl) & (pvals_neurons['error'][neuron] < sign_lvl)).astype(int)

    if neuron in pvals_neurons['slowing'].keys():
        global_temp['slowing']['all']  = global_temp['slowing']['all'] + 1
        global_temp['slowing']['sign'] = global_temp['slowing']['sign'] + (pvals_neurons['slowing'][neuron] < sign_lvl).astype(int)

    if neuron in pvals_neurons['error'].keys():
        global_temp['error']['all']  = global_temp['error']['all'] + 1
        global_temp['error']['sign'] = global_temp['error']['sign'] + (pvals_neurons['error'][neuron] < sign_lvl).astype(int)


    if neuron in neuron_region.keys():

        curr_region = neuron_region[neuron]

        if (neuron in pvals_neurons['slowing'].keys()) and (neuron in pvals_neurons['error'].keys()):

            if curr_region not in regions_temp['both']:
                regions_temp['both'][curr_region] = {'all'  : np.full(len(t_mids), 1),
                                                     'sign' : np.full(len(t_mids), 0)}
            else:
                regions_temp['both'][curr_region]['all'] = regions_temp['both'][curr_region]['all'] + 1

            regions_temp['both'][curr_region]['sign'] = regions_temp['both'][curr_region]['sign'] + ((pvals_neurons['slowing'][neuron] < sign_lvl) & (pvals_neurons['error'][neuron] < sign_lvl)).astype(int)

        if neuron in pvals_neurons['slowing'].keys():

            if curr_region not in regions_temp['slowing']:
                regions_temp['slowing'][curr_region] = {'all'  : np.full(len(t_mids), 1),
                                                        'sign' : np.full(len(t_mids), 0)}
            else:
                regions_temp['slowing'][curr_region]['all'] = regions_temp['slowing'][curr_region]['all'] + 1

            regions_temp['slowing'][curr_region]['sign'] = regions_temp['slowing'][curr_region]['sign'] + (pvals_neurons['slowing'][neuron] < sign_lvl).astype(int)

        if neuron in pvals_neurons['error'].keys():

            if curr_region not in regions_temp['error']:
                regions_temp['error'][curr_region] = {'all'  : np.full(len(t_mids), 1),
                                                      'sign' : np.full(len(t_mids), 0)}
            else:
                regions_temp['error'][curr_region]['all'] = regions_temp['error'][curr_region]['all'] + 1

            regions_temp['error'][curr_region]['sign'] = regions_temp['error'][curr_region]['sign'] + (pvals_neurons['error'][neuron] < sign_lvl).astype(int)


regions_not_involved = {'slowing' : [], 'error' : [], 'both' : []}

for curr_case in ['slowing', 'error', 'both']:
    for region in sorted(list(regions_temp[curr_case].keys())):

        assert(len(regions_temp[curr_case][curr_region]['sign']) == len(t_mids))
        assert(len(regions_temp[curr_case][curr_region]['all'])  == len(t_mids))

        assert(np.all(regions_temp[curr_case][region]['all'] == regions_temp[curr_case][region]['all'][0]))

        if ((np.sum(regions_temp[curr_case][region]['sign']) == 0) or (regions_temp[curr_case][region]['all'][0] < 30)):
            regions_not_involved[curr_case].append(region)

for curr_case in ['slowing', 'error', 'both']:
    for region in regions_not_involved[curr_case]:
        del regions_temp[curr_case][region]

ratios_temp = {'slowing' : {}, 'error' : {}, 'both' : {}}

for curr_case in ['slowing', 'error', 'both']:
    for region in regions_temp[curr_case]:

        ratios_temp[curr_case][region] = regions_temp[curr_case][region]['sign'] / regions_temp[curr_case][region]['all']

        assert(np.all(ratios_temp[curr_case][region] <= 1.0))

top_slowing_regions = sorted(regions_temp['slowing'].keys(), key=lambda k : np.max(ratios_temp['slowing'][k]), reverse=True)[ : 3]
top_error_regions   = sorted(regions_temp['error'].keys(),   key=lambda k : np.max(ratios_temp['error'][k]),   reverse=True)[ : 3]
top_both_regions    = sorted(regions_temp['both'].keys(),    key=lambda k : np.max(ratios_temp['both'][k]),    reverse=True)[ : 3]

print(f'Relevant for SLOWING: {top_slowing_regions}')
print(f'Relevant for ERROR: {top_error_regions}')
print(f'Relevant for BOTH: {top_both_regions}')

assert(len(global_temp['slowing']['sign']) == len(t_mids))
assert(len(global_temp['slowing']['all'])  == len(t_mids))
assert(len(global_temp['error']['sign'])   == len(t_mids))
assert(len(global_temp['error']['all'])    == len(t_mids))
assert(len(global_temp['both']['sign'])    == len(t_mids))
assert(len(global_temp['both']['all'])     == len(t_mids))

avg_ratios = {curr_case : global_temp[curr_case]['sign'] / global_temp[curr_case]['all'] for curr_case in global_temp.keys()}

assert(len(avg_ratios['slowing']) == len(t_mids))
assert(len(avg_ratios['error'])   == len(t_mids))
assert(len(avg_ratios['both'])    == len(t_mids))

assert(np.all(avg_ratios['slowing'] < 1.0))
assert(np.all(avg_ratios['error']   < 1.0))
assert(np.all(avg_ratios['both']    < 1.0))


# General figure

In [None]:
max_ratios_gen = {}
t_max_ratios   = {}
n_neurons_gen  = {region : 0 for region in list(set(set(regions_temp['slowing'].keys()) | set(regions_temp['error'].keys()) | set(regions_temp['both'].keys())))}

for test_case in ['slowing', 'error', 'both']:

    max_ratios_gen[test_case] = {}
    t_max_ratios[test_case]   = {}

    for region in sorted(list(ratios_temp[test_case].keys())):

        if (region in regions_temp[test_case].keys()):

            if n_neurons_gen[region] < regions_temp[test_case][region]['all'][0]:

                n_neurons_gen[region] = regions_temp[test_case][region]['all'][0]

        max_ratios_gen[test_case][region] = np.max(ratios_temp[test_case][region])

        t_max_ratios[test_case][region] = int(t_mids[np.argmax(ratios_temp[test_case][region])])


slowing_regions_sorted = sorted(regions_temp['slowing'].keys(), key=lambda k : np.max(ratios_temp['slowing'][k]), reverse=False)
error_regions_sorted   = sorted(regions_temp['error'].keys(),   key=lambda k : np.max(ratios_temp['error'][k]),   reverse=False)
both_regions_sorted    = sorted(regions_temp['both'].keys(),    key=lambda k : np.max(ratios_temp['both'][k]),    reverse=False)

print('Number of neurons recorded:')

for region in sorted(list(n_neurons_gen)):

    print(f'{region} : {n_neurons_gen[region]}')


In [None]:
fig, axs = plt.subplots(2, 2, figsize=(13, 10))

axs[0, 0].remove()
axs[1, 0].remove()

gs      = axs[0, 0].get_gridspec()
ax_left = fig.add_subplot(gs[:, 0])


bars_err = ax_left.barh([region for region in error_regions_sorted], [max_ratios_gen['error'][region] for region in error_regions_sorted], edgecolor='black', linewidth=1.2, color='#E6C27A')

x_uplim_error = ax_left.get_xlim()[1]

for bar, label in zip(bars_err, [t_max_ratios['error'][region] for region in error_regions_sorted]):
    ax_left.text(
        1.0 * x_uplim_error,
        bar.get_y() + bar.get_height() / 2,
        str(label) + ' ms',
        va='center',
        ha='left',
        fontsize=13)

ax_left.set_xlim([None, 1.2 * x_uplim_error])
ax_left.set_xlabel('Maximum ratio of significant "error" \nneurons')


bars_slow = axs[0, 1].barh([region for region in slowing_regions_sorted], [max_ratios_gen['slowing'][region] for region in slowing_regions_sorted], edgecolor='black', linewidth=1.2, color='#E6A8A1')

x_uplim_slow = axs[0, 1].get_xlim()[1]

for bar, label in zip(bars_slow, [t_max_ratios['slowing'][region] for region in slowing_regions_sorted]):
    axs[0, 1].text(
        1.0 * x_uplim_slow,
        bar.get_y() + bar.get_height() / 2,
        str(label) + ' ms',
        va='center',
        ha='left',
        fontsize=13)

axs[0, 1].set_xlim([None, 1.2 * x_uplim_slow])
axs[0, 1].set_xlabel('Maximum ratio of significant "slowing" \nneurons')


bars_both = axs[1, 1].barh([region for region in both_regions_sorted], [max_ratios_gen['both'][region] for region in both_regions_sorted], edgecolor='black', linewidth=1.2, color='#A8B97A')

x_uplim_both = axs[1, 1].get_xlim()[1]

for bar, label in zip(bars_both, [t_max_ratios['both'][region] for region in both_regions_sorted]):
    axs[1, 1].text(
        1.0 * x_uplim_both,
        bar.get_y() + bar.get_height() / 2,
        str(label) + ' ms',
        va='center',
        ha='left',
        fontsize=13)

axs[1, 1].set_xlim([None, 1.2 * x_uplim_both])
axs[1, 1].set_xlabel('Maximum ratio of significant neurons \ninvolved in both processes')

ax_left.set_ylim(-0.5,   len(error_regions_sorted)   - 0.5)
axs[0, 1].set_ylim(-0.5, len(slowing_regions_sorted) - 0.5)
axs[1, 1].set_ylim(-0.5, len(both_regions_sorted)    - 0.5)


plt.subplots_adjust(hspace=0.37)
plt.savefig('*** Update the folder path here to match your directory structure ***', format='png', dpi=300, bbox_inches='tight')
plt.show()


# Overview charts

In [None]:
curr_colors = {'RN'  : '#4A90E2',
               'MOs' : '#FF7F0E',
               'MRN' : '#2CA02C',
               'IP'  : '#D62728',
               'DEC' : '#9467BD',
               'avg' : '#7D7D7D'}


In [None]:
_, axs = plt.subplots(3, 1, figsize=(14, 15), sharex=True)

smooth_sig = 1.1

axs[0].plot(t_mids, gaussian_filter1d(avg_ratios['slowing'], sigma=smooth_sig), color=curr_colors['avg'], label='Average', linewidth=0.5)

for region in top_slowing_regions:
    axs[0].plot(t_mids, gaussian_filter1d(ratios_temp['slowing'][region], sigma=smooth_sig), label=region, color=curr_colors[region])

axs[0].set_xticks(time_plot)
axs[0].vlines(0, 0, axs[0].get_ylim()[1], color='k', linestyle=(0, (5, 10)), zorder=0, linewidth=0.8)
axs[0].set_xlabel('Time from Go Cue (ms)')
axs[0].set_ylabel('Ratio of "slowing" neurons')
axs[0].legend()
axs[0].tick_params(axis='x', labelbottom=True)


axs[1].plot(t_mids, gaussian_filter1d(avg_ratios['error'], sigma=smooth_sig), color=curr_colors['avg'], label='Average', linewidth=0.5)

for region in top_error_regions:
    axs[1].plot(t_mids, gaussian_filter1d(ratios_temp['error'][region], sigma=smooth_sig), label=region, color=curr_colors[region])

axs[1].set_xticks(time_plot)
axs[1].vlines(0, 0, axs[1].get_ylim()[1], color='k', linestyle=(0, (5, 10)), zorder=0, linewidth=0.8)
axs[1].set_xlabel('Time from Go Cue (ms)')
axs[1].set_ylabel('Ratio of "error" neurons')
axs[1].legend(loc='upper left')
axs[1].tick_params(axis='x', labelbottom=True)


axs[2].plot(t_mids, gaussian_filter1d(avg_ratios['both'], sigma=smooth_sig), color=curr_colors['avg'], label='Average', linewidth=0.5)

for region in top_both_regions:
    axs[2].plot(t_mids, gaussian_filter1d(ratios_temp['both'][region], sigma=smooth_sig), label=region, color=curr_colors[region])

axs[2].set_xticks(time_plot)
axs[2].vlines(0, 0, axs[2].get_ylim()[1], color='k', linestyle=(0, (5, 10)), zorder=0, linewidth=0.8)
axs[2].set_xlabel('Time from Go Cue (ms)')
axs[2].set_ylabel('Ratio of neurons involved in \nboth processes')
axs[2].legend()


plt.savefig('*** Update the folder path here to match your directory structure ***', format='png', dpi=300, bbox_inches='tight')
plt.show()


# RN charts

In [None]:
print(f"p-value of the slowing binomial test: {binomtest(5, regions_temp['slowing']['RN']['all'][0], p=sign_lvl, alternative='two-sided').pvalue}")
print(f"p-value of the error binomial test: {  binomtest(5, regions_temp['error']['RN']['all'][0],   p=sign_lvl, alternative='two-sided').pvalue}")
print(f"p-value of the both binomial test: {   binomtest(2, regions_temp['both']['RN']['all'][0],    p=sign_lvl * sign_lvl, alternative='two-sided').pvalue}")


In [None]:
_, axs = plt.subplots(3, 1, figsize=(14, 15), sharex=True)

axs[0].bar(range(len(t_mids)), ratios_temp['slowing']['RN'], color=curr_colors['RN'], edgecolor='black', linewidth=1)

axs[0].set_xlim(0 - 0.5, len(t_mids) - 0.5)
axs[0].set_xticks(ind_time_plot, time_plot)
axs[0].hlines(5 / regions_temp['slowing']['RN']['all'][0], 1, len(t_mids) - 1, color='#FF3B00', zorder=0, linewidth=0.8)
axs[0].vlines(len(t_mids_pre) - 0.5, 0, axs[0].get_ylim()[1], color='k', linestyle=(0, (5, 10)), zorder=1, linewidth=0.8)
axs[0].set_xlabel('Time from Go Cue (ms)')
axs[0].set_ylabel('Ratio of "slowing" neurons')
axs[0].tick_params(axis='x', labelbottom=True)


axs[1].bar(range(len(t_mids)), ratios_temp['error']['RN'], color=curr_colors['RN'], edgecolor='black', linewidth=1)

axs[1].set_xticks(ind_time_plot, time_plot)
axs[1].hlines(5 / regions_temp['error']['RN']['all'][0], 1, len(t_mids) - 1, color='#FF3B00', zorder=0, linewidth=0.8)
axs[1].vlines(len(t_mids_pre) - 0.5, 0, axs[1].get_ylim()[1], color='k', linestyle=(0, (5, 10)), zorder=1, linewidth=0.8)
axs[1].set_xlabel('Time from Go Cue (ms)')
axs[1].set_ylabel('Ratio of "error" neurons')
axs[1].tick_params(axis='x', labelbottom=True)


axs[2].bar(range(len(t_mids)), ratios_temp['both']['RN'], color=curr_colors['RN'], edgecolor='black', linewidth=1)

axs[2].set_xticks(ind_time_plot, time_plot)
axs[2].hlines(2 / regions_temp['both']['RN']['all'][0], 1, len(t_mids) - 1, color='#FF3B00', zorder=0, linewidth=0.8)
axs[2].vlines(len(t_mids_pre) - 0.5, 0, axs[2].get_ylim()[1], color='k', linestyle=(0, (5, 10)), zorder=1, linewidth=0.8)
axs[2].set_xlabel('Time from Go Cue (ms)')
axs[2].set_ylabel('Ratio of neurons involved in \nboth processes')


plt.savefig('*** Update the folder path here to match your directory structure ***', format='png', dpi=300, bbox_inches='tight')
plt.show()


# Raster plot

In [None]:
with open('*** Update the folder path here to match your directory structure ***', 'rb') as f:
    fr_z = pickle.load(f)

with open('*** Update the folder path here to match your directory structure ***', 'rb') as f:
    trials_all = pickle.load(f)


In [None]:
# IMPORTED PARAMETERS

t_lims     = (-350.0, 350.0) # ms
sampl_rate = 10.0 # ms
time_axis  = np.arange(t_lims[0] + sampl_rate / 2, t_lims[1] - sampl_rate / 2 + 1, sampl_rate)

print(time_axis)


In [None]:
rn_neurons = {}

for neuron in sorted(list(neuron_region.keys())):

    if neuron_region[neuron] == 'RN':

        if (neuron in neurons2sess_contr_side['error'].keys()) and (neuron in neurons2sess_contr_side['slowing'].keys()):

            if neurons2sess_contr_side['error'][neuron] == neurons2sess_contr_side['slowing'][neuron]:

                rn_neurons[neuron] = neurons2sess_contr_side['error'][neuron]


In [None]:
rand_rn_neuron = ('3c283107-7012-48fc-a6c2-ed096b23974f', np.int64(710))
rand_triplet   = rn_neurons[rand_rn_neuron]

print(rand_rn_neuron)
print(rand_triplet)

rand_trials_pes   = matched_trials_mod[rand_triplet]['error']['pes']
rand_trials_pens  = matched_trials_mod[rand_triplet]['slowing']['npes']
rand_trials_pcorr = matched_trials_mod[rand_triplet]['error']['npes']

assert(rand_trials_pes == matched_trials_mod[rand_triplet]['slowing']['pes'])


In [None]:
rand_sp_pes   = []
rand_sp_pens  = []
rand_sp_pcorr = []

rand_fr_pes   = []
rand_fr_pens  = []
rand_fr_pcorr = []

rand_rt_pes   = []
rand_rt_pens  = []
rand_rt_pcorr = []

for trial in rand_trials_pes:

    curr_spiketrain = spikes_all[(rand_triplet[0], rand_triplet[1])][rand_rn_neuron][trial]

    rand_sp_pes.append(curr_spiketrain[np.logical_and(curr_spiketrain > t_lims[0], curr_spiketrain < t_lims[1])])
    rand_fr_pes.append(fr_z[(rand_triplet[0], rand_triplet[1])][rand_rn_neuron][trial])

    rand_rt_pes.append(trials_all[rand_triplet[0]][rand_triplet[1]][trial][0])


for trial in rand_trials_pens:

    curr_spiketrain = spikes_all[(rand_triplet[0], rand_triplet[1])][rand_rn_neuron][trial]

    rand_sp_pens.append(curr_spiketrain[np.logical_and(curr_spiketrain > t_lims[0], curr_spiketrain < t_lims[1])])
    rand_fr_pens.append(fr_z[(rand_triplet[0], rand_triplet[1])][rand_rn_neuron][trial])

    rand_rt_pens.append(trials_all[rand_triplet[0]][rand_triplet[1]][trial][0])


for trial in rand_trials_pcorr:

    curr_spiketrain = spikes_all[(rand_triplet[0], rand_triplet[1])][rand_rn_neuron][trial]

    rand_sp_pcorr.append(curr_spiketrain[np.logical_and(curr_spiketrain > t_lims[0], curr_spiketrain < t_lims[1])])
    rand_fr_pcorr.append(fr_z[(rand_triplet[0], rand_triplet[1])][rand_rn_neuron][trial])

    rand_rt_pcorr.append(trials_all[rand_triplet[0]][rand_triplet[1]][trial][0])


_, axs = plt.subplots(4, 1, figsize=(8, 12), sharex=True, gridspec_kw={'height_ratios': [1, 1, 1, 3]})

axs[0].eventplot(rand_sp_pes, color='k')
axs[1].eventplot(rand_sp_pens, color='k')
axs[2].eventplot(rand_sp_pcorr, color='k')


ind_pes   = np.argmin(np.abs(time_axis - np.median(rand_rt_pes))).astype(int)
ind_pens  = np.argmin(np.abs(time_axis - np.median(rand_rt_pens))).astype(int)
ind_pcorr = np.argmin(np.abs(time_axis - np.median(rand_rt_pcorr))).astype(int)

axs[3].plot(time_axis, np.mean(rand_fr_pes, axis=0), label='Slow post-error', color='#FF9F45')
axs[3].scatter([time_axis[ind_pes]], [np.mean(rand_fr_pes, axis=0)[ind_pes]], marker='o', color='#FF9F45')

axs[3].plot(time_axis, np.mean(rand_fr_pens, axis=0), label='Non-slow post-error', color='#9A7DCC')
axs[3].scatter([time_axis[ind_pens]], [np.mean(rand_fr_pens, axis=0)[ind_pens]], marker='o', color='#9A7DCC')

axs[3].plot(time_axis, np.mean(rand_fr_pcorr, axis=0), label='Slow post-correct', color='#4682B4')
axs[3].scatter([time_axis[ind_pcorr]], [np.mean(rand_fr_pcorr, axis=0)[ind_pcorr]], marker='o', color='#4682B4')

axs[3].vlines(0, axs[3].get_ylim()[0], axs[3].get_ylim()[1], linestyle='--', color='#FF6F61', zorder=0)

axs[3].legend()

axs[3].set_xlabel('Time from the Go Cue (ms)')
axs[3].set_ylabel('Average z-scored\nfiring rate')

axs[0].set_ylabel('Trials')
axs[1].set_ylabel('Trials')
axs[2].set_ylabel('Trials')

axs[0].text(1.02, 0.5, 'Slow post-error\nspiketrains',     transform=axs[0].transAxes, va='center', ha='left', rotation=0, fontsize=14)
axs[1].text(1.02, 0.5, 'Non-slow post-error\nspiketrains', transform=axs[1].transAxes, va='center', ha='left', rotation=0, fontsize=14)
axs[2].text(1.02, 0.5, 'Slow post-correct\nspiketrains',   transform=axs[2].transAxes, va='center', ha='left', rotation=0, fontsize=14)

plt.savefig('*** Update the folder path here to match your directory structure ***', format='png', dpi=300, bbox_inches='tight')
plt.show()


In [None]:
from one.api import ONE

one  = ONE(base_url='https://openalyx.internationalbrainlab.org', password='international', silent=True)

ex_sess  = uuid.UUID('1b9e349e-93f2-41cc-a4b5-b212d7ddc8df')
info_ex = one.get_details(ex_sess)

print('EID:', ex_sess)
print('Contrast:', 1.0)
print('Side:', 'right')
print('Lab:', info_ex['lab'])
print('Animal:', info_ex['subject'])
print('Starting time:', info_ex['start_time'])
print('PID:', '3c283107-7012-48fc-a6c2-ed096b23974f')
print('Cluster ID:', 710)

