## Figure on inference in HH model on simulated and Allen data

In [None]:
%run -i common.ipynb

import lfimodels.hodgkinhuxley.utils as utils

from lfimodels.hodgkinhuxley.HodgkinHuxley import HodgkinHuxley
from lfimodels.hodgkinhuxley.HodgkinHuxleyStatsMoments import HodgkinHuxleyStatsMoments
from lfimodels.hodgkinhuxley.HodgkinHuxleyStatsSpikes_mf import HodgkinHuxleyStatsSpikes_mf

# FIGURE and GRID
FIG_HEIGHT_MM = 150
FIG_WIDTH_MM = FIG_WIDTH_MM  # set in common notebook to a default value for all figures
FIG_N_ROWS = 5
ROW_1_NCOLS = 2
ROW_1_HEIGHT_MM = FIG_HEIGHT_MM / FIG_N_ROWS
ROW_1_WIDTH_COL_1_MM = FIG_WIDTH_MM / ROW_1_NCOLS
ROW_1_WIDTH_COL_2_MM = FIG_WIDTH_MM / ROW_1_NCOLS
ROW_2_NCOLS = 2
ROW_2_HEIGHT_MM = FIG_HEIGHT_MM / FIG_N_ROWS
ROW_2_WIDTH_COL_1_MM = FIG_WIDTH_MM / ROW_2_NCOLS
ROW_2_WIDTH_COL_2_MM = FIG_WIDTH_MM / ROW_2_NCOLS
ROW_3_NCOLS = 4
ROW_3_HEIGHT_MM = FIG_HEIGHT_MM / FIG_N_ROWS
ROW_3_WIDTH_COL_1_MM = FIG_WIDTH_MM / ROW_3_NCOLS
ROW_3_WIDTH_COL_2_MM = FIG_WIDTH_MM / ROW_3_NCOLS
ROW_3_WIDTH_COL_3_MM = FIG_WIDTH_MM / ROW_3_NCOLS
ROW_3_WIDTH_COL_4_MM = FIG_WIDTH_MM / ROW_3_NCOLS
ROW_4_NCOLS = 2
ROW_4_HEIGHT_MM = FIG_HEIGHT_MM / FIG_N_ROWS
ROW_4_WIDTH_COL_1_MM = FIG_WIDTH_MM / ROW_4_NCOLS
ROW_4_WIDTH_COL_2_MM = FIG_WIDTH_MM / ROW_4_NCOLS
ROW_5_NCOLS = 4.2
ROW_5_HEIGHT_MM = FIG_HEIGHT_MM / FIG_N_ROWS
ROW_5_WIDTH_COL_1_MM = FIG_WIDTH_MM / ROW_5_NCOLS
ROW_5_WIDTH_COL_2_MM = FIG_WIDTH_MM / ROW_5_NCOLS
ROW_5_WIDTH_COL_3_MM = FIG_WIDTH_MM / ROW_5_NCOLS


W_FACT = 0.85
H_FACT = 0.85

PATH_DROPBOX_FIGS = PATH_DROPBOX + 'figs/'

# PATHS
PANEL_A = PATH_DROPBOX_FIGS + 'fig_hh_a.svg'
PANEL_B = PATH_DROPBOX_FIGS + 'fig_hh_b.svg'
PANEL_C = PATH_DROPBOX_FIGS + 'fig_hh_c.svg'
PANEL_D1 = PATH_DROPBOX_FIGS + 'fig_hh_d1.svg'
PANEL_D2 = PATH_DROPBOX_FIGS + 'fig_hh_d2.svg'
PANEL_D3 = PATH_DROPBOX_FIGS + 'fig_hh_d3.svg'
PANEL_D4 = PATH_DROPBOX_FIGS + 'fig_hh_d4.svg'
PANEL_E = PATH_DROPBOX_FIGS + 'fig_hh_e.svg'
PANEL_F = PATH_DROPBOX_FIGS + 'fig_hh_f.svg'
PANEL_F1 = PATH_DROPBOX_FIGS + 'fig_hh_f1.svg'
PANEL_G = PATH_DROPBOX_FIGS + 'fig_hh_g.svg'
PANEL_H = PATH_DROPBOX_FIGS + 'fig_hh_h.svg'
PANEL_I = PATH_DROPBOX_FIGS + 'fig_hh_i.svg'

## save figures or not


In [None]:
save_fig = 1 # 1: save figures

## panels A-C

### load data

In [None]:
true_params, labels_params = utils.obs_params()

seed = 1
prior_uniform = True
prior_log = False
prior_extent = True
n_xcorr = 0
n_mom = 4
cython=True
summary_stats = 1

I, t_on, t_off, dt = utils.syn_current()
A_soma = np.pi*((70.*1e-4)**2)  # cm2

obs = utils.syn_obs_data(I, dt, true_params, seed=seed, cython=cython)
y_obs = obs['data']
t = obs['time']
duration = np.max(t)

m = HodgkinHuxley(I, dt, V0=obs['data'][0],seed=seed, cython=cython,prior_log=prior_log)

p = utils.prior(true_params=true_params,prior_uniform=prior_uniform,
                prior_extent=prior_extent,prior_log=prior_log, seed=seed)


# number of summary features
n_summary_ls = [1,4,7]
n_post = len(n_summary_ls)

#######################
# SNPE parameters
n_components = 2
n_sims = 50000
n_rounds = 2
svi = False
if svi:
    svi_flag = '_svi'
else:
    svi_flag = '_nosvi'

#######################

s_ls = []
posterior_ls = []
res_ls = []
for nsum in n_summary_ls:
            
#     s = HodgkinHuxleyStatsSpikes_mf(t_on=t_on, t_off=t_off,n_summary=nsum)
    s = HodgkinHuxleyStatsMoments(t_on=t_on, t_off=t_off,n_xcorr=n_xcorr,n_mom=n_mom,n_summary=nsum)
    s_ls.append(s)
    
    ##############################################################################
    # SNPE results
#     filename1 = '../hodgkinhuxley/results/sim_run_2_round2_prior0013_param8_statspikes_nsum'+str(nsum)+\
#     svi_flag+'_ncomp'+str(n_components)+'_nsims'+str(n_sims*n_rounds)+'_snpe_rej_res.pkl'
    filename1 = '../hodgkinhuxley/results/sim_run_1_round2_prior0013_param8_nsum'+str(nsum)+\
    svi_flag+'_ncomp'+str(n_components)+'_nsims'+str(n_sims*n_rounds)+'_snpe_res.pkl'
    
    res = io.load(filename1)
    res_ls.append(res)
    
    posterior = res.predict(res.obs)
    posterior_ls.append(posterior)

### panel A

In [None]:
# matplotlib takes figsize specified as inches
# in common SCALE_IN and the conversion function are defined
fig_inches = (SCALE_IN*mm2inches(W_FACT*ROW_1_WIDTH_COL_1_MM), SCALE_IN*mm2inches(H_FACT*ROW_1_HEIGHT_MM))

svg(PANEL_A)

### panel B

In [None]:
import plot_multipdf

prior_min = p.lower
prior_max = p.upper

prior_lims = np.concatenate((prior_min.reshape(-1,1),
                             prior_max.reshape(-1,1)),
                            axis=1)

mn_post = posterior_ls[-1].xs[np.argmax(posterior_ls[-1].a)].m
    
# less likely parameter set from posterior than mode
post_low = mn_post*0.8

post_samp = np.concatenate((mn_post.reshape(-1,1),post_low.reshape(-1,1)), axis=1)
col_samp = [COL['EFREE'],COL['SAMPLES']]
# col_samp = [(255/255,255/255,255/255),(255/255,255/255,255/255)]

# matplotlib takes figsize specified as inches
# in common SCALE_IN and the conversion function are defined
fig_inches = (SCALE_IN*mm2inches(W_FACT*ROW_1_WIDTH_COL_2_MM), SCALE_IN*mm2inches(2*H_FACT*ROW_1_HEIGHT_MM))

with mpl.rc_context(fname=MPL_RC):
    fig = plt.figure(figsize=fig_inches)
    plot_multipdf.plot_pdf_multipts(posterior_ls[-1], lims=prior_lims,gt=post_samp,labels_params=LABELS_HH,
                                    figsize=fig_inches,fontscale=0.5,ticks=True,col2=COL['SNPE'],col_samp=col_samp)
    if save_fig:
        plt.savefig(PANEL_B, facecolor='None', transparent=True)  # the figure is saved as svg
        plt.close()
        svg(PANEL_B)
    else:
        plt.show()

### panel C

In [None]:
# number of simulations for same parameter set
num_rep = 100
        
x_high_ls = []
x_low_ls = []
sum_stats_high_ls = []
sum_stats_low_ls = []
for rep in range(num_rep):
    x_high = m.gen_single(mn_post)
    x_high_ls.append(x_high)

    sum_stats_high = res_ls[-1].generator.summary.calc([x_high])[0]
    sum_stats_high_ls.append(sum_stats_high)

    x_low = m.gen_single(post_low)
    x_low_ls.append(x_low)

    sum_stats_low = res_ls[-1].generator.summary.calc([x_low])[0]
    sum_stats_low_ls.append(sum_stats_low)
    
mn_stats_high = np.nanmean(sum_stats_high_ls,axis=0)
std_stats_high = np.nanstd(sum_stats_high_ls,axis=0)

mn_stats_low = np.nanmean(sum_stats_low_ls,axis=0)
std_stats_low = np.nanstd(sum_stats_low_ls,axis=0)

In [None]:
n_summary_stats = n_summary_ls[-1]

obs_stats_norm_mat = res_ls[-1].obs[0]/res_ls[-1].stats_std
# arg_sort_stats = np.argsort(obs_stats_norm_mat)
arg_sort_stats = np.linspace(0,n_summary_stats-1,n_summary_stats).astype('int')
LABELS_HH_SUMSTATS1 = np.array(LABELS_HH_SUMSTATS)

mn_stats_high_norm = mn_stats_high/res_ls[-1].stats_std
mn_stats_low_norm = mn_stats_low/res_ls[-1].stats_std

std_stats_high_norm = std_stats_high/res_ls[-1].stats_std
std_stats_low_norm = std_stats_low/res_ls[-1].stats_std

# matplotlib takes figsize specified as inches
# in common SCALE_IN and the conversion function are defined
fig_inches = (SCALE_IN*mm2inches(1.5*W_FACT*ROW_2_WIDTH_COL_1_MM), SCALE_IN*mm2inches(H_FACT*ROW_2_HEIGHT_MM))

with mpl.rc_context(fname=MPL_RC):
    fig = plt.figure(figsize=fig_inches)

    ax = plt.subplot(121)
    plt.plot(t, obs['data'], color = COL['GT'], lw=2, label='observation')
    plt.plot(t, x_high['data'], color = col_samp[0], lw=2, label='mode')
    plt.plot(t, x_low['data'], color = col_samp[1], lw=2, label='low prob')
    plt.xlabel('time (ms)')
    plt.ylabel('voltage (mV)')

    ax = plt.gca()
    ax.set_xticks([0, duration/2, duration])
    ax.set_yticks([-80, -20, 40])
    
    width = 0.3
    ax = plt.subplot(122)
    plt.bar(np.linspace(0,n_summary_stats-1,n_summary_stats),obs_stats_norm_mat[arg_sort_stats],
            width,color=COL['GT'],label='observation')
    plt.bar(np.linspace(0,n_summary_stats-1,n_summary_stats)+width,mn_stats_high_norm[arg_sort_stats],
            width, color=col_samp[0],yerr=std_stats_high_norm[arg_sort_stats],label='mode')
    plt.bar(np.linspace(0,n_summary_stats-1,n_summary_stats)+2*width,mn_stats_low_norm[arg_sort_stats],
            width, color=col_samp[1],yerr=std_stats_low_norm[arg_sort_stats],label='low probability')
    ax.set_xlim(-1.5*width,n_summary_stats+width/2)
    ax.set_xticks(np.linspace(0,n_summary_stats-1,n_summary_stats)+width/2)
    ax.set_xticklabels(LABELS_HH_SUMSTATS1[arg_sort_stats])
    plt.ylabel(r'$\frac{f}{\sigma_{f \ PRIOR}}$')
    plt.legend(bbox_to_anchor=(1.51, 1.1), loc='upper right')

#     plt.yscale('log')
    ax.yaxis.set_major_formatter(mpl.ticker.FormatStrFormatter('%.1f'))
    
        
    if save_fig:
        plt.savefig(PANEL_C, facecolor='None', transparent=True)  # the figure is saved as svg
        plt.close()
        svg(PANEL_C)
    else:
        plt.show()

### panel D

In [None]:
# matplotlib takes figsize specified as inches
# in common SCALE_IN and the conversion function are defined
fig_inches = (SCALE_IN*mm2inches(W_FACT*ROW_3_WIDTH_COL_1_MM), SCALE_IN*mm2inches(H_FACT*ROW_3_HEIGHT_MM))


# two samples from posteriors:
# sample 1 is the mode of posterior with all features (last posterior);
# sample 2 is the mean of second mixture component for posterior with 4 features (2nd posterior)
samp_post1 = posterior_ls[-1].xs[np.argmax(posterior_ls[-1].a)].m
samp_post2 = posterior_ls[1].xs[np.argmin(posterior_ls[1].a)].m
post_samp = np.concatenate((samp_post1.reshape(-1,1),samp_post2.reshape(-1,1)), axis=1)

partial_ls = [1,2,4,7]

label_feature = [' feature']+[' features']*(n_post-1)

PANEL_D_ls = [PANEL_D1,PANEL_D2,PANEL_D3]

for i in range(n_post):
    with mpl.rc_context(fname=MPL_RC):
        fig = plt.figure(figsize=fig_inches)
#         plot_pdf(posterior_ls[i], lims=prior_lims,labels_params=LABELS_HH,
#              figsize=fig_inches,fontscale=0.5,ticks=True,partial=True,col2=COL['SNPE'])
        plot_multipdf.plot_pdf_multipts(posterior_ls[i], lims=prior_lims,gt=post_samp,labels_params=LABELS_HH,
                                        figsize=fig_inches,fontscale=0.5,ticks=True,partial=True,
                                        partial_ls = partial_ls,col2=COL['SNPE'])
        x0, xmax = plt.xlim()
        y0, ymax = plt.ylim()
        data_width = xmax - x0
        data_height = ymax - y0
        plt.text(x0+data_width*-2.7, y0+data_height*4.8,
                 str(n_summary_ls[i])+str(label_feature[i]),fontsize=12)
        if save_fig:
            plt.savefig(PANEL_D_ls[i], facecolor='None', transparent=True)  # the figure is saved as svg
            plt.close()
            svg(PANEL_D_ls[i])
        else:
            plt.show()

In [None]:
# number of simulations for same parameter set
num_rep = 100

x_samp1_ls = []
x_samp2_ls = []
sum_stats_samp1_ls = []
sum_stats_samp2_ls = []    
for rep in range(num_rep):        
    x_samp1 = m.gen_single(samp_post1)
    x_samp1_ls.append(x_samp1)

    sum_stats_samp1 = res_ls[-1].generator.summary.calc([x_samp1])[0]
    sum_stats_samp1_ls.append(sum_stats_samp1)

    x_samp2 = m.gen_single(samp_post2)
    x_samp2_ls.append(x_samp2)

    sum_stats_samp2 = res_ls[-1].generator.summary.calc([x_samp2])[0]
    sum_stats_samp2_ls.append(sum_stats_samp2)


mn_stats_samp1 = np.nanmean(sum_stats_samp1_ls,axis=0)
std_stats_samp1 = np.nanstd(sum_stats_samp1_ls,axis=0)

mn_stats_samp2 = np.nanmean(sum_stats_samp2_ls,axis=0)
std_stats_samp2 = np.nanstd(sum_stats_samp2_ls,axis=0)

In [None]:
# matplotlib takes figsize specified as inches
# in common SCALE_IN and the conversion function are defined
fig_inches = (SCALE_IN*mm2inches(W_FACT*ROW_3_WIDTH_COL_4_MM), SCALE_IN*mm2inches(H_FACT*ROW_3_HEIGHT_MM))

col_min = 1
num_colors = 2+col_min
cm1 = mpl.cm.Reds
col1 = [cm1(1.*i/num_colors) for i in range(col_min,num_colors)]

with mpl.rc_context(fname=MPL_RC):
    fig = plt.figure(figsize=fig_inches)

#     ax = plt.subplot(121)
    plt.plot(t, obs['data'], color=COL['GT'], lw=2, label='observation')
    plt.plot(t, x_samp1['data'], color=col1[1], lw=2, label='sample 1')
    plt.plot(t, x_samp2['data'], color=col1[0], lw=2, label='sample 2')
    plt.xlabel('time (ms)')
    plt.ylabel('voltage (mV)')
    plt.legend(bbox_to_anchor=(1.35, 1.1), loc='upper right', fontsize=9)

    ax = plt.gca()
    ax.set_xticks([0, duration/2, duration])
    ax.set_yticks([-80, -20, 40])
    
#     n_summary_stats = n_summary_ls[-1]
#     width = 0.3
#     ax = plt.subplot(122)
#     plt.bar(np.linspace(0,n_summary_stats-1,n_summary_stats),res_ls[-1].obs[0]/res_ls[-1].stats_std,
#             width,color=COL['GT'],label='observation')
#     plt.bar(np.linspace(0,n_summary_stats-1,n_summary_stats)+width,mn_stats_samp1/res_ls[-1].stats_std,
#             width, color=col1[1],yerr=std_stats_samp1/res_ls[-1].stats_std,label='sample 1')
#     plt.bar(np.linspace(0,n_summary_stats-1,n_summary_stats)+2*width,mn_stats_samp2/res_ls[-1].stats_std,
#             width, color=col1[0],yerr=std_stats_samp2/res_ls[-1].stats_std,label='sample 2')
#     ax.set_xlim(-1.5*width,n_summary_stats+width/2)
#     ax.set_xticks(np.linspace(0,n_summary_stats-1,n_summary_stats)+width/2)
#     ax.set_xticklabels(LABELS_HH_SUMSTATS)
#     plt.ylabel(r'$\frac{f}{\sigma_{PRIOR}}$')

# #     plt.yscale('log')
#     ax.yaxis.set_major_formatter(mpl.ticker.FormatStrFormatter('%.1f'))
    
        
    if save_fig:
        plt.savefig(PANEL_D4, facecolor='None', transparent=True)  # the figure is saved as svg
        plt.close()
        svg(PANEL_D4)
    else:
        plt.show()

## panels E-G

### load data

In [None]:
true_params, labels_params = utils.obs_params()

seed = 1
prior_uniform = True
prior_log = False
prior_extent = True
n_xcorr = 0
n_mom = 4
cython=True
n_summary = 7
summary_stats = 1


list_cells_AllenDB = [[518290966,57,0.0234/126],[509881736,39,0.0153/184],[566517779,46,0.0195/198],
                      [567399060,38,0.0259/161],[569469018,44,0.033/403],[532571720,42,0.0139/127],
                      [555060623,34,0.0294/320],[534524026,29,0.027/209],[532355382,33,0.0199/230],
                      [526950199,37,0.0186/218]]


n_post = len(list_cells_AllenDB)

# SNPE parameters
n_components = 1
n_sims = 25000
n_rounds = 2
svi = False
if svi:
    svi_flag = '_svi'
else:
    svi_flag = '_nosvi'

# IBEA parameters
algo = 'ibea'
offspring_size = 500
max_ngen = 100

obs_ls = []
I_ls = []
dt_ls = []
t_on_ls = []
t_off_ls = []
obs_stats_ls = []
m_ls = []
s_ls = []
posterior_ls = []
res_ls = []
halloffame_ls = []
for cell_num in range(n_post):

    ephys_cell = list_cells_AllenDB[cell_num][0]
    sweep_number = list_cells_AllenDB[cell_num][1]
    A_soma = list_cells_AllenDB[cell_num][2]
    junction_potential = -14

    obs = utils.allen_obs_data(ephys_cell=ephys_cell,sweep_number=sweep_number,A_soma=A_soma)

    obs['data'] = obs['data'] + junction_potential
    I = obs['I']
    dt = obs['dt']
    t_on = obs['t_on']
    t_off = obs['t_off']
    
    obs_ls.append(obs)
    I_ls.append(I)
    dt_ls.append(dt)
    t_on_ls.append(t_on)
    t_off_ls.append(t_off)
    
    obs_stats = utils.allen_obs_stats(data=obs,ephys_cell=ephys_cell,sweep_number=sweep_number,
                                  n_xcorr=n_xcorr,n_mom=n_mom,
                                  summary_stats=summary_stats,n_summary=n_summary)
    
    obs_stats_ls.append(obs_stats[0])
    
    m = HodgkinHuxley(I=I, dt=dt, V0=obs['data'][0], seed=seed, cython=cython, prior_log=prior_log)
    m_ls.append(m)
    
#     s = HodgkinHuxleyStatsSpikes_mf(t_on=t_on, t_off=t_off,n_summary=n_summary)
    s = HodgkinHuxleyStatsMoments(t_on=t_on, t_off=t_off,n_xcorr=n_xcorr,n_mom=n_mom,n_summary=n_summary)
    s_ls.append(s)
    
    ##############################################################################
    # SNPE results
#     filename1 = '../hodgkinhuxley/results/allen_'+str(ephys_cell)+'_'+str(sweep_number)+\
#     '_run_1_round2_prior0013_param8_statspikes'+svi_flag+'_ncomp'+str(n_components)+\
#     '_nsims'+str(n_sims*n_rounds)+'_snpe_rej_res.pkl'
    filename1 = '../hodgkinhuxley/results/allen_'+str(ephys_cell)+'_'+str(sweep_number)+\
    '_run_1_round2_prior0013_param8'+svi_flag+'_ncomp'+str(n_components)+\
    '_nsims'+str(n_sims*n_rounds)+'_snpe_res.pkl'
    
    res = io.load(filename1)
    res_ls.append(res)
    
    posterior = res.predict(obs_stats)
    posterior_ls.append(posterior)

    ##############################################################################
    # IBEA results
#     _, halloffame, _, _ = io.load_pkl('../hodgkinhuxley/results/allen_'+str(ephys_cell)+'_'+str(sweep_number)+\
#                                                '_run_1_offspr'+str(offspring_size)+\
#                                                '_max_gen'+str(max_ngen)+ '_param8_statspikes_' + algo + '.pkl')
    _, halloffame, _, _ = io.load_pkl('../hodgkinhuxley/results/allen_'+str(ephys_cell)+'_'+str(sweep_number)+\
                                               '_run_1_offspr'+str(offspring_size)+\
                                               '_max_gen'+str(max_ngen)+ '_param8_' + algo + '.pkl')
    
    halloffame_ls.append(halloffame)

### panel E

In [None]:
# number of cell recordings plotted
num_rec = 3
cell_ls = [0,1,2]

# matplotlib takes figsize specified as inches
# in common SCALE_IN and the conversion function are defined
fig_inches = (SCALE_IN*mm2inches(1.1*ROW_4_WIDTH_COL_1_MM), SCALE_IN*mm2inches(H_FACT*ROW_4_HEIGHT_MM))

with mpl.rc_context(fname=MPL_RC):
    fig = plt.figure(figsize=fig_inches)
    
    i = 0
    for cell_num in cell_ls:
        y_obs = obs_ls[cell_num]['data']
        t = obs_ls[cell_num]['time']
        duration = np.max(t)
        
        mn_post = posterior_ls[cell_num].xs[np.argmax(posterior_ls[cell_num].a)].m
        x_post = m_ls[cell_num].gen_single(mn_post)
        
        plt.subplot(2,3,i+1)
        plt.plot(t, y_obs, color = COL['GT'], lw=2, label='Allen Cell Types Database')
        ax = plt.gca()
        ax.set_xticks([])
        
        if i==0:
            ax.set_yticks([-80, -20, 40])
        else:
            ax.set_yticks([])
            
        if i==1:
            plt.legend(bbox_to_anchor=(0.52, 1.4), loc='upper center')
        
        plt.subplot(2,3,i+4)
        plt.plot(t, x_post['data'], color = COL['SNPE'], lw=2, label='mode')
        plt.xlabel('time (ms)')
        
        ax = plt.gca()       
        ax.set_xticks([0, duration/2, duration])
        if i==0:
            plt.ylabel('voltage (mV)',x=0, y=1.2)
            ax.set_yticks([-80, -20, 40])
        else:
            ax.set_yticks([])
            
        if i==1:
            plt.legend(bbox_to_anchor=(0.35, 1.35), loc='upper center')
            
        i = i+1

    if save_fig:
        plt.savefig(PANEL_E, facecolor='None', transparent=True)  # the figure is saved as svg
        plt.close()
        svg(PANEL_E)
    else:
        plt.show()

### panel F

In [None]:
# matplotlib takes figsize specified as inches
# in common SCALE_IN and the conversion function are defined
fig_inches = (SCALE_IN*mm2inches(0.5*W_FACT*ROW_4_WIDTH_COL_2_MM), SCALE_IN*mm2inches(H_FACT*ROW_4_HEIGHT_MM))

with mpl.rc_context(fname=MPL_RC):
    fig = plt.figure(figsize=fig_inches)

    ax1 = fig.add_axes([0.05, 0.80, 0.6, 0.1])
    # [left, bottom, width, height]

    cmap = mpl.cm.Blues
    norm = mpl.colors.Normalize(vmin=1, vmax=10)

    cb1 = mpl.colorbar.ColorbarBase(ax1, cmap=cmap,norm=norm,orientation='horizontal')
    cb1.set_label('recording', x=1.25, labelpad=-13)
    cb1.outline.set_visible(False)
    cb1.set_ticks([])

    if save_fig:
        plt.savefig(PANEL_F1, facecolor='None', transparent=True)  # the figure is saved as svg
        plt.close()
        svg(PANEL_F1)
    else:
        plt.show()

In [None]:
# matplotlib takes figsize specified as inches
# in common SCALE_IN and the conversion function are defined
fig_inches = (SCALE_IN*mm2inches(W_FACT*ROW_4_WIDTH_COL_2_MM), SCALE_IN*mm2inches(2*H_FACT*ROW_4_HEIGHT_MM))

col_min = 1
num_colors = n_post+col_min
cm1 = mpl.cm.Blues
col1 = [cm1(1.*i/num_colors) for i in range(col_min,num_colors)]

# partial_ls = [5,6,7]

cell_num = 0
prior_min = res_ls[cell_num].generator.prior.lower
prior_max = res_ls[cell_num].generator.prior.upper

prior_lims = np.concatenate((prior_min.reshape(-1,1),
                             prior_max.reshape(-1,1)),
                            axis=1)

with mpl.rc_context(fname=MPL_RC):
    fig = plt.figure(figsize=fig_inches)

    plot_multipdf.plot_multipdf(posterior_ls, lims=prior_lims,
                 labels_params=LABELS_HH, partial=False,
                 figsize=fig_inches,fontscale=0.5,ticks = True,colrs=col1,imageshow=False)

    if save_fig:
        plt.savefig(PANEL_F, facecolor='None', transparent=True)  # the figure is saved as svg
        plt.close()
        svg(PANEL_F)
    else:
        plt.show()

### panel G

In [None]:
# comparison between posterior means

# collect means and normalized covariances across recordings
mn_post_ls = []
cov_post_ls = []
for cell_num in range(n_post):
    mn_post = posterior_ls[cell_num].xs[np.argmax(posterior_ls[cell_num].a)].m
    mn_post_ls.append(mn_post)
    
    # divide covariances by respective prior standard deviations
    S_prior = np.diag(res_ls[cell_num].generator.prior.std)
    cov_post = posterior_ls[cell_num].xs[np.argmax(posterior_ls[cell_num].a)].S
    cov_post_ls.append(np.linalg.inv(S_prior).dot(cov_post).dot(np.linalg.inv(S_prior)))

mn_post_mat = np.asarray(mn_post_ls)
cov_post_mat = np.asarray(cov_post_ls)

std_mn_post_mat = np.std(mn_post_mat,axis=0)/res_ls[cell_num].generator.prior.std

arg_sort_std = np.argsort(std_mn_post_mat)
LABELS_HH1 = np.array(LABELS_HH)
n_params = len(LABELS_HH1)

# matplotlib takes figsize specified as inches
# in common SCALE_IN and the conversion function are defined
fig_inches = (SCALE_IN*mm2inches(W_FACT*ROW_5_WIDTH_COL_1_MM), SCALE_IN*mm2inches(H_FACT*ROW_5_HEIGHT_MM))

with mpl.rc_context(fname=MPL_RC):
    fig = plt.figure(figsize=fig_inches)

    plt.plot(np.linspace(1,n_params,n_params),std_mn_post_mat[arg_sort_std])
    ax = plt.gca()
    ax.set_xticks(np.linspace(1,n_params,n_params))
    ax.set_xticklabels(LABELS_HH1[arg_sort_std])
    plt.locator_params(axis='y', nbins=3)
    plt.ylabel(r'$\frac{\sigma_{\bar{\theta}}}{\sigma_{PRIOR}}$')
    
    if save_fig:
        plt.savefig(PANEL_G, facecolor='None', transparent=True)  # the figure is saved as svg
        plt.close()
        svg(PANEL_G)
    else:
        plt.show()

### panel H

In [None]:
# eigenvalues of posterior covariances

# check if number of components is 1 for each posterior
n_components_crit = 1
for i in range(n_post):
    if res_ls[i].network.n_components>1:
        n_components_crit = 0
    
# compute eigenvectors of covariances
eig_val_ls = []
eig_vec_ls = []
if n_components_crit==1:
    for cell_num in range(n_post):
        w, v = np.linalg.eig(cov_post_ls[cell_num])
        arg_sort = np.argsort(w)[::-1]
        eig_val_ls.append(w[arg_sort])
        eig_vec_ls.append(v[:,arg_sort])
    eig_val_mat = np.asarray(eig_val_ls)
    eig_vec_mat = np.asarray(eig_vec_ls)
else:
    print('some posteriors have more than 1 component, so this analysis is invalid')

num_eig = len(eig_val_ls[0])

col_min = 1
num_colors = n_post+col_min
cm1 = mpl.cm.Blues
col1 = [cm1(1.*i/num_colors) for i in range(col_min,num_colors)]

# matplotlib takes figsize specified as inches
# in common SCALE_IN and the conversion function are defined
fig_inches = (SCALE_IN*mm2inches(W_FACT*ROW_5_WIDTH_COL_2_MM), SCALE_IN*mm2inches(H_FACT*ROW_5_HEIGHT_MM))

with mpl.rc_context(fname=MPL_RC):
    fig = plt.figure(figsize=fig_inches)
    
    for i in range(n_post):
        plt.plot(np.linspace(1,num_eig,num_eig),np.cumsum(eig_val_mat[i]/np.sum(eig_val_mat[i])),
                     color=col1[i],lw=1,marker='o',label=str(i+1))
        plt.xlabel('component')
        plt.ylabel('variance explained (%)')
    plt.locator_params(axis='y', nbins=3)
#     plt.legend(bbox_to_anchor=(1.15, 1), loc='upper right')

    if save_fig:
        plt.savefig(PANEL_H, facecolor='None', transparent=True)  # the figure is saved as svg
        plt.close()
        svg(PANEL_H)
    else:
        plt.show()

### panel I

In [None]:
# first eigenvector of posterior covariances

from matplotlib.colors import LogNorm

eig_number = 0
sum_eig1 = np.sum(np.abs(eig_vec_mat[:,:,eig_number]),axis=0)
arg_sort_sum_eig1 = np.argsort(sum_eig1)
LABELS_HH1 = np.array(LABELS_HH)
    
vmin = np.min(np.abs(eig_vec_mat[:,:,eig_number]))
vmax = np.max(np.abs(eig_vec_mat[:,:,eig_number]))

# matplotlib takes figsize specified as inches
# in common SCALE_IN and the conversion function are defined
fig_inches = (SCALE_IN*mm2inches(W_FACT*ROW_5_WIDTH_COL_3_MM), SCALE_IN*mm2inches(H_FACT*ROW_5_HEIGHT_MM))

with mpl.rc_context(fname=MPL_RC):
    fig = plt.figure(figsize=fig_inches)

    plt.imshow(np.abs(eig_vec_mat[:,arg_sort_sum_eig1,eig_number]),
               extent=[1,num_eig,1,n_post],norm=LogNorm(vmin=vmin, vmax=vmax), aspect='auto')
    cb1 = plt.colorbar()
    cb1.outline.set_visible(False)
    ax = plt.gca()
    ax.set_xticks(np.linspace(1,num_eig,num_eig))
    ax.set_xticklabels(LABELS_HH1[arg_sort_sum_eig1])
    ax.set_yticks([])
#     ax.axes.get_yaxis().set_visible(False)
    plt.ylabel('recording')
    plt.title('|eigenvector 1|',fontsize=12)
    
    if save_fig:
        plt.savefig(PANEL_I, facecolor='None', transparent=True)  # the figure is saved as svg
        plt.close()
        svg(PANEL_I)
    else:
        plt.show()

### compose figure

In [None]:
fig = create_fig(FIG_WIDTH_MM, FIG_HEIGHT_MM)

xoffset = 0.
yoffset = 2.

# row 1: panels A and B
fig = add_svg(fig, PANEL_A, 0 + xoffset, 0 + yoffset, scale=0.75)
fig = add_svg(fig, PANEL_B, ROW_1_WIDTH_COL_1_MM + xoffset, 0 + yoffset, scale=.98/W_FACT)

# row 2: panel C
fig = add_svg(fig, PANEL_C, 0 + xoffset, ROW_1_HEIGHT_MM + yoffset)

# row 3: panel D
fig = add_svg(fig, PANEL_D1, 0 + xoffset, ROW_1_HEIGHT_MM + ROW_2_HEIGHT_MM + yoffset)
fig = add_svg(fig, PANEL_D2, W_FACT*ROW_3_WIDTH_COL_1_MM + xoffset, ROW_1_HEIGHT_MM + ROW_2_HEIGHT_MM + yoffset)
fig = add_svg(fig, PANEL_D3, 2*W_FACT*ROW_3_WIDTH_COL_1_MM + xoffset, ROW_1_HEIGHT_MM + ROW_2_HEIGHT_MM + yoffset)
fig = add_svg(fig, PANEL_D4, 3*W_FACT*ROW_3_WIDTH_COL_1_MM + xoffset, ROW_1_HEIGHT_MM + ROW_2_HEIGHT_MM + yoffset)

# row 4: panels E and F
yoffset_F = 2
fig = add_svg(fig, PANEL_E, 0 + xoffset , ROW_1_HEIGHT_MM + ROW_2_HEIGHT_MM + ROW_3_HEIGHT_MM + yoffset, scale=.9)
fig = add_svg(fig, PANEL_F, ROW_4_WIDTH_COL_1_MM + xoffset,
              ROW_1_HEIGHT_MM + ROW_2_HEIGHT_MM + ROW_3_HEIGHT_MM + yoffset + yoffset_F, scale=.98/W_FACT)
fig = add_svg(fig, PANEL_F1, 1.27*ROW_4_WIDTH_COL_1_MM + xoffset,
              ROW_1_HEIGHT_MM + ROW_2_HEIGHT_MM + ROW_3_HEIGHT_MM + yoffset - 0.4, scale=.8)

# row 5: panels G, H and I
fig = add_svg(fig, PANEL_G, 0 + xoffset,
              ROW_1_HEIGHT_MM + ROW_2_HEIGHT_MM + ROW_3_HEIGHT_MM + ROW_4_HEIGHT_MM + yoffset)
fig = add_svg(fig, PANEL_H, ROW_5_WIDTH_COL_1_MM + xoffset,
              ROW_1_HEIGHT_MM + ROW_2_HEIGHT_MM + ROW_3_HEIGHT_MM + ROW_4_HEIGHT_MM + yoffset)
fig = add_svg(fig, PANEL_I, 2*ROW_5_WIDTH_COL_1_MM + xoffset,
              ROW_1_HEIGHT_MM + ROW_2_HEIGHT_MM + ROW_3_HEIGHT_MM + ROW_4_HEIGHT_MM + yoffset)

###########
yoffset = 2.3

# row 1: panels A and B
fig = add_label(fig, 'A', 0, 0 + yoffset)
fig = add_label(fig, 'B', ROW_1_WIDTH_COL_1_MM, 0 + yoffset)

# row 2: panel C
fig = add_label(fig, 'C', 0 , ROW_1_HEIGHT_MM + yoffset)

# row 3: panel D
fig = add_label(fig, 'D', 0, ROW_1_HEIGHT_MM + ROW_2_HEIGHT_MM + yoffset)

# row 4: panels E and F
fig = add_label(fig, 'E', 0, ROW_1_HEIGHT_MM + ROW_2_HEIGHT_MM + ROW_3_HEIGHT_MM + yoffset)
fig = add_label(fig, 'F', ROW_4_WIDTH_COL_1_MM, ROW_1_HEIGHT_MM + ROW_2_HEIGHT_MM + ROW_3_HEIGHT_MM + yoffset)

# row 5: panels G, H and I
fig = add_label(fig, 'G', 0, ROW_1_HEIGHT_MM + ROW_2_HEIGHT_MM + ROW_3_HEIGHT_MM + ROW_4_HEIGHT_MM + yoffset)
fig = add_label(fig, 'H', ROW_5_WIDTH_COL_1_MM,
                ROW_1_HEIGHT_MM + ROW_2_HEIGHT_MM + ROW_3_HEIGHT_MM + ROW_4_HEIGHT_MM + yoffset)
fig = add_label(fig, 'I', 2*ROW_5_WIDTH_COL_1_MM,
                ROW_1_HEIGHT_MM + ROW_2_HEIGHT_MM + ROW_3_HEIGHT_MM + ROW_4_HEIGHT_MM + yoffset)

###########
PATH_SVG = PATH_DROPBOX_FIGS + 'fig_hh.svg'
fig.save(PATH_SVG)


svg(PATH_SVG)

!$INKSCAPE --export-pdf $PATH_DROPBOX_FIGS/fig_hh.pdf $PATH_SVG