In [214]:
import matplotlib
import matplotlib.pyplot as plt
from IPython.display import set_matplotlib_formats
%matplotlib inline
set_matplotlib_formats('pdf')
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

keys = ['ar_intrinsic', 'ar_seasonal', 'ex_reward', 'ex_choice_side', 'ex_choice_stimulus']
paths_names = ['both', 'stable', 'volatile']
area_label = ["dlPFC", "OFC", "ACC"]
ses_area = np.hstack((np.zeros(shape=(108)), np.ones(shape=(117)), 2*np.ones(shape=(91))))



  set_matplotlib_formats('pdf')


In [215]:
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd

def plot_parameter(parameter_vals, significant, parameter_name, figname):
    """
    parameter_vals: list[list[float]]
        list of length num_areas, each element of the list is a list of the parameters
    """
    gen_color = '#F15946'
    # remove outliers
    for idx, area_vals in enumerate(parameter_vals):
        df = pd.DataFrame({'vals': area_vals})
        df = remove_insignificant(df, significant[idx])
        df = remove_outlier_IQR(df)
        parameter_vals[idx] = list(df['vals'].to_numpy())
    
    # convert to dataframe for plotting
    area = []
    parameter = []
    for i in range(len(area_label)):
        parameter = parameter + parameter_vals[i]
        for _ in range(len(parameter_vals[i])):
            area.append(area_label[i])
    df = pd.DataFrame({'area': area, parameter_name: parameter})

    plt.figure(figsize=(2.5,4))
    ax = sns.pointplot(data=df, x = 'area', color=gen_color,y = parameter_name, estimator=np.median, ci=95)
    # Hide the right and top spines
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)

    # Only show ticks on the left and bottom spines
    ax.yaxis.set_ticks_position('left')
    ax.xaxis.set_ticks_position('bottom')

    ax.set_xlabel("")
    ax.set_ylabel(parameter_name, fontsize=14)
    plt.tight_layout()
    plt.savefig('../figures/' + figname + '.pdf')  
    plt.show()

def plot_parameter_comparison(parameter_vals_sta, significant_sta, parameter_vals_vol, significant_vol, parameter_name, figname):
    """
    parameter_vals: list[list[float]]
        list of length num_areas, each element of the list is a list of the parameters
    """
    vol_color = '#53B3CB'
    stab_color = '#F9C22E'
    hue_dict = dict()
    hue_dict['volatile'] = vol_color
    hue_dict['stable'] = stab_color
    # remove outliers
    for idx, area_vals in enumerate(parameter_vals_sta):
        df = pd.DataFrame({'vals': area_vals})
        df = remove_insignificant(df, significant_sta[idx])
        df = remove_outlier_IQR(df)
        parameter_vals_sta[idx] = list(df['vals'].to_numpy())

    # remove outliers
    for idx, area_vals in enumerate(parameter_vals_vol):
        df = pd.DataFrame({'vals': area_vals})
        df = remove_insignificant(df, significant_vol[idx])
        df = remove_outlier_IQR(df)
        parameter_vals_vol[idx] = list(df['vals'].to_numpy())
    
    # convert significant to dataframe for plotting
    area = []
    significance = []
    volatility = []
    for i in range(len(area_label)):
        significance.append(sum(significant_sta[i])/len(significant_sta[i])*100)
        area.append(area_label[i])
        volatility.append("stable")

    for i in range(len(area_label)):
        significance.append(sum(significant_vol[i])/len(significant_vol[i])*100)
        area.append(area_label[i])
        volatility.append("volatile")
    
    df_sig = pd.DataFrame({'area': area, "percent of neurons": significance, 'volatility': volatility})

    # convert to dataframe for plotting
    area = []
    parameter = []
    volatility = []
    for i in range(len(area_label)):
        parameter = parameter + parameter_vals_sta[i]
        for _ in range(len(parameter_vals_sta[i])):
            area.append(area_label[i])
            volatility.append("stable")

    for i in range(len(area_label)):
        parameter = parameter + parameter_vals_vol[i]
        for _ in range(len(parameter_vals_vol[i])):
            area.append(area_label[i])
            volatility.append("volatile")

    df = pd.DataFrame({'area': area, parameter_name: parameter, 'volatility': volatility})

    plt.figure(figsize=(2.5,4))
    ax = sns.pointplot(data=df, x = 'area',y = parameter_name, hue='volatility',palette=hue_dict, dodge=.2, estimator=np.median)
    ax.legend(frameon=False)
    # Hide the right and top spines
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)

    # Only show ticks on the left and bottom spines
    ax.yaxis.set_ticks_position('left')
    ax.xaxis.set_ticks_position('bottom')

    ax.set_xlabel("")
    ax.set_ylabel(parameter_name, fontsize=14)
    plt.tight_layout()
    plt.savefig('../figures/' + figname + '_comparison_t.pdf')  
    plt.show()

    g = sns.catplot(
        data=df_sig, kind="bar",
        x="area", y="percent of neurons", hue="volatility",
        ci=None, palette=hue_dict, aspect=.62, height=4, legend=False 
    )
    #g.get_axis().set_xlabel("")
    #g.get_axis().set_ylabel("Percent of neurons", fontsize=12)
    #g.legend.set_title("")
    plt.tight_layout()
    plt.savefig('../figures/' + figname + '_comparison_b.pdf')  
    plt.show()

def remove_insignificant(df, significant):
    s = pd.Series(significant, name='bools')
    df_final = df[s.values]
    return df_final

def remove_outlier_IQR(df):
    Q1=df.quantile(0.25)
    Q3=df.quantile(0.75)
    IQR=Q3-Q1
    df_final=df[~((df<(Q1-1.5*IQR)) | (df>(Q3+1.5*IQR)))]
    return df_final


In [216]:
import pickle
import numpy as np
both_path = "../data/postprocessed/both.npy"
key = 'ex_reward'
key_to_label = {'ar_intrinsic': r'$\tau_{intrinsic}$', 'ar_seasonal': r'$\tau_{seasonal}$', \
     'ex_reward': r'$\tau_{reward}$', 'ex_choice_side': r'$\tau_{choice side}$', 'ex_choice_stimulus': r'$\tau_{choice stimulus}$'}
with open(both_path, 'rb') as input_file:
    timescales, amplitudes, significant = pickle.load(input_file)

plot_parameter(timescales[key], significant[key],key_to_label[key],key)

both_path = "../data/postprocessed/stable.npy"
with open(both_path, 'rb') as input_file:
    timescales_sta, amplitudes_sta, significant_sta = pickle.load(input_file)

both_path = "../data/postprocessed/volatile.npy"
with open(both_path, 'rb') as input_file:
    timescales_vol, amplitudes_vol, significant_vol = pickle.load(input_file)

plot_parameter_comparison(timescales_sta[key], significant_sta[key],timescales_vol[key], significant_vol[key], key_to_label[key], key)


<Figure size 180x288 with 1 Axes>

<Figure size 180x288 with 1 Axes>

<Figure size 178.56x288 with 1 Axes>