In [None]:
import delfi.distribution as dd
import delfi.generator as dg
import delfi.inference as infer
import delfi.utils.io as io
import delfi.summarystats as ds
import model.utils as utils
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pickle

from model.HodgkinHuxley import HodgkinHuxley
from model.HodgkinHuxleyStatsMoments import HodgkinHuxleyStatsMoments

%matplotlib inline

## log-transform functions, if necessary

In [None]:
def param_transform(prior_log, x):
    if prior_log:
        return np.log(x)
    else:
        return x

def param_invtransform(prior_log, x):
    if prior_log:
        return np.exp(x)
    else:
        return x

## define model, prior, summary statistics and generator

In [None]:
reduced_model = False
true_params, labels_params = utils.obs_params(reduced_model=reduced_model)

seed = 1
prior_uniform = True
prior_log = False
prior_extent = True
n_xcorr = 0
n_mom = 4
cython=True
n_summary = 7
summary_stats = 1

# list of all Allen recordings
list_cells_AllenDB = [[518290966,57,0.0234/126],[509881736,39,0.0153/184],[566517779,46,0.0195/198],
                      [567399060,38,0.0259/161],[569469018,44,0.033/403],[532571720,42,0.0139/127],
                      [555060623,34,0.0294/320],[534524026,29,0.027/209],[532355382,33,0.0199/230],
                      [526950199,37,0.0186/218]]


n_cells = len(list_cells_AllenDB)

# define prior
p = utils.prior(true_params=true_params,prior_uniform=prior_uniform,
                prior_extent=prior_extent,prior_log=prior_log, seed=seed)

# define model, summary statistics and generator
obs_stats_ls = []
m_ls = []
s_ls = []
g_ls = []
for cell_num in range(n_cells):
    ephys_cell = list_cells_AllenDB[cell_num][0]
    sweep_number = list_cells_AllenDB[cell_num][1]
    A_soma = list_cells_AllenDB[cell_num][2]
    junction_potential = -14

    obs = utils.allen_obs_data(ephys_cell=ephys_cell,sweep_number=sweep_number,A_soma=A_soma)

    obs['data'] = obs['data'] + junction_potential
    I = obs['I']
    dt = obs['dt']
    t_on = obs['t_on']
    t_off = obs['t_off']

    obs_stats = utils.allen_obs_stats(data=obs,ephys_cell=ephys_cell,sweep_number=sweep_number,
                                      n_xcorr=n_xcorr,n_mom=n_mom,
                                      summary_stats=summary_stats,n_summary=n_summary)
    obs_stats_ls.append(obs_stats)

    n_processes = 6

    seeds_model = np.arange(1,n_processes+1,1)
    m = []
    for i in range(n_processes):
        m.append(HodgkinHuxley(I, dt, V0=obs['data'][0], seed=seeds_model[i], cython=cython,prior_log=prior_log))
    m_ls.append(m)
    s = HodgkinHuxleyStatsMoments(t_on=t_on, t_off=t_off,n_xcorr=n_xcorr,n_mom=n_mom,n_summary=n_summary)
    s_ls.append(s)
    g = dg.MPGenerator(models=m, prior=p, summary=s)
    g_ls.append(g)

## run inference and save results

In [None]:
seed = 1
svi = False
impute_missing = False
pilot_samples = 1000
n_sims = 125000
n_rounds = 2
n_components = 1
n_hiddens = [100]*2

for cell_num in range(n_cells):
    # setup inference and run pilot run
    res = infer.SNPE(g_ls[cell_num], obs=obs_stats_ls[cell_num], pilot_samples=pilot_samples, n_hiddens=n_hiddens,
                     seed=seed, prior_norm=True,n_components=n_components, svi=svi, impute_missing=impute_missing)

    # run with N samples
    log, train_data, posterior = res.run(n_sims, n_rounds=n_rounds, epochs=1000)
    
    # save results
    if svi:
        svi_flag = '_svi'
    else:
        svi_flag = '_nosvi'
        
    ephys_cell = list_cells_AllenDB[cell_num][0]
    sweep_number = list_cells_AllenDB[cell_num][1]

    filename1 = './results/allen_'+str(ephys_cell)+'_'+str(sweep_number)+\
    '_run_1_round2_prior0013_param8'+svi_flag+'_ncomp'+str(n_components)+\
    '_nsims'+str(n_sims*n_rounds)+'_snpe.pkl'
    filename2 = './results/allen_'+str(ephys_cell)+'_'+str(sweep_number)+\
    '_run_1_round2_prior0013_param8'+svi_flag+'_ncomp'+str(n_components)+\
    '_nsims'+str(n_sims*n_rounds)+'_snpe_res.pkl'
    io.save_pkl((log, train_data, posterior),filename1)
    io.save(res, filename2)