## Figure on inference in HH model on Allen data

In [None]:
%run -i common.ipynb

import lfimodels.hodgkinhuxley.utils as utils

from lfimodels.hodgkinhuxley.HodgkinHuxley import HodgkinHuxley
from lfimodels.hodgkinhuxley.HodgkinHuxleyStatsMoments import HodgkinHuxleyStatsMoments
from lfimodels.hodgkinhuxley.HodgkinHuxleyStatsSpikes_mf import HodgkinHuxleyStatsSpikes_mf
from scipy.special import erf

# FIGURE and GRID
FIG_HEIGHT_MM = 60
FIG_WIDTH_MM = FIG_WIDTH_MM  # set in common notebook to a default value for all figures
FIG_N_ROWS = 2
ROW_1_NCOLS = 2
ROW_1_HEIGHT_MM = FIG_HEIGHT_MM / FIG_N_ROWS
ROW_1_WIDTH_COL_1_MM = FIG_WIDTH_MM / ROW_1_NCOLS
ROW_1_WIDTH_COL_2_MM = FIG_WIDTH_MM / ROW_1_NCOLS
ROW_2_NCOLS = 4.2
ROW_2_HEIGHT_MM = FIG_HEIGHT_MM / FIG_N_ROWS
ROW_2_WIDTH_COL_1_MM = FIG_WIDTH_MM / ROW_2_NCOLS
ROW_2_WIDTH_COL_2_MM = FIG_WIDTH_MM / ROW_2_NCOLS
ROW_2_WIDTH_COL_3_MM = FIG_WIDTH_MM / ROW_2_NCOLS


W_FACT = 0.85
H_FACT = 0.85

PATH_DROPBOX_FIGS = PATH_DROPBOX + 'figs/'

# PATHS
PANEL_A = PATH_DROPBOX_FIGS + 'fig_allen_a.svg'
PANEL_B = PATH_DROPBOX_FIGS + 'fig_allen_b.svg'
PANEL_B1 = PATH_DROPBOX_FIGS + 'fig_allen_b1.svg'
PANEL_C = PATH_DROPBOX_FIGS + 'fig_allen_c.svg'
PANEL_D = PATH_DROPBOX_FIGS + 'fig_allen_d.svg'
PANEL_E = PATH_DROPBOX_FIGS + 'fig_allen_e.svg'

## save figures or not


In [None]:
save_fig = 0 # 1: save figures

### load data

In [None]:
true_params, labels_params = utils.obs_params()

seed = 1
prior_uniform = True
prior_log = False
prior_extent = True
n_xcorr = 0
n_mom = 4
cython=True
n_summary = 7
summary_stats = 1


list_cells_AllenDB = [[518290966,57,0.0234/126],[509881736,39,0.0153/184],[566517779,46,0.0195/198],
                      [567399060,38,0.0259/161],[569469018,44,0.033/403],[532571720,42,0.0139/127],
                      [555060623,34,0.0294/320],[534524026,29,0.027/209],[532355382,33,0.0199/230],
                      [526950199,37,0.0186/218]]


n_post = len(list_cells_AllenDB)

# SNPE parameters
n_components = 1
n_sims = 25000
n_rounds = 2
svi = False
if svi:
    svi_flag = '_svi'
else:
    svi_flag = '_nosvi'

# IBEA parameters
algo = 'ibea'
offspring_size = 500
max_ngen = 100

obs_ls = []
I_ls = []
dt_ls = []
t_on_ls = []
t_off_ls = []
obs_stats_ls = []
m_ls = []
s_ls = []
posterior_ls = []
res_ls = []
halloffame_ls = []
for cell_num in range(n_post):

    ephys_cell = list_cells_AllenDB[cell_num][0]
    sweep_number = list_cells_AllenDB[cell_num][1]
    A_soma = list_cells_AllenDB[cell_num][2]
    junction_potential = -14

    obs = utils.allen_obs_data(ephys_cell=ephys_cell,sweep_number=sweep_number,A_soma=A_soma)

    obs['data'] = obs['data'] + junction_potential
    I = obs['I']
    dt = obs['dt']
    t_on = obs['t_on']
    t_off = obs['t_off']
    
    obs_ls.append(obs)
    I_ls.append(I)
    dt_ls.append(dt)
    t_on_ls.append(t_on)
    t_off_ls.append(t_off)
    
    obs_stats = utils.allen_obs_stats(data=obs,ephys_cell=ephys_cell,sweep_number=sweep_number,
                                  n_xcorr=n_xcorr,n_mom=n_mom,
                                  summary_stats=summary_stats,n_summary=n_summary)
    
    obs_stats_ls.append(obs_stats[0])
    
    m = HodgkinHuxley(I=I, dt=dt, V0=obs['data'][0], seed=seed, cython=cython, prior_log=prior_log)
    m_ls.append(m)
    
#     s = HodgkinHuxleyStatsSpikes_mf(t_on=t_on, t_off=t_off,n_summary=n_summary)
    s = HodgkinHuxleyStatsMoments(t_on=t_on, t_off=t_off,n_xcorr=n_xcorr,n_mom=n_mom,n_summary=n_summary)
    s_ls.append(s)
    
    ##############################################################################
    # SNPE results
#     filename1 = '../hodgkinhuxley/results/allen_'+str(ephys_cell)+'_'+str(sweep_number)+\
#     '_run_1_round2_prior0013_param8_statspikes'+svi_flag+'_ncomp'+str(n_components)+\
#     '_nsims'+str(n_sims*n_rounds)+'_snpe_rej_res.pkl'
    filename1 = '../hodgkinhuxley/results/allen_'+str(ephys_cell)+'_'+str(sweep_number)+\
    '_run_1_round2_prior0013_param8'+svi_flag+'_ncomp'+str(n_components)+\
    '_nsims'+str(n_sims*n_rounds)+'_snpe_res.pkl'
    
    res = io.load(filename1)
    res_ls.append(res)
    
    posterior = res.predict(obs_stats)
    posterior_ls.append(posterior)

    ##############################################################################
    # IBEA results
#     _, halloffame, _, _ = io.load_pkl('../hodgkinhuxley/results/allen_'+str(ephys_cell)+'_'+str(sweep_number)+\
#                                                '_run_1_offspr'+str(offspring_size)+\
#                                                '_max_gen'+str(max_ngen)+ '_param8_statspikes_' + algo + '.pkl')
    _, halloffame, _, _ = io.load_pkl('../hodgkinhuxley/results/allen_'+str(ephys_cell)+'_'+str(sweep_number)+\
                                               '_run_1_offspr'+str(offspring_size)+\
                                               '_max_gen'+str(max_ngen)+ '_param8_' + algo + '.pkl')
    
    halloffame_ls.append(halloffame)

### panel A

In [None]:
# number of cell recordings plotted
num_rec = 3
cell_ls = [0,1,2]

# matplotlib takes figsize specified as inches
# in common SCALE_IN and the conversion function are defined
fig_inches = (SCALE_IN*mm2inches(1.1*ROW_1_WIDTH_COL_1_MM), SCALE_IN*mm2inches(H_FACT*ROW_1_HEIGHT_MM))

with mpl.rc_context(fname=MPL_RC):
    fig = plt.figure(figsize=fig_inches)
    
    i = 0
    for cell_num in cell_ls:
        y_obs = obs_ls[cell_num]['data']
        t = obs_ls[cell_num]['time']
        duration = np.max(t)
        
        mn_post = posterior_ls[cell_num].xs[np.argmax(posterior_ls[cell_num].a)].m
        x_post = m_ls[cell_num].gen_single(mn_post)
        
        plt.subplot(2,3,i+1)
        plt.plot(t, y_obs, color = COL['OBS'], lw=2, label='Allen Cell Types Database')
        ax = plt.gca()
        ax.set_xticks([])
        
        if i==0:
            ax.set_yticks([-80, -20, 40])
        else:
            ax.set_yticks([])
            
        if i==1:
            plt.legend(bbox_to_anchor=(0.52, 1.4), loc='upper center')
        
        plt.subplot(2,3,i+4)
        plt.plot(t, x_post['data'], color = COL['MODE'], lw=2, label='mode')
        plt.xlabel('time (ms)')
        
        ax = plt.gca()       
        ax.set_xticks([0, duration/2, duration])
        if i==0:
            plt.ylabel('voltage (mV)',x=0, y=1.2)
            ax.set_yticks([-80, -20, 40])
        else:
            ax.set_yticks([])
            
        if i==1:
            plt.legend(bbox_to_anchor=(0.35, 1.35), loc='upper center')
            
        i = i+1

    if save_fig:
        plt.savefig(PANEL_A, facecolor='None', transparent=True)  # the figure is saved as svg
        plt.close()
        svg(PANEL_A)
    else:
        plt.show()

### panel B

In [None]:
# # matplotlib takes figsize specified as inches
# # in common SCALE_IN and the conversion function are defined
# fig_inches = (SCALE_IN*mm2inches(0.5*W_FACT*ROW_1_WIDTH_COL_2_MM), SCALE_IN*mm2inches(H_FACT*ROW_1_HEIGHT_MM))

# with mpl.rc_context(fname=MPL_RC):
#     fig = plt.figure(figsize=fig_inches)

#     ax1 = fig.add_axes([0.05, 0.80, 0.6, 0.1])
#     # [left, bottom, width, height]

#     cmap = mpl.cm.Blues
#     norm = mpl.colors.Normalize(vmin=1, vmax=10)

#     cb1 = mpl.colorbar.ColorbarBase(ax1, cmap=cmap,norm=norm,orientation='horizontal')
#     cb1.set_label('recording', x=1.25, labelpad=-13)
#     cb1.outline.set_visible(False)
#     cb1.set_ticks([])

#     if save_fig:
#         plt.savefig(PANEL_B1, facecolor='None', transparent=True)  # the figure is saved as svg
#         plt.close()
#         svg(PANEL_B1)
#     else:
#         plt.show()

In [None]:
import plot_multipdf

# matplotlib takes figsize specified as inches
# in common SCALE_IN and the conversion function are defined
fig_inches = (SCALE_IN*mm2inches(W_FACT*ROW_1_WIDTH_COL_2_MM), SCALE_IN*mm2inches(2*H_FACT*ROW_1_HEIGHT_MM))

col_min = 1
num_colors = n_post+col_min
cm1 = mpl.cm.Blues
col1 = [cm1(1.*i/num_colors) for i in range(col_min,num_colors)]
col2 = [col1[-1]]*n_post

# partial_ls = [5,6,7]

cell_num = 0
prior_min = res_ls[cell_num].generator.prior.lower
prior_max = res_ls[cell_num].generator.prior.upper

prior_lims = np.concatenate((prior_min.reshape(-1,1),
                             prior_max.reshape(-1,1)),
                            axis=1)

with mpl.rc_context(fname=MPL_RC):
    fig = plt.figure(figsize=fig_inches)

#     plot_multipdf.plot_multipdf(posterior_ls, lims=prior_lims,
#                  labels_params=LABELS_HH, partial=False,
#                  figsize=fig_inches,fontscale=0.5,ticks = True,colrs=col1,imageshow=False)
    plot_multipdf.plot_multipdf(posterior_ls, lims=prior_lims,
                 labels_params=LABELS_HH, partial=False,
                 figsize=fig_inches,fontscale=0.5,ticks = True,colrs=col2,alpha=0.3,imageshow=False)

    if save_fig:
        plt.savefig(PANEL_B, facecolor='None', transparent=True)  # the figure is saved as svg
        plt.close()
        svg(PANEL_B)
    else:
        plt.show()

### panel C. TODO: recompute means and covariances for truncated (multivariate) posterior


In [None]:
# # recompute means and covariances for truncated normal

# def norm_cdf(x):
#     return .5*(1 + erf(x/np.sqrt(2)))

# def norm_pdf(x):
#     return (2*np.pi)**(-.5)*np.exp(-.5*x**2)


# class TruncNormal:
    
#     def __init__(self, m, s, a, b):
#         self.m = m
#         self.s = s
#         self.a = a
#         self.b = b
        
#     def z_score(self, x):
#         return (x - self.m) / self.s
    
#     @property
#     def a1(self):
#         return self.z_score(self.a)
    
#     @property
#     def b1(self):
#         return self.z_score(self.b)
    
#     @property
#     def normalizer(self):
#         return norm_cdf(self.b1) - norm_cdf(self.a1)
    
#     @property
#     def mean(self):
#         return self.m + (norm_pdf(self.a1) - norm_pdf(self.b1)) * self.s / self.normalizer
    
#     @property
#     def var(self):
#         f1 = (self.a1 * norm_pdf(self.a1) - self.b1 * norm_pdf(self.b1)) / self.normalizer
#         f2 = ((norm_pdf(self.a1) - norm_pdf(self.b1)) / self.normalizer) ** 2
#         return self.s**2 * (1 + f1 + f2)
    
#     @property
#     def entropy(self):
#         f1 = 0.5 * (self.a1 * norm_pdf(self.a1) - self.b1 * norm_pdf(self.b1)) / self.normalizer
#         f2 = np.log(np.sqrt(2 * np.pi) * self.s * self.normalizer) + .5
#         return f1 + f2

In [None]:
# posterior_marginals = []
# cv_post = []
# cv_post_nocorrection = []
# for cell_num in range(n_post):
#     mn_post = posterior_ls[cell_num].xs[np.argmax(posterior_ls[cell_num].a)].m
#     cov_post = posterior_ls[cell_num].xs[np.argmax(posterior_ls[cell_num].a)].S
#     std_post = np.sqrt(np.diag(cov_post))
    
#     posterior_marginals2 = []
#     cv_post1 = []
#     for mi, si, ai, bi in zip(mn_post, std_post, prior_min, prior_max):
#         posterior_marginals1 = TruncNormal(mi, si, ai, bi)
#         posterior_marginals2.append(posterior_marginals1)
#         cv_post1.append(np.sqrt(posterior_marginals1.var) / posterior_marginals1.mean)
    
#     cv_post.append(cv_post1)
#     cv_post_nocorrection.append(std_post / mn_post)
#     posterior_marginals.append(posterior_marginals2)

# cv_post_mat = np.asarray(cv_post)
# cv_post_nocorrection_mat = np.asarray(cv_post_nocorrection)
# LABELS_HH1 = np.array(LABELS_HH)

# plt.boxplot(cv_post_mat, labels=LABELS_HH1)
# plt.ylabel('coefficient of variation');

In [None]:
# collect means and normalized covariances across recordings
cv_post_ls = []
for cell_num in range(n_post):
    mn_post = posterior_ls[cell_num].xs[np.argmax(posterior_ls[cell_num].a)].m
    
    # divide covariances by respective means: coefficient of variation matrices
    mn_post1 = np.diag(mn_post)
    cov_post = posterior_ls[cell_num].xs[np.argmax(posterior_ls[cell_num].a)].S
    cv_post_ls.append(np.linalg.inv(mn_post1).dot(cov_post).dot(np.linalg.inv(mn_post1)))

# eigenvalues of posterior covariances
# check if number of components is 1 for each posterior
n_components_crit = 1
for i in range(n_post):
    if res_ls[i].network.n_components>1:
        n_components_crit = 0
    
# compute eigenvectors of covariances
eig_val_ls = []
eig_vec_ls = []
if n_components_crit==1:
    for cell_num in range(n_post):
        w, v = np.linalg.eig(cv_post_ls[cell_num])
        arg_sort = np.argsort(w)[::-1]
        eig_val_ls.append(w[arg_sort])
        eig_vec_ls.append(v[:,arg_sort])
    eig_val_mat = np.asarray(eig_val_ls)
    eig_vec_mat = np.asarray(eig_vec_ls)
else:
    print('some posteriors have more than 1 component, so this analysis is invalid')

num_eig = len(eig_val_ls[0])

col3 = ['#1f77b4']*n_post

# matplotlib takes figsize specified as inches
# in common SCALE_IN and the conversion function are defined
fig_inches = (SCALE_IN*mm2inches(W_FACT*ROW_2_WIDTH_COL_1_MM), SCALE_IN*mm2inches(H_FACT*ROW_2_HEIGHT_MM))

with mpl.rc_context(fname=MPL_RC):
    fig = plt.figure(figsize=fig_inches)
    
    for i in range(n_post):
        plt.plot(np.linspace(1,num_eig,num_eig),np.cumsum(eig_val_mat[i]/np.sum(eig_val_mat[i])),
                     color=col3[i],lw=1,marker='o',label=str(i+1))
    plt.xlabel('component')
    plt.ylabel('variance explained (%)')
    plt.title('coefficient of variation matrix',fontsize=12)
    plt.locator_params(axis='y', nbins=3);
#     plt.legend(bbox_to_anchor=(1.15, 1), loc='upper right')

    if save_fig:
        plt.savefig(PANEL_C, facecolor='None', transparent=True)  # the figure is saved as svg
        plt.close()
        svg(PANEL_C)
    else:
        plt.show()

### panel D

In [None]:
# first eigenvector of posterior covariances

from matplotlib.colors import LogNorm

eig_number = 0
sum_eig1 = np.sum(np.abs(eig_vec_mat[:,:,eig_number]),axis=0)
arg_sort_sum_eig1 = np.argsort(sum_eig1)
LABELS_HH1 = np.array(LABELS_HH)
    
vmin = np.min(np.abs(eig_vec_mat[:,:,eig_number]))
vmax = np.max(np.abs(eig_vec_mat[:,:,eig_number]))

# matplotlib takes figsize specified as inches
# in common SCALE_IN and the conversion function are defined
fig_inches = (SCALE_IN*mm2inches(W_FACT*ROW_2_WIDTH_COL_2_MM), SCALE_IN*mm2inches(H_FACT*ROW_2_HEIGHT_MM))

with mpl.rc_context(fname=MPL_RC):
    fig = plt.figure(figsize=fig_inches)

    plt.imshow(np.abs(eig_vec_mat[:,arg_sort_sum_eig1,eig_number]),
               extent=[1,num_eig,1,n_post],norm=LogNorm(vmin=vmin, vmax=vmax), aspect='auto')
    cb1 = plt.colorbar()
    cb1.outline.set_visible(False)
    ax = plt.gca()
    ax.set_xticks(np.linspace(1,num_eig,num_eig))
    ax.set_xticklabels(LABELS_HH1[arg_sort_sum_eig1])
    ax.set_yticks([])
#     ax.axes.get_yaxis().set_visible(False)
    plt.ylabel('recording')
    plt.title('|component 1|',fontsize=12)
    
    if save_fig:
        plt.savefig(PANEL_D, facecolor='None', transparent=True)  # the figure is saved as svg
        plt.close()
        svg(PANEL_D)
    else:
        plt.show()

### panel E

In [None]:
# second eigenvector of posterior covariances
eig_number = 1
sum_eig1 = np.sum(np.abs(eig_vec_mat[:,:,eig_number]),axis=0)
arg_sort_sum_eig1 = np.argsort(sum_eig1)
LABELS_HH1 = np.array(LABELS_HH)
    
vmin = np.min(np.abs(eig_vec_mat[:,:,eig_number]))
vmax = np.max(np.abs(eig_vec_mat[:,:,eig_number]))

# matplotlib takes figsize specified as inches
# in common SCALE_IN and the conversion function are defined
fig_inches = (SCALE_IN*mm2inches(W_FACT*ROW_2_WIDTH_COL_3_MM), SCALE_IN*mm2inches(H_FACT*ROW_2_HEIGHT_MM))

with mpl.rc_context(fname=MPL_RC):
    fig = plt.figure(figsize=fig_inches)

    plt.imshow(np.abs(eig_vec_mat[:,arg_sort_sum_eig1,eig_number]),
               extent=[1,num_eig,1,n_post],norm=LogNorm(vmin=vmin, vmax=vmax), aspect='auto')
    cb1 = plt.colorbar()
    cb1.outline.set_visible(False)
    ax = plt.gca()
    ax.set_xticks(np.linspace(1,num_eig,num_eig))
    ax.set_xticklabels(LABELS_HH1[arg_sort_sum_eig1])
    ax.set_yticks([])
#     ax.axes.get_yaxis().set_visible(False)
    plt.ylabel('recording')
    plt.title('|component 2|',fontsize=12)
    
    if save_fig:
        plt.savefig(PANEL_E, facecolor='None', transparent=True)  # the figure is saved as svg
        plt.close()
        svg(PANEL_E)
    else:
        plt.show()

### compose figure

In [None]:
fig = create_fig(FIG_WIDTH_MM, FIG_HEIGHT_MM)

xoffset = 0.
yoffset = 2.

# row 1: panels A and B
yoffset_B = 2
fig = add_svg(fig, PANEL_A, 0 + xoffset , 0 + yoffset, scale=.9)
fig = add_svg(fig, PANEL_B, ROW_1_WIDTH_COL_1_MM + xoffset, 0 + yoffset + yoffset_B, scale=.98/W_FACT)
# fig = add_svg(fig, PANEL_B1, 1.27*ROW_1_WIDTH_COL_1_MM + xoffset, 0 + yoffset - 0.4, scale=.8)

# row 2: panels C, D and E
fig = add_svg(fig, PANEL_C, 0 + xoffset, ROW_1_HEIGHT_MM + yoffset)
fig = add_svg(fig, PANEL_D, ROW_2_WIDTH_COL_1_MM + xoffset, ROW_1_HEIGHT_MM + yoffset)
fig = add_svg(fig, PANEL_E, 2*ROW_2_WIDTH_COL_1_MM + xoffset, ROW_1_HEIGHT_MM + yoffset)

###########
yoffset = 2.3

# row 1: panels A and B
fig = add_label(fig, 'A', 0, 0 + yoffset)
fig = add_label(fig, 'B', ROW_1_WIDTH_COL_1_MM, 0 + yoffset)

# row 2: panels C, D and E
fig = add_label(fig, 'C', 0, ROW_1_HEIGHT_MM + yoffset)
fig = add_label(fig, 'D', ROW_2_WIDTH_COL_1_MM, ROW_1_HEIGHT_MM + yoffset)
fig = add_label(fig, 'E', 2*ROW_2_WIDTH_COL_1_MM, ROW_1_HEIGHT_MM + yoffset)

###########
PATH_SVG = PATH_DROPBOX_FIGS + 'fig_allen.svg'
fig.save(PATH_SVG)


svg(PATH_SVG)

!$INKSCAPE --export-pdf $PATH_DROPBOX_FIGS/fig_allen.pdf $PATH_SVG