## Figure on inference in HH model on simulated and Allen data

In [None]:
%run -i common.ipynb

import lfimodels.hodgkinhuxley.utils as utils

from lfimodels.hodgkinhuxley.HodgkinHuxley import HodgkinHuxley
from lfimodels.hodgkinhuxley.HodgkinHuxleyStatsSpikes_mf import HodgkinHuxleyStatsSpikes_mf

# FIGURE and GRID
FIG_HEIGHT_MM = 90
FIG_WIDTH_MM = FIG_WIDTH_MM  # set in common notebook to a default value for all figures
FIG_N_ROWS = 3
ROW_1_NCOLS = 3
ROW_1_HEIGHT_MM = FIG_HEIGHT_MM / FIG_N_ROWS
ROW_1_WIDTH_COL_1_MM = FIG_WIDTH_MM / ROW_1_NCOLS
ROW_1_WIDTH_COL_2_MM = 0.8*FIG_WIDTH_MM / ROW_1_NCOLS
ROW_1_WIDTH_COL_3_MM = 1.2*FIG_WIDTH_MM / ROW_1_NCOLS
ROW_2_NCOLS = 3
ROW_2_HEIGHT_MM = FIG_HEIGHT_MM / FIG_N_ROWS
ROW_2_WIDTH_COL_1_MM = FIG_WIDTH_MM / ROW_2_NCOLS
ROW_2_WIDTH_COL_2_MM = 0.8*FIG_WIDTH_MM / ROW_2_NCOLS
ROW_2_WIDTH_COL_3_MM = 1.2*FIG_WIDTH_MM / ROW_2_NCOLS
ROW_3_NCOLS = 2
ROW_3_HEIGHT_MM = FIG_HEIGHT_MM / FIG_N_ROWS
ROW_3_WIDTH_COL_1_MM = FIG_WIDTH_MM / ROW_3_NCOLS
ROW_3_WIDTH_COL_2_MM = FIG_WIDTH_MM / ROW_3_NCOLS

W_FACT = 0.85
H_FACT = 0.85

PATH_DROPBOX_FIGS = PATH_DROPBOX + 'figs/'

# PATHS
PANEL_A = PATH_DROPBOX_FIGS + 'fig_hh_a.svg'
PANEL_B = PATH_DROPBOX_FIGS + 'fig_hh_b.svg'
PANEL_C = PATH_DROPBOX_FIGS + 'fig_hh_c.svg'
PANEL_D = PATH_DROPBOX_FIGS + 'fig_hh_d.svg'
PANEL_E = PATH_DROPBOX_FIGS + 'fig_hh_e.svg'
PANEL_F = PATH_DROPBOX_FIGS + 'fig_hh_f.svg'
PANEL_G = PATH_DROPBOX_FIGS + 'fig_hh_g.svg'
PANEL_H = PATH_DROPBOX_FIGS + 'fig_hh_h.svg'

## save figures or not


In [None]:
save_fig = 0 # 1: save figures

## panels A-C

### load data

In [None]:
true_params, labels_params = utils.obs_params()

seed = 1
prior_uniform = True
prior_log = False
prior_extent = True
n_xcorr = 0
n_mom = 4
cython=True
summary_stats = 3

I, t_on, t_off, dt = utils.syn_current()
A_soma = np.pi*((70.*1e-4)**2)  # cm2

obs = utils.syn_obs_data(I, dt, true_params, seed=seed, cython=cython)
y_obs = obs['data']
t = obs['time']
duration = np.max(t)

m = HodgkinHuxley(I, dt, V0=obs['data'][0],seed=seed, cython=cython,prior_log=prior_log)

p = utils.prior(true_params=true_params,prior_uniform=prior_uniform,
                prior_extent=prior_extent,prior_log=prior_log, seed=seed)


# number of summary features
n_summary_ls = [1,4,8]
n_post = len(n_summary_ls)

#######################
# SNPE parameters
n_components = 2
n_sims = 50000
n_rounds = 2
svi = False
if svi:
    svi_flag = '_svi'
else:
    svi_flag = '_nosvi'

#######################

s_ls = []
posterior_ls = []
res_ls = []
for nsum in n_summary_ls:
            
    s = HodgkinHuxleyStatsSpikes_mf(t_on=t_on, t_off=t_off,n_summary=nsum)
    s_ls.append(s)
    
    ##############################################################################
    # SNPE results
    filename1 = '../hodgkinhuxley/results/sim_run_2_round2_prior0013_param8_statspikes_nsum'+str(nsum)+\
    svi_flag+'_ncomp'+str(n_components)+'_nsims'+str(n_sims*n_rounds)+'_snpe_rej_res.pkl'
    
    res = io.load(filename1)
    res_ls.append(res)
    
    posterior = res.predict(res.obs)
    posterior_ls.append(posterior)

### panel A

In [None]:
# matplotlib takes figsize specified as inches
# in common SCALE_IN and the conversion function are defined
fig_inches = (SCALE_IN*mm2inches(W_FACT*ROW_1_WIDTH_COL_1_MM), SCALE_IN*mm2inches(H_FACT*ROW_1_HEIGHT_MM))

with mpl.rc_context(fname=MPL_RC):
    fig = plt.figure(figsize=fig_inches)

    gs = mpl.gridspec.GridSpec(2, 1, height_ratios=[4, 1])
    ax = plt.subplot(gs[0])
    plt.plot(t, obs['data'], color = COL['GT'], lw=2, label='')
    plt.ylabel('voltage (mV)')
    plt.legend(bbox_to_anchor=(1.15, 1), loc='upper right')

#     plt.ylim([-80,80])
    ax.set_xticks([])
    ax.set_yticks([-80, -20, 40])
    
    
#     ########################################
#     # add rectangles labelling summary features
#     props = dict(boxstyle='round', color='r', facecolor='r')
#     arrowprops=dict(arrowstyle='|-|')
# #     arrowprops=dict(arrowstyle='->')
#     color_text = 'w'
#     fontweight = 'bold'
#     # place a text box in upper left in axes coords
#     pos_1 = [0.18,0.29,0.42]
#     for pos in pos_1:
#         ax.text(pos, 0.93, '1', transform=ax.transAxes, bbox=props, color=color_text, fontweight=fontweight)
#     ax.text(pos+0.07, 0.93, '...', transform=ax.transAxes, fontweight=fontweight)
    
#     ax.text(0.15, -.23, '2', transform=ax.transAxes, bbox=props, color=color_text, fontweight=fontweight)
#     ax.annotate('', xy=(20, -70), xytext=(10, -70),arrowprops=arrowprops,size=3)
    
#     ax.text(0.1, 0.8, '3', transform=ax.transAxes, bbox=props, color=color_text, fontweight=fontweight)
#     ax.annotate('', xy=(25, 50), xytext=(15, 50),arrowprops=arrowprops,size=3)
    
#     ax.text(0.04, 0.18, '4,5', transform=ax.transAxes, bbox=props, color=color_text, fontweight=fontweight)
#     ax.text(0.37, 0.05, '6', transform=ax.transAxes, bbox=props, color=color_text, fontweight=fontweight)
#     for i in range(len(pos_1)-1):
#         ax.text((pos_1[i]+pos_1[i+1])/2, 1.2, '7', transform=ax.transAxes, bbox=props, color=color_text,
#                 fontweight=fontweight)
#     ax.text(0.37, 0.5, '8', transform=ax.transAxes, bbox=props, color=color_text, fontweight=fontweight)   
#     ########################################

    ax = plt.subplot(gs[1])
    plt.plot(t,I*A_soma*1e3,color = COL['I'], lw=2)
    plt.xlabel('time (ms)')
    plt.ylabel('input (nA)')

    ax.set_xticks([0, duration/2, duration])
    ax.set_yticks([0, 1.1*np.max(I*A_soma*1e3)])
    ax.yaxis.set_major_formatter(mpl.ticker.FormatStrFormatter('%.2f'))
    
#     plt.tight_layout()
    
    if save_fig:
        plt.savefig(PANEL_A, facecolor='None', transparent=True)  # the figure is saved as svg
        plt.close()
        svg(PANEL_A)
    else:
        plt.show()

### panel B

In [None]:
import plot_multipdf

col_min = 1
num_colors = n_post+col_min
cm1 = mpl.cm.Blues
col1 = [cm1(1.*i/num_colors) for i in range(col_min,num_colors)]

prior_min = p.lower
prior_max = p.upper

prior_lims = np.concatenate((prior_min.reshape(-1,1),
                             prior_max.reshape(-1,1)),
                            axis=1)

posterior_ls_rev = list(reversed(posterior_ls))
col1_rev = list(reversed(col1))

partial_ls = [5,6,7]

# matplotlib takes figsize specified as inches
# in common SCALE_IN and the conversion function are defined
fig_inches = (SCALE_IN*mm2inches(W_FACT*ROW_1_WIDTH_COL_2_MM), SCALE_IN*mm2inches(H_FACT*ROW_1_HEIGHT_MM))

with mpl.rc_context(fname=MPL_RC):
    fig = plt.figure(figsize=fig_inches)
#     plot_multipdf.plot_multipdf(posterior_ls_rev, lims=prior_lims,
#                  labels_params=LABELS_HH, partial=True,
#                  figsize=fig_inches,fontscale=0.6,ticks = True,colrs=col1_rev);
    plot_multipdf.plot_multipdf(posterior_ls, lims=prior_lims,
                 labels_params=LABELS_HH, partial=True, partial_ls=partial_ls,
                 figsize=fig_inches,fontscale=0.6,ticks = True,colrs=col1)
    if save_fig:
        plt.savefig(PANEL_B, facecolor='None', transparent=True)  # the figure is saved as svg
        plt.close()
        svg(PANEL_B)
    else:
        plt.show()

### supplementary figure with full posterior

In [None]:
fig_inches = (SCALE_IN*mm2inches(FIG_WIDTH_MM), SCALE_IN*mm2inches(FIG_WIDTH_MM))

SUPP = 'supp_fig_hh'  # name for appendix figure
SUPP_FIG_SVG = PATH_DROPBOX_FIGS + SUPP + '.svg'
SUPP_FIG_PDF = PATH_DROPBOX_FIGS + SUPP + '.pdf'

with mpl.rc_context(fname=MPL_RC):
    fig = plt.figure(figsize=fig_inches)

#     plot_multipdf.plot_multipdf(posterior_ls_rev, lims=prior_lims,
#                  labels_params=LABELS_HH,
#                  figsize=fig_inches,ticks = True,colrs=col1_rev);
    plot_multipdf.plot_multipdf(posterior_ls, lims=prior_lims,
                 labels_params=LABELS_HH,
                 figsize=fig_inches,ticks = True,colrs=col1);
    
    if save_fig:
        plt.savefig(SUPP_FIG_SVG, facecolor='None', transparent=True)  # the figure is saved as svg
        plt.close()
        svg(SUPP_FIG_SVG)
        !$INKSCAPE --export-pdf $SUPP_FIG_PDF $SUPP_FIG_SVG
    else:
        plt.show()

### panel C

In [None]:
# number of simulations for same parameter set
num_rep = 100

x_ls = []
sum_stats_ls = []
for i in range(n_post):
    mn_post = posterior_ls[i].xs[np.argmax(posterior.a)].m

    x_ls1 = []
    sum_stats_ls1 = []
    for rep in range(num_rep):      
        x = m.gen_single(mn_post)
        x_ls1.append(x)
        sum_stats_post = res_ls[-1].generator.summary.calc([x])[0]
        sum_stats_ls1.append(sum_stats_post)
        
    x_ls.append(x_ls1)
    sum_stats_ls.append(sum_stats_ls1)

In [None]:
# matplotlib takes figsize specified as inches
# in common SCALE_IN and the conversion function are defined
fig_inches = (SCALE_IN*mm2inches(W_FACT*ROW_1_WIDTH_COL_3_MM), SCALE_IN*mm2inches(H_FACT*ROW_1_HEIGHT_MM))

col_min = 1
num_colors = n_post+col_min
cm1 = mpl.cm.Blues
col1 = [cm1(1.*i/num_colors) for i in range(col_min,num_colors)]

label_feature = [' feature']+[' features']*(n_post-1)

n_summary_stats = n_summary_ls[-1]

with mpl.rc_context(fname=MPL_RC):
    fig = plt.figure(figsize=fig_inches)

    for i in range(n_post):
        mean_err = np.nanmean(np.abs(sum_stats_ls[i]-res_ls[-1].obs[0]),axis=0)
        std_err = np.nanstd(np.abs(sum_stats_ls[i]-res_ls[-1].obs[0]),axis=0)
        
        plt.fill_between(np.linspace(0,n_summary_stats-1,n_summary_stats),
                     mean_err-std_err,
                     mean_err+std_err,
                     facecolor=col1[i], alpha=0.3)

        plt.plot(mean_err,color = col1[i], lw=2,label=str(n_summary_ls[i])+str(label_feature[i]))
        
#     plt.plot(np.abs(res_ls[-1].obs[0]),color=COL['GT'],linestyle='--', lw=2, label='observation')

    ax = plt.gca()
    ax.set_xticks(np.linspace(0,n_summary_stats-1,n_summary_stats))
    ax.set_xticklabels(LABELS_HH_SUMSTATS);
    plt.locator_params(axis='y', nbins=3)
    plt.ylabel(r'$f^*$ - f');

    handles, labels = ax.get_legend_handles_labels()
    ax.legend(bbox_to_anchor=(1.3, 1), loc='upper right')
    
    if save_fig:
        plt.savefig(PANEL_C, facecolor='None', transparent=True)  # the figure is saved as svg
        plt.close()
        svg(PANEL_C)
    else:
        plt.show()

## panels D-H

### load data

In [None]:
true_params, labels_params = utils.obs_params()

seed = 1
prior_uniform = True
prior_log = False
prior_extent = True
n_xcorr = 0
n_mom = 4
cython=True
n_summary = 8
summary_stats = 3


list_cells_AllenDB = [[518290966,57,0.0234/126],[509881736,39,0.0153/184],[566517779,46,0.0195/198],
                      [567399060,38,0.0259/161],[569469018,44,0.033/403],[532571720,42,0.0139/127],
                      [555060623,34,0.0294/320],[534524026,29,0.027/209],[532355382,33,0.0199/230],
                      [526950199,37,0.0186/218]]


n_post = len(list_cells_AllenDB)

# SNPE parameters
n_components = 1
n_sims = 50000
n_rounds = 2
svi = False
if svi:
    svi_flag = '_svi'
else:
    svi_flag = '_nosvi'

# IBEA parameters
algo = 'ibea'
offspring_size = 1000
max_ngen = 100

obs_ls = []
I_ls = []
dt_ls = []
t_on_ls = []
t_off_ls = []
obs_stats_ls = []
m_ls = []
s_ls = []
posterior_ls = []
res_ls = []
halloffame_ls = []
for cell_num in range(n_post):

    ephys_cell = list_cells_AllenDB[cell_num][0]
    sweep_number = list_cells_AllenDB[cell_num][1]
    A_soma = list_cells_AllenDB[cell_num][2]
    junction_potential = -14

    obs = utils.allen_obs_data(ephys_cell=ephys_cell,sweep_number=sweep_number,A_soma=A_soma)

    obs['data'] = obs['data'] + junction_potential
    I = obs['I']
    dt = obs['dt']
    t_on = obs['t_on']
    t_off = obs['t_off']
    
    obs_ls.append(obs)
    I_ls.append(I)
    dt_ls.append(dt)
    t_on_ls.append(t_on)
    t_off_ls.append(t_off)
    
    obs_stats = utils.allen_obs_stats(data=obs,ephys_cell=ephys_cell,sweep_number=sweep_number,
                                  n_xcorr=n_xcorr,n_mom=n_mom,
                                  summary_stats=summary_stats,n_summary=n_summary)
    
    obs_stats_ls.append(obs_stats[0])
    
    m = HodgkinHuxley(I=I, dt=dt, V0=obs['data'][0], seed=seed, cython=cython, prior_log=prior_log)
    m_ls.append(m)
    
    s = HodgkinHuxleyStatsSpikes_mf(t_on=t_on, t_off=t_off,n_summary=n_summary)
    s_ls.append(s)
    
    ##############################################################################
    # SNPE results
#     filename1 = '../hodgkinhuxley/results/allen_'+str(ephys_cell)+'_'+str(sweep_number)+\
#     '_run_1_round2_prior0013_param8_statspikes'+svi_flag+'_ncomp'+str(n_components)+\
#     '_nsims'+str(n_sims*n_rounds)+'_snpe_rej.pkl'
    filename2 = '../hodgkinhuxley/results/allen_'+str(ephys_cell)+'_'+str(sweep_number)+\
    '_run_1_round2_prior0013_param8_statspikes'+svi_flag+'_ncomp'+str(n_components)+\
    '_nsims'+str(n_sims*n_rounds)+'_snpe_rej_res.pkl'
#     _, _, posterior = io.load_pkl(filename1)
    
    res = io.load(filename2)
    res_ls.append(res)
    
    posterior = res.predict(obs_stats)
    posterior_ls.append(posterior)

    ##############################################################################
    # IBEA results
    _, halloffame, _, _ = io.load_pkl('../hodgkinhuxley/results/allen_'+str(ephys_cell)+'_'+str(sweep_number)+\
                                               '_run_1_offspr'+str(offspring_size)+\
                                               '_max_gen'+str(max_ngen)+ '_param8_statspikes_' + algo + '.pkl')
    
    halloffame_ls.append(halloffame)

### panel D

In [None]:
# cell id
cell_num = 0
y_obs = obs_ls[cell_num]['data']
t = obs_ls[cell_num]['time']
duration = np.max(t)
I = I_ls[cell_num]
A_soma = list_cells_AllenDB[cell_num][2]


# matplotlib takes figsize specified as inches
# in common SCALE_IN and the conversion function are defined
fig_inches = (SCALE_IN*mm2inches(W_FACT*ROW_2_WIDTH_COL_1_MM), SCALE_IN*mm2inches(H_FACT*ROW_2_HEIGHT_MM))

with mpl.rc_context(fname=MPL_RC):
    fig = plt.figure(figsize=fig_inches)

    gs = mpl.gridspec.GridSpec(2, 1, height_ratios=[4, 1])
    ax = plt.subplot(gs[0])
    plt.plot(t, y_obs, color = COL['GT'], lw=2, label='')
    plt.ylabel('voltage (mV)')
#     plt.title('cell '+str(list_cells_AllenDB[cell_num][0])+'; sweep number '+str(list_cells_AllenDB[cell_num][1]))
    plt.legend(bbox_to_anchor=(1.15, 1), loc='upper right')

    ax.set_xticks([])
    ax.set_yticks([-80, -20, 40])

    ax = plt.subplot(gs[1])
    plt.plot(t,I*A_soma*1e3,color = COL['I'], lw=2)
    plt.xlabel('time (ms)')
    plt.ylabel('input (nA)')

    ax.set_xticks([0, duration/2, duration])
    ax.set_yticks([0, 1.1*np.max(I*A_soma*1e3)])
    ax.yaxis.set_major_formatter(mpl.ticker.FormatStrFormatter('%.2f'))
    
    if save_fig:
        plt.savefig(PANEL_D, facecolor='None', transparent=True)  # the figure is saved as svg
        plt.close()
        svg(PANEL_D)
    else:
        plt.show()

### panel E

In [None]:
# cell id
cell_num = 0

prior_min = res_ls[cell_num].generator.prior.lower
prior_max = res_ls[cell_num].generator.prior.upper

prior_lims = np.concatenate((prior_min.reshape(-1,1),
                             prior_max.reshape(-1,1)),
                            axis=1)

# matplotlib takes figsize specified as inches
# in common SCALE_IN and the conversion function are defined
fig_inches = (SCALE_IN*mm2inches(W_FACT*ROW_2_WIDTH_COL_2_MM), SCALE_IN*mm2inches(H_FACT*ROW_2_HEIGHT_MM))

with mpl.rc_context(fname=MPL_RC):
    fig = plt.figure(figsize=fig_inches)
    
    ibea_samp = halloffame_ls[cell_num][:]*1
    ibea_samp = np.array(ibea_samp).T
    
#     plot_multipdf.plot_multipdf(posterior_ls, lims=prior_lims,
#                  labels_params=LABELS_HH, partial=True,
#                  figsize=fig_inches,fontscale=0.6,ticks = True,colrs=col1)
    plot_multipdf.plot_pdf_multipts(posterior_ls[cell_num], lims=prior_lims,
                                    figsize=fig_inches, partial=True,
                                    gt=ibea_samp, labels_params=LABELS_HH,
                                    fontscale=0.6,ticks=True)
    if save_fig:
        plt.savefig(PANEL_E, facecolor='None', transparent=True)  # the figure is saved as svg
        plt.close()
        svg(PANEL_E)
    else:
        plt.show()

In [None]:
posterior_ls[0]

### supplementary figure: comparison SNPE with IBEA (Allen Cell Type Database)

In [None]:
# cell id
cell_num = 0

fig_inches = (SCALE_IN*mm2inches(FIG_WIDTH_MM), SCALE_IN*mm2inches(FIG_WIDTH_MM))
    
SUPP = 'supp_fig_hh_allen_ibea_cell'+str(cell_num)  # name for appendix figure
SUPP_FIG_SVG = PATH_DROPBOX_FIGS + SUPP + '.svg'
SUPP_FIG_PDF = PATH_DROPBOX_FIGS + SUPP + '.pdf'

with mpl.rc_context(fname=MPL_RC):
    fig = plt.figure(figsize=fig_inches)
    plot_multipdf.plot_pdf_multipts(posterior_ls[cell_num], lims=prior_lims,
                                    figsize=fig_inches,
                                    gt=ibea_samp, labels_params=LABELS_HH,
                                    fontscale=0.6,ticks=True)        

    if save_fig:
        plt.savefig(SUPP_FIG_SVG, facecolor='None', transparent=True)  # the figure is saved as svg
        plt.close()

        svg(SUPP_FIG_SVG)
        !$INKSCAPE --export-pdf $SUPP_FIG_PDF $SUPP_FIG_SVG
    else:
        plt.show()

### panel F

In [None]:
# number of simulations for same parameter set
num_rep = 100

# reject samples outside the prior box
num_rep1 = 10000*num_rep

mn_post_ls = []
x_snpe_ls = []
samp_post_ls = []
x_snpe_samp_ls = []
x_ibea_ls = []
sum_stats_snpe_ls = []
sum_stats_snpe_samp_ls = []
sum_stats_ibea_ls = []
for cell_num in range(n_post):
#     mn_post, _ = posterior.calc_mean_and_cov()
    mn_post = posterior_ls[cell_num].xs[np.argmax(posterior_ls[cell_num].a)].m
    mn_post_ls.append(mn_post)
    
    # reject samples outside the prior box
    samp_post = posterior_ls[cell_num].gen(num_rep1)
    ind = (samp_post > prior_min) & (samp_post < prior_max)
    samp_post = samp_post[np.prod(ind,axis=1)==1]
    
    samp_post_ls.append(samp_post[0:num_rep])
    
    x_snpe_ls1 = []
    x_snpe_samp_ls1 = []
    x_ibea_ls1 = []
    sum_stats_snpe_ls1 = []
    sum_stats_snpe_samp_ls1 = []
    sum_stats_ibea_ls1 = []
    for rep in range(num_rep):
        x_snpe = m_ls[cell_num].gen_single(mn_post)
        x_snpe_ls1.append(x_snpe)

        sum_stats_snpe = s_ls[cell_num].calc([x_snpe])[0]
        sum_stats_snpe_ls1.append(sum_stats_snpe)
        
        x_snpe_samp = m_ls[cell_num].gen_single(samp_post_ls[cell_num][rep,:])
        x_snpe_samp_ls1.append(x_snpe_samp)

        sum_stats_snpe_samp = s_ls[cell_num].calc([x_snpe_samp])[0]
        sum_stats_snpe_samp_ls1.append(sum_stats_snpe_samp)

        x_ibea = m_ls[cell_num].gen_single(halloffame_ls[cell_num][0])
        x_ibea_ls1.append(x_ibea)

        sum_stats_ibea = s_ls[cell_num].calc([x_ibea])[0]
        sum_stats_ibea_ls1.append(sum_stats_ibea)
    
    
    x_snpe_ls.append(x_snpe_ls1)
    sum_stats_snpe_ls.append(sum_stats_snpe_ls1)
    x_snpe_samp_ls.append(x_snpe_samp_ls1)
    sum_stats_snpe_samp_ls.append(sum_stats_snpe_samp_ls1)
    x_ibea_ls.append(x_ibea_ls1)
    sum_stats_ibea_ls.append(sum_stats_ibea_ls1)

In [None]:
# cell id
cell_num = 0

sum_stats_snpe_mat = np.asarray(sum_stats_snpe_ls[cell_num])
sum_stats_ibea_mat = np.asarray(sum_stats_ibea_ls[cell_num])
sum_stats_snpe_samp_mat = np.asarray(sum_stats_snpe_samp_ls[cell_num])

######################################
n_summary_stats = len(obs_stats_ls[cell_num])
xx = np.linspace(1,n_summary_stats,n_summary_stats).astype(int)
XX = np.tile(xx,(num_rep,1))

######################################
yy_samp = sum_stats_snpe_samp_mat-obs_stats_ls[cell_num]
yy_samp_err = np.nanstd(sum_stats_snpe_samp_mat-obs_stats_ls[cell_num],axis=0)
yy_snpe = sum_stats_snpe_mat-obs_stats_ls[cell_num]
yy_ibea = sum_stats_ibea_mat-obs_stats_ls[cell_num]

######################################
ratio_snpe_median = np.nanmedian(yy_snpe/yy_samp_err,axis=0)
median_ord = np.argsort(ratio_snpe_median)
ratio_snpe_std = np.nanstd(yy_snpe/yy_samp_err,axis=0)

ratio_ibea_median = np.nanmedian(yy_ibea/yy_samp_err,axis=0)
ratio_ibea_std = np.nanstd(yy_ibea/yy_samp_err,axis=0)

ratio_snpe_samp_median = np.nanmedian(yy_samp/yy_samp_err,axis=0)
ratio_snpe_samp_std = np.nanstd(yy_samp/yy_samp_err,axis=0)

######################################

median_ord = xx - 1 # rewrite reordering
LABELS_HH_SUMSTATS_reord = [ LABELS_HH_SUMSTATS[i] for i in median_ord]

## matplotlib takes figsize specified as inches
# in common SCALE_IN and the conversion function are defined
fig_inches = (SCALE_IN*mm2inches(W_FACT*ROW_2_WIDTH_COL_3_MM), SCALE_IN*mm2inches(H_FACT*ROW_2_HEIGHT_MM))

with mpl.rc_context(fname=MPL_RC):
    fig = plt.figure(figsize=fig_inches)

    plt.fill_between(xx,ratio_snpe_median[median_ord]-ratio_snpe_std[median_ord],
                     ratio_snpe_median[median_ord]+ratio_snpe_std[median_ord],
                     facecolor=COL['SNPE'], alpha=0.3)
    plt.plot(xx,ratio_snpe_median[median_ord],color = COL['SNPE'], lw=2,label='SNPE')
    
    plt.fill_between(xx,ratio_snpe_samp_median[median_ord]-ratio_snpe_samp_std[median_ord],
                 ratio_snpe_samp_median[median_ord]+ratio_snpe_samp_std[median_ord],
                 facecolor=COL['SAMPLES'], alpha=0.3)
    plt.plot(xx,ratio_snpe_samp_median[median_ord],color = COL['SAMPLES'], lw=2,label='samples')

    plt.fill_between(xx,ratio_ibea_median[median_ord]-ratio_ibea_std[median_ord],
                     ratio_ibea_median[median_ord]+ratio_ibea_std[median_ord],
                     facecolor=COL['IBEA'], alpha=0.3)
    plt.plot(xx,ratio_ibea_median[median_ord],color = COL['IBEA'], lw=2,label='IBEA')

    plt.ylabel(r'$\frac{f^* - f}{\sigma_{SNPE}}$')

    plt.plot(xx,np.ones(n_summary_stats),'--k')
    plt.plot(xx,-np.ones(n_summary_stats),'--k')
    
    ax = plt.gca()
    ax.set_xticks(xx)
    ax.set_xticklabels(LABELS_HH_SUMSTATS_reord)
    plt.locator_params(axis='y', nbins=3)
    ax.legend(bbox_to_anchor=(1.32, 1), loc='upper right');
    
    if save_fig:
        plt.savefig(PANEL_F, facecolor='None', transparent=True)  # the figure is saved as svg
        plt.close()
        svg(PANEL_F)
    else:
        plt.show()

## panels G, H

### panel G

In [None]:
var_post_ls = []
for cell_num in range(n_post):
    var_post = np.diag(posterior_ls[cell_num].xs[np.argmax(posterior_ls[cell_num].a)].S)
    var_post_ls.append(var_post)
    
var_post_mat = np.asarray(var_post_ls)
mn_post_snpe_mat = np.asarray(mn_post_ls)


best_ibea_ls = []
for cell_num in range(n_post):
    best_ibea_ls.append(halloffame_ls[cell_num][0])
    
best_ibea_mat = np.asarray(best_ibea_ls)

num_params = len(true_params)

xx_params = np.linspace(0,num_params-1,num_params)


col_min = 1
num_colors = n_post+col_min
cm1 = mpl.cm.Blues
col1 = [cm1(1.*i/num_colors) for i in range(col_min,num_colors)]

## matplotlib takes figsize specified as inches
# in common SCALE_IN and the conversion function are defined
fig_inches = (SCALE_IN*mm2inches(W_FACT*ROW_3_WIDTH_COL_1_MM), SCALE_IN*mm2inches(H_FACT*ROW_3_HEIGHT_MM))

with mpl.rc_context(fname=MPL_RC):
    fig = plt.figure(figsize=fig_inches)

    for cell_num in range(n_post):
        yy = (mn_post_snpe_mat[cell_num,:]-best_ibea_mat[cell_num,:])/np.sqrt(var_post_mat[cell_num,:])

        plt.scatter(xx_params, yy, color=col1[cell_num])
        plt.ylabel(r'$\frac{\theta_{SNPE} - \theta_{IBEA}}{\sigma_{SNPE}}$')

    plt.plot(xx_params,np.ones(num_params),'--k')
    plt.plot(xx_params,-np.ones(num_params),'--k')
    ax = plt.gca()
    ax.set_xticks(xx_params)
    ax.set_xticklabels(LABELS_HH)

    if save_fig:
        plt.savefig(PANEL_G, facecolor='None', transparent=True)  # the figure is saved as svg
        plt.close()
        svg(PANEL_G)
    else:
        plt.show()

### panel H

In [None]:
obs_stats_mat = np.transpose(np.tile(np.asarray(obs_stats_ls),(num_rep,1,1)),(1,0,2))
sum_stats_snpe_mat = np.asarray(sum_stats_snpe_ls)
sum_stats_snpe_samp_mat = np.asarray(sum_stats_snpe_samp_ls)
sum_stats_ibea_mat = np.asarray(sum_stats_ibea_ls)


## matplotlib takes figsize specified as inches
# in common SCALE_IN and the conversion function are defined
fig_inches = (SCALE_IN*mm2inches(W_FACT*ROW_3_WIDTH_COL_2_MM), SCALE_IN*mm2inches(H_FACT*ROW_3_HEIGHT_MM))

with mpl.rc_context(fname=MPL_RC):
    fig = plt.figure(figsize=fig_inches)
    
    for cell_num in range(n_post):
        yy_samp_err = np.nanstd(sum_stats_snpe_samp_mat[cell_num,:,:],axis=0)
        yy_snpe = sum_stats_snpe_mat[cell_num,:,:]-obs_stats_mat[cell_num,:,:]
        yy_ibea = sum_stats_ibea_mat[cell_num,:,:]-obs_stats_mat[cell_num,:,:]

        ######################################
        ratio_snpe_median = np.nanmedian(yy_snpe/yy_samp_err,axis=0)
        ratio_ibea_median = np.nanmedian(yy_ibea/yy_samp_err,axis=0)
        
        ######################################

        plt.scatter(xx, ratio_snpe_median,color=COL['SNPE'], label='SNPE')
        plt.scatter(xx, ratio_ibea_median,color=COL['IBEA'], label='IBEA')
        plt.ylabel(r'$\frac{f^* - f}{\sigma_{SNPE}}$')

        plt.plot(xx,np.ones(n_summary_stats),'--k')
        plt.plot(xx,-np.ones(n_summary_stats),'--k')
        ax = plt.gca()
        ax.set_xticks(xx)
        ax.set_xticklabels(LABELS_HH_SUMSTATS)
        plt.locator_params(axis='y', nbins=3)
        if cell_num == 0:
            ax.legend(bbox_to_anchor=(1.3, 1), loc='upper right');
    
    if save_fig:
        plt.savefig(PANEL_H, facecolor='None', transparent=True)  # the figure is saved as svg
        plt.close()
        svg(PANEL_H)
    else:
        plt.show()

### compose figure

In [None]:
fig = create_fig(FIG_WIDTH_MM, FIG_HEIGHT_MM)

yoffset = 2.
xoffset = 0.
xoffset_CF = -3.
xoffset_CF1 = -3.5
xoffset_H = xoffset_CF
fig = add_svg(fig, PANEL_A, 0 + xoffset, 0 + yoffset - 2)
fig = add_svg(fig, PANEL_B, ROW_1_WIDTH_COL_1_MM + xoffset, 0 + yoffset)
fig = add_svg(fig, PANEL_C, ROW_1_WIDTH_COL_1_MM + ROW_1_WIDTH_COL_2_MM + xoffset, 0 + yoffset)
yoffset = 2.
fig = add_svg(fig, PANEL_D, 0 + xoffset, ROW_1_HEIGHT_MM + yoffset)
fig = add_svg(fig, PANEL_E, ROW_2_WIDTH_COL_1_MM + xoffset , ROW_1_HEIGHT_MM + yoffset)
fig = add_svg(fig, PANEL_F, ROW_2_WIDTH_COL_1_MM + ROW_2_WIDTH_COL_2_MM + xoffset + xoffset_CF1,
              ROW_2_HEIGHT_MM + yoffset)
yoffset = 2.
fig = add_svg(fig, PANEL_G, 0 + xoffset, ROW_1_HEIGHT_MM + ROW_2_HEIGHT_MM + yoffset)
fig = add_svg(fig, PANEL_H, ROW_3_WIDTH_COL_1_MM + xoffset + xoffset_H, ROW_1_HEIGHT_MM + ROW_2_HEIGHT_MM + yoffset)

yoffset = 2.3
fig = add_label(fig, 'A', 0, 0 + yoffset)
fig = add_label(fig, 'B', ROW_1_WIDTH_COL_1_MM, 0 + yoffset)
fig = add_label(fig, 'C', ROW_1_WIDTH_COL_1_MM + ROW_1_WIDTH_COL_2_MM + xoffset_CF, 0 + yoffset)
yoffset = 2.3
fig = add_label(fig, 'D', 0, ROW_1_HEIGHT_MM + yoffset)
fig = add_label(fig, 'E', ROW_2_WIDTH_COL_1_MM, ROW_1_HEIGHT_MM + yoffset)
fig = add_label(fig, 'F', ROW_2_WIDTH_COL_1_MM + ROW_2_WIDTH_COL_2_MM + xoffset_CF, ROW_1_HEIGHT_MM + yoffset)
yoffset = 2.3
fig = add_label(fig, 'G', 0, ROW_1_HEIGHT_MM + ROW_2_HEIGHT_MM + yoffset)
fig = add_label(fig, 'H', ROW_3_WIDTH_COL_1_MM + xoffset + xoffset_H, ROW_1_HEIGHT_MM + ROW_2_HEIGHT_MM + yoffset)


PATH_SVG = PATH_DROPBOX_FIGS + 'fig_hh.svg'
fig.save(PATH_SVG)


svg(PATH_SVG)

!$INKSCAPE --export-pdf $PATH_DROPBOX_FIGS/fig_hh.pdf $PATH_SVG