In [None]:
import delfi.distribution as dd
import delfi.generator as dg
import delfi.inference as infer
import delfi.utils.io as io
import delfi.summarystats as ds
import model.utils as utils
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pickle

from delfi.simulator import TransformedSimulator
from delfi.utils.bijection import named_bijection
from model.HodgkinHuxley import HodgkinHuxley
from model.HodgkinHuxleyStatsMoments import HodgkinHuxleyStatsMoments

import sys; sys.path.append('../')
from common import plot_pdf, samples_nd, col, svg

%matplotlib inline

!mkdir -p support_files


seed = 1
pilot_samples = 1000
n_sims = 100000
n_rounds = 1
n_components = 2
n_hiddens = [50]*2
n_xcorr = 0
n_mom = 4
n_summary = 7
summary_stats = 1
true_params, labels_params = utils.obs_params(reduced_model=False)
I, t_on, t_off, dt = utils.syn_current()
obs = utils.syn_obs_data(I, dt, true_params, seed=seed, cython=True)
obs_stats = utils.syn_obs_stats(data=obs,I=I, t_on=t_on, t_off=t_off, dt=dt, params=true_params,
                                seed=seed, n_xcorr=n_xcorr, n_mom=n_mom, cython=True,
                                summary_stats=summary_stats,n_summary=n_summary)

p = utils.prior(true_params=true_params, prior_uniform=True,
                prior_extent=True, prior_log=False, seed=seed)

n_processes = 5
seeds_model = np.arange(1,n_processes+1,1)
m = []
for i in range(n_processes):
    sim = HodgkinHuxley(I, dt, V0=obs['data'][0], reduced_model=False, seed=seeds_model[i], cython=True, prior_log=False)
    m.append(sim)

n_summary_ls = [7, 4, 1]

g_ls = []
for nsum in n_summary_ls:           
    stats = HodgkinHuxleyStatsMoments(t_on=t_on, t_off=t_off, n_xcorr=n_xcorr, n_mom=n_mom, n_summary=nsum)   
    g = dg.MPGenerator(models=m, prior=p, summary=stats)
    g_ls.append(g)

In [None]:
density = 'maf'  # 'mog' or 'maf'

for i, nsum in enumerate(n_summary_ls):
    res = infer.APT(
        g_ls[i], 
        obs=obs_stats[0, 0:nsum], 
        pilot_samples=pilot_samples, 
        n_hiddens=n_hiddens,
        seed=seed, 
        n_mades=5,
        prior_norm=True,
        impute_missing=False,
        density=density,
    )

    log, train_data, posterior = res.run(
        n_train=n_sims, 
        n_rounds=n_rounds, 
        minibatch=500,
        epochs=1000,
        silent_fail=False,
        proposal='mog',
        val_frac=0.1,
        monitor_every=1,
    )

    filename = './results/posterior_{nsum}_single_round_{density}_lfs.pkl'.format(density=density, nsum=nsum)
    io.save_pkl((log, train_data, posterior), filename)

    filename = './results/posterior_{nsum}_single_round_{density}_res_lfs.pkl'.format(density=density, nsum=nsum)
    io.save(res, filename)