In [None]:
import delfi.generator as dg
import delfi.inference as infer
import delfi.utils.io as io
import lfimodels.hodgkinhuxley.utils as utils
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np

from lfimodels.hodgkinhuxley.HodgkinHuxley import HodgkinHuxley
from lfimodels.hodgkinhuxley.HodgkinHuxleyStatsMoments import HodgkinHuxleyStatsMoments
# from lfimodels.hodgkinhuxley.HodgkinHuxleyStatsSpikes_mf import HodgkinHuxleyStatsSpikes_mf
# from lfimodels.hodgkinhuxley.HodgkinHuxleyStatsSpikes import HodgkinHuxleyStatsSpikes
from run_genetic import run_deap

%matplotlib inline

## Set up the model

In [None]:
true_params, labels_params = utils.obs_params()

# labels_sum_stats = ['sp_t','ISI_mn','ISI_std','c0','c1','c2','c3','c4','c5','c6','c7','c8','c9',
#                     'r_pot','mn','m2','m3','m4','m5','m6','m7','m8']
# labels_sum_stats = ['f_rate','ISI_mn','ISI_std','AP_lat','AP_oversh','r_pot','r_pot_std','AHD','A_ind','spike_w']
labels_sum_stats = ['sp_t','r_pot','r_pot_std','m1','m2','m3','m4']
# labels_sum_stats = [r'$f\_rate$',r'$ISI\_mn$',r'$ISI\_std$',r'$AP\_lat$',r'$AP\_oversh$',
#                     r'$r\_pot$',r'$r\_pot\_std$',r'$AHD$',r'$A\_ind$',r'$spike\_w$']

seed = 1
prior_uniform = True
prior_log = False
prior_extent = True
n_xcorr = 0
n_mom = 4
cython=True
n_summary = 10
summary_stats = 1

list_cells_AllenDB = [[518290966,57,0.0234/126],[509881736,39,0.0153/184],[566517779,46,0.0195/198],
                      [567399060,38,0.0259/161],[569469018,44,0.033/403],[532571720,42,0.0139/127],
                      [555060623,34,0.0294/320],[534524026,29,0.027/209],[532355382,33,0.0199/230],
                      [526950199,37,0.0186/218]]
cell_num = 0
ephys_cell = list_cells_AllenDB[cell_num][0]
sweep_number = list_cells_AllenDB[cell_num][1]
A_soma = list_cells_AllenDB[cell_num][2]
junction_potential = -14

obs = utils.allen_obs_data(ephys_cell=ephys_cell,sweep_number=sweep_number,A_soma=A_soma)

obs['data'] = obs['data'] + junction_potential
I = obs['I']
dt = obs['dt']
t_on = obs['t_on']
t_off = obs['t_off']

obs_stats = utils.allen_obs_stats(data=obs,ephys_cell=ephys_cell,sweep_number=sweep_number,
                                  n_xcorr=n_xcorr,n_mom=n_mom,
                                  summary_stats=summary_stats,n_summary=n_summary)

# define model, prior, summary statistics and generator
m = HodgkinHuxley(I, dt, V0=obs['data'][0], seed=seed, cython=cython,prior_log=prior_log)
p = utils.prior(true_params=true_params,prior_uniform=prior_uniform, prior_extent=prior_extent,
                prior_log=prior_log, seed=seed)
s = HodgkinHuxleyStatsMoments(t_on=t_on, t_off=t_off,n_xcorr=n_xcorr,n_mom=n_mom)
# s = HodgkinHuxleyStatsSpikes_mf(t_on=t_on, t_off=t_off, n_summary=n_summary)
# s = HodgkinHuxleyStatsSpikes(t_on=t_on, t_off=t_off, n_summary=n_summary)
g = dg.Default(model=m, prior=p, summary=s)

bounds = np.asarray([p.lower,p.upper]).T

In [None]:
['g_Na', 'g_K', 'g_leak', 'g_M', 't_max', '-V_T', 'noise','-E_leak']
# param_allen_2 = [12.1585,1.22534,0.103795,0.49671,1200,78,0.1,83.0383]
param_allen_4 = [12.1585,10.22534,0.103795,0.40671,800,83,0.1,83.0383]

fig = plt.figure(figsize=(15,5))
plt.subplot(121)
plt.plot(m.gen_single(param_allen_4)['data'])

plt.subplot(122)
plt.plot(obs['data'])

## pilot run

In [None]:
seed = 1
svi = False
res = infer.SNPE(g, obs=obs_stats, pilot_samples=1000, n_hiddens=[50], seed=seed, prior_norm=True,
                 n_components=1, svi=svi)

## Run genetic algorithm

In [None]:
algo = 'ibea'
offspring_size = 2500
max_ngen = 100
final_pop, halloffame, log, hist = run_deap(model=m, bounds=bounds, labels_params=labels_params,
                                            summary=s, obs_stats=obs_stats, labels_sum_stats=labels_sum_stats,
                                            stats_std=res.stats_std,
                                            algo=algo, offspring_size=offspring_size, max_ngen=max_ngen, seed=seed)

In [None]:
io.save_pkl((final_pop,halloffame,log,hist),'./results/allen_'+str(ephys_cell)+'_'+str(sweep_number)+\
            '_run_1_offspr'+str(offspring_size)+'_max_gen'+str(max_ngen)+ '_param8_' + algo + '.pkl')

In [None]:
final_pop, halloffame, log, hist = io.load_pkl('./results/allen_'+str(ephys_cell)+'_'+str(sweep_number)+\
                                               '_run_1_offspr'+str(offspring_size)+\
                                               '_max_gen'+str(max_ngen)+ '_param8_' + algo + '.pkl')

## Inspect output

In [None]:
best_params = halloffame[0]
best_sumstats = best_params.fitness.values

In [None]:
fig = plt.figure()

n_params = len(bounds[:,0])
n_sum_stats = len(obs_stats[0,:])
y_obs = obs['data']
t = obs['time']
duration = np.max(t)

COL = {}
COL['GT']   = (35/255,86/255,167/255)
COL['IBEA'] = (0, 174/255,239/255)

x = m.gen_single(best_params)
V_best = x['data']

# plotting simulation
plt.plot(t, V_best, color=COL['IBEA'], lw=2, label='best ibea')
plt.plot(t, y_obs, color=COL['GT'], lw=2, label='observation')
plt.xlabel('time (ms)')
plt.ylabel('voltage (mV)')

ax = plt.gca()
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[::-1], labels[::-1], bbox_to_anchor=(1.3, 1), loc='upper right')

ax.set_xticks([0, duration/2, duration])
ax.set_yticks([-80, -20, 40]);

In [None]:
plt.plot(t, V_best, color=COL['IBEA'], lw=2, label='best ibea')

In [None]:
n_summary_stats = len(labels_sum_stats)

sum_stats_post = s.calc([m.gen_single(best_params)])[0]

fig = plt.figure(figsize=(20,5))
ax = plt.subplot(1,2,1)
plt.plot(obs_stats[0], lw=2, label='observation')
plt.plot(sum_stats_post, color='r', lw=2, label='best ibea')
ax.set_xticks(np.linspace(0,n_summary_stats-1,n_summary_stats))
ax.set_xticklabels(labels_sum_stats)
plt.ylabel('feature value')
ax = plt.gca()
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[::-1], labels[::-1], bbox_to_anchor=(1.2, 1), loc='upper right')

ax = plt.subplot(1,2,2)
plt.semilogy(np.abs(obs_stats[0]),color=COL['GT'],linestyle='--', lw=2, label='observation')
# plt.semilogy(np.abs(np.asarray(best_sumstats)*res.stats_std),color='g', lw=2, label='best ibea')
plt.semilogy(np.abs(sum_stats_post-obs_stats[0]),color='r', lw=2, label='best ibea')
ax.set_xticks(np.linspace(0,n_summary_stats-1,n_summary_stats))
ax.set_xticklabels(labels_sum_stats);
plt.ylabel(r'$f^*$ - f');

## Analysis of optimization

In [None]:
fig = plt.figure(figsize=(15,15))

# error across iterations
plt.subplot(311)
plt.semilogy(log.select('gen'), log.select('max'), '--k', lw=1, label='max')
plt.semilogy(log.select('gen'), log.select('avg'), 'b', lw=2, label='mean')
plt.semilogy(log.select('gen'), log.select('min'), '--r', lw=1, label='min')
plt.xlabel('iteration')
plt.ylabel('summed error')
plt.legend(frameon=False)

obs_stats_2 = obs_stats / res.stats_std


###############################################################################
# histogram of differences between true summary features and final range of summary features

diff_sum_stats = np.divide(best_sumstats,np.abs(obs_stats_2))

all_best_sum_stats = np.zeros((len(halloffame[0].fitness.values),len(halloffame)))
for i in range(len(halloffame)):
    all_best_sum_stats[:,i] = halloffame[i].fitness.values


err_sum_stats_up_norm = np.divide( np.std(all_best_sum_stats,1) , np.abs(obs_stats_2) )
err_sum_stats_down_norm = np.divide( np.std(all_best_sum_stats,1) , np.abs(obs_stats_2) )

labels_sum_stats = labels_sum_stats[0:n_sum_stats]
                
plt.subplot(312)
plt.bar(np.linspace(0,len(labels_sum_stats)-1,len(labels_sum_stats)),diff_sum_stats[0], align='center',
        yerr=[err_sum_stats_down_norm[0],err_sum_stats_up_norm[0]])

ax = plt.gca()
ax.set_xticks(np.linspace(0,len(labels_sum_stats)-1,len(labels_sum_stats)))
ax.set_xticklabels(labels_sum_stats)


###############################################################################
# histogram of differences between true parameters and final parameter range

diff_params = np.divide(true_params-best_params,true_params)

err_params_up_norm = np.divide( np.std(halloffame,0) , true_params )
err_params_down_norm = np.divide( np.std(halloffame,0) , true_params )

labels_params = labels_params[0:len(true_params)]
                
plt.subplot(313)
plt.bar(np.linspace(0,len(labels_params)-1,len(labels_params)),diff_params, align='center',
        yerr=[err_params_down_norm,err_params_up_norm])
ax = plt.gca()
ax.set_xticks(np.linspace(0,len(labels_params)-1,len(labels_params)))
ax.set_xticklabels(labels_params);