## Figure on inference in HH model on simulated data

In [None]:
%run -i common.ipynb

import lfimodels.hodgkinhuxley.utils as utils

from lfimodels.hodgkinhuxley.HodgkinHuxley import HodgkinHuxley
from lfimodels.hodgkinhuxley.HodgkinHuxleyStatsMoments import HodgkinHuxleyStatsMoments
from lfimodels.hodgkinhuxley.HodgkinHuxleyStatsSpikes_mf import HodgkinHuxleyStatsSpikes_mf

# FIGURE and GRID
FIG_HEIGHT_MM = 90
FIG_WIDTH_MM = FIG_WIDTH_MM  # set in common notebook to a default value for all figures
FIG_N_ROWS = 3
ROW_1_NCOLS = 2
ROW_1_HEIGHT_MM = FIG_HEIGHT_MM / FIG_N_ROWS
ROW_1_WIDTH_COL_1_MM = FIG_WIDTH_MM / ROW_1_NCOLS
ROW_1_WIDTH_COL_2_MM = FIG_WIDTH_MM / ROW_1_NCOLS
ROW_2_NCOLS = 2
ROW_2_HEIGHT_MM = FIG_HEIGHT_MM / FIG_N_ROWS
ROW_2_WIDTH_COL_1_MM = FIG_WIDTH_MM / ROW_2_NCOLS
ROW_2_WIDTH_COL_2_MM = FIG_WIDTH_MM / ROW_2_NCOLS
ROW_3_NCOLS = 4
ROW_3_HEIGHT_MM = FIG_HEIGHT_MM / FIG_N_ROWS
ROW_3_WIDTH_COL_1_MM = FIG_WIDTH_MM / ROW_3_NCOLS
ROW_3_WIDTH_COL_2_MM = FIG_WIDTH_MM / ROW_3_NCOLS
ROW_3_WIDTH_COL_3_MM = FIG_WIDTH_MM / ROW_3_NCOLS
ROW_3_WIDTH_COL_4_MM = FIG_WIDTH_MM / ROW_3_NCOLS


W_FACT = 0.85
H_FACT = 0.85

PATH_DROPBOX_FIGS = PATH_DROPBOX + 'figs/'

# PATHS
PANEL_A = PATH_DROPBOX_FIGS + 'fig_hh_a.svg'
PANEL_B = PATH_DROPBOX_FIGS + 'fig_hh_b.svg'
PANEL_C = PATH_DROPBOX_FIGS + 'fig_hh_c.svg'
PANEL_D1 = PATH_DROPBOX_FIGS + 'fig_hh_d1.svg'
PANEL_D2 = PATH_DROPBOX_FIGS + 'fig_hh_d2.svg'
PANEL_D3 = PATH_DROPBOX_FIGS + 'fig_hh_d3.svg'
PANEL_D4 = PATH_DROPBOX_FIGS + 'fig_hh_d4.svg'

## save figures or not


In [None]:
save_fig = 0 # 1: save figures

## panels A-C

### load data

In [None]:
true_params, labels_params = utils.obs_params()

seed = 1
prior_uniform = True
prior_log = False
prior_extent = True
n_xcorr = 0
n_mom = 4
cython=True
summary_stats = 1

I, t_on, t_off, dt = utils.syn_current()
A_soma = np.pi*((70.*1e-4)**2)  # cm2

obs = utils.syn_obs_data(I, dt, true_params, seed=seed, cython=cython)
y_obs = obs['data']
t = obs['time']
duration = np.max(t)

m = HodgkinHuxley(I, dt, V0=obs['data'][0],seed=seed, cython=cython,prior_log=prior_log)

p = utils.prior(true_params=true_params,prior_uniform=prior_uniform,
                prior_extent=prior_extent,prior_log=prior_log, seed=seed)


# number of summary features
n_summary_ls = [1,4,7]
n_post = len(n_summary_ls)

#######################
# SNPE parameters
n_components = 2
n_sims = 50000
n_rounds = 2
svi = False
if svi:
    svi_flag = '_svi'
else:
    svi_flag = '_nosvi'

#######################

s_ls = []
posterior_ls = []
res_ls = []
for nsum in n_summary_ls:
            
#     s = HodgkinHuxleyStatsSpikes_mf(t_on=t_on, t_off=t_off,n_summary=nsum)
    s = HodgkinHuxleyStatsMoments(t_on=t_on, t_off=t_off,n_xcorr=n_xcorr,n_mom=n_mom,n_summary=nsum)
    s_ls.append(s)
    
    ##############################################################################
    # SNPE results
#     filename1 = '../hodgkinhuxley/results/sim_run_2_round2_prior0013_param8_statspikes_nsum'+str(nsum)+\
#     svi_flag+'_ncomp'+str(n_components)+'_nsims'+str(n_sims*n_rounds)+'_snpe_rej_res.pkl'
    filename1 = '../hodgkinhuxley/results/sim_run_1_round2_prior0013_param8_nsum'+str(nsum)+\
    svi_flag+'_ncomp'+str(n_components)+'_nsims'+str(n_sims*n_rounds)+'_snpe_res.pkl'
    
    res = io.load(filename1)
    res_ls.append(res)
    
    posterior = res.predict(res.obs)
    posterior_ls.append(posterior)

### panel A

In [None]:
# matplotlib takes figsize specified as inches
# in common SCALE_IN and the conversion function are defined
fig_inches = (SCALE_IN*mm2inches(W_FACT*ROW_1_WIDTH_COL_1_MM), SCALE_IN*mm2inches(H_FACT*ROW_1_HEIGHT_MM))

svg(PANEL_A)

### panel B

In [None]:
import plot_multipdf

prior_min = p.lower
prior_max = p.upper

prior_lims = np.concatenate((prior_min.reshape(-1,1),
                             prior_max.reshape(-1,1)),
                            axis=1)

mn_post = posterior_ls[-1].xs[np.argmax(posterior_ls[-1].a)].m
    
# less likely parameter set from posterior than mode
post_low = mn_post*0.8

post_samp = np.concatenate((mn_post.reshape(-1,1),post_low.reshape(-1,1)), axis=1)
post_samp1 = post_low.reshape(-1,1)
post_samp2 = np.concatenate((true_params.reshape(-1,1),post_low.reshape(-1,1)), axis=1)
col_samp = [COL['MODE'],COL['SAMPLES']]
col_samp1 = [COL['GT'],COL['SAMPLES']]


# matplotlib takes figsize specified as inches
# in common SCALE_IN and the conversion function are defined
fig_inches = (SCALE_IN*mm2inches(W_FACT*ROW_1_WIDTH_COL_2_MM), SCALE_IN*mm2inches(2*H_FACT*ROW_1_HEIGHT_MM))

with mpl.rc_context(fname=MPL_RC):
    fig = plt.figure(figsize=fig_inches)
    plot_multipdf.plot_pdf_multipts(posterior_ls[-1], lims=prior_lims,gt=post_samp2,labels_params=LABELS_HH,
                                    figsize=fig_inches,fontscale=0.5,ticks=True,col2=COL['SNPE'],col_samp=col_samp1)
    if save_fig:
        plt.savefig(PANEL_B, facecolor='None', transparent=True)  # the figure is saved as svg
        plt.close()
        svg(PANEL_B)
    else:
        plt.show()

### panel C

In [None]:
# number of simulations for same parameter set
num_rep = 100
        
x_high_ls = []
x_low_ls = []
sum_stats_high_ls = []
sum_stats_low_ls = []
for rep in range(num_rep):
    x_high = m.gen_single(mn_post)
    x_high_ls.append(x_high)

    sum_stats_high = res_ls[-1].generator.summary.calc([x_high])[0]
    sum_stats_high_ls.append(sum_stats_high)

    x_low = m.gen_single(post_low)
    x_low_ls.append(x_low)

    sum_stats_low = res_ls[-1].generator.summary.calc([x_low])[0]
    sum_stats_low_ls.append(sum_stats_low)
    
mn_stats_high = np.nanmean(sum_stats_high_ls,axis=0)
std_stats_high = np.nanstd(sum_stats_high_ls,axis=0)

mn_stats_low = np.nanmean(sum_stats_low_ls,axis=0)
std_stats_low = np.nanstd(sum_stats_low_ls,axis=0)

In [None]:
n_summary_stats = n_summary_ls[-1]

obs_stats_norm_mat = res_ls[-1].obs[0]/res_ls[-1].stats_std
# arg_sort_stats = np.argsort(obs_stats_norm_mat)
arg_sort_stats = np.linspace(0,n_summary_stats-1,n_summary_stats).astype('int')
LABELS_HH_SUMSTATS1 = np.array(LABELS_HH_SUMSTATS)

mn_stats_high_norm = mn_stats_high/res_ls[-1].stats_std
mn_stats_low_norm = mn_stats_low/res_ls[-1].stats_std

std_stats_high_norm = std_stats_high/res_ls[-1].stats_std
std_stats_low_norm = std_stats_low/res_ls[-1].stats_std

# matplotlib takes figsize specified as inches
# in common SCALE_IN and the conversion function are defined
fig_inches = (SCALE_IN*mm2inches(1.5*W_FACT*ROW_2_WIDTH_COL_1_MM), SCALE_IN*mm2inches(H_FACT*ROW_2_HEIGHT_MM))

with mpl.rc_context(fname=MPL_RC):
    fig = plt.figure(figsize=fig_inches)

    ax = plt.subplot(121)
    plt.plot(t, obs['data'], color = COL['OBS'], lw=2, label='observation')
    plt.plot(t, x_high['data'], color = col_samp[0], lw=2, label='mode')
    plt.plot(t, x_low['data'], color = col_samp[1], lw=2, label='low prob')
    plt.xlabel('time (ms)')
    plt.ylabel('voltage (mV)')

    ax = plt.gca()
    ax.set_xticks([0, duration/2, duration])
    ax.set_yticks([-80, -20, 40])
    
    width = 0.3
    ax = plt.subplot(122)
    plt.bar(np.linspace(0,n_summary_stats-1,n_summary_stats),obs_stats_norm_mat[arg_sort_stats],
            width,color=COL['OBS'],label='observation')
    plt.bar(np.linspace(0,n_summary_stats-1,n_summary_stats)+width,mn_stats_high_norm[arg_sort_stats],
            width, color=col_samp[0],yerr=std_stats_high_norm[arg_sort_stats],label='mode')
    plt.bar(np.linspace(0,n_summary_stats-1,n_summary_stats)+2*width,mn_stats_low_norm[arg_sort_stats],
            width, color=col_samp[1],yerr=std_stats_low_norm[arg_sort_stats],label='low probability')
    ax.set_xlim(-1.5*width,n_summary_stats+width/2)
    ax.set_xticks(np.linspace(0,n_summary_stats-1,n_summary_stats)+width/2)
    ax.set_xticklabels(LABELS_HH_SUMSTATS1[arg_sort_stats])
    plt.ylabel(r'$\frac{f}{\sigma_{f \ PRIOR}}$')
    plt.legend(bbox_to_anchor=(1.51, 1.1), loc='upper right')

#     plt.yscale('log')
    ax.yaxis.set_major_formatter(mpl.ticker.FormatStrFormatter('%.1f'))
    
        
    if save_fig:
        plt.savefig(PANEL_C, facecolor='None', transparent=True)  # the figure is saved as svg
        plt.close()
        svg(PANEL_C)
    else:
        plt.show()

### panel D

In [None]:
# matplotlib takes figsize specified as inches
# in common SCALE_IN and the conversion function are defined
fig_inches = (SCALE_IN*mm2inches(W_FACT*ROW_3_WIDTH_COL_1_MM), SCALE_IN*mm2inches(H_FACT*ROW_3_HEIGHT_MM))

partial_ls = [1,2,4,7]

label_feature = [' feature']+[' features']*(n_post-1)

PANEL_D_ls = [PANEL_D1,PANEL_D2,PANEL_D3]

for i in range(n_post):
    with mpl.rc_context(fname=MPL_RC):
        fig = plt.figure(figsize=fig_inches)
        plot_multipdf.plot_pdf_multipts(posterior_ls[i], lims=prior_lims,gt=true_params.reshape(-1,1),
                                        labels_params=LABELS_HH,
                                        figsize=fig_inches,fontscale=0.5,ticks=True,partial=True,
                                        partial_ls = partial_ls,col2=COL['SNPE'],col_samp=[COL['GT']])
        x0, xmax = plt.xlim()
        y0, ymax = plt.ylim()
        data_width = xmax - x0
        data_height = ymax - y0
        plt.text(x0+data_width*-2.7, y0+data_height*4.8,
                 str(n_summary_ls[i])+str(label_feature[i]),fontsize=12)
        if save_fig:
            plt.savefig(PANEL_D_ls[i], facecolor='None', transparent=True)  # the figure is saved as svg
            plt.close()
            svg(PANEL_D_ls[i])
        else:
            plt.show()

In [None]:
# mode for each of the 3 posteriors:

# mode of posterior with 1 feature (1st posterior)
# mode_feat1 = posterior_ls[0].xs[np.argmin(posterior_ls[0].a)].m
from scipy.optimize import minimize

def mog_fun1(x):
    return -posterior_ls[0].eval([x])
def mog_fun4(x):
    return -posterior_ls[1].eval([x])

# optimisation
x0 = [posterior_ls[0].xs[np.argmin(posterior_ls[0].a)].m]
res_xopt = minimize(mog_fun1, x0, tol=1e-9)
mode_feat1 = res_xopt['x']

# mode of posterior with 4 feature2 (2nd posterior)
x0 = [posterior_ls[1].xs[np.argmin(posterior_ls[1].a)].m]
res_xopt = minimize(mog_fun4, x0, tol=1e-9)
mode_feat4 = res_xopt['x']

# mode of posterior with all features (last posterior);
mode_feat7 = posterior_ls[-1].xs[np.argmax(posterior_ls[-1].a)].m

post_modes = np.concatenate((mode_feat1.reshape(-1,1),mode_feat4.reshape(-1,1),mode_feat7.reshape(-1,1)), axis=1)

x_feat1 = m.gen_single(mode_feat1)
x_feat4 = m.gen_single(mode_feat4)
x_feat7 = m.gen_single(mode_feat7)

In [None]:
# matplotlib takes figsize specified as inches
# in common SCALE_IN and the conversion function are defined
fig_inches = (SCALE_IN*mm2inches(W_FACT*ROW_3_WIDTH_COL_4_MM), SCALE_IN*mm2inches(H_FACT*ROW_3_HEIGHT_MM))

col_min = 1
num_colors = 2+col_min
cm1 = mpl.cm.Reds
col1 = [cm1(1.*i/num_colors) for i in range(col_min,num_colors)]

with mpl.rc_context(fname=MPL_RC):
    fig = plt.figure(figsize=fig_inches)

#     ax = plt.subplot(121)
    plt.plot(t, obs['data'], color=COL['OBS'], lw=2, label='observation')
    plt.plot(t, x_feat7['data'], color=COL['MODE'], alpha=1, lw=2, label='7 features')
    plt.plot(t, x_feat4['data'], '--', color=COL['MODE'], alpha=0.6, lw=2, label='4 features')
    plt.plot(t, x_feat1['data'], color=COL['MODE'], alpha=0.5, lw=2, label='1 feature')
    plt.xlabel('time (ms)')
    plt.ylabel('voltage (mV)')
    plt.legend(bbox_to_anchor=(1.35, 1.1), loc='upper right', fontsize=9)

    ax = plt.gca()
    ax.set_xticks([0, duration/2, duration])
    ax.set_yticks([-80, -20, 40])
    
#     n_summary_stats = n_summary_ls[-1]
#     width = 0.3
#     ax = plt.subplot(122)
#     plt.bar(np.linspace(0,n_summary_stats-1,n_summary_stats),res_ls[-1].obs[0]/res_ls[-1].stats_std,
#             width,color=COL['OBS'],label='observation')
#     plt.bar(np.linspace(0,n_summary_stats-1,n_summary_stats)+width,mn_stats_samp1/res_ls[-1].stats_std,
#             width, color=col1[1],yerr=std_stats_samp1/res_ls[-1].stats_std,label='sample 1')
#     plt.bar(np.linspace(0,n_summary_stats-1,n_summary_stats)+2*width,mn_stats_samp2/res_ls[-1].stats_std,
#             width, color=col1[0],yerr=std_stats_samp2/res_ls[-1].stats_std,label='sample 2')
#     ax.set_xlim(-1.5*width,n_summary_stats+width/2)
#     ax.set_xticks(np.linspace(0,n_summary_stats-1,n_summary_stats)+width/2)
#     ax.set_xticklabels(LABELS_HH_SUMSTATS)
#     plt.ylabel(r'$\frac{f}{\sigma_{PRIOR}}$')

# #     plt.yscale('log')
#     ax.yaxis.set_major_formatter(mpl.ticker.FormatStrFormatter('%.1f'))
    
        
    if save_fig:
        plt.savefig(PANEL_D4, facecolor='None', transparent=True)  # the figure is saved as svg
        plt.close()
        svg(PANEL_D4)
    else:
        plt.show()

### compose figure

In [None]:
fig = create_fig(FIG_WIDTH_MM, FIG_HEIGHT_MM)

xoffset = 0.
yoffset = 0.

# row 1: panels A and B
fig = add_svg(fig, PANEL_A, 0 + xoffset, 0 + yoffset, scale=0.75)
fig = add_svg(fig, PANEL_B, ROW_1_WIDTH_COL_1_MM + xoffset, 0 + yoffset, scale=.98/W_FACT)

# row 2: panel C
fig = add_svg(fig, PANEL_C, 0 + xoffset, ROW_1_HEIGHT_MM + yoffset)

# row 3: panel D
fig = add_svg(fig, PANEL_D1, 0 + xoffset, ROW_1_HEIGHT_MM + ROW_2_HEIGHT_MM + yoffset)
fig = add_svg(fig, PANEL_D2, W_FACT*ROW_3_WIDTH_COL_1_MM + xoffset, ROW_1_HEIGHT_MM + ROW_2_HEIGHT_MM + yoffset)
fig = add_svg(fig, PANEL_D3, 2*W_FACT*ROW_3_WIDTH_COL_1_MM + xoffset, ROW_1_HEIGHT_MM + ROW_2_HEIGHT_MM + yoffset)
fig = add_svg(fig, PANEL_D4, 3*W_FACT*ROW_3_WIDTH_COL_1_MM + xoffset, ROW_1_HEIGHT_MM + ROW_2_HEIGHT_MM + yoffset)

###########
yoffset = 2.3

# row 1: panels A and B
fig = add_label(fig, 'A', 0, 0 + yoffset)
fig = add_label(fig, 'B', ROW_1_WIDTH_COL_1_MM, 0 + yoffset)

# row 2: panel C
fig = add_label(fig, 'C', 0 , ROW_1_HEIGHT_MM + yoffset)

# row 3: panel D
fig = add_label(fig, 'D', 0, ROW_1_HEIGHT_MM + ROW_2_HEIGHT_MM + yoffset)

###########
PATH_SVG = PATH_DROPBOX_FIGS + 'fig_hh.svg'
fig.save(PATH_SVG)


svg(PATH_SVG)

!$INKSCAPE --export-pdf $PATH_DROPBOX_FIGS/fig_hh.pdf $PATH_SVG