In [None]:
import delfi.distribution as dd
import delfi.generator as dg
import delfi.inference as infer
import delfi.utils.io as io
import delfi.summarystats as ds
import matplotlib.pyplot as plt
import numpy as np
import util

from delfi.utils.viz import plot_pdf
from model.GLM import GLM
from model.GLMStats import GLMStats

%matplotlib inline

!mkdir -p results


seed = 42
m = GLM(seed=seed)
p = util.smoothing_prior(n_params=m.n_params, seed=seed)
s = GLMStats(n_summary=m.n_params)
g = dg.Default(model=m, prior=p, summary=s)

true_params, labels_params = util.obs_params()
obs = util.obs_data(true_params, seed=seed)
obs_stats = util.obs_stats(true_params, seed=seed)    

res = infer.APT(
   g, 
   obs=obs_stats, 
   n_hiddens=[50,50],
   seed=seed,        
   pilot_samples=1000,
   svi=False,
   n_components=1,
   prior_norm=True,
)

logs, tds, posteriors = res.run(
    n_train=10000,
    n_rounds=1,
    minibatch=100,
    epochs=1000,
    silent_fail=False,
    proposal='mog',
)

In [None]:
# mcmc reference
try:
    sam = np.load('results/sam_lfs.npz')['arr_0']
except:
    sam = util.pg_mcmc(true_params, obs)
    np.savez('results/sam_lfs.npz', sam)    

In [None]:
# quick inspection of posteriors over rounds
for r in range(1):
    posterior = posteriors[r]
    plot_pdf(posterior, 
        lims=[-2,2], 
        samples=sam, 
        gt=true_params, 
        figsize=(14,14));

In [None]:
# saving results
filename = 'results/single_round_lfs'
np.save(filename, 
        {
        'posterior' : posteriors[-1],
        'prior'    : g.prior        
    })

filename = 'results/ground_truth_data_lfs'
np.save(filename, 
        {
        'obs_stats' : obs_stats,
        'pars_true'    : true_params,
        'labels_params' : labels_params
        
    })