## Inference of Hodgkin-Huxley model on cell from Allen Cell Type Database with multiple repeats 

In [None]:
import delfi.distribution as dd
import delfi.generator as dg
import delfi.inference as infer
import delfi.utils.io as io
import delfi.summarystats as ds
import lfimodels.hodgkinhuxley.utilsMultiStep as utilsMultiStep
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pickle

from lfimodels.hodgkinhuxley.HodgkinHuxleyMultiStep import HodgkinHuxleyMultiStep
from lfimodels.hodgkinhuxley.HodgkinHuxleyMultiStepStatsMoments import HodgkinHuxleyMultiStepStatsMoments
from lfimodels.hodgkinhuxley.HodgkinHuxleyMultiStepStatsSpikes_mf import HodgkinHuxleyMultiStepStatsSpikes_mf
from delfi.utils.viz import plot_pdf

%matplotlib inline

In [None]:
true_params, labels_params = utilsMultiStep.obs_params()

seed = 1
prior_uniform = True
prior_log = False
prior_extent = True
n_xcorr = 0
n_mom = 4
cython=True
n_summary = 8
summary_stats = 1

ephys_cell = 566517779
sweep_number_ls=[44,56,57,58,59]
A_soma = 0.0195/198
junction_potential = -14

num_repeats = len(sweep_number_ls)

obs = utilsMultiStep.allen_obs_data(ephys_cell=ephys_cell,sweep_number_ls=sweep_number_ls,A_soma=A_soma)

obs['data'] = obs['data'] + junction_potential

# correct t_on and t_off
obs['t_on'] = 205.0
obs['t_off'] = 1204.99

I = obs['I']
dt = obs['dt']
t_on = obs['t_on']
t_off = obs['t_off']

obs_stats = utilsMultiStep.allen_obs_stats(data=obs,ephys_cell=ephys_cell,sweep_number_ls=sweep_number_ls,
                                           n_xcorr=n_xcorr,n_mom=n_mom,
                                           summary_stats=summary_stats,n_summary=n_summary)

# define model, prior, summary statistics and generator
n_processes = 8

def rej(x):
    return ~np.any(np.isnan(x))

if n_processes>1:
    seeds_model = np.arange(1,n_processes+1,1)
    m = []
    for i in range(n_processes):
        m.append(HodgkinHuxleyMultiStep(I, dt, V0=obs['data'][:,0], repeats=num_repeats, seed=seeds_model[i],
                                        cython=cython,
                                        prior_log=prior_log))
    p = utilsMultiStep.prior(true_params=true_params,prior_uniform=prior_uniform,
                    prior_extent=prior_extent,prior_log=prior_log, seed=seed)
    s = HodgkinHuxleyMultiStepStatsMoments(t_on=t_on, t_off=t_off,n_xcorr=n_xcorr,n_mom=n_mom,n_summary=n_summary)
#     s = HodgkinHuxleyMultiStepStatsSpikes_mf(t_on=t_on, t_off=t_off,n_summary=n_summary)
#     g = dg.MPGenerator(models=m, prior=p, summary=s, rej=rej)
    g = dg.MPGenerator(models=m, prior=p, summary=s)
else:
    seed = None
    m = HodgkinHuxleyMultiStep(I, dt, V0=obs['data'][:,0], repeats=num_repeats, seed=seed,
                               cython=cython,prior_log=prior_log)
    p = utilsMultiStep.prior(true_params=true_params,prior_uniform=prior_uniform,
                    prior_extent=prior_extent,prior_log=prior_log, seed=seed)
    s = HodgkinHuxleyMultiStepStatsMoments(t_on=t_on, t_off=t_off,n_xcorr=n_xcorr,n_mom=n_mom,n_summary=n_summary)
#     s = HodgkinHuxleyMultiStepStatsSpikes_mf(t_on=t_on, t_off=t_off,n_summary=n_summary)
    g = dg.RejKernel(model=m, prior=p, summary=s, rej=rej)

In [None]:
fig = plt.figure(figsize=(20,8))
for i in range(num_repeats):
    plt.subplot(2,3,i+1)
    plt.plot(obs['time'],obs['data'][i],lw=2)
    plt.xlabel('time (ms)')
    plt.ylabel('voltage (mV)')

In [None]:
print(t_on)
print(t_off)
print(dt)

In [None]:
obs_stats

## SNPE

In [None]:
seed = 1
svi = False
impute_missing = False
pilot_samples = 1000
n_sims = 125000
n_rounds = 2
n_components = 1
n_hiddens = [100]*2
kernel_loss='x_kl'
res = infer.SNPE(g, obs=obs_stats, pilot_samples=pilot_samples, n_hiddens=n_hiddens, seed=seed, prior_norm=True,
                 n_components=n_components, svi=svi, impute_missing=impute_missing)

# run with N samples
log, train_data, posterior = res.run(n_sims, n_rounds=n_rounds, epochs=1000)

In [None]:
if svi:
    svi_flag = '_svi'
else:
    svi_flag = '_nosvi'

filename1 = './results/allen_'+str(ephys_cell)+'_'+str(num_repeats)+\
'repeats_run_1_round2_prior0013_param8'+svi_flag+'_ncomp'+str(n_components)+\
'_nsims'+str(n_sims*n_rounds)+'_snpe.pkl'
filename2 = './results/allen_'+str(ephys_cell)+'_'+str(num_repeats)+\
'repeats_run_1_round2_prior0013_param8'+svi_flag+'_ncomp'+str(n_components)+\
'_nsims'+str(n_sims*n_rounds)+'_snpe_res.pkl'
io.save_pkl((log, train_data, posterior),filename1)
io.save(res, filename2)

In [None]:
# n_sims = 50000
# n_rounds = 2
filename1 = './results/allen_'+str(ephys_cell)+'_'+str(num_repeats)+\
'repeats_run_1_round2_prior0013_param8'+svi_flag+'_ncomp'+str(n_components)+\
'_nsims'+str(n_sims*n_rounds)+'_snpe.pkl'
filename2 = './results/allen_'+str(ephys_cell)+'_'+str(num_repeats)+\
'repeats_run_1_round2_prior0013_param8'+svi_flag+'_ncomp'+str(n_components)+\
'_nsims'+str(n_sims*n_rounds)+'_snpe_res.pkl'
log, train_data, posterior = io.load_pkl(filename1)
res = io.load(filename2)
# posterior = res.predict(obs_stats)

In [None]:
print(res.network.n_components)
print(res.network.n_hiddens)
print(np.shape(train_data[0][0]))
print(np.shape(train_data[0][1]))
print(np.shape(train_data[1][0]))
print(np.shape(train_data[1][1]))
print(res.round)

In [None]:
# imputation_values = log[-1]['imputation_values'][-1]*res.stats_std+res.stats_mean

In [None]:
# plt.plot(imputation_values,label='imputation values')
# plt.plot(obs_stats[0],label='observed features')
# plt.legend()

In [None]:
# # use samples from first round to re-learn network with different configurations

# # set network
# import theano
# import theano.tensor as tt

# from delfi.neuralnet.NeuralNet import NeuralNet

# dtype = theano.config.floatX

# # n_hiddens=[1000,1000]
# n_hiddens=[200]*3
# n_components=1
# svi = False
# impute_missing = True
# res.network = NeuralNet(n_inputs=train_data[0][1].shape[1],
#                         n_outputs = train_data[0][0].shape[1],
#                         n_hiddens=n_hiddens, n_components=n_components, svi=svi, impute_missing=impute_missing)
# res.network.iws = tt.vector('iws', dtype=dtype)

# from delfi.neuralnet.Trainer import Trainer

# def train_net(res=res, epochs=100, minibatch=50, round_cl=1, stop_on_nan=False, monitor=None, **kwargs):
#     """Run algorithm"""

#     # load training data (z-transformed params and stats)
#     _, trn_data, _ = io.load_pkl(filename1)
#     trn_data = trn_data[0]
#     n_train_round = trn_data[0].shape[0]

#     # precompute importance weights
#     iws = np.ones((n_train_round,))

#     # normalize weights
#     iws = (iws/np.sum(iws))*n_train_round

#     trn_data = (trn_data[0], trn_data[1], iws)
#     trn_inputs = [res.network.params, res.network.stats,
#                   res.network.iws]

#     t = Trainer(res.network,
#                 res.loss(N=n_train_round, round_cl=round_cl),
#                 trn_data=trn_data, trn_inputs=trn_inputs,
#                 seed=res.gen_newseed(),
#                 monitor=res.monitor_dict_from_names(monitor),
#                 **kwargs)
#     log = t.train(epochs=epochs, minibatch=minibatch,
#                         verbose=res.verbose, stop_on_nan=stop_on_nan)

#     posterior = res.predict(res.obs)

#     return log, posterior

# epochs=100
# minibatch=50
# log, posterior = train_net(epochs=epochs, minibatch=minibatch)

In [None]:
posterior

In [None]:
def param_transform(prior_log, x):
    if prior_log:
        return np.log(x)
    else:
        return x

def param_invtransform(prior_log, x):
    if prior_log:
        return np.exp(x)
    else:
        return x

In [None]:
if prior_uniform:
    prior_min = res.generator.prior.lower
    prior_max = res.generator.prior.upper
else:
    prior_min = param_transform(prior_log,np.array([.5,1e-4,1e-4,1e-4,50.,40.,1e-4,35.]))
    prior_max = param_transform(prior_log,np.array([80.,15.,.6,.6,3000.,90.,.15,100.]))

prior_lims = np.concatenate((prior_min.reshape(-1,1),
                                 prior_max.reshape(-1,1)),
                                axis=1)

for i in range(res.round):
    fig = plt.figure(figsize=(15,15))
    plot_pdf(posterior[i], lims=prior_lims, samples=None,figsize=(15,15))
    plt.show()
    plt.close()

In [None]:
posterior=posterior[-1]

In [None]:
# # plot best solution of genetic algorithm
# _, halloffame, _, _ = io.load_pkl('allen_'+str(ephys_cell)+'_run_3_offspr100_max_gen100_ibea.pkl')

if prior_uniform:
    prior_min = res.generator.prior.lower
    prior_max = res.generator.prior.upper
else:
    prior_min = param_transform(prior_log,np.array([.5,1e-4,1e-4,1e-4,50.,40.,1e-4,35.]))
    prior_max = param_transform(prior_log,np.array([80.,15.,.6,.6,3000.,90.,.15,100.]))

prior_lims = np.concatenate((prior_min.reshape(-1,1),
                             prior_max.reshape(-1,1)),
                            axis=1)

# plot_pdf(posterior, lims=prior_lims, samples=None, figsize=(15,15),
#          gt=halloffame[0], labels_params=labels_params, ticks=True);
plot_pdf(posterior, lims=prior_lims, samples=None, figsize=(15,15), labels_params=labels_params, ticks=True);

In [None]:
fig = plt.figure(figsize=(15,10))

col_min = 1
num_colors = 2+col_min
cm1 = mpl.cm.Blues
col1 = [cm1(1.*i/num_colors) for i in range(col_min,num_colors)]

for i in range(2):
    plt.subplot(2,1,i+1)
    plt.plot(log[i]['loss'], color=col1[i], lw=2)
    plt.xlabel('iteration')
    plt.ylabel('loss')
    plt.title('round'+str(i+1));

In [None]:
n_sim = len(train_data[0][0][:,0])
ess = np.zeros(res.round)
for i in range(res.round):
    ess[i] = 1/np.sum((train_data[i][2]/n_sim)**2)
ess

In [None]:
posterior.a

In [None]:
res.round = len(train_data)

In [None]:
plt.plot(np.sqrt(np.diag(posterior.xs[np.argmax(posterior.a)].S))/mn_post,'o')

In [None]:
fig = plt.figure()

# mn_post, S = posterior.calc_mean_and_cov()
mn_post = posterior.xs[np.argmax(posterior.a)].m
# S = posterior.xs[np.argmax(posterior.a)].S
n_params = len(mn_post)
y_obs = obs['data']
t = obs['time']
duration = np.max(t)

COL = {}
COL['GT']   = (35/255,86/255,167/255)
COL['SNPE'] = (0, 174/255,239/255)

# # parameter set from training data with minimum distance to observed data
# train_data_unzscored = train_data[res.round-1][1]*res.stats_std+res.stats_mean
# # param_min_arg = np.argmin(np.linalg.norm(train_data_unzscored-obs_stats[0],axis=1))
# param_min_arg = np.argmin(np.linalg.norm((train_data_unzscored-obs_stats[0])/obs_stats[0],axis=1))
# param_min = train_data[res.round-1][0][param_min_arg,:]*res.params_std + res.params_mean
# param_min_stats = train_data[res.round-1][1][param_min_arg,:]*res.stats_std + res.stats_mean

num_samp = 3

# # sampling at contour of 1 covariance away from mean (if samples from outside the prior box, contour is at prior box)
# x_samp = np.random.randn(n_params,num_samp)
# x_samp = np.divide(x_samp,np.linalg.norm(x_samp,axis=0))
# x_samp = (np.dot(S,x_samp)).T+mn_post

# sample from posterior
x_samp = posterior.gen(n_samples=num_samp)

# reject samples outside the prior box
ind = (x_samp > prior_min) & (x_samp < prior_max)
x_samp = x_samp[np.prod(ind,axis=1)==1]

num_samp = len(x_samp[:,0])
num_colors = num_samp+1
cm1 = mpl.cm.Oranges
col1 = [cm1(1.*i/num_colors) for i in range(num_colors)]

# params = param_invtransform(prior_log,np.concatenate((np.array([param_min]),np.array([mn_post]),x_samp)))
params = param_invtransform(prior_log,np.concatenate((np.array([mn_post]),x_samp)))


fig = plt.figure(figsize=(20,10))
V = np.zeros((num_repeats,len(t),1+num_samp))
for i in range(1+num_samp):
    m = HodgkinHuxleyMultiStep(I=I, dt=dt, V0=obs['data'][0], repeats=num_repeats,
                               seed=230+i, cython=True, prior_log=prior_log)
    x = m.gen_single(param_transform(prior_log,params[i,:]))
    V[:,:,i] = x['data']
    if i>0:
        for j in range(num_repeats):
            plt.subplot(2,3,j+1)
            plt.plot(t, V[j, :, i], color = col1[i-1], lw=2, label='sample '+str(num_samp-i+1))

# plotting simulation
# plt.plot(t, V[:, 0], color='r', lw=2, label='min sample')
for j in range(num_repeats):
    plt.subplot(2,3,j+1)
    plt.plot(t, V[j, :, 0], color=COL['SNPE'], lw=2, label='mode')
    plt.plot(t, y_obs[j], color=COL['GT'], lw=2, label='observation')


# # average parameter set between the two modes (if two components considered)
# if res.network.n_components == 2:
#     param_av = (posterior.xs[0].m + posterior.xs[1].m)/2
#     m = HodgkinHuxley(I=I, dt=dt, V0=obs['data'][0], seed=231+i, cython=True, prior_log=prior_log)
#     x = m.gen_single(param_transform(prior_log,param_av))
# #     plt.plot(t, x['data'], color='r', lw=2, label='modes average')
    
#     mn_post_small = posterior.xs[np.argmin(posterior.a)].m
#     m = HodgkinHuxley(I=I, dt=dt, V0=obs['data'][0], seed=232+i, cython=True, prior_log=prior_log)
#     x = m.gen_single(param_transform(prior_log,mn_post_small))
# #     plt.plot(t, x['data'], color='g', lw=2, label='smallest mode')
# else:
#     param_av = posterior.xs[np.argmax(posterior.a)].m


plt.xlabel('time (ms)')
plt.ylabel('voltage (mV)')

ax = plt.gca()
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[::-1], labels[::-1], bbox_to_anchor=(1.3, 1), loc='upper right')

ax.set_xticks([0, duration/2, duration])
ax.set_yticks([-80, -20, 40]);

In [None]:
fig = plt.figure(figsize=(20,10))
# plotting simulation
for j in range(num_repeats):
    plt.subplot(2,3,j+1)
    # plt.plot(t, V[:, 0], color='r', lw=2, label='min sample')
    plt.plot(t, V[j, :, 0], color=COL['SNPE'], lw=2, label='mode')
    plt.xlabel('time (ms)')
    plt.ylabel('voltage (mV)');

In [None]:
# labels_sum_stats = ['sp_t','ISI_mn','ISI_std','c0','c1','c2','c3','c4','c5','c6','c7','c8','c9',
#                     'r_pot','mn','m2','m3','m4','m5','m6','m7','m8']
# labels_sum_stats = ['sp_t','c0','c1','c2','c3','c4',
#                     'r_pot','r_pot_std','m1','m2','m3','m4','m5']
# labels_sum_stats = ['sp_t','c0','c1','r_pot','r_pot_std','m1','m2','m3','m4']
labels_sum_stats = ['sp_t','r_pot','r_pot_std','m1','m2','m3','m4','mn_sp_t']
# labels_sum_stats = ['f_rate','AP_lat','AP_oversh','r_pot','r_pot_std','AHD','A_ind','spike_w','ISI_mn','ISI_std']
# labels_sum_stats = ['f_rate','AP_lat','AP_oversh','r_pot','r_pot_std','AHD','A_ind','spike_w']

n_summary_stats = len(labels_sum_stats)

sum_stats_post = res.generator.summary.calc([m.gen_single(mn_post)])[0]
# sum_stats_post_small = res.generator.summary.calc([m.gen_single(param_transform(prior_log,mn_post_small))])[0]
# sum_stats_param_av = res.generator.summary.calc([m.gen_single(param_transform(prior_log,param_av))])[0]


fig = plt.figure(figsize=(20,5))
ax = plt.subplot(1,2,1)
plt.plot(res.obs[0], color=COL['GT'], lw=2, label='observation')
plt.plot(sum_stats_post, color=COL['SNPE'], lw=2, label='mode')
# plt.plot(imputation_values, color='r', lw=2, label='imputation values')
# plt.plot(sum_stats_post_small, color='g', lw=2, label='smallest mode')
# plt.plot(sum_stats_param_av, color='r', lw=2, label='modes average')
ax.set_xticks(np.linspace(0,n_summary_stats-1,n_summary_stats))
ax.set_xticklabels(labels_sum_stats)
plt.ylabel('feature value')
ax = plt.gca()
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[::-1], labels[::-1], bbox_to_anchor=(1.2, 1), loc='upper right')

ax = plt.subplot(1,2,2)
plt.semilogy(np.abs(res.obs[0]),color=COL['GT'],linestyle='--', lw=2, label='observation')
plt.semilogy(np.abs(sum_stats_post-res.obs[0]),color=COL['SNPE'], lw=2, label='mode')
# plt.plot(imputation_values-obs_stats[0], color='r', lw=2, label='imputation values')
# plt.plot(sum_stats_post_small-obs_stats[0], color='g', lw=2, label='smallest mode')
# plt.plot(sum_stats_param_av-obs_stats[0], color='r', lw=2, label='modes average')
ax.set_xticks(np.linspace(0,n_summary_stats-1,n_summary_stats))
ax.set_xticklabels(labels_sum_stats);
plt.ylabel(r'$f^*$ - f');

## rejection ABC

In [None]:
round_num = 0
obs_stats_zscored = (obs_stats[0]-res.stats_mean)/res.stats_std
dist_train_data = np.linalg.norm((train_data[round_num][1]-obs_stats_zscored),axis=1)
dist_argsort = np.argsort(dist_train_data)
percent_accept = .1
percent_criterion = int(len(dist_train_data)*percent_accept/100)
train_data_accept = train_data[round_num][0][dist_argsort[0:percent_criterion],:]*res.params_std + res.params_mean

In [None]:
np.shape(train_data_accept)

In [None]:
plt.hist(dist_train_data[~np.isnan(dist_train_data)]);

In [None]:
plt.hist(dist_train_data[dist_argsort[0:percent_criterion]]);

In [None]:
plot_pdf(posterior, lims=prior_lims, samples=train_data_accept.T, figsize=(15,15),
         labels_params=labels_params, ticks=True);

## summary statistics for samples from the prior and posterior

In [None]:
m = HodgkinHuxley(I=I, dt=dt, V0=obs['data'][0], seed=230+i, cython=True, prior_log=prior_log)

#############################################################################
# samples from the prior
num_samp = 100
params = param_invtransform(prior_log,p.gen(num_samp))

sum_stats_prior = []
for i in range(num_samp):
    x = m.gen_single(param_transform(prior_log,params[i,:]))
    sum_stats1 = s.calc([x])
    sum_stats_prior.append(sum_stats1)

mn_sum_stats_prior = np.nanmean(sum_stats_prior,axis=0)
std_sum_stats_prior = np.nanstd(sum_stats_prior,axis=0)

sum_stats_prior_mat = np.asarray(sum_stats_prior)

#############################################################################
# samples from the posterior
num_samp1 = 100
params = param_invtransform(prior_log,posterior.gen(num_samp1))

# reject samples outside the prior box
ind = (params > prior_min) & (params < prior_max)
params = params[np.prod(ind,axis=1)==1]

num_samp = len(params[:,0])

sum_stats = []
for i in range(num_samp):
    x = m.gen_single(param_transform(prior_log,params[i,:]))
    sum_stats1 = s.calc([x])
    sum_stats.append(sum_stats1)

mn_sum_stats = np.nanmean(sum_stats,axis=0)
std_sum_stats = np.nanstd(sum_stats,axis=0)

sum_stats_mat = np.asarray(sum_stats)

sum_stats_mat1 = np.ma.array(sum_stats_mat, mask=np.isnan(sum_stats_mat))
# cov_sum_stats = np.cov(sum_stats_mat[:,0,:], rowvar=False)
cov_sum_stats = np.ma.cov(sum_stats_mat1[:,0,:], rowvar=False)

sum_stats_min = np.min(mn_sum_stats,axis=0).reshape(-1,1)
sum_stats_max = np.max(std_sum_stats,axis=0).reshape(-1,1)
sum_stats_lims = np.concatenate((sum_stats_min,sum_stats_max),axis=1)


n_summary_stats = len(mn_sum_stats[0,:])

labels_sum_stats = ['sp_t','c0','c1','c2','c3','c4',
                    'r_pot','r_pot_std','mn','m2','m3','m4','m5']
# labels_sum_stats = ['f_rate','ISI_mn','ISI_std','AP_lat','AP_oversh','r_pot','r_pot_std','AHD','A_ind','spike_w']

In [None]:
# plotting
width = 0.3
fig = plt.figure(figsize=(10,5))
ax = plt.subplot()
plt.bar(np.linspace(0,n_summary_stats-1,n_summary_stats),mn_sum_stats_prior[0,:],width,
       yerr=std_sum_stats_prior[0,:],label='prior stats')
plt.bar(np.linspace(0,n_summary_stats-1,n_summary_stats)+width,mn_sum_stats[0,:],width,
       yerr=std_sum_stats[0,:],label='posterior stats')
plt.bar(np.linspace(0,n_summary_stats-1,n_summary_stats)+2*width,obs_stats[0,:],width,label='obs stats')
ax.set_xlim(-1.5*width,n_summary_stats+width/2)
ax.set_xticks(np.linspace(0,n_summary_stats-1,n_summary_stats)+width/2)
ax.set_xticklabels(labels_sum_stats)
plt.legend(bbox_to_anchor=(1.2, 1), loc='upper right')
plt.title('summary statistics');

# plt.yscale('log')

In [None]:
fig = plt.figure(figsize=(20,16))
for i in range(n_summary_stats):
    plt.subplot(4,4,i+1)
    plt.hist(sum_stats_prior_mat[~np.isnan(sum_stats_prior_mat[:,0,i]),0,i], label='prior stat')
    plt.hist(sum_stats_mat[~np.isnan(sum_stats_mat[:,0,i]),0,i], label='posterior stat')
    plt.plot(obs_stats[0,i],1,'o',markersize=10, label='obs stat')
    plt.title(labels_sum_stats[i])

In [None]:
np.min(sum_stats_mat[~np.isnan(sum_stats_mat[:,0,9]),0,9])

In [None]:
obs_stats

In [None]:
# # reject samples with NaNs
# sum_stats_mat2 = sum_stats_mat[:,0,:]
# ind = ~np.isnan(sum_stats_mat2)
# sum_stats_mat2 = sum_stats_mat2[np.prod(ind,axis=1)==1]

In [None]:
# # summary statistics from posterior

# pdf1 = dd.Gaussian(m=mn_sum_stats[0], S=cov_sum_stats)
# # pdf1 = dd.Gaussian(m=mn_sum_stats[0], S=np.diag(np.ones(n_summary_stats)))
# plot_pdf(pdf1, lims=sum_stats_lims, samples=sum_stats_mat2.T, gt=res.obs[0],figsize=(20,20));