In [None]:
%matplotlib inline

import delfi.distribution as dd
import delfi.generator as dg
import delfi.inference as infer
import delfi.utils.io as io
import delfi.summarystats as ds
import lfimodels.glm.utils as utils
import matplotlib.pyplot as plt
import numpy as np

from lfimodels.glm.GLM import GLM
from lfimodels.glm.GLMStats import GLMStats
from delfi.utils.viz import plot_pdf

seed = 42
m = GLM(seed=seed)
p = utils.smoothing_prior(n_params=m.n_params, seed=seed)
s = GLMStats(n_summary=m.n_params)
g = dg.Default(model=m, prior=p, summary=s)

true_params, labels_params = utils.obs_params()
obs = utils.obs_data(true_params, seed=seed)
obs_stats = utils.obs_stats(true_params, seed=seed)

rerun = False  # if False, will try loading file from disk

try:
    assert rerun == False, 'rerun requested'
    sam = np.load('sam.npz')['arr_0']
except:
    sam = utils.pg_mcmc(true_params, obs)
    np.savez('sam.npz', sam)
    
g = dg.Default(model=m, prior=p, summary=s)

res = infer.CDELFI(g, 
                   obs=obs_stats, 
                   n_hiddens=[50], 
                   seed=seed, 
                   reg_lambda=0.01,
                   pilot_samples=1000,
                   svi=True,
                   prior_norm=False)

logs, tds, posteriors = res.run(n_train=5000, 
                                n_rounds=5, 
                                minibatch=100, 
                                epochs=1000
                                #round_cl=3
                               )


In [None]:
for r in range(5):
    posterior = posteriors[r]
    plot_pdf(posterior.xs[0], 
             lims=[-2,2], 
             samples=sam, 
             gt=true_params, 
             figsize=(9,9));

In [None]:
posterior = posteriors[-1]
plot_pdf(posterior.xs[0], 
         lims=[-3,3], 
         levels=(0.01, 0.68, 0.95),
         samples=sam, 
         gt=true_params, 
         figsize=(10,10));

In [None]:

filename = 'glm_5k_elife_prior_gp_run_1_round5_param10_CDELFI_posterior'
np.save(filename, 
        {
        'posterior' : posteriors[-1],
        'proposal' :  posteriors[-2],
        'prior'    : g.prior
        
    })

In [None]:

filename = 'ground_truth_data'
np.save(filename, 
        {
        'obs_stats' : obs_stats,
        'pars_true'    : true_params,
        'labels_params' : labels_params
        
    })

In [None]:
posterior