In [None]:
import delfi.distribution as dd
import delfi.generator as dg
import delfi.inference as infer
import matplotlib.pyplot as plt
import numpy as np
import os
import pickle
import scipy.stats as st
import time 

from lfimodels.balancednetwork.BalancedNetworkSimulator import BalancedNetwork
from lfimodels.balancednetwork.BalancedNetworkStats import BalancedNetworkStats
from lfimodels.balancednetwork.BalancedNetworkGenerator import BalancedNetworkGenerator

%matplotlib inline

### Define the objects we need for the simulation: 

model, prior, summarystats, a generator to combine it all

In [None]:
n_params = 4
n_cores_to_use = 4

m = BalancedNetwork(dim=n_params, first_port=8010, 
                    verbose=True, n_servers=n_cores_to_use, duration=3.)
p = dd.Uniform(lower=[0.01] * n_params, upper=[0.1] * n_params)
s = BalancedNetworkStats(n_workers=n_cores_to_use)
g = BalancedNetworkGenerator(model=m, prior=p, summary=s)

### Make an observation by running the parameters from the paper

In [None]:
# here we set the true params 
true_params = [[0.024, 0.045, 0.014, 0.057]]  # params from the paper 
# run forward model 
data = m.gen(true_params)
# get summary stats
stats_obs = s.calc(data[0])

In [None]:
print(true_params, stats_obs)

In [None]:
# set up inference
res = infer.Basic(g, n_components=3, pilot_samples=0)

In [None]:
ntrain = 10
# run with N samples
out, trn_data = res.run(ntrain, epochs=1000, minibatch=10)

In [None]:
plt.plot(out['loss']);

In [None]:
m.stop_server()

## Test the result: plot the result posterior

In [None]:
# evaluate the posterior at the observed data 
posterior = res.predict(stats_obs)

In [None]:
# set up a dict for saving the results 
save_data = True
path_to_save_folder = 'data/'  # has to exist on your local path

if save_data and os.path.exists(path_to_save_folder): 
    nrounds=1
    result_dict = dict(true_params=true_params, stats_obs=stats_obs, nrouns=nrounds, ntrain=ntrain,
                       posterior=posterior, out=out, trn_data=trn_data)
    
    filename = os.path.join(path_to_save_folder, 
                           '{}_basic_ntrain{}'.format(time.time(), ntrain).replace('.', '') + '.p')
    with open(filename, 'wb') as handle:
        pickle.dump(result_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
    print(filename)

In [None]:
# extract the posterior 
n_components = len(posterior.a)
means = [posterior.xs[c].m for c in range(n_components)]
Ss = [posterior.xs[c].S for c in range(n_components)]

In [None]:
def get_delfi_grid_pdf(theta, delfi_obj, log=False): 
    """
    Get pdf of a whole grid of values 
    """
    x, y = np.meshgrid(theta, theta)
    z = np.zeros_like(x)
    for i in range(z.shape[0]): 
        # arrange the samples in rows 
        v = np.array([x[i, :], y[i, :]]).T
        # evaluate the pdf for rows of z
        z[i, :] = delfi_obj.eval(x=v, log=log)
    return x, y, z

In [None]:
dim_params = n_params 

plt.figure(figsize=(15, 10))
theta = np.linspace(0.01, 0.1, 100)
weight_labels = ['$J^{EE}$', '$J^{EI}$', '$J^{IE}$', '$J^{II}$']
plot_idx = 1
for i in range(dim_params): 
    for j in range(dim_params): 
        if i==j: 
            
            # define a 1D MoG
            sub_means = [[means[c][i]] for c in range(n_components)]
            sub_cov = [[[Ss[c][i, j]]] for c in range(n_components)]
            pdf = dd.mixture.MoG(a=posterior.a, ms=sub_means, Ss=sub_cov)
            post_pdf = pdf.eval(theta[:, np.newaxis], log=False)
            
            plt.subplot(dim_params, dim_params, plot_idx)            
            plt.plot(theta, post_pdf)
            plt.axvline(x=true_params[0][i], color='C1', label=weight_labels[i])
            plt.legend(prop=dict(size=12))
            
        elif i < j:            
            # define a 2D MoG
            sub_means = [[posterior.xs[c].m[i], posterior.xs[c].m[j]] for c in range(n_components)]
            sub_cov = [[[posterior.xs[c].S[i, i], posterior.xs[c].S[i, j]], 
                       [posterior.xs[c].S[j, i], posterior.xs[c].S[j, j]]] for c in range(n_components)]
            pdf = dd.mixture.MoG(a=posterior.a, ms=sub_means, Ss=sub_cov)            
            x, y, z = get_delfi_grid_pdf(theta, delfi_obj=pdf, log=False)
        
            plt.subplot(dim_params, dim_params, plot_idx)
            plt.contourf(x, y, z)
            plt.plot([true_params[0][i]], [true_params[0][j]], 'o', color='C1')
        plot_idx += 1 