In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
import delfi.distribution as dd
import numpy as np
import pickle
import time
import scipy.stats as st
import os 
from lfimodels.balancednetwork.BalancedNetworkSimulator import BalancedNetwork
from lfimodels.balancednetwork.BalancedNetworkStats import BalancedNetworkStats

In [None]:
path_to_save_folder = 'figures/'
save_figure = True

inference_method = 'snpe'
fileformat = '.pdf'
dpi = 400
filenames = ['1509979643087894_snpe_ree_r5_ntrain100']
#titles = ['lfi-models sept 20, delfi nov 1, svi=True', 'lfi-models sept 20, delfi nov 1, svi=False', 
#         'all nov 1, svi=False, 8 sum stats', 'all nov 1, svi=False, 19 sum stats']

param_name = r'$R_{ee}$'

In [None]:
plt.figure(figsize=(15, 5))
legend_size = 13
label_size = 16

for idx, filename in enumerate(filenames): 
    time_str = filename[:filename.find('_')]
    fullname = 'data/' + filename + '.p'

    # load data 
    with open(fullname, 'rb') as handle:
        result_dict = pickle.load(handle)

    print(result_dict.keys())
    # unpack values 
    try: 
        true_params, stats_obs, nrounds, ntrain, posterior, out, trn_data, prior, posteriors = result_dict.values()
        theta = np.linspace(prior.lower[0], prior.upper[0], 1000)
    except ValueError:
        try: 
            true_params, stats_obs, nrounds, ntrain, posterior, out, trn_data, prior = result_dict.values()            
            theta = np.linspace(prior.lower[0], prior.upper[0], 1000)
        except ValueError: 
            true_params, stats_obs, nrounds, ntrain, posterior, out, trn_data= result_dict.values()            
            theta = np.linspace(0, 5, 1000)
    assert len(true_params) == 1, 'this notebook is for inference on 1 parameter: len(params)'

    nrounds = len(out)
    
    print(stats_obs.shape)

    # extract the posterior 
    n_components = len(posterior.a)
    means = [posterior.xs[c].m for c in range(n_components)]
    Ss = [posterior.xs[c].S for c in range(n_components)]

    sub_means = [[means[c][0]] for c in range(n_components)]
    sub_cov = np.asarray([Ss[c] for c in range(n_components)])
    pdf = dd.mixture.MoG(a=posterior.a, ms=sub_means, Ss=sub_cov)
    post_pdf = pdf.eval(theta[:, np.newaxis], log=False)

    plt.subplot(len(filenames), 1, idx + 1)
    plt.plot(theta, post_pdf, label='posterior')
    plt.axvline(x=true_params[0], label=r'observed {}'.format(param_name), linestyle='--', color='C1')
    plt.legend(prop=dict(size=legend_size))
    plt.xlabel(param_name, fontsize=label_size)
    if idx == 0: 
        plt.ylabel(r'$\hat{p}( \theta | x=x_{o})$', fontsize=label_size)
    
if len(filenames) > 1: 
    filename = time_str + '_combined_results'
else: 
    filename = time_str + '_{}_posterior_r{}_ntrain{}'.format(inference_method, nrounds, ntrain)
    
plt.tight_layout()

if save_figure and os.path.exists(path_to_save_folder): 
    plt.savefig(os.path.join(path_to_save_folder, filename + fileformat), dpi=dpi);

In [None]:
plt.figure(figsize=(15, 5))

for idx, posterior in enumerate(posteriors): 
    n_components = len(posterior.a)
    means = [posterior.xs[c].m for c in range(n_components)]
    Ss = [posterior.xs[c].S for c in range(n_components)]

    sub_means = [[means[c][0]] for c in range(n_components)]
    sub_cov = np.asarray([Ss[c] for c in range(n_components)])
    pdf = dd.mixture.MoG(a=posterior.a, ms=sub_means, Ss=sub_cov)
    post_pdf = pdf.eval(theta[:, np.newaxis], log=False)

    plt.plot(theta, post_pdf, label='round {}'.format(idx + 1))
    plt.xlabel(param_name, fontsize=label_size)
    
plt.axvline(x=true_params[0], label=r'observed {}'.format(param_name), linestyle='--', color='k')
plt.legend(prop=dict(size=legend_size))
plt.title('Posteriors over rounds', fontsize=label_size);

plt.tight_layout()
filename = time_str + '_{}_posteriors_over_rounds_r{}_ntrain{}'.format(inference_method, nrounds, ntrain)
if save_figure and os.path.exists(path_to_save_folder): 
    plt.savefig(os.path.join(path_to_save_folder, filename + fileformat), dpi=dpi);

## Plotting the summary stats 

In [None]:
plt.figure(figsize=(15, 5))
p = trn_data[0][0].flatten()
s = trn_data[0][1]
plt.plot(p, s, 'o')
plt.legend(['ff', 'r', 'k', 'rho']);

In [None]:
trn_data[0][1]

In [None]:
posterior.mean

## Posterior predictive checking 

Generate samples from the posterior and simulate them. The resulting data should be near the observed data. 

In [None]:
m = BalancedNetwork(inference_param='wee', n_servers=1, duration=3., first_port=8010, 
                    save_raster_plots=True, 
                    save_folder='/Users/Jan/Dropbox/Master/mackelab/code/lfi-experiments/balancednetwork/figures/simulation_raster_plots/')
s = BalancedNetworkStats(n_workers=2)

In [None]:
# calculate mean of mixture 
mean = np.sum([a * m[0] for a, m in zip(posterior.a, means)])
# calculate variance of mixture 
sum_squared = np.sum([a * m[0]**2 for a, m in zip(posterior.a, means)])
squared_sum = mean**2
std = np.sum([a * Ss[0][0] for a, m in zip(posterior.a, Ss)]) + sum_squared - squared_sum

In [None]:
m.start_server()
# generate theta +-3, 2, 1 0 stds away from mode
thetas = [mean + i * std for i in [0]]
sum_stats = []
# simulate and collect sum stats
data = m.gen(thetas)

In [None]:
sum_stats = s.calc_all(data)

In [None]:
# plot the resulting stats with the observed stats 
plt.figure(figsize=(8, 4))
sum_stats = np.array(sum_stats).squeeze()
plt.axvline(x=mean, linestyle='--', color='C4')
plt.axvline(x=true_params, linestyle='--', color='C5')

plt.plot(thetas, sum_stats, 'o')
#plt.plot(true_params, stats_obs, '*')

plt.legend(['posterior mean', 'true theta', 'ff1', 'rate mean', 'kurtosis', 'rel corr'])
plt.xlabel('theta')
plt.ylabel('stats')
plt.title('Summary stats +-3 std around the posterior mean')

filename = '{}_ppch_r{}_ntrain{}'.format(inference_method, nrounds, ntrain)
plt.tight_layout()
if save_figure and os.path.exists(path_to_save_folder): 
    plt.savefig('figures/' + filename + fileformat, dpi=dpi);

In [None]:
thetas

In [None]:
sum_stats[0].shape

## Compare only simulation of mean posterior vs. original parameters

In [None]:
jee = posterior.mean[0]
stats = sum_stats[0].flatten()

In [None]:
labels = ['FF', 'rate', 'kurtosis', '+ correlation', 'FF inh', 'rate inh', 'kurtosis inh', '+ correlation inh']
plt.figure(figsize=(15, 5))
plt.plot(stats, '-o', label='posterior mean')
ax = plt.plot(stats_obs.T, '-o', label='observed')
plt.xticks(np.arange(8), labels)
plt.gca().tick_params(axis='x', labelsize=15)
plt.legend(prop=dict(size=legend_size + 2))

filename = '{}_ppch_ntrain{}_{}'.format(inference_method, ntrain, param_name)

plt.tight_layout()
if save_figure and os.path.exists(path_to_save_folder): 
    plt.savefig('figures/' + filename + fileformat, dpi=dpi);

In [None]:
plt.savefig?