<a href="https://colab.research.google.com/github/gergogomori/brain_wide_pes/blob/main/pes_demixing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install ONE-api
!pip install iblatlas
!pip install ibllib


In [None]:
import numpy as np
import matplotlib.pyplot as plt

import pickle
import shutil
import random

from sklearn.linear_model import LinearRegression
from scipy.stats import permutation_test
from statsmodels.stats.multitest import multipletests
from matplotlib.ticker import MaxNLocator

from one.api import ONE
one = ONE(base_url='https://openalyx.internationalbrainlab.org', password='international', silent=True)

eids = one.search(datasets='spikes.times.npy')
print(len(eids))

from iblatlas.atlas import AllenAtlas
from iblatlas.plots import plot_swanson_vector
brain_atlas   = AllenAtlas()
brain_regions = brain_atlas.regions

from brainbox.io.one import SpikeSortingLoader

import warnings
import one.alf.exceptions as alferr
warnings.filterwarnings("ignore", category=alferr.ALFWarning)

import uuid

import pandas  as pd
import seaborn as sns

contrasts = np.array([0.0625, 0.125, 0.25, 1.0])

rng = np.random.default_rng()


In [None]:
import matplotlib as mpl

mpl.rcParams.update({
    'axes.labelsize'  : 15,
    'xtick.labelsize' : 14,
    'ytick.labelsize' : 14,
    'legend.fontsize' : 14})


# Loading external data

In [None]:
with open('*** Update the folder path here to match your directory structure ***', 'rb') as f:
    trials_all = pickle.load(f)

with open('*** Update the folder path here to match your directory structure ***', 'rb') as f:
    fr_all = pickle.load(f)


In [None]:
# IMPORTED PARAMETERS

min_n_trials_neuro = 10

t_lims     = (-350.0, 350.0) # ms
sampl_rate = 10.0 # ms
time_axis  = np.arange(t_lims[0] + sampl_rate / 2, t_lims[1] - sampl_rate / 2 + 1, sampl_rate)

print(time_axis)


In [None]:
# NEW PARAMETERS

min_n_trials    = 8
time_end_perc   = 5
pes_rt_perc     = 55
min_fr_std      = 0.1
min_pens_trials = 5
time_eps        = 5.0 # ms
time_perc_plot  = 80
perc_coeff      = 80
signif_lvl      = 0.05
n_char_eid      = 4


# Categorizing trials within sessions

In [None]:
matched_trials = {}

for curr_eid_contr in sorted(list(fr_all.keys())):

    act_EID, act_contr = curr_eid_contr

    act_trials = sorted(list(trials_all[act_EID][act_contr].keys()))

    curr_RTs_pcorr = []

    for trial in act_trials:

        if trials_all[act_EID][act_contr][trial][3] == -1:
            curr_RTs_pcorr.append(trials_all[act_EID][act_contr][trial][0])

    curr_pes_thres = np.percentile(curr_RTs_pcorr, pes_rt_perc)

    curr_pes_trials_left   = []
    curr_pens_trials_left  = []
    curr_pcorr_trials_left = []

    curr_pes_trials_right   = []
    curr_pens_trials_right  = []
    curr_pcorr_trials_right = []

    for trial in act_trials:

        if (trials_all[act_EID][act_contr][trial][3] == 1) and (trials_all[act_EID][act_contr][trial][0] >= curr_pes_thres):

            if trials_all[act_EID][act_contr][trial][1] == -1:
                curr_pes_trials_left.append(trial)

            else:
                assert(trials_all[act_EID][act_contr][trial][1] == 1)
                curr_pes_trials_right.append(trial)

        elif (trials_all[act_EID][act_contr][trial][3] == 1) and (trials_all[act_EID][act_contr][trial][0] < curr_pes_thres):

            if trials_all[act_EID][act_contr][trial][1] == -1:
                curr_pens_trials_left.append(trial)

            else:
                assert(trials_all[act_EID][act_contr][trial][1] == 1)
                curr_pens_trials_right.append(trial)


        elif trials_all[act_EID][act_contr][trial][3] == -1:

            if trials_all[act_EID][act_contr][trial][1] == -1:
                curr_pcorr_trials_left.append(trial)

            else:
                assert(trials_all[act_EID][act_contr][trial][1] == 1)
                curr_pcorr_trials_right.append(trial)

    assert(len(set(curr_pes_trials_left))    == len(curr_pes_trials_left))
    assert(len(set(curr_pens_trials_left))   == len(curr_pens_trials_left))
    assert(len(set(curr_pcorr_trials_left))  == len(curr_pcorr_trials_left))

    assert(len(set(curr_pes_trials_right))   == len(curr_pes_trials_right))
    assert(len(set(curr_pens_trials_right))  == len(curr_pens_trials_right))
    assert(len(set(curr_pcorr_trials_right)) == len(curr_pcorr_trials_right))

    assert(len(set(curr_pes_trials_left)   &
               set(curr_pens_trials_left)  &
               set(curr_pcorr_trials_left) &
               set(curr_pes_trials_right)  &
               set(curr_pens_trials_right) &
               set(curr_pcorr_trials_right)) == 0)

    if ((len(curr_pes_trials_left) >= min_n_trials) and (len(curr_pens_trials_left) >= min_pens_trials)):

        matched_trials[(curr_eid_contr[0], curr_eid_contr[1], 'left')] = {'error' : {'pes'  : curr_pes_trials_left,
                                                                                     'npes' : curr_pens_trials_left}}

    if ((len(curr_pes_trials_right) >= min_n_trials) and (len(curr_pens_trials_right) >= min_pens_trials)):

        matched_trials[(curr_eid_contr[0], curr_eid_contr[1], 'right')] = {'error' : {'pes'  : curr_pes_trials_right,
                                                                                      'npes' : curr_pens_trials_right}}

    if (len(curr_pes_trials_left) >= min_n_trials) and (len(curr_pcorr_trials_left) >= len(curr_pes_trials_left)):

        if (curr_eid_contr[0], curr_eid_contr[1], 'left') not in matched_trials.keys():
            matched_trials[(curr_eid_contr[0], curr_eid_contr[1], 'left')] = {'choice' : {'pes' : curr_pes_trials_left}}
        else:
            matched_trials[(curr_eid_contr[0], curr_eid_contr[1], 'left')]['choice'] = {'pes' : curr_pes_trials_left}

        curr_pes_rts   = [trials_all[act_EID][act_contr][trial][0] for trial in curr_pes_trials_left]
        curr_pcorr_rts = [(trial, trials_all[act_EID][act_contr][trial][0]) for trial in curr_pcorr_trials_left]

        curr_matched_pcorr     = []
        curr_matched_pcorr_rts = []

        # Match post-correct trials to PES

        for curr_RT in curr_pes_rts:

            curr_match = curr_pcorr_rts[np.argmin(np.abs(curr_RT - np.array([pair[1] for pair in curr_pcorr_rts]))).astype(int)]

            curr_pcorr_rts.remove(curr_match)

            curr_matched_pcorr.append(curr_match[0])
            curr_matched_pcorr_rts.append(curr_match[1])

        matched_trials[(curr_eid_contr[0], curr_eid_contr[1], 'left')]['choice']['npes'] = sorted(curr_matched_pcorr)

        assert(len(set(matched_trials[(curr_eid_contr[0], curr_eid_contr[1], 'left')]['choice']['pes']) &
                   set(matched_trials[(curr_eid_contr[0], curr_eid_contr[1], 'left')]['choice']['npes'])) == 0)

        assert(len(matched_trials[(curr_eid_contr[0], curr_eid_contr[1], 'left')]['choice']['pes']) ==
               len(matched_trials[(curr_eid_contr[0], curr_eid_contr[1], 'left')]['choice']['npes']))

    if (len(curr_pes_trials_right) >= min_n_trials) and (len(curr_pcorr_trials_right) >= len(curr_pes_trials_right)):

        if (curr_eid_contr[0], curr_eid_contr[1], 'right') not in matched_trials.keys():
            matched_trials[(curr_eid_contr[0], curr_eid_contr[1], 'right')] = {'choice' : {'pes' : curr_pes_trials_right}}
        else:
            matched_trials[(curr_eid_contr[0], curr_eid_contr[1], 'right')]['choice'] = {'pes' : curr_pes_trials_right}

        curr_pes_rts   = [trials_all[act_EID][act_contr][trial][0] for trial in curr_pes_trials_right]
        curr_pcorr_rts = [(trial, trials_all[act_EID][act_contr][trial][0]) for trial in curr_pcorr_trials_right]

        curr_matched_pcorr     = []
        curr_matched_pcorr_rts = []

        # Match post-correct trials to PES

        for curr_RT in curr_pes_rts:

            curr_match = curr_pcorr_rts[np.argmin(np.abs(curr_RT - np.array([pair[1] for pair in curr_pcorr_rts]))).astype(int)]

            curr_pcorr_rts.remove(curr_match)

            curr_matched_pcorr.append(curr_match[0])
            curr_matched_pcorr_rts.append(curr_match[1])

        matched_trials[(curr_eid_contr[0], curr_eid_contr[1], 'right')]['choice']['npes'] = sorted(curr_matched_pcorr)

        assert(len(set(matched_trials[(curr_eid_contr[0], curr_eid_contr[1], 'right')]['choice']['pes']) &
                   set(matched_trials[(curr_eid_contr[0], curr_eid_contr[1], 'right')]['choice']['npes'])) == 0)

        assert(len(matched_trials[(curr_eid_contr[0], curr_eid_contr[1], 'right')]['choice']['pes']) ==
               len(matched_trials[(curr_eid_contr[0], curr_eid_contr[1], 'right')]['choice']['npes']))

    if (curr_eid_contr[0], curr_eid_contr[1], 'left') in matched_trials.keys():

        if (('error'  in matched_trials[(curr_eid_contr[0], curr_eid_contr[1], 'left')].keys()) and
            ('choice' in matched_trials[(curr_eid_contr[0], curr_eid_contr[1], 'left')].keys())):

            assert(matched_trials[(curr_eid_contr[0], curr_eid_contr[1], 'left')]['error']['pes'] ==
                   matched_trials[(curr_eid_contr[0], curr_eid_contr[1], 'left')]['choice']['pes'])

            assert(len(set(matched_trials[(curr_eid_contr[0], curr_eid_contr[1], 'left')]['error']['npes']) &
                       set(matched_trials[(curr_eid_contr[0], curr_eid_contr[1], 'left')]['choice']['npes'])) == 0)

    if (curr_eid_contr[0], curr_eid_contr[1], 'right') in matched_trials.keys():

        if (('error'  in matched_trials[(curr_eid_contr[0], curr_eid_contr[1], 'right')].keys()) and
            ('choice' in matched_trials[(curr_eid_contr[0], curr_eid_contr[1], 'right')].keys())):

            assert(matched_trials[(curr_eid_contr[0], curr_eid_contr[1], 'right')]['error']['pes'] ==
                   matched_trials[(curr_eid_contr[0], curr_eid_contr[1], 'right')]['choice']['pes'])

            assert(len(set(matched_trials[(curr_eid_contr[0], curr_eid_contr[1], 'right')]['error']['npes']) &
                       set(matched_trials[(curr_eid_contr[0], curr_eid_contr[1], 'right')]['choice']['npes'])) == 0)

print(f"Total number of ERROR tests: {len([sess_contr_side  for sess_contr_side in list(matched_trials.keys()) if 'error'  in list(matched_trials[sess_contr_side].keys())])}")
print(f"Total number of CHOICE tests: {len([sess_contr_side for sess_contr_side in list(matched_trials.keys()) if 'choice' in list(matched_trials[sess_contr_side].keys())])}")

with open('*** Update the folder path here to match your directory structure ***', 'wb') as f:
    pickle.dump(matched_trials, f, pickle.HIGHEST_PROTOCOL)


In [None]:
# Verification

rand_eid_contr_side = random.choice(list(matched_trials.keys()))

if 'error' in matched_trials[rand_eid_contr_side].keys():

    rand_pes_rts  = []
    rand_npes_rts = []

    for trial in matched_trials[rand_eid_contr_side]['error']['pes']:

        if rand_eid_contr_side[2] == 'left':
            assert(trials_all[rand_eid_contr_side[0]][rand_eid_contr_side[1]][trial][1] == -1)
        else:
            assert(rand_eid_contr_side[2] == 'right')
            assert(trials_all[rand_eid_contr_side[0]][rand_eid_contr_side[1]][trial][1] == 1)

        assert(trials_all[rand_eid_contr_side[0]][rand_eid_contr_side[1]][trial][3] == 1)

        rand_pes_rts.append(int(np.round(trials_all[rand_eid_contr_side[0]][rand_eid_contr_side[1]][trial][0])))

    for trial in matched_trials[rand_eid_contr_side]['error']['npes']:

        if rand_eid_contr_side[2] == 'left':
            assert(trials_all[rand_eid_contr_side[0]][rand_eid_contr_side[1]][trial][1] == -1)
        else:
            assert(rand_eid_contr_side[2] == 'right')
            assert(trials_all[rand_eid_contr_side[0]][rand_eid_contr_side[1]][trial][1] == 1)

        assert(trials_all[rand_eid_contr_side[0]][rand_eid_contr_side[1]][trial][3] == 1)

        rand_npes_rts.append(int(np.round(trials_all[rand_eid_contr_side[0]][rand_eid_contr_side[1]][trial][0])))

    assert(np.mean(rand_pes_rts) > np.mean(rand_npes_rts))

    print('Error trial sets are correct.')


if 'choice' in matched_trials[rand_eid_contr_side].keys():

    rand_pes_rts  = []
    rand_npes_rts = []

    for trial in matched_trials[rand_eid_contr_side]['choice']['pes']:

        if rand_eid_contr_side[2] == 'left':
            assert(trials_all[rand_eid_contr_side[0]][rand_eid_contr_side[1]][trial][1] == -1)
        else:
            assert(rand_eid_contr_side[2] == 'right')
            assert(trials_all[rand_eid_contr_side[0]][rand_eid_contr_side[1]][trial][1] == 1)

        assert(trials_all[rand_eid_contr_side[0]][rand_eid_contr_side[1]][trial][3] == 1)

        rand_pes_rts.append(int(np.round(trials_all[rand_eid_contr_side[0]][rand_eid_contr_side[1]][trial][0])))

    for trial in matched_trials[rand_eid_contr_side]['choice']['npes']:

        if rand_eid_contr_side[2] == 'left':
            assert(trials_all[rand_eid_contr_side[0]][rand_eid_contr_side[1]][trial][1] == -1)
        else:
            assert(rand_eid_contr_side[2] == 'right')
            assert(trials_all[rand_eid_contr_side[0]][rand_eid_contr_side[1]][trial][1] == 1)

        assert(trials_all[rand_eid_contr_side[0]][rand_eid_contr_side[1]][trial][3] == -1)

        rand_npes_rts.append(int(np.round(trials_all[rand_eid_contr_side[0]][rand_eid_contr_side[1]][trial][0])))

    print(sorted(rand_pes_rts),  'ms')
    print(sorted(rand_npes_rts), 'ms')


In [None]:
all_triplets = list(set(matched_trials.keys()))

print(f'Total number of triplets: {len(all_triplets)}')


In [None]:
overview_triplets = {}

for sess_contr_side in sorted(list(matched_trials.keys())):

    for test_case in sorted(list(matched_trials[sess_contr_side].keys())):

        if sess_contr_side not in overview_triplets.keys():
            overview_triplets[sess_contr_side] = {}

        overview_triplets[sess_contr_side][test_case] = {}

        overview_triplets[sess_contr_side][test_case]['pes']  = matched_trials[sess_contr_side][test_case]['pes']
        overview_triplets[sess_contr_side][test_case]['npes'] = matched_trials[sess_contr_side][test_case]['npes']

assert(len(list(overview_triplets.keys())) == len(list(set(overview_triplets.keys()))))

table_triplets = {sess_contr_side : {} for sess_contr_side in sorted(list(overview_triplets.keys()))}

for sess_contr_side in sorted(list(overview_triplets.keys())):

    if ('error' in overview_triplets[sess_contr_side].keys()) and ('choice' in overview_triplets[sess_contr_side].keys()):

        assert(overview_triplets[sess_contr_side]['error']['pes'] == overview_triplets[sess_contr_side]['choice']['pes'])

        table_triplets[sess_contr_side]['pes']   = len(overview_triplets[sess_contr_side]['error']['pes'])
        table_triplets[sess_contr_side]['nslow'] = len(overview_triplets[sess_contr_side]['error']['npes'])
        table_triplets[sess_contr_side]['corr']  = len(overview_triplets[sess_contr_side]['choice']['npes'])

    elif ('choice' in overview_triplets[sess_contr_side].keys()):
        table_triplets[sess_contr_side]['pes']   = len(overview_triplets[sess_contr_side]['choice']['pes'])
        table_triplets[sess_contr_side]['nslow'] = 'insufficient'
        table_triplets[sess_contr_side]['corr']  = len(overview_triplets[sess_contr_side]['choice']['npes'])

    else:
        print('There is triplet with only error test!')

assert(len(list(overview_triplets.keys())) == len(list(table_triplets.keys())))

triplets_sorted = sorted(table_triplets, key=lambda k : table_triplets[k]['pes'], reverse=True)


In [None]:
df_table_trials = pd.DataFrame({
    'Triplet'      : [str(str(sess_contr_side[0])[ : n_char_eid] + ', ' + str(sess_contr_side[1]) + ', ' + sess_contr_side[2])  for sess_contr_side in triplets_sorted],
    'Slow post-error'     : [table_triplets[sess_contr_side]['pes']   for sess_contr_side in triplets_sorted],
    'Non-slow post-error' : [table_triplets[sess_contr_side]['nslow'] for sess_contr_side in triplets_sorted],
    'Slow post-correct'   : [table_triplets[sess_contr_side]['corr']  for sess_contr_side in triplets_sorted]})


plt.table(cellText=df_table_trials.values, colLabels=df_table_trials.columns, cellLoc='center', loc='center')

plt.axis('off')

plt.savefig('*** Update the folder path here to match your directory structure ***', format='png', bbox_inches='tight', dpi=300)
plt.show()


# Calculating coefficients for all neurons

In [None]:
coeff_axes = {}
fr_z       = {}

eids_contrs = []

for curr_eid_contr_side in sorted(list(matched_trials.keys())):
    if (curr_eid_contr_side[0], curr_eid_contr_side[1]) not in eids_contrs:
        eids_contrs.append((curr_eid_contr_side[0], curr_eid_contr_side[1]))

for ind_eid_contr, curr_eid_contr in enumerate(eids_contrs):

    print('Done with', np.around(ind_eid_contr / len(eids_contrs) * 100, 1), '% of the data.')

    EID, contr  = curr_eid_contr
    curr_trials = sorted(list(trials_all[EID][contr].keys()))

    n_left_pc  = 0
    n_right_pc = 0
    n_left_pe  = 0
    n_right_pe = 0

    ind2trials     = {}
    curr_trial_ind = 0

    for ind_trial, trial in enumerate(curr_trials):

        ind2trials[curr_trial_ind] = trial
        curr_trial_ind += 1

        if trials_all[EID][contr][trial][1]   == -1 and trials_all[EID][contr][trial][3] == -1:
            n_left_pc  += 1
        elif trials_all[EID][contr][trial][1] ==  1 and trials_all[EID][contr][trial][3] == -1:
            n_right_pc += 1
        elif trials_all[EID][contr][trial][1] == -1 and trials_all[EID][contr][trial][3] ==  1:
            n_left_pe  += 1
        elif trials_all[EID][contr][trial][1] ==  1 and trials_all[EID][contr][trial][3] ==  1:
            n_right_pe += 1

    assert(n_left_pc >= min_n_trials_neuro and n_right_pc >= min_n_trials_neuro and n_left_pe >= min_n_trials_neuro and n_right_pe >= min_n_trials_neuro)

    curr_data = []

    for trial in curr_trials:
        curr_data.append([trials_all[EID][contr][trial][1], trials_all[EID][contr][trial][3]])

    curr_X = np.array(curr_data)

    curr_neurons = sorted(list(fr_all[curr_eid_contr].keys()))

    for neuron in curr_neurons:

        assert(sorted(list(fr_all[curr_eid_contr][neuron].keys())) == curr_trials)

        curr_Y = []

        for trial in curr_trials:
            curr_Y.append(fr_all[curr_eid_contr][neuron][trial])

        if np.std(curr_Y) > min_fr_std:

            if ((curr_eid_contr not in sorted(list(coeff_axes.keys()))) and
                (curr_eid_contr not in sorted(list(fr_z.keys())))):
                coeff_axes[curr_eid_contr] = {}
                fr_z[curr_eid_contr]       = {}

            if neuron not in sorted(list(fr_z[curr_eid_contr].keys())):
                fr_z[curr_eid_contr][neuron] = {}

            curr_Y = (np.array(curr_Y) - np.mean(curr_Y)) / np.std(curr_Y)

            curr_fr_z       = []
            curr_coeff_axes = []

            for ind_time, time in enumerate(time_axis):

                assert(len(curr_X) == len(curr_Y[ : , ind_time]))

                curr_model = LinearRegression(fit_intercept=True)
                curr_res   = curr_model.fit(curr_X, curr_Y[ : , ind_time])

                curr_fr_z.append(curr_Y[ : , ind_time])

                curr_coeff_axes.append(np.concatenate((curr_res.coef_, np.array([curr_res.intercept_]))))

            curr_fr_z       = np.array(curr_fr_z)
            curr_coeff_axes = np.array(curr_coeff_axes)

            assert(curr_fr_z.shape       == (len(time_axis), len(curr_trials)))
            assert(curr_coeff_axes.shape == (len(time_axis), 3))

            for ind_trial in range(len(ind2trials.keys())):

                assert(len(curr_fr_z[ : , ind_trial]) == len(time_axis))

                fr_z[curr_eid_contr][neuron][ind2trials[ind_trial]] = curr_fr_z[ : , ind_trial]

            for ind_feature, feature in enumerate(['choice', 'error', 'unmod']):

                assert(len(curr_coeff_axes[ : , ind_feature]) == len(time_axis))

                if feature not in coeff_axes[curr_eid_contr].keys():
                    coeff_axes[curr_eid_contr][feature] = {neuron : curr_coeff_axes[ : , ind_feature]}
                else:
                    coeff_axes[curr_eid_contr][feature][neuron] = curr_coeff_axes[ : , ind_feature]

    for curr_repr in ['choice', 'error', 'unmod']:

        assert(sorted(list(coeff_axes[curr_eid_contr][curr_repr].keys())) == sorted(list(fr_z[curr_eid_contr].keys())))

assert(sorted(list(coeff_axes.keys())) == sorted(list(fr_z.keys())))

with open('*** Update the folder path here to match your directory structure ***', 'wb') as f:
    pickle.dump(fr_z, f, pickle.HIGHEST_PROTOCOL)


# Brain regions of each neuron

In [None]:
neuron_region = {}
extracted_eid = []

for ind_curr_eid_contr_side, curr_eid_contr_side in enumerate(sorted(list(coeff_axes.keys()))):

    print('Done with', np.around(ind_curr_eid_contr_side / len(coeff_axes.keys()) * 100, 1), '% of the data.')

    if curr_eid_contr_side[0] not in extracted_eid:

        curr_pids_neurons = {}

        assert(sorted(list(coeff_axes[(curr_eid_contr_side[0], curr_eid_contr_side[1])]['error'].keys())) ==
            sorted(list(coeff_axes[(curr_eid_contr_side[0], curr_eid_contr_side[1])]['choice'].keys())))

        for neuron in sorted(list(coeff_axes[(curr_eid_contr_side[0], curr_eid_contr_side[1])]['error'].keys())):

            if neuron[0] not in curr_pids_neurons.keys():
                curr_pids_neurons[neuron[0]] = [neuron[1]]
            else:
                curr_pids_neurons[neuron[0]].append(neuron[1])

        for PID in sorted(list(curr_pids_neurons.keys())):

            sl = SpikeSortingLoader(pid=PID, one=one, atlas=brain_atlas)
            spikes, clusters, channels = sl.load_spike_sorting()
            clusters = sl.merge_clusters(spikes, clusters, channels)

            for clust_num in curr_pids_neurons[PID]:

                curr_region = brain_regions.id2acronym(clusters['atlas_id'][np.where(clusters['cluster_id'] == clust_num)[0][0]], mapping='Swanson')[0]

                if ((curr_region != 'root')  and
                    (curr_region != 'P')     and
                    (curr_region != 'HPF')   and
                    (curr_region != 'MY')    and
                    (curr_region != 'void')  and
                    (curr_region != 'PO')    and
                    (curr_region != 'HY')    and
                    (curr_region != 'POST')  and
                    (curr_region != 'CTXsp') and
                    (curr_region != 'LA')):

                    if ((curr_region == 'VPL')   or
                        (curr_region == 'VPLpc') or
                        (curr_region == 'VPM')   or
                        (curr_region == 'VPMpc')):
                        curr_region = 'VP'
                    elif ((curr_region == 'SCdg') or
                        (curr_region == 'SCdw') or
                        (curr_region == 'SCiw') or
                        (curr_region == 'SCig')):
                        curr_region = 'SCm'
                    elif ((curr_region == 'CUL4') or
                        (curr_region == 'CUL5') or
                        (curr_region == 'CUL4 5')):
                        curr_region = 'CUL'
                    elif ((curr_region == 'LAV')  or
                        (curr_region == 'MV')   or
                        (curr_region == 'SPIV') or
                        (curr_region == 'SUV')):
                        curr_region = 'VNC'
                    elif ((curr_region == 'RSPagl')  or
                        (curr_region == 'RSPd')   or
                        (curr_region == 'RSPv')):
                        curr_region = 'RSP'

                    neuron_region[(PID, clust_num)] = curr_region

        extracted_eid.append(curr_eid_contr_side[0])

        shutil.rmtree('/root/Downloads/ONE/openalyx.internationalbrainlab.org/' + one.get_details(curr_eid_contr[0])['lab'], ignore_errors=True)

with open('*** Update the folder path here to match your directory structure ***', 'wb') as f:
    pickle.dump(neuron_region, f, pickle.HIGHEST_PROTOCOL)


# Selecting the weights for each neuron

In [None]:
t_max_coeffs = {}
w_max_coeffs = {}

for curr_eid_contr in sorted(list(coeff_axes.keys())):

    curr_rts = []

    for trial in sorted(list(trials_all[curr_eid_contr[0]][curr_eid_contr[1]].keys())):
        curr_rts.append(trials_all[curr_eid_contr[0]][curr_eid_contr[1]][trial][0])

    curr_time_inds = (int(np.argmin(np.abs(time_axis - (0.0 + time_eps))).astype(int)), int(np.argmin(np.abs(time_axis - np.percentile(curr_rts, time_end_perc))).astype(int)) + 1)

    for curr_repr in ['error', 'choice']:

        curr_weights = []

        for neuron in sorted(list(coeff_axes[curr_eid_contr][curr_repr].keys())):

            curr_weights.append(coeff_axes[curr_eid_contr][curr_repr][neuron][curr_time_inds[0] : curr_time_inds[1]])

        curr_weights = np.array(curr_weights)

        assert(curr_weights.shape == (len(list(coeff_axes[curr_eid_contr][curr_repr].keys())), curr_time_inds[1] - curr_time_inds[0]))

        global_ind = curr_time_inds[0] + int(np.argmax(np.linalg.norm(curr_weights, axis=0)).astype(int))

        assert(time_axis[global_ind] > 0.0)
        assert(time_axis[global_ind] < time_axis[-1])

        if ((curr_eid_contr not in t_max_coeffs.keys()) and
            (curr_eid_contr not in w_max_coeffs.keys())):

            t_max_coeffs[curr_eid_contr] = {curr_repr : global_ind}
            w_max_coeffs[curr_eid_contr] = {curr_repr : {}}

        else:
            t_max_coeffs[curr_eid_contr][curr_repr] = global_ind
            w_max_coeffs[curr_eid_contr][curr_repr] = {}

        for neuron in sorted(list(coeff_axes[curr_eid_contr][curr_repr].keys())):

            assert(len(coeff_axes[curr_eid_contr][curr_repr][neuron]) == len(time_axis))

            w_max_coeffs[curr_eid_contr][curr_repr][neuron] = coeff_axes[curr_eid_contr][curr_repr][neuron][global_ind]

assert(sorted(list(t_max_coeffs.keys())) == sorted(list(w_max_coeffs.keys())))



# Computing the projections

In [None]:
proj_all = {}

for curr_eid_contr_side in sorted(list(matched_trials.keys())):

    for curr_repr in sorted(list(matched_trials[curr_eid_contr_side].keys())):

        curr_coeffs = []

        for neuron in sorted(list(w_max_coeffs[(curr_eid_contr_side[0], curr_eid_contr_side[1])][curr_repr].keys())):
            curr_coeffs.append(w_max_coeffs[(curr_eid_contr_side[0], curr_eid_contr_side[1])][curr_repr][neuron])

        assert(len(curr_coeffs) == len(list(w_max_coeffs[(curr_eid_contr_side[0], curr_eid_contr_side[1])][curr_repr].keys())))

        curr_max_w_len = np.linalg.norm(curr_coeffs)

        curr_trials = matched_trials[curr_eid_contr_side][curr_repr]['pes'] + matched_trials[curr_eid_contr_side][curr_repr]['npes']

        for trial in sorted(curr_trials):

            curr_dot_prod = []

            for neuron in sorted(list(fr_z[(curr_eid_contr_side[0], curr_eid_contr_side[1])].keys())):

                curr_dot_prod.append(w_max_coeffs[(curr_eid_contr_side[0], curr_eid_contr_side[1])][curr_repr][neuron] * fr_z[(curr_eid_contr_side[0], curr_eid_contr_side[1])][neuron][trial])

                assert(len(curr_dot_prod[-1]) == len(time_axis))

            curr_dot_prod = np.array(curr_dot_prod)

            assert(curr_dot_prod.shape == (len(curr_coeffs), len(time_axis)))

            curr_dot_prod_res = np.sum(curr_dot_prod, axis=0) / curr_max_w_len

            assert(len(curr_dot_prod_res) == len(time_axis))

            if curr_repr not in proj_all.keys():
                proj_all[curr_repr] = {}

            if curr_eid_contr_side not in proj_all[curr_repr].keys():
                proj_all[curr_repr][curr_eid_contr_side] = {}

            proj_all[curr_repr][curr_eid_contr_side][trial] = curr_dot_prod_res


# Looking for the significant difference in the projections

In [None]:
def perm_stat(x, y, axis):
    return np.mean(x, axis=axis) - np.mean(y, axis=axis)


In [None]:
ind2sess   = {'choice' : {}, 'error' : {}}
proj_diff  = {'choice' : {}, 'error' : {}}
pvals_all  = {'choice' : [], 'error' : []}
pvals_corr = {}

for curr_repr in proj_all.keys():

    for curr_ind, curr_eid_contr_side in enumerate(sorted(list(proj_all[curr_repr].keys()))):

        ind2sess[curr_repr][curr_ind] = curr_eid_contr_side

        curr_ind_t_max = t_max_coeffs[(curr_eid_contr_side[0], curr_eid_contr_side[1])][curr_repr]

        curr_tmax_pes  = []
        curr_tmax_npes = []

        for trial in sorted(list(matched_trials[curr_eid_contr_side][curr_repr]['pes'])):
            curr_tmax_pes.append(proj_all[curr_repr][curr_eid_contr_side][trial][curr_ind_t_max])

        for trial in sorted(list(matched_trials[curr_eid_contr_side][curr_repr]['npes'])):
            curr_tmax_npes.append(proj_all[curr_repr][curr_eid_contr_side][trial][curr_ind_t_max])

        pvals_all[curr_repr].append(permutation_test((curr_tmax_pes, curr_tmax_npes), perm_stat, vectorized=True, alternative='two-sided', n_resamples=1e6).pvalue)

        proj_diff[curr_repr][curr_eid_contr_side] = np.mean(curr_tmax_pes) - np.mean(curr_tmax_npes)

    _, curr_pval_corrected, _, _ = multipletests(pvals_all[curr_repr], alpha=0.05, method='fdr_bh')

    pvals_corr[curr_repr] = curr_pval_corrected


In [None]:
diff_all        = {'choice' : [], 'error' : []}
pvals_corr_sess = {'choice' : {}, 'error' : {}}

for curr_repr in sorted(list(proj_diff.keys())):
    for curr_eid_contr_side in sorted(list(proj_diff[curr_repr].keys())):
        diff_all[curr_repr].append(proj_diff[curr_repr][curr_eid_contr_side])

for curr_repr in sorted(list(ind2sess.keys())):
    for curr_ind in sorted(list(ind2sess[curr_repr].keys())):
        pvals_corr_sess[curr_repr][ind2sess[curr_repr][curr_ind]] = pvals_corr[curr_repr][curr_ind]

fig, axs = plt.subplots(2, 2, figsize=(10, 8))

sns.histplot(diff_all['error'], bins='auto', ax=axs[0, 0], color='#78B8E6')
axs[0, 0].yaxis.set_major_locator(MaxNLocator(integer=True))
axs[0, 0].set(xlabel='Difference in projection for "error"', ylabel='Number of triplets')
axs[0, 0].set_xlim(-5.0, 5.0)

sns.histplot(pvals_all['error'], binwidth=0.05, binrange=(0.0, 1.0), ax=axs[0, 1], color='#78B8E6')
axs[0, 1].yaxis.set_major_locator(MaxNLocator(integer=True))
axs[0, 1].set(xlabel='p-values for "error"', ylabel='Number of triplets')

sns.histplot(diff_all['choice'], bins='auto', ax=axs[1, 0], color='#78B8E6')
axs[1, 0].yaxis.set_major_locator(MaxNLocator(integer=True))
axs[1, 0].set(xlabel='Difference in projection for "choice"', ylabel='Number of triplets')
axs[1, 0].set_xlim(-5.0, 5.0)

sns.histplot(pvals_all['choice'], binwidth=0.05, binrange=(0.0, 1.0), ax=axs[1, 1], color='#78B8E6')
axs[1, 1].yaxis.set_major_locator(MaxNLocator(integer=True))
axs[1, 1].set(xlabel='p-values for "choice"', ylabel='Number of triplets')

fig.subplots_adjust(hspace=0.35)

plt.savefig('*** Update the folder path here to match your directory structure ***', format='png', bbox_inches='tight', dpi=300)
plt.show()


In [None]:
sorted_error  = sorted(pvals_corr_sess['error'].keys(),  key=lambda k : pvals_corr_sess['error'][k])
sorted_choice = sorted(pvals_corr_sess['choice'].keys(), key=lambda k : pvals_corr_sess['choice'][k])

curr_eids_error  = list(set([str(sess_contr_side[0]) for sess_contr_side in sorted_error]))
curr_eids_choice = list(set([str(sess_contr_side[0]) for sess_contr_side in sorted_choice]))

curr_abbr_error  = [EID[ : n_char_eid] for EID in curr_eids_error]
curr_abbr_choice = [EID[ : n_char_eid] for EID in curr_eids_choice]

assert(len(curr_abbr_error)  == len(set(curr_abbr_error)))
assert(len(curr_abbr_choice) == len(set(curr_abbr_choice)))


In [None]:
df_table_error = pd.DataFrame({
    'Session'      : [str(sess_contr_side[0])[ : n_char_eid] for sess_contr_side in sorted_error],
    'Contrast'     : [sess_contr_side[1] for sess_contr_side in sorted_error],
    'Side'         : [sess_contr_side[2] for sess_contr_side in sorted_error],
    'Corr. p-val.' : [np.round(pvals_corr_sess['error'][sess_contr_side], 4) for sess_contr_side in sorted_error],
    'Proj. diff.'  : [np.round(proj_diff['error'][sess_contr_side],  4)      for sess_contr_side in sorted_error]})

plt.table(cellText=df_table_error.values, colLabels=df_table_error.columns, cellLoc='center', loc='center')
plt.axis('off')

plt.savefig('*** Update the folder path here to match your directory structure ***', format='png', bbox_inches='tight', dpi=300)
plt.show()


In [None]:
df_table_choice = pd.DataFrame({
    'Session'      : [str(sess_contr_side[0])[ : n_char_eid] for sess_contr_side in sorted_choice],
    'Contrast'     : [sess_contr_side[1] for sess_contr_side in sorted_choice],
    'Side'         : [sess_contr_side[2] for sess_contr_side in sorted_choice],
    'Corr. p-val.' : [np.round(pvals_corr_sess['choice'][sess_contr_side], 4) for sess_contr_side in sorted_choice],
    'Proj. diff.'  : [np.round(proj_diff['choice'][sess_contr_side],  4)      for sess_contr_side in sorted_choice]})

plt.table(cellText=df_table_choice.values, colLabels=df_table_choice.columns, cellLoc='center', loc='center')
plt.axis('off')

plt.savefig('*** Update the folder path here to match your directory structure ***', format='png', bbox_inches='tight', dpi=300)
plt.show()


# Method figure

In [None]:
relevant_error  = []
relevant_choice = []

for curr_eid_contr_side in sorted(list(pvals_corr_sess['error'].keys())):
    if pvals_corr_sess['error'][curr_eid_contr_side] < signif_lvl:
        relevant_error.append(curr_eid_contr_side)

for curr_eid_contr_side in sorted(list(pvals_corr_sess['choice'].keys())):
    if pvals_corr_sess['choice'][curr_eid_contr_side] < signif_lvl:
        relevant_choice.append(curr_eid_contr_side)

assert(len(relevant_error)  == 1)
assert(len(relevant_choice) == 1)

exampl_error  = relevant_error[0]
exampl_choice = relevant_choice[0]

print('Only triplet found for error:',  exampl_error)
print('Only triplet found for choice:', exampl_choice)


In [None]:
info_ex_error = one.get_details('07dc4b76-5b93-4a03-82a0-b3d9cc73f412')

print('Error example:')
print('EID:', '07dc4b76-5b93-4a03-82a0-b3d9cc73f412')
print('Contrast:', 1.0)
print('Stimulus side:', 'right')
print('Lab:', info_ex_error['lab'])
print('Animal:', info_ex_error['subject'])
print('Starting time:', info_ex_error['start_time'])

info_ex_choice = one.get_details('1b9e349e-93f2-41cc-a4b5-b212d7ddc8df')

print()
print('Choice example:')
print('EID:', '1b9e349e-93f2-41cc-a4b5-b212d7ddc8df')
print('Contrast:', 1.0)
print('Stimulus side:', 'left')
print('Lab:', info_ex_choice['lab'])
print('Animal:', info_ex_choice['subject'])
print('Starting time:', info_ex_choice['start_time'])


In [None]:
curr_rts_plot_error  = [trials_all[exampl_error[0]][exampl_error[1]][trial][0] for trial in sorted(list(trials_all[exampl_error[0]][exampl_error[1]].keys()))]
time_inds_plot_error = (int(np.argmin(np.abs(time_axis - (0.0 + time_eps))).astype(int)), int(np.argmin(np.abs(time_axis - np.percentile(curr_rts_plot_error, time_perc_plot))).astype(int)) + 1)


In [None]:
curr_rts_plot_choice  = [trials_all[exampl_choice[0]][exampl_choice[1]][trial][0] for trial in sorted(list(trials_all[exampl_choice[0]][exampl_choice[1]].keys()))]
time_inds_plot_choice = (int(np.argmin(np.abs(time_axis - (0.0 + time_eps))).astype(int)), int(np.argmin(np.abs(time_axis - np.percentile(curr_rts_plot_choice, time_perc_plot))).astype(int)) + 1)


In [None]:
_, axs = plt.subplots(2, 2, figsize=(12, 12))



curr_proj_pes  = []
curr_proj_npes = []

for trial in matched_trials[exampl_error]['error']['pes']:
    curr_proj_pes.append(proj_all['error'][exampl_error][trial][time_inds_plot_error[0] : time_inds_plot_error[1]])

for trial in matched_trials[exampl_error]['error']['npes']:
    curr_proj_npes.append(proj_all['error'][exampl_error][trial][time_inds_plot_error[0] : time_inds_plot_error[1]])

axs[0, 0].plot(time_axis[time_inds_plot_error[0] : time_inds_plot_error[1]], np.mean(curr_proj_npes, axis=0), color='#FFB84D', label='Post-error non-slow')
axs[0, 0].plot(time_axis[time_inds_plot_error[0] : time_inds_plot_error[1]], np.mean(curr_proj_pes,  axis=0), color='#FF6F61', label='Post-error slow')

axs[0, 0].arrow(x=time_axis[t_max_coeffs[(exampl_error[0], exampl_error[1])]['error']],
                y=1.5 * axs[0, 0].get_ylim()[1],
                dx=0,
                dy=-0.8,
                head_width=8.0,
                head_length=0.5,
                fc='k',
                ec='k',
                linewidth=1.5)

axs[0, 0].set_xlabel('Time from Go cue (ms)')
axs[0, 0].set_ylabel('Average projection onto the "error" axis')
axs[0, 0].legend()



curr_proj_pes_tmax  = []
curr_proj_npes_tmax = []

for trial in matched_trials[exampl_error]['error']['pes']:
    curr_proj_pes_tmax.append(proj_all['error'][exampl_error][trial][t_max_coeffs[(exampl_error[0], exampl_error[1])]['error']])

for trial in matched_trials[exampl_error]['error']['npes']:
    curr_proj_npes_tmax.append(proj_all['error'][exampl_error][trial][t_max_coeffs[(exampl_error[0], exampl_error[1])]['error']])

df_pes_error  = pd.DataFrame({'Projection' : curr_proj_pes_tmax,  'Trial type' : 'Post-error slow'})
df_npes_error = pd.DataFrame({'Projection' : curr_proj_npes_tmax, 'Trial type' : 'Post-error non-slow'})
df_sum_error  = pd.concat([df_pes_error, df_npes_error], ignore_index=True)

sns.histplot(df_sum_error, x='Projection', hue='Trial type', palette={'Post-error non-slow' : '#FFB84D', 'Post-error slow' : '#FF6F61'}, ax=axs[0, 1], kde=True, edgecolor=None, alpha=0.5, shrink=0.85, legend=False)

axs[0, 1].yaxis.set_major_locator(MaxNLocator(integer=True))
axs[0, 1].set_xlabel('Representation of the side at $t_{max, error}$')
axs[0, 1].set_ylabel('Number of trials')



curr_proj_pes  = []
curr_proj_npes = []

for trial in matched_trials[exampl_choice]['choice']['pes']:
    curr_proj_pes.append(proj_all['choice'][exampl_choice][trial][time_inds_plot_choice[0] : time_inds_plot_choice[1]])

for trial in matched_trials[exampl_choice]['choice']['npes']:
    curr_proj_npes.append(proj_all['choice'][exampl_choice][trial][time_inds_plot_choice[0] : time_inds_plot_choice[1]])

axs[1, 0].plot(time_axis[time_inds_plot_choice[0] : time_inds_plot_choice[1]], np.mean(curr_proj_npes, axis=0), color='#90BE6D', label='Post-correct slow')
axs[1, 0].plot(time_axis[time_inds_plot_choice[0] : time_inds_plot_choice[1]], np.mean(curr_proj_pes,  axis=0), color='#FF6F61', label='Post-error slow')

axs[1, 0].arrow(x=time_axis[t_max_coeffs[(exampl_choice[0], exampl_choice[1])]['choice']],
                y=1.8 * axs[1, 0].get_ylim()[1],
                dx=0,
                dy=-0.5,
                head_width=9.0,
                head_length=0.35,
                fc='k',
                ec='k',
                linewidth=1.5)

axs[1, 0].set_ylim(bottom=-6.0, top=None)
axs[1, 0].set_xlabel('Time from Go cue (ms)')
axs[1, 0].set_ylabel('Average projection onto the "choice" axis')
axs[1, 0].legend()



curr_proj_pes_tmax  = []
curr_proj_npes_tmax = []

for trial in matched_trials[exampl_choice]['choice']['pes']:
    curr_proj_pes_tmax.append(proj_all['choice'][exampl_choice][trial][t_max_coeffs[(exampl_choice[0], exampl_choice[1])]['choice']])

for trial in matched_trials[exampl_choice]['choice']['npes']:
    curr_proj_npes_tmax.append(proj_all['choice'][exampl_choice][trial][t_max_coeffs[(exampl_choice[0], exampl_choice[1])]['choice']])

df_pes_choice  = pd.DataFrame({'Projection' : curr_proj_pes_tmax,  'Trial type' : 'Post-error slow'})
df_npes_choice = pd.DataFrame({'Projection' : curr_proj_npes_tmax, 'Trial type' : 'Post-correct slow'})
df_sum_choice  = pd.concat([df_pes_choice, df_npes_choice], ignore_index=True)

sns.histplot(df_sum_choice, x='Projection', hue='Trial type', palette={'Post-correct slow' : '#90BE6D', 'Post-error slow' : '#FF6F61'}, ax=axs[1, 1], kde=True, edgecolor=None, alpha=0.5, shrink=0.85, legend=False)

axs[1, 1].yaxis.set_major_locator(MaxNLocator(integer=True))
axs[1, 1].set_xlabel('Representation of the side at $t_{max, choice}$')
axs[1, 1].set_ylabel('Number of trials')



plt.savefig('*** Update the folder path here to match your directory structure ***', format='png', dpi=300, bbox_inches='tight')
plt.show()


# Regions involved in post-error slowing

In [None]:
coeffs_all = {'error' : [], 'choice' : []}

for neuron in sorted(list(w_max_coeffs[(exampl_error[0], exampl_error[1])]['error'].keys())):
    coeffs_all['error'].append(np.abs(w_max_coeffs[(exampl_error[0], exampl_error[1])]['error'][neuron]))

for neuron in sorted(list(w_max_coeffs[(exampl_choice[0], exampl_choice[1])]['choice'].keys())):
    coeffs_all['choice'].append(np.abs(w_max_coeffs[(exampl_choice[0], exampl_choice[1])]['choice'][neuron]))

coeff_thres_error  = np.percentile(coeffs_all['error'],  perc_coeff)
coeff_thres_choice = np.percentile(coeffs_all['choice'], perc_coeff)

_, ax = plt.subplots(1, 2, figsize=(10, 4))


sns.histplot(coeffs_all['error'], ax=ax[0], bins='doane', color='#5D7FA3')
ax[0].vlines(coeff_thres_error, 0, ax[0].get_ylim()[1], linestyle='--', color='#FF80AB', label=str(perc_coeff) + 'th percentile')

ax[0].yaxis.set_major_locator(MaxNLocator(integer=True))
ax[0].set_xlabel('Absolute value of "error" coefficients')
ax[0].set_ylabel('Number of neurons')


sns.histplot(coeffs_all['choice'], ax=ax[1], bins='doane', color='#5D7FA3')
ax[1].vlines(coeff_thres_choice, 0, ax[1].get_ylim()[1], linestyle='--', color='#FF80AB', label=str(perc_coeff) + 'th percentile')

ax[1].yaxis.set_major_locator(MaxNLocator(integer=True))
ax[1].set_xlabel('Absolute value of "choice" coefficients')
ax[1].set_ylabel('Number of neurons')


plt.savefig('*** Update the folder path here to match your directory structure ***', format='png', dpi=300, bbox_inches='tight')
plt.show()


In [None]:
reg_swans_error = []

for neuron in sorted(list(w_max_coeffs[(exampl_error[0], exampl_error[1])]['error'].keys())):

    curr_coeff = np.abs(w_max_coeffs[(exampl_error[0], exampl_error[1])]['error'][neuron])

    if ((curr_coeff > coeff_thres_error) and (neuron in neuron_region.keys())):

        curr_region = neuron_region[neuron]

        if curr_region not in reg_swans_error:

            reg_swans_error.append(curr_region)


reg_swans_choice = []

for neuron in sorted(list(w_max_coeffs[(exampl_choice[0], exampl_choice[1])]['choice'].keys())):

    curr_coeff = np.abs(w_max_coeffs[(exampl_choice[0], exampl_choice[1])]['choice'][neuron])

    if ((curr_coeff > coeff_thres_choice) and (neuron in neuron_region.keys())):

        curr_region = neuron_region[neuron]

        if curr_region not in reg_swans_choice:

            reg_swans_choice.append(curr_region)

reg_swans_all = list(set(reg_swans_error) | set(reg_swans_choice))
val_swans_all = []

for region in reg_swans_all:

    if (region in reg_swans_error) and (region in reg_swans_choice):
        val_swans_all.append(2.0)
    elif (region in reg_swans_error) and (region not in reg_swans_choice):
        val_swans_all.append(0.0)
    elif (region not in reg_swans_error) and (region in reg_swans_choice):
        val_swans_all.append(1.0)

val_swans_all = np.array(val_swans_all)


In [None]:
from matplotlib.colors  import ListedColormap
from matplotlib.patches import Patch

swans_leg_vals = [0.0, 1.0, 2.0]
swans_leg_repr = ['Error', 'Choice', 'Both']
swans_leg_cols = ['#C3B1E1', '#FFD8B8', '#A4D4F4']
custom_cmap    = ListedColormap(swans_leg_cols)

legend_elements = [Patch(facecolor=swans_leg_cols[i], label=f'{swans_leg_repr[i]}') for i in range(len(swans_leg_vals))]

fig, ax_swans = plt.subplots(1, 1, figsize=(12, 9))

with warnings.catch_warnings():
    warnings.simplefilter('ignore', category=RuntimeWarning)

    plot_swanson_vector(reg_swans_all,
                        val_swans_all,
                        ax=ax_swans,
                        cmap=custom_cmap,
                        vmin=np.min(swans_leg_vals),
                        vmax=np.max(swans_leg_vals),
                        empty_color='whitesmoke')

ax_swans.set_axis_off()

ax_swans.legend(handles=legend_elements, loc='lower left',  bbox_to_anchor=(0.0, 0.9), title='Differently representing', title_fontsize=14)

plt.savefig('*** Update the folder path here to match your directory structure ***', format='png', dpi=300, bbox_inches='tight')
plt.show()
