In [None]:
%%capture
%matplotlib inline

import numpy as np
import matplotlib.pyplot as plt
import timeit

import delfi.distribution as dd
import delfi.inference as infer
import delfi.generator as dg

from delfi.simulator import GaussMixture
import delfi.summarystats as ds
from delfi.utils.viz import plot_pdf
from delfi.utils.math import MoGL2sq

In [None]:
seed=45
lower = np.array([-1.,-1.]) * 1.
upper = np.array([1., 1.]) * 1.

# SNPE parameters    
n_components = 5
    
# training schedule
n_train=[3000, 3000]
n_rounds=2

# fitting setup
minibatch=100
epochs=250

# network setup
n_hiddens=[20,20]
reg_lambda=0.01

# convenience
pilot_samples=0
svi=True
verbose=True
prior_norm=False
init_norm=False


In [None]:
p_true = dd.MoG(a=[0.5, 0.5], ms=[np.asarray([5., 5.]), np.asarray([-5., -5.])], Ss=[1.0*np.eye(2), 1.0*np.eye(2)])
p_true.ndim=2

return_abs = False

# basic approach to controlling generator seeds
def init_g(seed):
    m = GaussMixture(dim=2, bimodal=True, return_abs=return_abs, noise_cov=[0.01, 0.01], seed=seed)
    p = dd.Uniform(lower=lower, upper=upper, seed=seed)
    s = ds.Identity()
    return dg.Default(model=m, prior=p, summary=s)

g = init_g(seed=seed)

obs_stats = np.array([[.5, -.5]])

trn_data = g.gen(1000)
plt.subplot(1,2,1)
plt.plot(trn_data[1], trn_data[0], '.')
plt.show()


In [None]:
g = init_g(seed=seed)

res_A = infer.SNPEC(g, 
                 obs=obs_stats, 
                 n_hiddens=n_hiddens, 
                 n_components=n_components,
                 seed=seed, 
                 reg_lambda=reg_lambda,
                 pilot_samples=pilot_samples,
                 svi=svi,
                 verbose=verbose,
                 init_norm=init_norm,
                 prior_norm=prior_norm)

t = timeit.time.time()

logs_A, tds_A, posteriors_A = res_A.run(n_train=n_train,
                    proposal='mog',
                    n_rounds=n_rounds, 
                    minibatch=minibatch, 
                    epochs=epochs)

print(timeit.time.time() -  t)


In [None]:
for r in range(n_rounds):
    fig,_ = plot_pdf(posteriors_A[r],
             lims=[[lower[0],upper[0]],[lower[1],upper[1]]],
             ticks=True,
             resolution=100,
             figsize=(16,16));
    fig.suptitle('final posterior estimate vs MCMC samples and prior', fontsize=14)
    plt.figure()
    plt.plot(logs_A[r]['loss'])


In [None]:
plt.plot([x.mean[0] for i, x in enumerate(posteriors_A[0].xs)], [x.mean[1] for i, x in enumerate(posteriors_A[0].xs)],'.')
plt.xlabel('mean x_1')
plt.ylabel('mean x_2')
plt.figure()
plt.plot(posteriors_A[0].a)
plt.xlabel('component index')
plt.ylabel('mixture weight at x_0')