In [None]:
%%capture
%matplotlib inline

import numpy as np
import matplotlib.pyplot as plt
import timeit

import delfi.distribution as dd
import delfi.inference as infer
import delfi.generator as dg

from delfi.simulator import GaussMixture
import delfi.summarystats as ds
from delfi.utils.viz import plot_pdf

In [None]:
thetalims = (-5., 5.)
p_true = dd.MoG(a=[0.5, 0.5], ms=[np.asarray([-1.]), np.asarray([1.])], Ss=[1.0*np.eye(1), 1.0*np.eye(1)])
p_true.ndim=2

seed=2002

return_abs = False

# basic approach to controlling generator seeds
def init_g(seed):
    m = GaussMixture(dim=1, bimodal=True, return_abs=return_abs, noise_cov=[0.1, 0.1], seed=seed)
    p = dd.Uniform(lower=[thetalims[0]], upper=[thetalims[1]], seed=seed)
    s = ds.Identity()
    return dg.Default(model=m, prior=p, summary=s)

g = init_g(seed=seed)

obs_stats = np.array([[2.]])

trn_data = g.gen(1000)
plt.subplot(1,2,1)
plt.plot(trn_data[1], trn_data[0], '.')
plt.xlabel('x')
plt.ylabel('$\\theta$')
plt.show()

# SNPE parameters
    
n_components = 5
    
# training schedule
n_train=3000
n_rounds=1

# fitting setup
minibatch=100
epochs=250

# network setup
n_hiddens=[20,20]
reg_lambda=0.01

# convenience
pilot_samples=0
svi=False
verbose=True
prior_norm=False
init_norm=False


In [None]:
g = init_g(seed=seed)

res_A = infer.SNPEC(g, 
                 obs=obs_stats, 
                 n_hiddens=n_hiddens, 
                 n_components=n_components,
                 seed=seed, 
                 reg_lambda=reg_lambda,
                 pilot_samples=pilot_samples,
                 svi=svi,
                 verbose=verbose,
                 init_norm=init_norm,
                 prior_norm=prior_norm)

t = timeit.time.time()

logs_A, tds_A, posteriors_A = res_A.run(n_train=n_train, 
                    n_rounds=n_rounds, 
                    minibatch=minibatch, 
                    epochs=epochs)

print(timeit.time.time() -  t)


In [None]:
for r in range(n_rounds):
    fig,_ = plot_pdf(posteriors_A[r],             
             lims=[thetalims],
             ticks=True)
    fig.suptitle('final posterior estimate vs MCMC samples and prior', fontsize=14)
    fig.show()
