## SMC-ABC on Hodgkin-Huxley model on cell from Allen Cell Type Database

In [None]:
import delfi.distribution as dd
import delfi.distribution.mixture.GaussianMixture as GaussianMixture
import delfi.generator as dg
import delfi.inference as infer
import delfi.utils.io as io
import delfi.summarystats as ds
import lfimodels.hodgkinhuxley.utils as utils
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pickle

from lfimodels.abc_methods import run_abc
from lfimodels.hodgkinhuxley.HodgkinHuxley import HodgkinHuxley
from lfimodels.hodgkinhuxley.HodgkinHuxleyStatsMoments import HodgkinHuxleyStatsMoments
from lfimodels.hodgkinhuxley.HodgkinHuxleyStatsSpikes_mf import HodgkinHuxleyStatsSpikes_mf
from lfimodels.hodgkinhuxley.HodgkinHuxleyStatsSpikes import HodgkinHuxleyStatsSpikes
from delfi.utils.viz import plot_pdf
from sklearn import mixture
from sklearn.neighbors.kde import KernelDensity

%matplotlib inline

In [None]:
def param_transform(prior_log, x):
    if prior_log:
        return np.log(x)
    else:
        return x

def param_invtransform(prior_log, x):
    if prior_log:
        return np.exp(x)
    else:
        return x

In [None]:
true_params, labels_params = utils.obs_params()

n_params = len(true_params)
seed = 1
prior_uniform = True
prior_log = False
prior_extent = True
n_xcorr = 0
n_mom = 4
cython=True
n_summary = 10
summary_stats = 1

list_cells_AllenDB = [[518290966,57,0.0234/126],[509881736,39,0.0153/184],[566517779,46,0.0195/198],
                      [567399060,38,0.0259/161],[569469018,44,0.033/403],[532571720,42,0.0139/127],
                      [555060623,34,0.0294/320],[534524026,29,0.027/209],[532355382,33,0.0199/230],
                      [526950199,37,0.0186/218]]

cell_num = 0
ephys_cell = list_cells_AllenDB[cell_num][0]
sweep_number = list_cells_AllenDB[cell_num][1]
A_soma = list_cells_AllenDB[cell_num][2]
junction_potential = -14

obs = utils.allen_obs_data(ephys_cell=ephys_cell,sweep_number=sweep_number,A_soma=A_soma)

obs['data'] = obs['data'] + junction_potential
I = obs['I']
dt = obs['dt']
t_on = obs['t_on']
t_off = obs['t_off']

obs_stats = utils.allen_obs_stats(data=obs,ephys_cell=ephys_cell,sweep_number=sweep_number,
                                  n_xcorr=n_xcorr,n_mom=n_mom,
                                  summary_stats=summary_stats,n_summary=n_summary)

# define model, prior, summary statistics and generator
seed = None
m = HodgkinHuxley(I, dt, V0=obs['data'][0], seed=seed, cython=cython,prior_log=prior_log)
p = utils.prior(true_params=true_params,prior_uniform=prior_uniform,
                prior_extent=prior_extent,prior_log=prior_log, seed=seed)
s = HodgkinHuxleyStatsMoments(t_on=t_on, t_off=t_off,n_xcorr=n_xcorr,n_mom=n_mom)
#     s = HodgkinHuxleyStatsSpikes_mf(t_on=t_on, t_off=t_off,n_summary=n_summary)
#     s = HodgkinHuxleyStatsSpikes(t_on=t_on, t_off=t_off,n_summary=n_summary)
g = dg.Default(model=m, prior=p, summary=s)

bounds = np.asarray([p.lower,p.upper]).T

In [None]:
plt.plot(obs['time'],obs['data'])

In [None]:
print(t_on)
print(t_off)
print(dt)

In [None]:
obs_stats

### z-scoring summary statistics

In [None]:
n_summary_stats = len(obs_stats[0])
pilot_samples = 1000
_, pilots = g.gen(pilot_samples)
stats_mean = pilots.mean(axis=0).reshape(1,n_summary_stats)
stats_std = pilots.std(axis=0).reshape(1,n_summary_stats)

class normed_summary():
    def calc(self, y):
        x = g.summary.calc(y)
        return (x-stats_mean)/stats_std

obs_statz =  (obs_stats.flatten() - stats_mean) /  stats_std

## SMC-ABC

### initial epsilon

In [None]:
stats_median = np.median(pilots,axis=0)
stats_medianz =  (stats_median - stats_mean) /  stats_std
eps_init = run_abc.calc_dist(obs_statz,stats_medianz)

### run algorithm

In [None]:
n_particles = 1e3
maxsim = 1e6
ps_smc, logweights_smc, eps_smc, all_nsims_smc = run_abc.run_smc(model=m, prior=p, summary=normed_summary(),
                                                                 obs_stats=obs_statz,
                                                                 n_params=n_params, seed=None, 
                                                                 n_particles=n_particles,eps_init=eps_init,
                                                                 maxsim=maxsim, fn=None)

In [None]:
filename1 = './results/allen_'+str(ephys_cell)+'_'+str(sweep_number)+\
'_run_1_prior0013_param8_smc_abc.pkl'
io.save_pkl((ps_smc, logweights_smc, eps_smc, all_nsims_smc), filename1)

In [None]:
filename1 = './results/allen_'+str(ephys_cell)+'_'+str(sweep_number)+\
'_run_1_prior0013_param8_smc_abc.pkl'
ps_smc, logweights_smc, eps_smc, all_nsims_smc = io.load_pkl(filename1)

### mean and covariance

In [None]:
weights_smc = np.exp(logweights_smc)
nsims_smc = np.asarray(all_nsims_smc)

m_smc = []
cov_smc = []
for i in range(len(ps_smc)):
    m_smc.append(np.dot(weights_smc[i],ps_smc[i]))
    cov_smc.append(np.cov(ps_smc[i].T,aweights = weights_smc[i]))

### weighted samples

In [None]:
num_rep_samples = np.round(weights_smc[-1]*n_particles).astype('int')
weighted_samples = np.repeat(ps_smc[-1],num_rep_samples, axis=0)

params_mean = np.mean(weighted_samples,axis=0)
params_std = np.std(weighted_samples,axis=0)
weighted_samples_zscored = (weighted_samples - params_mean) / params_std

### kernel density estimation for finding MAP

In [None]:
kde = KernelDensity(kernel='gaussian', bandwidth=0.2)
kde.fit(weighted_samples)
log_dens = kde.score_samples(weighted_samples)
ind_max = np.argmax(log_dens)
mn_post = weighted_samples[ind_max]

### mixture of gaussians

In [None]:
clf = mixture.GaussianMixture(n_components=4, covariance_type='full',init_params='kmeans',n_init=10)
clf.fit(weighted_samples_zscored)
pdf1 = GaussianMixture.MoG(a=clf.weights_,ms=clf.means_,Ss=clf.covariances_).ztrans_inv(params_mean, params_std)

### plot posterior

In [None]:
prior_min = g.prior.lower
prior_max = g.prior.upper

prior_lims = np.concatenate((prior_min.reshape(-1,1),
                             prior_max.reshape(-1,1)),
                            axis=1)

# pdf1 = dd.Gaussian(m=m_smc[-1], S=cov_smc[-1])
plot_pdf(pdf1, lims=prior_lims, samples=weighted_samples.T,figsize=(15,15), labels_params=labels_params, ticks=True);

In [None]:
fig = plt.figure()

n_params = len(mn_post)
y_obs = obs['data']
t = obs['time']
duration = np.max(t)

COL = {}
COL['GT']   = (35/255,86/255,167/255)
COL['SNPE'] = (0, 174/255,239/255)

num_samp = 3

# most likely samples from posterior
density_sort = np.argsort(log_dens)
x_samp = weighted_samples[density_sort[-num_samp:]]

num_samp = len(x_samp[:,0])
num_colors = num_samp+1
cm1 = mpl.cm.Oranges
col1 = [cm1(1.*i/num_colors) for i in range(num_colors)]

params = param_invtransform(prior_log,np.concatenate((np.array([mn_post]),x_samp)))


V = np.zeros((len(t),1+num_samp))
for i in range(1+num_samp):
    m = HodgkinHuxley(I=I, dt=dt, V0=obs['data'][0], seed=230+i, cython=True, prior_log=prior_log)
    x = m.gen_single(param_transform(prior_log,params[i,:]))
    V[:,i] = x['data']
    if i>0:
        plt.plot(t, V[:, i], color = col1[i-1], lw=2, label='sample '+str(num_samp-i+1))

# plotting simulation
plt.plot(t, V[:, 0], color=COL['SNPE'], lw=2, label='mode')
plt.plot(t, y_obs, color=COL['GT'], lw=2, label='observation')


plt.xlabel('time (ms)')
plt.ylabel('voltage (mV)')

ax = plt.gca()
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[::-1], labels[::-1], bbox_to_anchor=(1.3, 1), loc='upper right')

ax.set_xticks([0, duration/2, duration])
ax.set_yticks([-80, -20, 40]);

In [None]:
# plotting simulation
plt.plot(t, V[:, 0], color=COL['SNPE'], lw=2, label='mode')
plt.xlabel('time (ms)')
plt.ylabel('voltage (mV)');

In [None]:
labels_sum_stats = ['sp_t','r_pot','r_pot_std','m1','m2','m3','m4']

n_summary_stats = len(labels_sum_stats)

sum_stats_post = g.summary.calc([m.gen_single(mn_post)])[0]


fig = plt.figure(figsize=(20,5))
ax = plt.subplot(1,2,1)
plt.plot(obs_stats[0], color=COL['GT'], lw=2, label='observation')
plt.plot(sum_stats_post, color=COL['SNPE'], lw=2, label='mode')
ax.set_xticks(np.linspace(0,n_summary_stats-1,n_summary_stats))
ax.set_xticklabels(labels_sum_stats)
plt.ylabel('feature value')
ax = plt.gca()
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[::-1], labels[::-1], bbox_to_anchor=(1.2, 1), loc='upper right')

ax = plt.subplot(1,2,2)
plt.semilogy(np.abs(obs_stats[0]),color=COL['GT'],linestyle='--', lw=2, label='observation')
plt.semilogy(np.abs(sum_stats_post-obs_stats[0]),color=COL['SNPE'], lw=2, label='mode')
ax.set_xticks(np.linspace(0,n_summary_stats-1,n_summary_stats))
ax.set_xticklabels(labels_sum_stats);
plt.ylabel(r'$f^*$ - f');