In [None]:
import delfi.distribution as dd
import delfi.generator as dg
import delfi.inference as infer
import matplotlib.pyplot as plt
import numpy as np
import pickle
import time

from lfimodels.balancednetwork.BalancedNetwork import BalancedNetwork
from lfimodels.balancednetwork.BalancedNetworkStats import BalancedNetworkStats

%matplotlib inline

## Setup delfi objects 

First we define the simulator model with numbers of parameters, the prior over those parameters, the summary stats and the generator that combines the those three objects. 

The generator also takes the proposal as an argument. This is the proposal prior from which new parameters are sampled instead of sampling from the overall prior. For now it is set to None, so that .gen() samples from the overall prior. 

In [None]:
n_params = 1

m = BalancedNetwork(dim=n_params, first_port=8010, verbose=False, n_servers=3, duration=3.)
p = dd.Uniform(lower=[1.], upper=[5.])
s = BalancedNetworkStats()
g = dg.Default(model=m, prior=p, summary=s, proposal=None)

## Start the server and make a first test run - our observation

For SNPE we can define an actual observation of the data by running the simulator once. The resulting summary stats is $x_{obs}$, the underlaying parameters are the true $\theta$ that we want to discover. 

When SNPE is run over more than one round, the estimated posterior after one round is evaluated at $x_{obs}$ to give the new proposal prior for the next round. Beside the use of SVI this is the main difference to the basic inference scheme. 

In [None]:
m.start_server()

In [None]:
true_params, stats_obs = g.gen(1)
print(true_params, stats_obs)

## Define the inference method as SNPE

In [None]:
res = infer.SNPE(g, obs=stats_obs)

In [None]:
# run the inference machine
ntrain = 50
nrounds = 1
out = res.run(n_train=ntrain, n_rounds=nrounds)
m.stop_server()

In [None]:
plt.figure(figsize=(15, 5))
plt.subplot(121)
plt.plot(out[0]['trn_iter'], out[0]['trn_val'])
plt.title('loss on first round')
plt.subplot(122)
plt.plot(out[1]['trn_iter'], out[1]['trn_val'])
plt.title('loss on second round');

## Done 

We now have an estimate of the posterior over the parameter $R_{ee}$ given the observed data $x_{obs}$. How can we check the performance? 

## Compare to true parameter 

We have generated the observed data ourselves so we do have the true parameter. The mean of the posterior should be close to it when evaluated for $x=x_{obs}$

In [None]:
# get the posterior at the observed data
posterior = res.predict(stats_obs)

In [None]:
mean = posterior.xs[0].m[0]
std = np.sqrt(posterior.xs[0].S[0][0])
print(mean, std)

In [None]:
# set up a dict for saving the results 
result_dict = dict(true_params=true_params, stats_obs=stats_obs, nrouns=nrounds, ntrain=ntrain,
                   posterior=dict(mean=mean, std=std))
filename = 'data/{}_snpe_r{}_ntrain{}'.format(time.time(), nrounds, ntrain).replace('.', '') + '.p'
with open(filename, 'wb') as handle:
    pickle.dump(result_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
theta = np.linspace(0, 5, 1000)
post_pdf = st.norm.pdf(x=theta, loc=mean, scale=std)

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(theta, post_pdf, label='$\hat{p}( theta | x=x_{obs})$')
plt.axvline(x=true_params[0], label='true theta', linestyle='--', color='C1')
plt.legend()
plt.xlabel('$R_{ee}$');

## Posterior predictive checking 

Generate samples from the posterior and simulate them. The resulting data should be near the observed data. 

In [None]:
m.start_server()
# generate theta +-3, 2, 1 0 stds away from mode
thetas = [mean + i * std for i in [-3, -2, -1, 0, 1, 2, 3]]
sum_stats = []
# simulate and collect sum stats
data = m.gen(thetas)
for datum in data: 
    sum_stats.append(s.calc(datum))
m.stop_server()

In [None]:
# plot the resulting stats with the observed stats 
plt.figure(figsize=(10, 4))
sum_stats = np.array(sum_stats).squeeze()
plt.axvline(x=mean, linestyle='--', color='C4')
plt.axvline(x=true_params, linestyle='--', color='C5')
plt.plot(thetas, sum_stats, '-o')
plt.plot(true_params, stats_obs, '*')
plt.legend(['posterior mean', 'true theta', 'ff1', 'ff2', 'ff3', 'rate mean', 'rate median'])
plt.xlabel('theta')
plt.ylabel('stats')
plt.title('Summary stats +-3 std around the posterior mean')