# Imports

In [None]:
import importlib
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import lines
import os
import sys
import seaborn as sns

In [None]:
pythoncodepath = os.path.abspath(os.path.join('..', '..', '_pythoncode'))
sys.path = [pythoncodepath] + sys.path

import importhelper
importhelper.addfolders2path(pythoncodepath)

In [None]:
import data_utils
import plot_utils

In [None]:
plot_utils.set_rcParams()

In [None]:
fig_num = os.getcwd().split('/')[-1][3:5]
print(fig_num)

# Get data

In [None]:
cbc_optim_folder = os.path.join('..', '..', 'step2a_optimize_cbc', 'optim_data')

In [None]:
sorted(os.listdir(cbc_optim_folder))

In [None]:
cell2folder = {
    'OFF':  os.path.join(cbc_optim_folder, 'optimize_OFF_submission2'),
    'ON':   os.path.join(cbc_optim_folder, 'optimize_ON_submission2'),
}

In [None]:
priors = {}
posteriors = {}
params = {}
best_params = {}

for cell, folder in cell2folder.items():
    sample_distributions = data_utils.load_var(os.path.join(folder, 'snpe', 'sample_distributions.pkl'))
        
    priors[cell] = sample_distributions[0]
    posteriors[cell] = sample_distributions[-1]
    
    params[cell] = data_utils.load_var(os.path.join(folder, 'params.pkl'))
    
    best_params[cell] = data_utils.load_var(os.path.join(folder, 'post_data', 'final_model_output.pkl'))['params']

# Export data

In [None]:
import pandas as pd

In [None]:
data_utils.make_dir('source_data')

In [None]:
for cell in cell2folder.keys():
    p_names_i = params[cell].p_names
    p_range_i = np.asarray(list(params[cell].p_range.values()))
    
    bounds_df = pd.DataFrame({'a': p_range_i[:,0], 'b': p_range_i[:,0]}, index=p_names_i)
    
    best_df = pd.DataFrame(index=p_names_i)
    best_df['normalized'] = params[cell].sim_params2opt_params(best_params[cell])
    best_df['not normalized'] = list(best_params[cell].values())
    
    prior_mean_df = pd.DataFrame(priors[cell].mean, index=p_names_i, columns=['mean'])
    prior_cov_df = pd.DataFrame(priors[cell].S, columns=p_names_i, index=p_names_i)
    
    post_mean_df = pd.DataFrame(posteriors[cell].mean, index=p_names_i, columns=['mean'])
    post_cov_df = pd.DataFrame(posteriors[cell].S, columns=p_names_i, index=p_names_i)
    
    bounds_df.to_csv('source_data/' + cell + '_truncation_bounds.csv', float_format='%.3f')
    
    best_df.to_csv('source_data/' + cell + '_best_parameters.csv', float_format='%.6f')

    prior_mean_df.to_csv('source_data/' + cell + '_prior_mean_normalized.csv', float_format='%.6f')
    prior_cov_df.to_csv('source_data/' + cell + '_prior_covariance_normalized.csv', float_format='%.6f')

    post_mean_df.to_csv('source_data/' + cell + '_posterior_mean_normalized.csv', float_format='%.6f')
    post_cov_df.to_csv('source_data/' + cell + '_posterior_covariance_normalized.csv', float_format='%.6f')

# Generate distrubtion plotter

In [None]:
import plot_sampling_dists
importlib.reload(plot_sampling_dists);

PPs = {}

for cell in cell2folder.keys():
    
    print(cell)
    
    PP = plot_sampling_dists.SamplingDistPlotter(
        params=params[cell],
        prior=priors[cell],
        posterior_list=[posteriors[cell]],
    )
    PP.set_bounds(lbs=priors[cell].lower, ubs=priors[cell].upper)
    PP.plot_sampling_dists_1D(plot_peak_lines=False, figsize=(12,8), opt_x=True)
    
    PPs[cell] = PP

# Get params names and units

In [None]:
all_p_names = list(np.unique(np.concatenate([params[cell].p_names for cell in ['ON', 'OFF']])))

all_p_names

## Merge parameters

In [None]:
mergeparam2param  = {
    'cd_H_at': ['cd_H1_at', 'cd_H4_at'],
    'cd_H_d':  ['cd_H1_d', 'cd_H4_d'],
    'cd_H_s':  ['cd_H1_s', 'cd_H4_s'],
}

for mergeparam, params_i in mergeparam2param.items():
    for param in params_i:
        assert (param in all_p_names)

## Units

In [None]:
units = {
    'b_rrp':     r'(ves.)', 
    'bp_cm':     r'($\mu$F cm$^{-2}$)',
    'bp_gain':   r'(n.u.)',
    'bp_rm':     r'(k$\Omega$ cm$^{2}$)',
    'bp_vrev':   r'(mV)',
    'c_Kir_off': r'(mV)',
    'c_Kv_off':  r'(mV)',
    'c_Kv_taua': r'(n.u.)',
    'c_L_off':   r'(mV)',
    'c_L_taua':  r'(n.u.)',
    'c_N_offh':  r'(mV)',
    'c_N_offm':  r'(mV)',
    'c_N_tau':   r'(n.u.)',
    'c_T_off':   r'(mV)',
    'c_T_taua':  r'(n.u.)',
    'ca_PK':     r'($\mu$M)',
    'cd_H_at':   r'(mS cm$^{-2}$)',
    'cd_H_d':    r'(mS cm$^{-2}$)',
    'cd_H_s':    r'(mS cm$^{-2}$)',
    'cd_Kir':    r'(mS cm$^{-2}$)',
    'cd_Kv_a':   r'(mS cm$^{-2}$)',
    'cd_Kv_d':   r'(mS cm$^{-2}$)',
    'cd_Kv_pa':  r'(mS cm$^{-2}$)',
    'cd_L_at':   r'(mS cm$^{-2}$)',
    'cd_L_s':    r'(mS cm$^{-2}$)',
    'cd_N':      r'(mS cm$^{-2}$)',
    'cd_P_at':   r'($\mu$S cm$^{-2}$)',
    'cd_P_s':    r'($\mu$S cm$^{-2}$)',
    'cd_T_at':   r'(mS cm$^{-2}$)', 
    'cd_T_s':    r'(mS cm$^{-2}$)',
    'r_tauc':    r'(n.u.)',
    'syn_cc':    r'(mM)',
}

for param in units.keys():
    assert (param in all_p_names) or (param in mergeparam2param.keys()), param

## Renaming to latex space.

In [None]:
renaming = {
    'b_rrp':     r'$v^{max}_{RRP}$', 
    'bp_cm':     r'$C_m$',
    'bp_gain':   r'$g_l$',
    'bp_rm':     r'$R_m$',
    'bp_vrev':   r'$V_r$',
    'c_Kir_off': r'$\Delta V_\alpha(K_{ir})$',
    'c_Kv_off':  r'$\Delta V_\alpha({K_V})$',
    'c_Kv_taua': r'$\tau_\alpha({K_V})$',
    'c_L_off':   r'$\Delta V_\alpha({Ca_L})$',
    'c_L_taua':  r'$\tau_\alpha({Ca_L})$',
    'c_N_offh':  r'$\Delta V_\alpha(Na_V)$',
    'c_N_offm':  r'$\Delta V_\gamma(Na_V)$',
    'c_N_tau':   r'$\tau_{all}(Na_V)$',
    'c_T_off':   r'$\Delta V_\alpha({Ca_T})$',
    'c_T_taua':  r'$\tau_\alpha({Ca_T})$',
    'ca_PK':     r'$Ca_{PK}$',
    'cd_H_at':   r'${HCN\,@\,AT}$',
    'cd_H_d':    r'${HCN\,@\,D}$',
    'cd_H_s':    r'${HCN\,@\,S}$',
    'cd_Kir':    r'${K_{ir}\,@\,S}$',
    'cd_Kv_a':   r'${K_{v}\,@\,A}$',
    'cd_Kv_d':   r'${K_{v}\,@\,D}$',
    'cd_Kv_pa':  r'${K_{v}\,@\,PA}$',
    'cd_L_at':   r'${Ca_{L}\,@\,AT}$',
    'cd_L_s':    r'${Ca_{L}\,@\,S}$',
    'cd_N':      r'${Na_{V}\,@\,DA}$',
    'cd_P_at':   r'${Ca_{P}\,@\,AT}$',
    'cd_P_s':    r'${Ca_{P}\,@\,S}$',
    'cd_T_at':   r'${Ca_{T}\,@\,AT}$', 
    'cd_T_s':    r'${Ca_{T}\,@\,S}$',
    'r_tauc':    r'$\tau_\alpha({Kainate})$',
    'syn_cc':    r'$STC$',
}

for param in renaming.keys():
    assert (param in all_p_names) or (param in mergeparam2param.keys()), param

## Define order of parameters.

In [None]:
p_names_sorted = [
    'bp_rm',
    'bp_vrev',
    'bp_cm',
    'r_tauc',
    'bp_gain',
    'b_rrp',
    '_legend',

    'cd_L_at',
    'cd_L_s',
    'cd_T_at',
    'cd_T_s',    
    'cd_P_at',
    'cd_P_s',
    'cd_Kir',

    'cd_Kv_a',
    'cd_Kv_d',
    'cd_Kv_pa',
    'cd_N',
    'cd_H_at',
    'cd_H_d',
    'cd_H_s',

    'c_L_off',
    'c_T_off',
    'c_Kir_off',
    'c_L_taua',
    'c_T_taua',
    'c_Kv_taua',
    'c_Kv_off',
    
    'c_N_offh',
    'c_N_offm',
    'c_N_tau',
    
    'ca_PK',
]

for param in p_names_sorted:
    assert (param in all_p_names) or (param in mergeparam2param.keys()) or (param == '_legend'), param

# Plotting functions

In [None]:
col_prior = {
    'ON':  'darkgreen',
    'OFF': 'darkblue',
    'all': 'dimgray'
}

col_post  = {
    'ON':  'darkgreen',
    'OFF': 'steelblue'
}

ls_prior = ':'
ls_post = '-'
ls_best_params_lines = '-'

In [None]:
def plot_params(ax, param):
    
    ax.set_title(renaming[param])
    ax.set_xlabel(units[param], labelpad=-3)
    
    plot_prior(ax, param)
    plot_posteriors(ax, param)

In [None]:
xvals = np.linspace(-0.04,1.04,201)

In [None]:
def plot_prior(ax, param):
    
    yvals = {}
    
    for cell, PP in PPs.items():
        param_idx = get_param_idx(param, p_names=PP.params.p_names)
        if param_idx is not None:
            yvals[cell] = PP.eval_1d_marginal(
                dist=PP.prior, idx=param_idx, x=xvals,
            )
            
    merge_prior = True
    try:
        assert len(yvals) <= 2
        if len(yvals) == 2:
            assert 'ON' in yvals
            assert 'OFF' in yvals
            assert np.allclose(yvals['ON'], yvals['OFF'])
    except:
        merge_prior = False
    
    if len(yvals) > 0:
        if merge_prior:
            ax.plot(xvals, yvals[list(yvals.keys())[0]], color=col_prior['all'], ls=ls_prior)
        else:
            for cell, yvals_cell in yvals.items():
                ax.plot(xvals, yvals_cell, color=col_prior[cell], ls=ls_prior)

In [None]:
def plot_posteriors(ax, param):
    
    yvals = {}
    
    for cell, PP in PPs.items():
        param_idx = get_param_idx(param, p_names=PP.params.p_names)
        if param_idx is not None:
            yvals[cell] = PP.eval_1d_marginal(
                dist=PP.posterior_list[0], idx=param_idx, x=xvals,
            )
    
    for cell, yvals_cell in yvals.items():
        ax.plot(xvals, yvals_cell, color=col_post[cell], ls=ls_post)

## Helper functions

In [None]:
def get_param_if_merged(param, p_names):
    for param_i in mergeparam2param[param]:
        if param_i in p_names:
            return param_i
    raise       

In [None]:
def get_param_idx(param, p_names):
    if param in mergeparam2param.keys():
        param = get_param_if_merged(param, p_names)
                
    if param in p_names:
        param_idx = np.argmax(np.array(p_names)==param)
    else:
        param_idx = None
        
    return param_idx

## Addtional plot functions

In [None]:
def plot_legend(ax):
    prior_leg = lines.Line2D([], [], color=col_prior['all'], label='prior', linestyle=ls_prior)
    post_legs = []
    post_legs.append(lines.Line2D([],[],color=col_post['OFF'], label = 'post: OFF', linestyle=ls_post))
    post_legs.append(lines.Line2D([],[],color=col_post['ON'], label = 'post: ON', linestyle=ls_post))
    ax.legend(
        handles=[prior_leg] + post_legs, handlelength=1.3,
        loc='upper left', bbox_to_anchor=(0, 1), borderaxespad=0., labelspacing=0.1, frameon=False
    )
    ax.axis('off')

In [None]:
def add_xticks(ax, param):
    
    in_ON = param in params['ON'].p_names
    in_OFF = param in params['OFF'].p_names
    
    if in_ON and not in_OFF:
        p_range = params['ON'].p_range[param]
    elif in_OFF and not in_ON:
        p_range = params['OFF'].p_range[param]
    elif in_ON and in_OFF:
        p_range = params['ON'].p_range[param]
        assert p_range == params['OFF'].p_range[param]
    elif param in mergeparam2param.keys():
        param_ON = get_param_if_merged(param, params['ON'].p_names)
        param_OFF = get_param_if_merged(param, params['OFF'].p_names)
        p_range = params['ON'].p_range[param_ON]
        assert p_range == params['OFF'].p_range[param_OFF]
    else:
        p_range = ["?","?"]
    
    ax.set_xticks([0,1])
    ax.set_xticklabels(p_range)

In [None]:
def add_best_params(ax, param):
    for cell, best_params_i in best_params.items():
        
        if param in mergeparam2param.keys():
            plot_param = get_param_if_merged(param, params[cell].p_names)
        else:
            plot_param = param
        
        if plot_param in params[cell].p_names:
            ax.axvline(
                params[cell].sim_param2opt_param(best_params_i[plot_param], plot_param),
                color=col_post[cell], ls=ls_best_params_lines, alpha=0.8
            )

# Make figure

In [None]:
nx_sb = 7
ny_sb = 4

In [None]:
fig, axs = plt.subplots(ny_sb, nx_sb, figsize=(7.9,ny_sb*0.8), sharey=False)

for ax, param in zip(axs.flatten(), p_names_sorted):
    if param == '_legend':
        plot_legend(ax)
    else:
        plot_params(ax, param)
        
    add_xticks(ax, param)
    add_best_params(ax, param)

sns.despine()
    
for ax in axs.flatten():
    ax.set_ylim((ax.get_ylim()[0], ax.get_ylim()[1]*1.05))
    ax.set_yticks([])
    ax.spines['left'].set_visible(False)
        
plt.tight_layout(w_pad=1, rect=[0,-0.02,1,1.02], h_pad=0.3)

plt.savefig(f'../_figures/fig{fig_num}_posteriors.pdf')
plt.show()

# Make supplement figure

## 2D helper functions

In [None]:
from matplotlib import cm

In [None]:
xvals2d = np.linspace(0,1,401)

In [None]:
def plot_dist_2d(ax, param1, param2, cell, dist, cmap, levels=None):
    
    levels = levels.copy()
    
    param_idx1 = get_param_idx(param1, p_names=params[cell].p_names)
    param_idx2 = get_param_idx(param2, p_names=params[cell].p_names)
    if (param_idx1 is not None) and (param_idx2 is not None):
        xx, yy, zz = PPs[cell].eval_2d_marginal(
            dist=dist, idx1=param_idx1, idx2=param_idx2, x1=xvals2d, x2=xvals2d,
        )

        if not isinstance(levels, int):
            levels *= np.max(zz)
        
        ax.contour(xx, yy, zz, cmap=cmap, vmin=0, vmax=np.max(zz), origin='lower', levels=levels)

In [None]:
def plot_prior_2d(ax, param1, param2, cell, levels=None):
    plot_dist_2d(ax, param1, param2, cell, dist=PPs[cell].prior, cmap=cm.gray_r, levels=levels)

In [None]:
def plot_post_2d(ax, param1, param2, cell, levels=None):
    plot_dist_2d(ax, param1, param2, cell, dist=PPs[cell].posterior_list[-1], cmap=cm.gist_heat_r, levels=levels)

In [None]:
def add_ticks(ax, param, cell, xaxis=True):
    
    param_idx = get_param_idx(param, p_names=params[cell].p_names)
    p_range = params[cell].p_range[params[cell].p_names[param_idx]]
    
    if xaxis:
        ax.set_xticks([0,1])
        ax.set_xticklabels(p_range, rotation=90, ha='center')
    else:
        ax.set_yticks([0,1])
        ax.set_yticklabels(p_range)
        
    ax.tick_params(length=0.0)

In [None]:
def plot_pair(ax, param1, param2, cell, xlabel=True, ylabel=True, levels=None):
    plot_prior_2d(ax=ax, param1=param1,  param2=param2, cell=cell, levels=levels)
    plot_post_2d(ax=ax, param1=param1,  param2=param2, cell=cell, levels=levels)
    
    param_idx1 = get_param_idx(param1, p_names=PPs[cell].params.p_names)
    param_idx2 = get_param_idx(param2, p_names=PPs[cell].params.p_names)
    
    if xlabel: ax.set_xlabel(renaming[param1] + '\n' + units[param1], rotation=90, ha='center', va='center', labelpad=30)
    if ylabel: ax.set_ylabel(renaming[param2] + '\n' + units[param2], rotation=0, ha='center', va='center', labelpad=30)
    
    ax.set_aspect('equal')
    
    add_ticks(ax, param1, cell, xaxis=True)
    add_ticks(ax, param2, cell, xaxis=False)
    
    if not xlabel: ax.set_xticklabels([]) 
    if not ylabel: ax.set_yticklabels([]) 

In [None]:
levels = np.arange(0.3, 1., 0.3)
levels

In [None]:
ax = plt.subplot(111)
plot_pair(ax, param1='bp_rm', param2='bp_vrev', cell='OFF', xlabel=True, levels=levels)

In [None]:
def plot_all_2d(p_names_sorted_2d, nparams2plot, cell, levels):

    sbnx = nparams2plot-1
    sbny = nparams2plot-1

    fig, axs = plt.subplots(sbny, sbnx, figsize=(np.min([6.7, 1.5*sbnx]),np.min([6.7, 1.5*sbnx])), squeeze=False)

    for idx1, param1 in enumerate(p_names_sorted_2d[:nparams2plot-1]):
        for idx2, param2 in enumerate(p_names_sorted_2d[1:nparams2plot]):
            ax = axs[idx2, idx1]
            if idx1 <= idx2:
                plot_pair(
                    ax=ax, param1=param1, param2=param2, cell=cell,
                    xlabel=(idx2+1==sbny), ylabel=(idx1==0), levels=levels
                )
            else:
                ax.axis('off')

    fig.align_ylabels(axs[:,0])

    plt.tight_layout(h_pad=0.05, w_pad=0.05)
    plt.savefig(f'../_figures_apx/figapx{fig_num}_2d_posterior'+ cell + '.pdf')

## OFF

In [None]:
cell = 'OFF'

p_names_sorted_2d = [
    'bp_rm',
    'bp_vrev',
    'bp_cm',
    'bp_gain',
    'cd_L_at',
    'cd_P_at',
    'b_rrp',
    'cd_H_at',
    'cd_H_d',
    'cd_H_s',
    'cd_Kir',
    'r_tauc',
    'cd_T_at',
    'cd_Kv_a',
    'cd_Kv_d',
    'cd_Kv_pa',
    'cd_N',
    'c_L_off',
    'c_T_off',
    'c_Kir_off',
    'c_L_taua',
    'c_T_taua',
    'c_Kv_taua',
    'c_Kv_off',
    'c_N_offh',
    'c_N_offm',
    'c_N_tau',
    'ca_PK',
]

plot_all_2d(p_names_sorted_2d=p_names_sorted_2d, nparams2plot=13, cell=cell, levels=levels)

## ON

In [None]:
cell = 'ON'

p_names_sorted_2d = [
    'bp_rm',
    'bp_vrev',
    'bp_cm',
    'bp_gain',
    'cd_L_at',
    'cd_P_at',
    'b_rrp',
    'cd_H_at',
    'cd_H_d',
    'cd_H_s',
    'cd_Kir',
    'cd_Kv_a',
    'cd_Kv_d',
    'cd_Kv_pa',
    'cd_N',
    'c_L_off',
    'c_Kir_off',
    'c_L_taua',
    'c_T_taua',
    'c_Kv_taua',
    'c_Kv_off',
    'c_N_offh',
    'c_N_offm',
    'c_N_tau',
    'ca_PK',
]

plot_all_2d(p_names_sorted_2d=p_names_sorted_2d, nparams2plot=13, cell=cell, levels=levels)