# Lotka-Volterra model

In [None]:
%matplotlib inline
import numpy as np
from matplotlib import pyplot as plt
import timeit

from delfi.inference import SNPEC as APT

from util import init_g_lv as init_g
from util import load_setup_lv as load_setup
from util import load_gt_lv as load_gt
from util import draw_sample_uniform_prior_52 as rej_sampler

from snl.util.plot import plot_hist_marginals

seed = 42

# simulation setup
setup_dict = load_setup()

pars_true, obs_stats = load_gt(generator=init_g(seed=seed))
print('pars_true : ', pars_true)
print('obs_stats : ', obs_stats)


# fit APT

In [None]:
if setup_dict['train_on_all']:
    epochs=[setup_dict['epochs']//(r+1) for r in range(setup_dict['n_rounds'])]
else:
    epochs=setup_dict['epochs']

# control MAF seed
rng = np.random
rng.seed(seed)

# generator
g = init_g(seed=seed)
    
res = APT(g,
          obs=obs_stats,
          n_hiddens=setup_dict['n_hiddens'],
          seed=seed,
          reg_lambda=setup_dict['reg_lambda'],
          pilot_samples=setup_dict['pilot_samples'],
          svi=setup_dict['svi'],
          n_mades=setup_dict['n_mades'],
          act_fun=setup_dict['act_fun'],
          mode=setup_dict['mode'],
          rng=rng,
          batch_norm=setup_dict['batch_norm'],
          verbose=setup_dict['verbose'],
          #upper=setup_dict['upper'], # box-constraints for support
          #lower=setup_dict['lower'], # of MAF outputs (maf.y)
          prior_norm=setup_dict['prior_norm'])

print('conditional density estimator', res.network)

# train
t = timeit.time.time()
print('fitting model with SNPC-C')
logs, tds, posteriors = res.run(
                    n_train=setup_dict['n_train'],
                    proposal=setup_dict['proposal'],
                    moo=setup_dict['moo'],
                    n_null = setup_dict['n_null'],
                    n_rounds=setup_dict['n_rounds'],
                    train_on_all=setup_dict['train_on_all'],
                    minibatch=setup_dict['minibatch'],
                    epochs=epochs)
print('fitting time : ', timeit.time.time() - t)

# store results

In [None]:
#from util import save_results, load_results
#model_id = 'lv'
#save_path = 'results/' + model_id + '_box_validationset'
#exp_id = 'seed'+str(seed)

#save_results(logs=logs, tds=tds, posteriors=posteriors, 
#             setup_dict=setup_dict, exp_id=exp_id, path=save_path)

#logs, tds, posteriors, setup_dict = load_results(exp_id=exp_id, path=path)

# inspect results

In [None]:
for r in np.arange(0, len(logs), 1):
    
    posterior = posteriors[r]
    samples = rej_sampler(posterior, 5000) # fast parallel rejection sampler
    
    fig = plot_hist_marginals(
                   samples,
                   gt=pars_true, 
                   lims=[-5,2])
    
    fig.set_figheight(12)
    fig.set_figwidth(12)
    fig.suptitle('APT posterior estimates, round r = '+str(r+1), fontsize=14)
    fig.show()
    print('negative log-probability of ground-truth pars \n', -posterior.eval(pars_true, log=True))

# inspect forward model

In [None]:
from delfi.distribution.PointMass import PointMass

g = init_g(seed=seed)
g.prior = PointMass(loc=pars_true)
_, stats = g.gen(1000)

fig = plot_hist_marginals(stats, gt=obs_stats.flatten())
fig.set_figheight(12)
fig.set_figwidth(12)
fig.suptitle('marginal over x given theta* (red dot = xo)', fontsize=14)
fig.show()