<a href="https://colab.research.google.com/github/gergogomori/brain_wide_pes/blob/main/pes_classifiers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pickle
import uuid

import pandas  as pd
import seaborn as sns
from matplotlib.ticker import MaxNLocator

from sklearn.linear_model    import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics         import f1_score

from warnings           import filterwarnings
from sklearn.exceptions import ConvergenceWarning
filterwarnings('ignore', category=RuntimeWarning)
filterwarnings('ignore', category=ConvergenceWarning)


In [None]:
import matplotlib as mpl

mpl.rcParams.update({
    'axes.labelsize'  : 15,
    'xtick.labelsize' : 14,
    'ytick.labelsize' : 14,
    'legend.fontsize' : 14})


In [None]:
with open('*** Update the folder path here to match your directory structure ***', 'rb') as f:
    trials_all = pickle.load(f)

with open('*** Update the folder path here to match your directory structure ***', 'rb') as f:
    matched_trials = pickle.load(f)

with open('*** Update the folder path here to match your directory structure ***', 'rb') as f:
    fr_z = pickle.load(f)

with open('*** Update the folder path here to match your directory structure ***', 'rb') as f:
    neuron_region = pickle.load(f)


In [None]:
matched_trials_mod = {}

for curr_eid_contr_side in sorted(list(matched_trials.keys())):

    matched_trials_mod[curr_eid_contr_side] = {}

    for curr_repr in sorted(list(matched_trials[curr_eid_contr_side].keys())):

        if curr_repr == 'error':
            matched_trials_mod[curr_eid_contr_side]['slowing'] = matched_trials[curr_eid_contr_side]['error']

        elif curr_repr == 'choice':
            matched_trials_mod[curr_eid_contr_side]['error'] = matched_trials[curr_eid_contr_side]['choice']


In [None]:
# IMPORTED PARAMETERS

time_eps        = 5.0 # ms
time_end_perc   = 5

t_lims     = (-350.0, 350.0) # ms
sampl_rate = 10.0 # ms
time_axis  = np.arange(t_lims[0] + sampl_rate / 2, t_lims[1] - sampl_rate / 2 + 1, sampl_rate)

print(time_axis)


In [None]:
# NEW PARAMETERS

regul_C    = [pow(10.0, C) for C in list(range(-2, 3))]
n_models   = 10
r_test     = 0.2
r_dev      = 0.3
perf_thres = 0.7

print(f'Regularization parameters: {regul_C}')


# Training the classifiers

In [None]:
perf_all = {'slowing' : {}, 'error' : {}}

for ind_curr_eid_contr_side, curr_eid_contr_side in enumerate(sorted(list(matched_trials_mod.keys()))):

    print('Done with', np.around(ind_curr_eid_contr_side / len(matched_trials_mod.keys()) * 100, 1), '% of the data.')

    curr_rts = []

    for trial in sorted(list(trials_all[curr_eid_contr_side[0]][curr_eid_contr_side[1]].keys())):
        curr_rts.append(trials_all[curr_eid_contr_side[0]][curr_eid_contr_side[1]][trial][0])

    curr_time_inds = (int(np.argmin(np.abs(time_axis - (0.0 + time_eps))).astype(int)), int(np.argmin(np.abs(time_axis - np.percentile(curr_rts, time_end_perc))).astype(int)) + 1)

    for curr_repr in sorted(list(matched_trials_mod[curr_eid_contr_side].keys())):

        perf_all[curr_repr][curr_eid_contr_side] = {}

        for neuron in sorted(list(fr_z[(curr_eid_contr_side[0], curr_eid_contr_side[1])])):

            perf_all[curr_repr][curr_eid_contr_side][neuron] = {}

            curr_data = []
            curr_labl = []

            for trial in matched_trials_mod[curr_eid_contr_side][curr_repr]['pes']:
                curr_data.append(fr_z[(curr_eid_contr_side[0], curr_eid_contr_side[1])][neuron][trial][curr_time_inds[0] : curr_time_inds[1]])
                curr_labl.append(1)

            for trial in matched_trials_mod[curr_eid_contr_side][curr_repr]['npes']:
                curr_data.append(fr_z[(curr_eid_contr_side[0], curr_eid_contr_side[1])][neuron][trial][curr_time_inds[0] : curr_time_inds[1]])
                curr_labl.append(0)

            curr_data = np.array(curr_data)
            curr_labl = np.array(curr_labl)

            assert(len(curr_data) == len(curr_labl))

            curr_res = []

            for ind_mod in range(n_models):

                X_not_test, X_test, y_not_test, y_test = train_test_split(curr_data,  curr_labl,  test_size=r_test, stratify=curr_labl)
                X_train, X_dev, y_train, y_dev         = train_test_split(X_not_test, y_not_test, test_size=r_dev,  stratify=y_not_test)

                res_dev = {}

                for curr_C in regul_C:

                    curr_mod = LogisticRegression(max_iter=100, penalty='l2', C=curr_C, fit_intercept=True)
                    curr_mod.fit(X_train, y_train)

                    res_dev[curr_C] = f1_score(y_dev, curr_mod.predict(X_dev))

                best_C   = max(res_dev, key=res_dev.get)
                best_mod = LogisticRegression(max_iter=100, penalty='l2', C=best_C, fit_intercept=True)

                best_mod.fit(np.concatenate((X_train, X_dev)), np.concatenate((y_train, y_dev)))

                curr_res.append(f1_score(y_test, best_mod.predict(X_test)))

            perf_all[curr_repr][curr_eid_contr_side][neuron] = curr_res

with open('*** Update the folder path here to match your directory structure ***', 'wb') as f:
    pickle.dump(perf_all, f, pickle.HIGHEST_PROTOCOL)


# Selecting one neuron for each trial condition

In [None]:
sess_cond = {'error' : {}, 'both' : {}}

all_sess_contr_side = sorted(list(set(perf_all['slowing'].keys()) | set(perf_all['error'].keys())))

for curr_eid_contr_side in all_sess_contr_side:

    if curr_eid_contr_side in perf_all['slowing'].keys():

        assert(curr_eid_contr_side in perf_all['error'].keys())

    if (curr_eid_contr_side in perf_all['slowing'].keys()) and (curr_eid_contr_side in perf_all['error'].keys()):

        if curr_eid_contr_side[0] not in sess_cond['both'].keys():
            sess_cond['both'][curr_eid_contr_side[0]] = [curr_eid_contr_side]
        else:
            sess_cond['both'][curr_eid_contr_side[0]].append(curr_eid_contr_side)


for curr_eid_contr_side in all_sess_contr_side:

    if curr_eid_contr_side[0] not in sess_cond['both'].keys():

        assert(not((curr_eid_contr_side in perf_all['slowing'].keys()) and (curr_eid_contr_side in perf_all['error'].keys())))

        if curr_eid_contr_side in perf_all['error'].keys():

            assert(curr_eid_contr_side not in perf_all['slowing'].keys())

            if curr_eid_contr_side[0] not in sess_cond['error'].keys():
                sess_cond['error'][curr_eid_contr_side[0]] = [curr_eid_contr_side]
            else:
                sess_cond['error'][curr_eid_contr_side[0]].append(curr_eid_contr_side)


In [None]:
perf_neurons = {'slowing' : {}, 'error' : {}}


for EID in sess_cond['error'].keys():

    if len(sess_cond['error'][EID]) == 1:

        for neuron in perf_all['error'][sess_cond['error'][EID][0]].keys():
            perf_neurons['error'][neuron] = np.mean(perf_all['error'][sess_cond['error'][EID][0]][neuron])

    else:

        curr_neurons = {}

        for neuron in perf_all['error'][sess_cond['error'][EID][0]].keys():
            curr_neurons[neuron] = np.mean(perf_all['error'][sess_cond['error'][EID][0]][neuron])

        for curr_eid_contr_side in sess_cond['error'][EID]:

            assert(len(set(curr_neurons.keys()) & set(perf_all['error'][curr_eid_contr_side].keys())) > 0)

            for neuron in sorted(list(perf_all['error'][curr_eid_contr_side].keys())):

                if neuron in curr_neurons.keys():

                    if curr_neurons[neuron] < np.mean(perf_all['error'][curr_eid_contr_side][neuron]):

                        curr_neurons[neuron] = np.mean(perf_all['error'][curr_eid_contr_side][neuron])

                else:
                    curr_neurons[neuron] = np.mean(perf_all['error'][curr_eid_contr_side][neuron])

        perf_neurons['error'].update(curr_neurons)


In [None]:
for EID in sess_cond['both'].keys():

    if len(sess_cond['both'][EID]) == 1:

        curr_neurons = sorted(list(perf_all['slowing'][sess_cond['both'][EID][0]].keys()))

        assert(curr_neurons == sorted(list(perf_all['error'][sess_cond['both'][EID][0]].keys())))

        for neuron in curr_neurons:

            perf_neurons['slowing'][neuron] = np.mean(perf_all['slowing'][sess_cond['both'][EID][0]][neuron])
            perf_neurons['error'][neuron]   = np.mean(perf_all['error'][sess_cond['both'][EID][0]][neuron])

    else:

        assert(len(sess_cond['both'][EID]) == 2)

        curr_neurons_slowing = set()
        curr_neurons_error   = set()

        for curr_eid_contr_side in sess_cond['both'][EID]:
            curr_neurons_slowing |= set(perf_all['slowing'][curr_eid_contr_side].keys())
            curr_neurons_error   |= set(perf_all['error'][curr_eid_contr_side].keys())

        assert(curr_neurons_slowing == curr_neurons_error)

        for neuron in sorted(list(curr_neurons_slowing)):

            curr_res = {}

            for curr_eid_contr_side in sess_cond['both'][EID]:

                if neuron in perf_all['slowing'][curr_eid_contr_side]:

                    assert(neuron in perf_all['error'][curr_eid_contr_side])

                    curr_res[curr_eid_contr_side] = np.max([np.mean(perf_all['slowing'][curr_eid_contr_side][neuron]), np.mean(perf_all['error'][curr_eid_contr_side][neuron])])

            assert(len(curr_res.keys()) == 2)

            best_perf_sess = max(curr_res, key=curr_res.get)

            perf_neurons['slowing'][neuron] = np.mean(perf_all['slowing'][best_perf_sess][neuron])
            perf_neurons['error'][neuron]   = np.mean(perf_all['error'][best_perf_sess][neuron])

assert((set(perf_neurons['slowing'].keys()) & set(perf_neurons['error'].keys())) == set(perf_neurons['slowing'].keys()))


# Performance distribution

In [None]:
perf_pooled = {'slowing' : [], 'error' : []}

for curr_repr in ['slowing', 'error']:
    for neuron in sorted(list(perf_neurons[curr_repr].keys())):
        perf_pooled[curr_repr].append(perf_neurons[curr_repr][neuron])


_, ax = plt.subplots(1, 2, figsize=(10, 4), sharex=True, sharey=True)

sns.histplot(perf_pooled['slowing'], ax=ax[0], binwidth=0.05, binrange=(0.0, 1.0), color='#5D7FA3')

ax[0].yaxis.set_major_locator(MaxNLocator(integer=True))
ax[0].set_xlabel('Performance for predicting \n"slowing"')
ax[0].set_ylabel('Number of neurons')


sns.histplot(perf_pooled['error'], ax=ax[1], binwidth=0.05, binrange=(0.0, 1.0), color='#5D7FA3')

curr_plot_y_max = ax[1].get_ylim()[1]
ax[1].vlines(perf_thres, 0, curr_plot_y_max, linestyle='--', color='#FF80AB')
ax[0].vlines(perf_thres, 0, curr_plot_y_max, linestyle='--', color='#FF80AB')

ax[1].yaxis.set_major_locator(MaxNLocator(integer=True))
ax[1].set_xlabel('Performance for predicting \n(the previous) error')
ax[1].set_ylabel('Number of neurons')
ax[1].tick_params(axis='y', labelleft=True)

plt.savefig('*** Update the folder path here to match your directory structure ***', format='png', dpi=300, bbox_inches='tight')
plt.show()


# Location of neurons with high performance

In [None]:
regions_all = {}

all_neurons = sorted(list(set(perf_neurons['slowing'].keys()) | set(perf_neurons['error'].keys())))

for neuron in all_neurons:

    if neuron in neuron_region.keys():

        curr_region = neuron_region[neuron]

        if curr_region not in regions_all.keys():
            regions_all[curr_region] = {'slowing' : 0, 'error' : 0, 'both' : 0, 'all' : 1}
        else:
            regions_all[curr_region]['all'] += 1

        if ((neuron in perf_neurons['slowing'].keys()) and (neuron in perf_neurons['error'].keys())):

            if ((perf_neurons['slowing'][neuron] > perf_thres) and (perf_neurons['error'][neuron] > perf_thres)):
                regions_all[curr_region]['both'] += 1

            else:
                if perf_neurons['slowing'][neuron] > perf_thres:
                    regions_all[curr_region]['slowing'] += 1

                if perf_neurons['error'][neuron] > perf_thres:
                    regions_all[curr_region]['error'] += 1

        elif neuron in perf_neurons['error'].keys():

            assert(not(neuron in perf_neurons['slowing'].keys()))

            if perf_neurons['error'][neuron] > perf_thres:
                regions_all[curr_region]['error'] += 1


not_involved_regions = []

for region in sorted(list(regions_all.keys())):

    if (regions_all[region]['slowing'] == 0 and regions_all[region]['error'] == 0 and regions_all[region]['both'] == 0) or (regions_all[region]['all'] < 30):
        not_involved_regions.append(region)

for region in not_involved_regions:
    del regions_all[region]


sort_plot_both    = []
sort_plot_slowing = []
sort_plot_error   = []

ratios_all = {}

for region in sorted(list(regions_all.keys())):

    ratios_all[region] = {'slowing' : regions_all[region]['slowing']  / regions_all[region]['all'],
                          'error'   : regions_all[region]['error']    / regions_all[region]['all'],
                          'both'    : regions_all[region]['both']     / regions_all[region]['all'],
                          'remain'  : (regions_all[region]['all'] - regions_all[region]['slowing'] - regions_all[region]['error'] - regions_all[region]['both']) / regions_all[region]['all']}

    assert(ratios_all[region]['slowing'] <= 1.0)
    assert(ratios_all[region]['error']   <= 1.0)
    assert(ratios_all[region]['both']    <= 1.0)
    assert(ratios_all[region]['remain']  <= 1.0)

    if regions_all[region]['both'] > 0:
        sort_plot_both.append(region)
    else:
        if regions_all[region]['slowing'] > 0:
            assert(ratios_all[region]['both'] == 0.0)
            sort_plot_slowing.append(region)
        else:
            assert(ratios_all[region]['both']    == 0.0)
            assert(ratios_all[region]['slowing'] == 0.0)
            assert(ratios_all[region]['error']    > 0.0)
            sort_plot_error.append(region)


sort_plot_both    = sorted(sort_plot_both,    key=lambda k : ratios_all[k]['both'])
sort_plot_slowing = sorted(sort_plot_slowing, key=lambda k : ratios_all[k]['slowing'])
sort_plot_error   = sorted(sort_plot_error,   key=lambda k : ratios_all[k]['error'])

sort_plot_all = sort_plot_error + sort_plot_slowing + sort_plot_both

assert(len(set(sort_plot_all)) == len(sort_plot_all))
assert(len(sort_plot_all)      == len(ratios_all.keys()))


In [None]:
region_cols = {'both'    : '#FFB07C',
               'slowing' : '#AAF0D1',
               'error'   : '#7FB3D5',
               'remain'  : '#D3D3D3'}

fig, ax = plt.subplots(1, 1, figsize=(6, 12))

for ind_region, region in enumerate(sort_plot_all):
    curr_both    = ratios_all[region]['both']
    curr_slowing = ratios_all[region]['slowing']
    curr_error   = ratios_all[region]['error']
    curr_remain  = ratios_all[region]['remain']

    ax.barh(ind_region, curr_both,                                               color=region_cols['both'],    label='Both'         if ind_region == 0 else '')
    ax.barh(ind_region, curr_slowing,  left=curr_both,                           color=region_cols['slowing'], label='Slowing'      if ind_region == 0 else '')
    ax.barh(ind_region, curr_error, left=curr_both + curr_slowing,               color=region_cols['error'],   label='Error'        if ind_region == 0 else '')
    ax.barh(ind_region, curr_remain, left=curr_both + curr_slowing + curr_error, color=region_cols['remain'],  label='Not involved' if ind_region == 0 else '')
    ax.text(1.01, ind_region, str(regions_all[region]['all']), va='center', ha='left', fontsize=13)

ax.set_yticks(range(len(sort_plot_all)))
ax.set_yticklabels(sort_plot_all)
ax.set_ylim(-0.5, len(sort_plot_all) - 0.5)
ax.set_xlim(0.0, 1.15)
ax.grid(axis='y', visible=False)

ax.set_xlabel('Ratio of analyzed neurons')

ax.legend(loc='center left', bbox_to_anchor=(1.0, 0.93))

plt.savefig('*** Update the folder path here to match your directory structure ***', format='png', bbox_inches='tight', dpi=300)
plt.show()
