In [None]:
import numpy as np
import pickle as p
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.colors import to_hex
from nest_utils import data_processing
from nest_utils import utils, visualizer as vsl

sns.set_theme(style="white")
sns.despine(offset=5, trim=True)

#plt.rc('font', weight='bold')
savings=0


values = [-0.6,0.,0.6]
norm = plt.Normalize(min(values), max(values))
cmap_ct = LinearSegmentedColormap.from_list(
    '', [(norm(value), tuple(np.array(color) / 255)) for value, color in zip(values, [ (1, 138, 101),(245,245,245),(212, 112, 11)])])

cmap_th = LinearSegmentedColormap.from_list(
    '', [(norm(value), tuple(np.array(color) / 255)) for value, color in zip([-0.3,0.3], [(36, 46, 14),(229, 247, 186)])])
cmap_nrt = LinearSegmentedColormap.from_list(
    '', [(norm(value), tuple(np.array(color) / 255)) for value, color in zip([-0.3,0.3], [(84, 47, 21), (242, 194, 160)])])

values = [1, 100]
colors_dcn = [(50,40,8), (253, 201, 43)]
colors_ebcc = [(20, 64, 53), (69, 216, 181)]
colors_pc = [(50,36,14), (253, 183, 71)]
colors_pc = [(54, 53, 19), (212, 210, 77)]

norm = plt.Normalize(min(values), max(values))
cmap_dcn = LinearSegmentedColormap.from_list(
    '', [(norm(value), tuple(np.array(color) / 255)) for value, color in zip(values, colors_dcn)])
cmap_pc = LinearSegmentedColormap.from_list(
    '', [(norm(value), tuple(np.array(color) / 255)) for value, color in zip(values, colors_pc)])
cmap_ebcc = LinearSegmentedColormap.from_list(
    '', [(norm(value), tuple(np.array(color) / 255)) for value, color in zip(values, colors_ebcc)])

colors_dcn_cmap = []
colors_ebcc_cmap = []
colors_pc_cmap = []
all_values = np.arange(values[0], values[-1] + 1)
for val in all_values:
    colors_ebcc_cmap.append(to_hex(cmap_ebcc(norm(val)), keep_alpha=True))
    colors_dcn_cmap.append(to_hex(cmap_dcn(norm(val)), keep_alpha=True))
    colors_pc_cmap.append(to_hex(cmap_pc(norm(val)), keep_alpha=True))

sm_dcn = plt.cm.ScalarMappable(cmap=cmap_dcn, norm=plt.Normalize(vmin=0, vmax=101))
sm_pc = plt.cm.ScalarMappable(cmap=cmap_pc, norm=plt.Normalize(vmin=0, vmax=101))


In [None]:

path = f'./last_results/'
mode_list = ["external", "internal","both" ]
dopa_list_ebcc = ["", "_dopadepl_4", "_dopadepl_8"]


### Figure 3

In [None]:
from matplotlib.legend_handler import HandlerTuple
import seaborn as sns
import pandas as pd

In [None]:

SMALL_SIZE = 16
MEDIUM_SIZE = 18
BIGGER_SIZE = 20

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

sns.set_theme(style="white")

COLORS = {
    "cereb": "#FFCA29",
    "bg": "#106787",
    "both": "#BA2D0B",
}


# Experiment-related constants
START_TIME = 1000.0
literature_control_reference_values = {
    "STN": [15 - 2.5, 15 + 2.5],
    "GPe": [33.7 - 1.3, 33.7 + 1.3],
    "dcn": [9.6 - 5, 9.6 + 5, 10.1 - 5.9, 10.1 + 5.9, 9.9 - 4.2, 9.9 + 4.2],
    "purkinje": [43.1 - 7.41, 43.1 + 7.41],
    "glomerulus": [0.96 - 0.88, 0.96 + 0.88],
    "SNr": [0.0, 0.0],
    "Striatum": [0.0, 0.0],
}
literature_dopa_reference_values = {
    "STN": [31 - 2.5, 31 + 2.5],
    "GPe": [20.2 - 0.5, 20.2 + 0.5],
    "dcn": [19.3 - 9.2, 19.3 + 9.2, 16.7 - 8.5, 16.7 + 8.5, 16 - 6.7, 16 + 6.7],
    "purkinje": [27.4 - 8.48, 27.4 + 8.48],
    "glomerulus": [0.47 - 0.39, 0.47 + 0.39],
    "SNr": [0.0, 0.0],
    "Striatum": [0.0, 0.0],
}

desired_modes = ["internal"]
desired_modes = ["both", "external"]#, "internal"]
rasters_data = data_processing.read_saved_rasters(
    desired_modes,
    shared_dir=path+"/complete_3000ms_x_1_sol18",
)
df = data_processing.collect_firing_rates_dataframe(
    *rasters_data,
    t_start=START_TIME,
)
ref_df = data_processing.reference_fr_dataframe_from_literature(
    literature_control_reference_values, literature_dopa_reference_values
)
# Merge the GP and striatum subpopulations
df = data_processing.average_subpopulations(
    df,
    [["GPeTI", "GPeTA"], ["MSND1", "MSND2", "FSN"]],
    ["GPe", "Striatum"],
    desired_modes,
)
#%%
# Filter df based on desired population
to_be_plotted = ["Striatum", "GPe", "SNr", "STN", "dcn", "glomerulus", "purkinje"]
plot_names = [
    "Striatum",
    "GPe",
    "SNr",
    "STN",
    "DCNp",
    "Glomeruli",
    "Purkinje Cells",
]

df = data_processing.normalize_firing_rates(df, to_be_plotted)
df = df.loc[(df["dopa_lvl"] > 0.0) & (df["name"].isin(to_be_plotted))]
# Order df to reflect plotting ordering
df["name"] = pd.Categorical(df["name"], categories=to_be_plotted, ordered=True)
# Adding reference values here
ref_df = data_processing.normalize_firing_rates(ref_df, reference=True)
#%%

df.sort_values(by=["name", "dopa_lvl"], inplace=True)

In [None]:
mode="both"
plt.rcParams["hatch.linewidth"] = 0.25
df = df.loc[df["depletion_mode"] == mode]
if ref_df is not None:
    df = pd.concat(
        [
            df,
            ref_df.loc[ref_df["dopa_lvl"] == "Reference"],
        ],
        ignore_index=True,
    )
fig, ax = plt.subplots(figsize=(11, 5), dpi=96)
sns.barplot(
    data=df,
    x="name",
    y="firing_rate",
    hue="dopa_lvl",
    palette={
        0.1: "#BA2D0B",
        0.2: "#BA2D0B",
        0.4: "#BA2D0B",
        0.8: "#BA2D0B",
        "Reference": "#ffffffff",
    },
    errcolor=".50",
    errwidth=0.75,
    capsize=0.1,
    ax=ax,
)
# Update colors according to areas involved
BOTH_COL = "BA2D0B"
colors = [BOTH_COL] * 7  # STR, GPe, SNr, STN, DCN, Glom, PC
for bar_group, hex_alhpa_value, hatch in zip(
    ax.containers, ["40", "80", "bf", "ff", "00"], ["", "", "", "", "///"]
):  # 25, 50, 75, 100
    for bar, color in zip(bar_group, colors):
        bar.set_linewidth(1.5)
        bar.set_edgecolor([0.25, 0.25, 0.25])
        bar.set_facecolor(f"#{color}{hex_alhpa_value}")
        if hatch:
            bar.set_hatch(hatch)
# Update the legend accordingly
legend_labels = [
    "-0.1",
    "-0.2",
    "-0.4",
    "-0.8",
    "Expected variation",
]

leg = ax.legend(
    handles=[tuple(bar_group) for bar_group in ax.containers],
    labels=legend_labels,
    title="Dopamine depletion level",
    handler_map={tuple: HandlerTuple(ndivide=2, pad=0.1)},
    loc='upper left', bbox_to_anchor=(1.,1.2)
)
sns.despine(offset=5, trim=True)
ax.set(xlabel="", ylabel="Relative variation from physiological firing rate")
ax.set_xticks(
    range(len(to_be_plotted)),
    labels=plot_names,
    weight="bold",
    rotation=45,
    ha="right",
)
for axis in ["bottom", "left"]:
    ax.spines[axis].set_linewidth(1.5)
ax.axline((0.0, 0.0), slope=0.0, linewidth=1.5, color="black")
# sns.despine()
if savings:
    plt.savefig(path+"figure_3_b.png",dpi=300, bbox_inches="tight")


### Figure 4

In [None]:
from scipy.ndimage import gaussian_filter


In [None]:
settling_time = 1000.
sim_time = 3000.
sim_period = 1.  # ms
trials = 1

sol_n = 18

experiment_list = ['active', 'EBCC']
experiment = experiment_list[0]

for area in ["bg", "cereb", "mass"]:
    if area == "bg":
        TARGET_POP = ['GPeTA', 'STN', 'SNr']
        times_key='times'
        data_key='instant_fr'
        t_start = settling_time
        peaks_width = [[2, 8], [2, 8], [2, 8]]
    elif area == "cereb":
        TARGET_POP = ['glomerulus', 'purkinje', 'dcn']
        times_key='times'
        data_key='instant_fr'
        t_start = settling_time
        peaks_width = [[1, 8], [1, 8], [1, 8]]
    elif area == "mass":
        times_key='mass_frs_times'
        data_key="mass_frs"
        TARGET_POP = ['CTX', 'thalamus', 'nRT']
        t_start = 0.
        peaks_width = [[2, 8], [2, 8], [2, 8]]


    # for mode in ["both"]:
    for mode in mode_list:
        savings_dir = f'{path}complete_{int(sim_time)}ms_x_{trials}_sol{sol_n}_both_dopa_{experiment}'  # f'savings/{date_time}'
        saving_dir_list = [savings_dir]
        savings_dir = f'{path}complete_{int(sim_time)}ms_x_{trials}_sol{sol_n}_{mode}_dopa_{experiment}'  # f'savings/{date_time}'
        for dopa_depl_level in [-0.1, -0.2, -0.4, -0.8]:
            saving_dir_list += [savings_dir + f'_dopadepl_{(str(int(-dopa_depl_level*10)))}']


        wavelet_per_trial_list = []


        for sd, dopa_depl in zip(saving_dir_list, [0, -0.1, -0.2, -0.4, -0.8]):
            activity_list = []
            for trial_idx in range(1, 6):
                sdt = sd + f'_trial_{trial_idx}'
                print(f'Simulation results loaded from {sdt}')

                with open(f'{sdt}/model_dic', 'rb') as pickle_file:
                    model_dic = p.load(pickle_file)
                with open(f'{sdt}/rasters', 'rb') as pickle_file:
                    rasters = p.load(pickle_file)
                with open(f'{sdt}/mass_models_sol', 'rb') as pickle_file:
                    mass_frs = p.load(pickle_file)

                if area == "mass":
                    activity_list += [mass_frs]
                else:

                    instant_fr = utils.fr_window_step(rasters, model_dic['pop_ids'], settling_time + sim_time * trials,
                                                    window=1., step=1., start_time=1.)
                    instant_fr = [i_f for i_f in instant_fr if i_f['name'] in TARGET_POP]

                    instant_fr_array = []
                    for i_f in instant_fr:
                        instant_fr_array += [gaussian_filter(i_f['instant_fr'], [0, 2]).sum(axis=0) / 1000.]

                    instant_fr_array = np.array(instant_fr_array)
                    instant_fr_dic = instant_fr[0]
                    instant_fr_dic['instant_fr'] = instant_fr_array.swapaxes(0, 1)
                    activity_list += [instant_fr_dic]

            y_val = utils.average_wavelet_transform(activity_list, sim_period, TARGET_POP, t_start=t_start,
                                                        dopa_depl=dopa_depl, times_key=times_key, data_key=data_key, peaks_width = peaks_width)
            wavelet_per_trial_list += [y_val]
            
        p.dump(wavelet_per_trial_list, open(f"{path}/{mode}_wavelet_per_trial_list_{area}.p", "wb"))


In [None]:
regions = ["bg", "mass", "cereb"]

wavelet_dict = {}
wavelet_dict_norm = {}

for region in regions:
    data_processing.load_and_process_wavelet(path, region, wavelet_dict, wavelet_dict_norm)


In [None]:
list_subplots = [["bg", 1, "STN"],["bg", 2, "SNr"],["cereb", 1, "PC"],["cereb", 2, "DCNp"], ["mass", 0, "Cortex"], ["mass", 1, "Thalamus"], ["mass", 2, "nRT"] ]
n = len(list_subplots)

In [None]:

vmin = -0.8
vmax = .8

fig, ax = plt.subplots(n,1,figsize=(7,10), sharex=True)#, gridspec_kw={'height_ratios': [1, 1,3]})


ax = ax.reshape(n,)
for n_ax, data_ax in enumerate(list_subplots):
    im = sns.heatmap(wavelet_dict_norm[data_ax[0]]["both"][1:,:-20,1], vmax=0.6, cmap=cmap_ct,vmin=-0.6, cbar=False, cbar_kws={'label': 'Relative variation'},ax = ax[n_ax]) #seismic #BrBG_r
    ax[n_ax].set_xticks(ticks = np.linspace(0,119-20,10),labels=np.linspace(0,50,10).astype(int) )
    ax[n_ax].set_yticks(ticks = [0.5,1.5,2.5,3.5],labels = ["-0.1", "-0.2", "-0.4", "-0.8"],rotation=0)
    ax[n_ax].set_ylabel("Dopamine  \n level")
    mappable = im.get_children()[0]
    plt.colorbar(mappable, ax = ax[n_ax] ,orientation = 'vertical', label="Relative\nvariation")
    ax[n_ax].set_title(data_ax[2])

    plt.tight_layout(pad=2.01)
ax[n_ax].set_xlabel("Frequencies [Hz]")
if savings:
        plt.savefig(f'{path}/heatmap_variation_both.png',dpi=300,bbox_inches="tight")



In [None]:
#sns.set_theme(context="notebook",style="whitegrid",font_scale=0.7, rc={ "grid.linestyle": "-."})
labels_both =["","-0.1", "-0.2", "-0.4", "-0.8"]
lims =[[-0.2,0.4],[-0.4,0.8],[-0.6,0.6],[-0.4,0.4],[-0.8,0.8],[-0.8,0.8],[-0.8,0.8],]


fig, ax = plt.subplots(n,1,figsize=(7,10), sharex = True)#, sharey = True)
ax = ax.reshape(n,)
for n_ax, data_ax in enumerate(list_subplots):
    sns.lineplot(np.log(wavelet_dict_norm[data_ax[0]]["external"][-1,:-20,data_ax[1]] +1), label="BG",  color = "#106787", alpha = 0.9, ax=ax[n_ax])

    for i in range(4,5):
        sns.lineplot(np.log(wavelet_dict_norm[data_ax[0]]["both"][i,:-20,data_ax[1]]+1), label="Cerebellum + BG", alpha = 0.4+0.15*i, color = "#BA2D0B", ax=ax[n_ax]) #seismic #BrBG_r
    ax[n_ax].set_xticks(ticks = np.linspace(0,119-20,10),labels=np.linspace(0,50,10).astype(int) )
    ax[n_ax].set_xlabel("Frequencies [Hz]")
    ax[n_ax].set_ylabel("Wavelet \n variation")
    ax[n_ax].grid(False)
    ax[n_ax].set_ylim(lims[n_ax])

    ax[n_ax].set_title(data_ax[2])
    ax[n_ax].get_legend().remove()
# plt.tight_layout()
ax[6].legend(title = "Dopamine depletion site", ncol = 2,loc="best", bbox_to_anchor=(0.,-0.))
for ax in ax:
    for axis in ["bottom", "left"]:
            ax.spines[axis].set_linewidth(1.5)
    for axis in ["top", "right"]:
            ax.spines[axis].set_linewidth(0.0)

plt.subplots_adjust(hspace = 0.8)

if savings:
    plt.savefig(path+"/spectral_variation_3_no_cereb.png",dpi=300,bbox_inches="tight")


In [None]:
#sns.set_theme(context="notebook",style="whitegrid",font_scale=0.7, rc={ "grid.linestyle": "-."})
labels_both =["","-0.1", "-0.2", "-0.4", "-0.8"]
fig, ax = plt.subplots(n,1,figsize=(7,7), sharex = True)#, sharey = True)
ax = ax.reshape(n,)
lims =[[-0.2,0.4],[-0.4,0.8],[-0.6,0.4],[-0.6,0.4],[-1.2,0.8],[-1.,0.8],[-1.5,0.8],]

for n_ax, data_ax in enumerate(list_subplots):
    sns.lineplot(np.log(wavelet_dict_norm[data_ax[0]]["internal"][-1,:-20,data_ax[1]]+1),label="Cerebellum", color = "#FFCA29", alpha = 0.9, ax=ax[n_ax]) #cmap=cmap_ct,
    ax[n_ax].set_xticks(ticks = np.linspace(0,119-20,10),labels=np.linspace(0,50,10).astype(int) )
    ax[n_ax].set_xlabel("Frequencies [Hz]")
    ax[n_ax].set_ylabel("Wavelet \n variation")
    ax[n_ax].set_ylim(lims[n_ax])
    ax[n_ax].grid(False)
    ax[n_ax].set_title(data_ax[2])
    ax[n_ax].get_legend().remove()
plt.tight_layout
plt.legend(title = "Dopamine depletion site",loc='upper left', bbox_to_anchor=(1.,3.95))
for ax in ax:
    for axis in ["bottom", "left"]:
            ax.spines[axis].set_linewidth(1.5)
    for axis in ["top", "right"]:
            ax.spines[axis].set_linewidth(0.0)
    
plt.subplots_adjust(hspace = 0.8)

if savings:
    plt.savefig(path+"/spectral_variation_3_cereb.png",dpi=300,bbox_inches="tight")


### Figure 5

In [None]:
import os
iter = 101 

In [None]:
for mode in mode_list:
# for mode in ["internal"]:
    for dopa_lvl in dopa_list_ebcc:

        name =f"/complete_580ms_x_101_sol18_{mode}_dopa_EBCC{dopa_lvl}"
        if not os.path.exists(path+name):
            os.makedirs(path+name)
            print(f'\nWriting to {path+name}\n')

        ratios = [1.]
        threshold = 6.2

        CR_dict = {}
        CR_list = []
        for i in range(1,6):
            f = open(path + name+"_trial_"+str(i) + "/rasters","rb")
            rster = p.load(f)
            f.close()

            f = open(path + name+ "_trial_"+str(i)+ "/model_dic","rb")
            model = p.load(f)
            f.close()

            n_trials = model["trials"]
            sim_time = model["simulation_time"]
            set_time = model["settling_time"]
            len_trial = int(sim_time + set_time)
            len_trial = int(sim_time)

            first = 100#set_time#all_data['simulations']['DCN_update']['devices']['CS']['parameters']['start_first']
            n_trials = n_trials#all_data['simulations']['DCN_update']['devices']['CS']['parameters']['n_trials']
            between_start = 580 #all_data['simulations']['DCN_update']['devices']['CS']['parameters']['between_start']
            last = first + between_start*(n_trials-1)
            burst_dur = 280#all_data['simulations']['DCN_update']['devices']['CS']['parameters']['burst_dur']
            burst_dur_us = 30#all_data['simulations']['DCN_update']['devices']['US']['parameters']['burst_dur']
            burst_dur_cs = burst_dur- burst_dur_us
            trials_start = np.arange(first, last+between_start, between_start)

            selected_trials = np.linspace(1,n_trials-1,n_trials-1).astype(int) #Can specify trials to be analyzed

            maf_step = 100 #selected step for moving average filter when computing motor output from DCN SDF

            cr = vsl.cr_isi(threshold, selected_trials, maf_step, threshold, burst_dur, burst_dur_cs, trials_start, rster,between_start)
            # cr = cr_thr(threshold, ratio)
            CR_list.append(cr)
            CR_dict[i] = cr
            
            sdf_mean_dcn, sdf_ma_dcn = vsl.get_cell_sdf_MA("dcn", selected_trials, rster, trials_start, burst_dur, maf_step)
            sdf_mean_pc, sdf_ma_pc = vsl.get_cell_sdf_MA("purkinje", selected_trials, rster, trials_start, burst_dur, maf_step)
            # sdf_mean_STN, sdf_ma_STN = get_cell_sdf_MA("STN")
            # plt.plot(np.array(sdf_mean_STN).T)
            sdf_norm_dcn, _ = vsl.norm_sdf("dcn", selected_trials, rster, trials_start, burst_dur, burst_dur_cs, between_start, maf_step)
            sdf_norm_ma_pc, sdf_norm_pc = vsl.norm_sdf("purkinje", selected_trials, rster, trials_start, burst_dur, burst_dur_cs, between_start, maf_step)
            dict_all = {
                "sdf_mean_dcn":sdf_mean_dcn,
                "sdf_ma_dcn":sdf_ma_dcn,
                "sdf_mean_pc":sdf_mean_pc,
                "sdf_ma_pc":sdf_ma_pc,
                "sdf_norm_pc":sdf_norm_pc,
                "sdf_norm_dcn":sdf_norm_dcn
                }
            f = open(path + name + "/dict_sdf_trial"+str(i),"wb")
            p.dump(dict_all, f)
            f.close()


        # plt.show()

        f = open(path + name + "/dict_sdf","wb")
        p.dump(CR_list, f)
        f.close()

In [None]:
dict_sdf = {0: {},
            1: {},
            2: {}}
for dopa_i in range(3):
    for mode_i in range(3):

        mode = mode_list[mode_i]
        dopa = dopa_list_ebcc[dopa_i]
        #%%
        ext = f'/complete_580ms_x_101_sol18_{mode}_dopa_EBCC{dopa}/'
        name = 'dict_sdf_trial'
        trial = 1
        dict_sdf[dopa_i][mode_i] = {}
        if mode_i == 2:
            dict_sdf[dopa_i][mode_i]["title_mode"] = " in Cerebellum + BGs"
        elif mode_i == 0:
            dict_sdf[dopa_i][mode_i]["title_mode"] = " in BGs"
        elif mode_i == 1:
            dict_sdf[dopa_i][mode_i]["title_mode"] = " in Cerebellum"
        if dopa_i == 0:
            dict_sdf[dopa_i][mode_i]["title_dopa"] = "physiological"
            dict_sdf[dopa_i][mode_i]["title_mode"] = ""
        elif dopa_i == 1:
            dict_sdf[dopa_i][mode_i]["title_dopa"] = "dopamine level -0.4"
        elif dopa_i == 2:
            dict_sdf[dopa_i][mode_i]["title_dopa"] = "Dopamine level -0.8"

        file = path + ext + name+str(trial)

        f = open(file,'rb')
        data = p.load(f)
        f.close()

        dict_sdf[dopa_i][mode_i]["sdf_pc"] = np.array(data["sdf_mean_pc"])
        dict_sdf[dopa_i][mode_i]["sdf_dcn"] = np.array(data["sdf_mean_dcn"])
        alphas = np.linspace(20,80,100)

        dict_sdf[dopa_i][mode_i]["sdf_pc_ma"] = np.array(data["sdf_ma_pc"])
        dict_sdf[dopa_i][mode_i]["sdf_dcn_ma"] = np.array(data["sdf_ma_dcn"])
        dict_sdf[dopa_i][mode_i]["sdf_dcn_norm"] = np.array(data["sdf_norm_dcn"])
        dict_sdf[dopa_i][mode_i]["sdf_pc_norm"] = np.array(data["sdf_norm_pc"])

In [None]:
colors = sns.color_palette("husl", 3)
x = range(100,381)

sdf_dcn_norm = dict_sdf[0][1]["sdf_dcn_norm"]
sdf_pc_norm = dict_sdf[0][1]["sdf_pc_norm"]


fig, axes = plt.subplots(2, 1,sharex=True, figsize=(7,7))
fig.suptitle('EBCC in physiological conditions')
axes[0].set_title('Purkinje cell')
axes[1].set_title('Deep cerebellar nuclei')
axes[1].set_xlabel("Time [ms]")
axes[0].set_ylabel("SDF [Hz]")
axes[1].set_ylabel("SDF [Hz]")
axes[0].axvline(350,color = colors[1])
axes[0].axvline(100,color = colors[0])
axes[1].axvline(100,label="CS start",color = colors[0])
axes[1].axvline(350,label="US start",color = colors[1])
for i in range(100):
    sns.lineplot(ax=axes[1],x=x,y=sdf_dcn_norm[i,:], color = colors_dcn_cmap[i], alpha=0.8)
    sns.lineplot(ax=axes[0],x=x,y=sdf_pc_norm[i,99:], color = colors_dcn_cmap[i], alpha=0.8)
plt.text(110, 12, "Threshold", horizontalalignment='left', size='medium', color='b' )
axes[1].axhline(6,color = "b")

plt.legend(loc='center', bbox_to_anchor=(0.5, 1.3),ncol=2)
plt.subplots_adjust(hspace=0.5, )

cbar_ax = fig.add_axes([0.95,0.2,0.01,0.6])
plt.colorbar(sm_dcn, cax = cbar_ax,orientation = 'vertical', label = "Trial")
axes[1].set_ylim((-25,20))
axes[0].set_xlim((95,385))
axes[1].set_xlim((95,385))

for ax in axes:
    for axis in ["bottom", "left"]:
            ax.spines[axis].set_linewidth(1.5)
    for axis in ["top", "right"]:
            ax.spines[axis].set_linewidth(0.0)
if savings:
    plt.savefig("fig/ebcc_physio.png",dpi=300, bbox_inches="tight")


In [None]:
colors = sns.color_palette("husl", 3)
x = range(100,381)

sdf_dcn_norm = dict_sdf[2][2]["sdf_dcn_norm"]
sdf_pc_norm = dict_sdf[2][2]["sdf_pc_norm"]


fig, axes = plt.subplots(2, 1,sharex=True, figsize=(7,7))
fig.suptitle('EBCC with dopamine depletion -0.8 in Cerebellum + BGs')
axes[0].set_title('Purkinje cell')
axes[1].set_title('Deep cerebellar nuclei')
axes[1].set_xlabel("Time [ms]")
axes[0].set_ylabel("SDF [Hz]")
axes[1].set_ylabel("SDF [Hz]")
axes[0].axvline(350,color = colors[1])
axes[0].axvline(100,color = colors[0])
axes[1].axvline(100,label="CS start",color = colors[0])
axes[1].axvline(350,label="US start",color = colors[1])
for i in range(100):
    sns.lineplot(ax=axes[1],x=x,y=sdf_dcn_norm[i,:], color = colors_dcn_cmap[i], alpha=0.8)
    sns.lineplot(ax=axes[0],x=x,y=sdf_pc_norm[i,99:], color = colors_dcn_cmap[i], alpha=0.8)
plt.text(110, 12, "Threshold", horizontalalignment='left', size='medium', color='b' )
axes[1].axhline(6,color = "b")

plt.legend(loc='center', bbox_to_anchor=(0.5, 1.3),ncol=2)
plt.subplots_adjust(hspace=0.5, )

cbar_ax = fig.add_axes([0.95,0.2,0.01,0.6])
plt.colorbar(sm_dcn, cax = cbar_ax,orientation = 'vertical', label = "Trial")

axes[1].set_ylim((-25,20))
axes[0].set_xlim((95,385))
axes[1].set_xlim((95,385))

for ax in axes:
    for axis in ["bottom", "left"]:
            ax.spines[axis].set_linewidth(1.5)
    for axis in ["top", "right"]:
            ax.spines[axis].set_linewidth(0.0)
if savings:
    plt.savefig(path+"ebcc_8 in both.png",dpi=300, bbox_inches="tight")


In [None]:
colors = sns.color_palette("husl", 3)
x = range(100,381)
fig, axes = plt.subplots(2, 3,sharex=True,sharey=False, figsize=(10,7))
dopa_i = 2

for i in [2,1,0]:
    sdf_dcn_norm = dict_sdf[dopa_i][i]["sdf_dcn_norm"]
    sdf_pc_norm = dict_sdf[dopa_i][i]["sdf_pc_norm"]
    axes[0,i].set_title('Purkinje cell')
    axes[1,i].set_title('Deep cerebellar nuclei')
    axes[1,i].set_xlabel("Time [ms]")
    axes[0,0].set_ylabel("Adj SDF [Hz]")
    axes[1,0].set_ylabel("Adj SDF [Hz]")


    axes[0,i].axvline(350,color = colors[1])
    axes[0,i].axvline(100,color = colors[0])
    axes[1,i].axvline(100,color = colors[0])
    axes[1,i].axvline(350,color = colors[1])
    axes[1,i].axhline(6,color = "b")
    axes[0,i].text(240, 290, dict_sdf[dopa_i][i]["title_mode"], horizontalalignment='center', size='large', color='black' )
    axes[0,1].text(240, 320, dict_sdf[dopa_i][i]["title_dopa"], horizontalalignment='center', size='large', color='black' )

    for j in range(100):
        sns.lineplot(ax=axes[1,i],x=x,y=sdf_dcn_norm[j,:], color = colors_dcn_cmap[j], alpha=0.8)
        sns.lineplot(ax=axes[0,i],x=x,y=sdf_pc_norm[j,99:], color = colors_dcn_cmap[j], alpha=0.8)
    axes[1,i].set_ylim((-25,20))
    axes[0,i].set_ylim((-50,250))
    axes[0,i].set_xlim((95,385))
    axes[1,i].set_xlim((95,385))

cbar_ax = fig.add_axes([0.95,0.2,0.01,0.6])
fig.colorbar(sm_dcn, cax=cbar_ax,orientation = 'vertical', label = "Trial")
axes[1,1].axvline(100,label="CS start",color = colors[0])
axes[1,1].axvline(350,label="US start",color = colors[1])
axes[1,1].axhline(6,color = "b", label="Threshold")

axes[1,1].legend(loc='center', bbox_to_anchor=(0.5, -0.3),ncol=3)
for k in range(1,3):        
    axes[0,k].get_yaxis().set_ticklabels([])
    axes[1,k].get_yaxis().set_ticklabels([])
for ax in axes.reshape(6):
    for axis in ["bottom", "left"]:
            ax.spines[axis].set_linewidth(1.5)
    for axis in ["top", "right"]:
            ax.spines[axis].set_linewidth(0.0)
if savings:
    plt.savefig(path+"ebcc_norm_"+str(dopa_list_ebcc[dopa_i])+".png",dpi=300,bbox_inches="tight")



### CR

In [None]:
dict_cr = {}
dict_cr = {0:{},
            1:{},
            2:{}}

for i_mode in range(3):
    for i_dopa in range(3):
        mode = mode_list[i_mode]
        dopa = dopa_list_ebcc[i_dopa]
        ext = f'/complete_580ms_x_101_sol18_{mode}_dopa_EBCC{dopa}/'

        file = path + ext + 'dict_sdf'
        f = open(file,'rb')
        data = p.load(f)
        f.close()
        
        dict_cr[i_mode][i_dopa] = data

In [None]:
#%%
colors = ["red", "blue", "green"]
colors = sns.color_palette("husl", 3)
colors = ["#050505", "#737272", "#bababa"]

labels= ["physiological", "-0.4", "-0.8"]

fig, axes = plt.subplots(1, 3,sharey=True, figsize=(21,5))
fig.suptitle('Complex responses in different dopamine depletion conditions')
axes[2].set_title('Dopamine depletion in BGs and Cerebellum')
axes[1].set_title('Dopamine depletion in Cerebellum')
axes[0].set_title('Dopamine depletion in BGs')
axes[1].set_xlabel("Trial set [#]")
axes[2].set_xlabel("Trial set [#]")
axes[0].set_xlabel("Trial set [#]")
axes[0].set_ylabel("CRs [%]")

d_0 = np.array(dict_cr[0][0])
d_1 = np.array(dict_cr[1][0])
d_2 = np.array(dict_cr[2][0])
d = np.concatenate((d_0,d_1))
d = np.concatenate((d,d_2))

err = ('ci', 95)
#err = "se"
sns.pointplot(ax=axes[0], data= d, color = colors[0], errorbar=('ci', 95))
ax = sns.pointplot(ax=axes[1], data= d, color = colors[0], errorbar=('ci', 95))
ax = sns.pointplot(ax=axes[2], data= d, color = colors[0], errorbar=('ci', 95))
sns.pointplot(ax=axes[1], data= np.array(dict_cr[1][1]), color = colors[1], errorbar=err)
sns.pointplot(ax=axes[1], data= np.array(dict_cr[1][2]), color = colors[2], errorbar=err)
sns.pointplot(ax=axes[2], data= np.array(dict_cr[2][1]), color = colors[1], errorbar=err)
sns.pointplot(ax=axes[2], data= np.array(dict_cr[2][2]), color = colors[2], errorbar=err)
sns.pointplot(ax=axes[0], data= np.array(dict_cr[0][1]), color = colors[1], errorbar=err)
sns.pointplot(ax=axes[0], data= np.array(dict_cr[0][2]), color = colors[2], errorbar=err)
axes[0].set_xticklabels(range(1,11))
axes[1].set_xticklabels(range(1,11))
for curve, label in zip(axes[2].collections, labels):
    curve.set_label(label)
axes[2].legend(title='Dopamine level', loc='upper right', bbox_to_anchor=(1.25, 1.01))
for ax in axes:
    for axis in ["bottom", "left"]:
            ax.spines[axis].set_linewidth(1.5)
    for axis in ["top", "right"]:
            ax.spines[axis].set_linewidth(0.0)
    # ax.set_xticks(
    #     weight="bold",
    #        )
plt.tight_layout()
if savings:
    plt.savefig("fig/cr.png",dpi=300,bbox_inches="tight")

#%%