In [None]:
import delfi.distribution as dd
import delfi.generator as dg
import delfi.inference as infer
import matplotlib.pyplot as plt
import numpy as np
import os
import pickle
import time 

from lfimodels.balancednetwork.BalancedNetworkSimulator import BalancedNetwork
from lfimodels.balancednetwork.BalancedNetworkStats import BalancedNetworkStats
from lfimodels.balancednetwork.BalancedNetworkGenerator import BalancedNetworkGenerator

%matplotlib inline

### Define the objects we need for the simulation: 

model, prior, summarystats, a generator to combine it all

The parameter we use for the balanced network simulation is the clustering coef $R_{ee}$. For now we want this to be very close around 1. 

In [None]:
n_params = 1

m = BalancedNetwork(dim=n_params, first_port=8010, 
                    verbose=False, n_servers=4, duration=3.)
p = dd.Uniform(lower=[1.], upper=[5.])
s = BalancedNetworkStats()

In [None]:
g = BalancedNetworkGenerator(model=m, prior=p, summary=s)

### Make a test by running one sample = one simulation 

The generator returns the params used and the correspoding stats: the theta-x tuples used by the MDN. 

In [None]:
# here we set the true params 
true_params = [[2.5]]
# run forward model 
data = m.gen(true_params)
# get summary stats
stats_obs = s.calc(data[0])

In [None]:
print(true_params, stats_obs)

In [None]:
plt.plot(true_params, stats_obs, 'o')
plt.legend(['rate', 'ff', 'rho']);

In [None]:
# set up inference
res = infer.Basic(g)
ntrain = 50
# run with N samples, for N=100 this will take 2000s = 0.6h
out = res.run(ntrain, epochs=1000, minibatch=50)

In [None]:
plt.plot(out['trn_iter'], out['trn_val'])

## Test the result: generate an observation and compare it to simulations from the posterior

In [None]:
# evaluate the posterior at the observed data 
posterior = res.predict(stats_obs)
# get the mean and std 
mean = posterior.xs[0].m[0]
std = np.sqrt(posterior.xs[0].S[0][0])
print(mean, std)

In [None]:
# set up a dict for saving the results 
save_data = False 
path_to_save_folder = 'data/'
if save_data and os.path.exists(path_to_save_folder): 
    nrounds=1
    result_dict = dict(true_params=true_params, stats_obs=stats_obs, nrouns=nrounds, ntrain=ntrain,
                       posterior=dict(mean=mean, std=std))
    filename = os.path.join(path_to_save_folder, 
                           '{}_basic_ntrain{}'.format(time.time(), ntrain).replace('.', '') + '.p')
    with open(filename, 'wb') as handle:
        pickle.dump(result_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
import scipy.stats as st
theta = np.linspace(0, 5, 1000)
post_pdf = st.norm.pdf(x=theta, loc=mean, scale=std)

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(theta, post_pdf, label='$\hat{p}( theta | x=x_{obs})$')
plt.axvline(x=true_params[0], label='true theta', linestyle='--', color='C1')
plt.legend()
plt.xlabel('$R_{ee}$');