## Comparison SNPE with SMC-ABC (Hodgkin-Huxley model on cell from Allen Cell Type Database)

In [None]:
import delfi.distribution as dd
import delfi.distribution.mixture.GaussianMixture as GaussianMixture
import delfi.generator as dg
import delfi.inference as infer
import delfi.utils.io as io
import delfi.summarystats as ds
import lfimodels.hodgkinhuxley.utils as utils
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pickle

from lfimodels.abc_methods import run_abc
from lfimodels.hodgkinhuxley.HodgkinHuxley import HodgkinHuxley
from lfimodels.hodgkinhuxley.HodgkinHuxleyStatsMoments import HodgkinHuxleyStatsMoments
from delfi.utils.viz import plot_pdf
from sklearn import mixture
from sklearn.neighbors.kde import KernelDensity

%matplotlib inline

## Load SNPE and IBEA results

In [None]:
list_cells_AllenDB = [[518290966,57,0.0234/126],[509881736,39,0.0153/184],[566517779,46,0.0195/198]]

seed = 1
prior_log = False
n_xcorr = 0
n_mom = 4
cython=True
n_summary = 10
summary_stats = 1


n_post = len(list_cells_AllenDB)

# SNPE parameters
n_components = 1
n_sims = 25000
n_rounds = 2

# SMC-ABC parameters
n_particles = 1e3
maxsim = 1e6

obs_ls = []
I_ls = []
dt_ls = []
t_on_ls = []
t_off_ls = []
obs_stats_ls = []
m_ls = []
s_ls = []
posterior_ls = []
res_ls = []
ps_smc_ls = []
logweights_smc_ls = []
eps_smc_ls = []
all_nsims_smc_ls = []
for cell_num in range(n_post):

    ephys_cell = list_cells_AllenDB[cell_num][0]
    sweep_number = list_cells_AllenDB[cell_num][1]
    A_soma = list_cells_AllenDB[cell_num][2]
    junction_potential = -14

    obs = utils.allen_obs_data(ephys_cell=ephys_cell,sweep_number=sweep_number,A_soma=A_soma)

    obs['data'] = obs['data'] + junction_potential
    I = obs['I']
    dt = obs['dt']
    t_on = obs['t_on']
    t_off = obs['t_off']
    
    obs_ls.append(obs)
    I_ls.append(I)
    dt_ls.append(dt)
    t_on_ls.append(t_on)
    t_off_ls.append(t_off)
    
    obs_stats = utils.allen_obs_stats(data=obs,ephys_cell=ephys_cell,sweep_number=sweep_number,
                                  n_xcorr=n_xcorr,n_mom=n_mom,
                                  summary_stats=summary_stats,n_summary=n_summary)
    
    obs_stats_ls.append(obs_stats[0])
    
    m = HodgkinHuxley(I=I, dt=dt, V0=obs['data'][0], seed=seed, cython=cython, prior_log=prior_log)
    m_ls.append(m)
    
    s = HodgkinHuxleyStatsMoments(t_on=t_on, t_off=t_off,n_xcorr=n_xcorr,n_mom=n_mom)
    s_ls.append(s)
    
    ##############################################################################
    # SNPE results
    filename1 = './results/allen_'+str(ephys_cell)+'_'+str(sweep_number)+\
    '_run_1_round2_prior0013_param8_nosvi_ncomp'+str(n_components)+'_nsims'+str(n_sims*n_rounds)+'_snpe_res.pkl'
    
    res = io.load(filename1)
    res_ls.append(res)
    
    posterior = res.predict(obs_stats)
    posterior_ls.append(posterior)

    ##############################################################################
    # SMC-ABC results
    filename1 = './results/allen_'+str(ephys_cell)+'_'+str(sweep_number)+'_run_1_prior0013_param8_smc_abc.pkl'
    ps_smc, logweights_smc, eps_smc, all_nsims_smc = io.load_pkl(filename1)
    
    ps_smc_ls.append(ps_smc)
    logweights_smc_ls.append(logweights_smc)
    eps_smc_ls.append(eps_smc)
    all_nsims_smc_ls.append(all_nsims_smc)

## collect modes and compute respective summary statistics for different recordings

In [None]:
# prior
true_params, _ = utils.obs_params()

seed = 1
prior_uniform = True
prior_log = False
prior_extent = True

p = utils.prior(true_params=true_params,prior_uniform=prior_uniform,
                prior_extent=prior_extent,prior_log=prior_log, seed=seed)

In [None]:
# number of simulations for same parameter set
num_rep = min(100,n_particles)

params_prior_mat = p.gen(num_rep)

mn_post_ls = []
x_snpe_ls = []
samp_post_ls = []
x_snpe_samp_ls = []

mode_post_smc_ls = []
x_smc_ls = []
samp_post_smc_ls = []
x_smc_samp_ls = []
weighted_samples_ls = []
pdf_smc_ls = []

x_prior_ls = []

sum_stats_snpe_ls = []
sum_stats_snpe_samp_ls = []
sum_stats_smc_ls = []
sum_stats_smc_samp_ls = []
sum_stats_prior_ls = []

for cell_num in range(n_post):
    
    #################################################################
    # SNPE
    mn_post = posterior_ls[cell_num].xs[np.argmax(posterior_ls[cell_num].a)].m
    mn_post_ls.append(mn_post)
    
    samp_post_ls.append(posterior_ls[cell_num].gen(num_rep))
    
    #################################################################
    # SMC-ABC
    # mean and covariance
    weights_smc = np.exp(logweights_smc_ls[cell_num])
    nsims_smc = np.asarray(all_nsims_smc_ls[cell_num])

    m_smc = np.dot(weights_smc[-1],ps_smc_ls[cell_num][-1])
    cov_smc = np.cov(ps_smc_ls[cell_num][-1].T,aweights = weights_smc[-1])
        
    # weighted samples
    num_rep_samples = np.round(weights_smc[-1]*n_particles).astype('int')
    weighted_samples = np.repeat(ps_smc_ls[cell_num][-1],num_rep_samples, axis=0)
    weighted_samples_ls.append(weighted_samples)

    mn_post_smc = np.mean(weighted_samples,axis=0)
    std_post_smc = np.std(weighted_samples,axis=0)
    weighted_samples_zscored = (weighted_samples - mn_post_smc) / std_post_smc
    
    # kernel density estimation for finding MAP
    kde = KernelDensity(kernel='gaussian', bandwidth=0.2)
    kde.fit(weighted_samples)
    log_dens = kde.score_samples(weighted_samples)
    ind_max = np.argmax(log_dens)
    mode_post_smc = weighted_samples[ind_max]
    mode_post_smc_ls.append(mode_post_smc)
    
    # most likely samples from posterior
    density_sort = np.argsort(log_dens)
    samp_post_smc_ls.append(weighted_samples[density_sort[-num_rep:]])
    
    # mixture of gaussians
    clf = mixture.GaussianMixture(n_components=4, covariance_type='full',init_params='kmeans',n_init=10)
    clf.fit(weighted_samples_zscored)
    pdf1 = GaussianMixture.MoG(a=clf.weights_,ms=clf.means_,Ss=clf.covariances_).ztrans_inv(mn_post_smc, std_post_smc)
     
    pdf_smc_ls.append(pdf1)
    
    #################################################################
        
    x_snpe_ls1 = []
    x_snpe_samp_ls1 = []
    x_smc_ls1 = []
    x_smc_samp_ls1 = []
    x_prior_ls1 = []
    sum_stats_snpe_ls1 = []
    sum_stats_snpe_samp_ls1 = []
    sum_stats_smc_ls1 = []
    sum_stats_smc_samp_ls1 = []
    sum_stats_prior_ls1 = []
    for rep in range(num_rep):
        
        # SNPE
        x_snpe = m_ls[cell_num].gen_single(mn_post)
        x_snpe_ls1.append(x_snpe)

        sum_stats_snpe = s_ls[cell_num].calc([x_snpe])[0]
        sum_stats_snpe_ls1.append(sum_stats_snpe)
        
        x_snpe_samp = m_ls[cell_num].gen_single(samp_post_ls[cell_num][rep,:])
        x_snpe_samp_ls1.append(x_snpe_samp)

        # SMC-ABC
        sum_stats_snpe_samp = s_ls[cell_num].calc([x_snpe_samp])[0]
        sum_stats_snpe_samp_ls1.append(sum_stats_snpe_samp)
        
        x_smc = m_ls[cell_num].gen_single(mode_post_smc)
        x_smc_ls1.append(x_smc)

        sum_stats_smc = s_ls[cell_num].calc([x_smc])[0]
        sum_stats_smc_ls1.append(sum_stats_smc)
        
        x_smc_samp = m_ls[cell_num].gen_single(samp_post_smc_ls[cell_num][rep,:])
        x_smc_samp_ls1.append(x_smc_samp)

        sum_stats_smc_samp = s_ls[cell_num].calc([x_smc_samp])[0]
        sum_stats_smc_samp_ls1.append(sum_stats_smc_samp)
        
        # PRIOR
        x_prior = m_ls[cell_num].gen_single(params_prior_mat[rep,:])
        x_prior_ls1.append(x_prior)

        sum_stats_prior = s_ls[cell_num].calc([x_prior])[0]
        sum_stats_prior_ls1.append(sum_stats_prior)
    
    
    x_snpe_ls.append(x_snpe_ls1)
    sum_stats_snpe_ls.append(sum_stats_snpe_ls1)
    x_snpe_samp_ls.append(x_snpe_samp_ls1)
    sum_stats_snpe_samp_ls.append(sum_stats_snpe_samp_ls1)
    x_smc_ls.append(x_smc_ls1)
    sum_stats_smc_ls.append(sum_stats_smc_ls1)
    x_smc_samp_ls.append(x_smc_samp_ls1)
    sum_stats_smc_samp_ls.append(sum_stats_smc_samp_ls1)
    x_prior_ls.append(x_prior_ls1)
    sum_stats_prior_ls.append(sum_stats_prior_ls1)

## MODES OF SNPE VS. SMC-ABC

In [None]:
fig = plt.figure(figsize=(30,7*n_post))

n_summary_stats = len(obs_stats_ls[0])
labels_sum_stats = [r'$spikes$',r'$r\_pot$',r'$\sigma_{r\_pot}$',r'$m_1$',r'$m_2$',r'$m_3$',r'$m_4$']

# colors
COL = {}
COL['GT']   = (35/255,86/255,167/255)
COL['SNPE'] = (0, 174/255,239/255)
COL['SMC-ABC']  = (102/255, 179/255, 46/255)
COL['PRIOR']  = (244/255, 152/255, 25/255)

for cell_num in range(n_post):   
    y_obs = obs_ls[cell_num]['data']
    t = obs_ls[cell_num]['time']
    duration = np.max(t)
    
    ax = plt.subplot(n_post,3,3*cell_num+1)
    plt.plot(t, y_obs, color=COL['GT'], lw=2, label='observation')
    plt.plot(t, x_snpe_ls[cell_num][0]['data'], color=COL['SNPE'], lw=2, label='SNPE')
    plt.plot(t, x_smc_ls[cell_num][0]['data'], color=COL['SMC-ABC'], lw=2, label='SMC-ABC')

    plt.xlabel('time (ms)')
    plt.ylabel('voltage (mV)')
    plt.title('cell '+str(list_cells_AllenDB[cell_num][0])+'; sweep number'+str(list_cells_AllenDB[cell_num][1]))

    ax = plt.gca()
    ax.legend(bbox_to_anchor=(1.1, 1), loc='upper right')

    ax.set_xticks([0, duration/2, duration])
    ax.set_yticks([-80, -20, 40]);
    

    ax = plt.subplot(n_post,3,3*cell_num+2)
    plt.plot(obs_stats_ls[cell_num], color=COL['GT'], lw=2, label='observation')

    mean_sum_stats_prior = np.nanmean(sum_stats_prior_ls[cell_num],axis=0)
    std_sum_stats_prior = np.nanstd(sum_stats_prior_ls[cell_num],axis=0)
    mean_sum_stats_snpe = np.nanmean(sum_stats_snpe_ls[cell_num],axis=0)
    std_sum_stats_snpe = np.nanstd(sum_stats_snpe_ls[cell_num],axis=0)
    mean_sum_stats_smc = np.nanmean(sum_stats_smc_ls[cell_num],axis=0)
    std_sum_stats_smc = np.nanstd(sum_stats_smc_ls[cell_num],axis=0)

    plt.fill_between(np.linspace(0,n_summary_stats-1,n_summary_stats),
                     mean_sum_stats_prior-std_sum_stats_prior,
                     mean_sum_stats_prior+std_sum_stats_prior,
                     facecolor=COL['PRIOR'], alpha=0.3)
    plt.fill_between(np.linspace(0,n_summary_stats-1,n_summary_stats),
                     mean_sum_stats_snpe-std_sum_stats_snpe,
                     mean_sum_stats_snpe+std_sum_stats_snpe,
                     facecolor=COL['SNPE'], alpha=0.3)
    plt.fill_between(np.linspace(0,n_summary_stats-1,n_summary_stats),
                     mean_sum_stats_smc-std_sum_stats_smc,
                     mean_sum_stats_smc+std_sum_stats_smc,
                     facecolor=COL['SMC-ABC'], alpha=0.3)
    plt.plot(mean_sum_stats_prior,color=COL['PRIOR'], lw=2, label='PRIOR')
    plt.plot(mean_sum_stats_snpe,color=COL['SNPE'], lw=2, label='SNPE')
    plt.plot(mean_sum_stats_smc,color=COL['SMC-ABC'], lw=2, label='SMC-ABC')
    ax.set_xticks(np.linspace(0,n_summary_stats-1,n_summary_stats))
    ax.set_xticklabels(labels_sum_stats)
    plt.ylabel('feature value')

    ax = plt.subplot(n_post,3,3*cell_num+3)
    plt.semilogy(np.abs(obs_stats_ls[cell_num]), color=COL['GT'],linestyle='--', lw=2)
    
    mean_err_prior = np.nanmean(np.abs(sum_stats_prior_ls[cell_num]-obs_stats_ls[cell_num]),axis=0)
    std_err_prior = np.nanstd(np.abs(sum_stats_prior_ls[cell_num]-obs_stats_ls[cell_num]),axis=0)
    mean_err_snpe = np.nanmean(np.abs(sum_stats_snpe_ls[cell_num]-obs_stats_ls[cell_num]),axis=0)
    std_err_snpe = np.nanstd(np.abs(sum_stats_snpe_ls[cell_num]-obs_stats_ls[cell_num]),axis=0)
    mean_err_smc = np.nanmean(np.abs(sum_stats_smc_ls[cell_num]-obs_stats_ls[cell_num]),axis=0)
    std_err_smc = np.nanstd(np.abs(sum_stats_smc_ls[cell_num]-obs_stats_ls[cell_num]),axis=0)

#     plt.fill_between(np.linspace(0,n_summary_stats-1,n_summary_stats),
#                      mean_err_prior-std_err_prior,
#                      mean_err_prior+std_err_prior,
#                      facecolor=COL['PRIOR'], alpha=0.3)
#     plt.fill_between(np.linspace(0,n_summary_stats-1,n_summary_stats),
#                      mean_err_snpe-std_err_snpe,
#                      mean_err_snpe+std_err_snpe,
#                      facecolor=COL['SNPE'], alpha=0.3)
#     plt.fill_between(np.linspace(0,n_summary_stats-1,n_summary_stats),
#                      mean_err_smc-std_err_smc,
#                      mean_err_smc+std_err_smc,
#                      facecolor=COL['SMC-ABC'], alpha=0.3)
    plt.semilogy(mean_err_prior,color=COL['PRIOR'], lw=2)
    plt.semilogy(mean_err_snpe,color=COL['SNPE'], lw=2)
    plt.semilogy(mean_err_smc,color=COL['SMC-ABC'], lw=2)
    ax.set_xticks(np.linspace(0,n_summary_stats-1,n_summary_stats))
    ax.set_xticklabels(labels_sum_stats);
    plt.ylabel(r'|$f^*$ - f|');

## errors per feature

In [None]:
fig = plt.figure(figsize=(20,20))

obs_stats_mat = np.transpose(np.tile(np.asarray(obs_stats_ls),(num_rep,1,1)),(1,0,2))
sum_stats_snpe_mat = np.asarray(sum_stats_snpe_ls)
sum_stats_smc_mat = np.asarray(sum_stats_smc_ls)

for i in range(n_summary_stats):
    plt.subplot(3,3,i+1)
    
    xx = np.nanmean(np.abs(sum_stats_smc_mat[:,:,i]-obs_stats_mat[:,:,i]),axis=1)
    xx_err = np.nanstd(np.abs(sum_stats_smc_mat[:,:,i]-obs_stats_mat[:,:,i]),axis=1)
    yy = np.nanmean(np.abs(sum_stats_snpe_mat[:,:,i]-obs_stats_mat[:,:,i]),axis=1)
    yy_err = np.nanstd(np.abs(sum_stats_snpe_mat[:,:,i]-obs_stats_mat[:,:,i]),axis=1)
    
    xx1 = [np.min(np.minimum(xx,yy)),np.max(np.maximum(xx,yy))]
    plt.errorbar(xx, yy, xerr=xx_err, yerr=yy_err, fmt='.', markersize=10)
    plt.plot(xx1,xx1,'--k')
    plt.xlabel(r'|$f^*$ - $f_{SMC-ABC}$|')
    plt.ylabel(r'|$f^*$ - $f_{SNPE}$|')
    plt.title(labels_sum_stats[i]);

In [None]:
fig = plt.figure(figsize=(10,7*n_post))

sum_stats_snpe_samp_mat = np.asarray(sum_stats_snpe_samp_ls)
sum_stats_smc_samp_mat = np.asarray(sum_stats_smc_samp_ls)

XX = np.tile(np.linspace(0,n_summary_stats-1,n_summary_stats),(num_rep,1))

for cell_num in range(n_post):
    plt.subplot(n_post,1,cell_num+1)
    yy_samp = sum_stats_snpe_samp_mat[cell_num,:,:]-obs_stats_mat[cell_num,:,:]
    yy_samp_err = np.nanstd(sum_stats_snpe_samp_mat[cell_num,:,:]-obs_stats_mat[cell_num,:,:],axis=0)
    yy_smc_samp = sum_stats_smc_samp_mat[cell_num,:,:]-obs_stats_mat[cell_num,:,:]
    yy_smc_samp_err = np.nanstd(sum_stats_smc_samp_mat[cell_num,:,:]-obs_stats_mat[cell_num,:,:],axis=0)
    yy_snpe = sum_stats_snpe_mat[cell_num,:,:]-obs_stats_mat[cell_num,:,:]
    yy_smc = sum_stats_smc_mat[cell_num,:,:]-obs_stats_mat[cell_num,:,:]
        
    plt.scatter(XX, yy_samp/yy_samp_err,marker='.',color=(105/255, 105/255, 105/255), label='SNPE samples')
    plt.scatter(XX, yy_snpe/yy_samp_err,color=COL['SNPE'], label='SNPE mode')
    plt.scatter(XX, yy_smc_samp/yy_smc_samp_err,marker='.',color=(200/255, 200/255, 200/255), label='SMC-ABC samples')
    plt.scatter(XX, yy_smc/yy_smc_samp_err,color=COL['SMC-ABC'], label='SMC-ABC')
    plt.ylabel(r'$\frac{\bar{f^* - f}}{\sigma_{f}}$')
    plt.title('cell '+str(list_cells_AllenDB[cell_num][0])+'; sweep number'+str(list_cells_AllenDB[cell_num][1]))

    plt.plot(np.linspace(0,n_summary_stats-1,n_summary_stats),np.ones(n_summary_stats),'--k')
    plt.plot(np.linspace(0,n_summary_stats-1,n_summary_stats),-np.ones(n_summary_stats),'--k')
    ax = plt.gca()
    ax.set_xticks(np.linspace(0,n_summary_stats-1,n_summary_stats))
    ax.set_xticklabels(labels_sum_stats)
    if cell_num == 0:
        ax.legend(bbox_to_anchor=(1.1, 1), loc='upper right');

## plot posteriors

In [None]:
labels_params1 =[r'$g_{Na}$', r'$g_{K}$', r'$g_{l}$', r'$g_{M}$', r'$t_{max}$', r'$-V_{T}$', r'$noise$', r'$-E_{l}$']

for cell_num in range(n_post):
    plot_pdf(posterior_ls[cell_num], samples=weighted_samples_ls[cell_num].T, figsize=(15,15),
             labels_params=labels_params1, ticks=True);

## plot posteriors with prior bounds

In [None]:
prior_min = p.lower
prior_max = p.upper

prior_lims = np.concatenate((prior_min.reshape(-1,1),
                             prior_max.reshape(-1,1)),
                            axis=1)

labels_params1 =[r'$g_{Na}$', r'$g_{K}$', r'$g_{l}$', r'$g_{M}$', r'$t_{max}$', r'$-V_{T}$', r'$noise$', r'$-E_{l}$']

for cell_num in range(n_post):
    plot_pdf(posterior_ls[cell_num], lims=prior_lims, samples=weighted_samples_ls[cell_num].T, figsize=(15,15),
             labels_params=labels_params1, ticks=True);