In [None]:
import delfi.distribution as dd
import delfi.generator as dg
import delfi.inference as infer
import matplotlib.pyplot as plt
import numpy as np
import os
import pickle
import time 

from lfimodels.balancednetwork.BalancedNetworkSimulator import BalancedNetwork
from lfimodels.balancednetwork.BalancedNetworkStats import BalancedNetworkStats
from lfimodels.balancednetwork.BalancedNetworkGenerator import BalancedNetworkGenerator

%matplotlib inline

## Setup delfi objects 

First we define the simulator model with numbers of parameters, the prior over those parameters, the summary stats and the generator that combines the those three objects. 

The generator also takes the proposal as an argument. This is the proposal prior from which new parameters are sampled instead of sampling from the overall prior. For now it is set to None, so that .gen() samples from the overall prior. 

In [None]:
n_params = 1
n_cores_to_use = 4

m = BalancedNetwork(dim=n_params, first_port=8010, 
                    verbose=True, n_servers=n_cores_to_use, duration=3.)
p = dd.Uniform(lower=[1.], upper=[5.])
s = BalancedNetworkStats(n_workers=n_cores_to_use)
g = BalancedNetworkGenerator(model=m, prior=p, summary=s)

## Start the server and make a first test run - our observation

For SNPE we can define an actual observation of the data by running the simulator once. The resulting summary stats is $x_{obs}$, the underlaying parameters are the true $\theta$ that we want to discover. 

When SNPE is run over more than one round, the estimated posterior after one round is evaluated at $x_{obs}$ to give the new proposal prior for the next round. Beside the use of SVI this is the main difference to the basic inference scheme. 

In [None]:
# here we set the true params 
true_params = [[2.5]]
# run forward model 
data = m.gen(true_params)
# get summary stats
stats_obs = s.calc(data[0])

In [None]:
print(true_params, stats_obs)

## Define the inference method as SNPE

In [None]:
res = infer.SNPE(g, obs=stats_obs, n_components=3, pilot_samples=0)

In [None]:
# run the inference machine
ntrain = 10
nrounds = 1
out, trn_data = res.run(n_train=ntrain, n_rounds=nrounds, minibatch=10)

In [None]:
for i, r in enumerate(out): 
    plt.figure(figsize=(15, 5))
    plt.plot(r['loss'], label='round {}'.format(i + 1))
plt.title('loss over iterations')
plt.legend();

## Done 

We now have an estimate of the posterior over the parameter $R_{ee}$ given the observed data $x_{obs}$. How can we check the performance? 

## Compare to true parameter 

We have generated the observed data ourselves so we do have the true parameter. The mean of the posterior should be close to it when evaluated for $x=x_{obs}$

In [None]:
# evaluate the posterior at the observed data 
posterior = res.predict(stats_obs)

In [None]:
mean = posterior.xs[0].m[0]
std = np.sqrt(posterior.xs[0].S[0][0])
print(mean, std)

In [None]:
# set up a dict for saving the results 
save_data = True
path_to_save_folder = 'data/'  # has to exist on your local path

if save_data and os.path.exists(path_to_save_folder): 
    nrounds=1
    result_dict = dict(true_params=true_params, stats_obs=stats_obs, nrouns=nrounds, ntrain=ntrain,
                       posterior=posterior, out=out, trn_data=trn_data)
    
    filename = os.path.join(path_to_save_folder, 
                           '{}_snpe_ntrain{}'.format(time.time(), ntrain).replace('.', '') + '.p')
    with open(filename, 'wb') as handle:
        pickle.dump(result_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
    print(filename)

In [None]:
# extract the posterior 
n_components = len(posterior.a)
means = [posterior.xs[c].m for c in range(n_components)]
Ss = [posterior.xs[c].S for c in range(n_components)]

In [None]:
theta = np.linspace(1, 5, 1000)
sub_means = [[means[c][0]] for c in range(n_components)]
sub_cov = np.asarray([Ss[c] for c in range(n_components)])
pdf = dd.mixture.MoG(a=posterior.a, ms=sub_means, Ss=sub_cov)
post_pdf = pdf.eval(theta[:, np.newaxis], log=False)

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(theta, post_pdf, label='$\hat{p}( theta | x=x_{obs})$')
plt.axvline(x=true_params[0], label='true theta', linestyle='--', color='C1')
plt.legend()
plt.xlabel('$R_{ee}$');