In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl
%matplotlib inline
import delfi.distribution as dd
import numpy as np
import pickle
import time
import scipy.stats as st
import os 
from lfimodels.balancednetwork.BalancedNetworkSimulator import BalancedNetwork
from lfimodels.balancednetwork.BalancedNetworkStats import BalancedNetworkStats, Identity

mpl_params = {'legend.fontsize': 14,
                      'axes.titlesize': 20,
                      'axes.labelsize': 17,
                      'xtick.labelsize': 12,
                      'ytick.labelsize': 12,
             'figure.figsize' : (15, 5)}

mpl.rcParams.update(mpl_params)

In [None]:
save_figure = True

inference_method = 'snpe'
fileformat = '.png'
dpi = 300
param_label = r'$w_{ei}$'
inference_param_name = 'wei'

# set name to find the folder 
simulation_name = '15140527888578181_snpe_cJie_r10_n1000_rcl5'
path_to_save_folder = os.path.join('../results', simulation_name)

In [None]:
# load data 
time_str = simulation_name[:simulation_name.find('_')]
fullname = os.path.join(path_to_save_folder, simulation_name + '.p')


# load data 
with open(fullname, 'rb') as handle:
    result_dict = pickle.load(handle)

print(result_dict.keys())

try: 
    true_params, stats_obs, nrounds, ntrain, posterior, out, trn_data, prior, posteriors = result_dict.values()
    theta = np.linspace(prior.lower[0], prior.upper[0], 1000)
except ValueError:
    try: 
        true_params, stats_obs, nrounds, ntrain, posterior, out, trn_data, prior = result_dict.values()            
        theta = np.linspace(prior.lower[0], prior.upper[0], 1000)
    except ValueError: 
        true_params, stats_obs, nrounds, ntrain, posterior, out, trn_data= result_dict.values()            
        theta = np.linspace(0, 5, 1000)

## Plot single figure with posterior over rounds 

In [None]:
for idx, posterior in enumerate(posteriors): 
    n_components = len(posterior.a)
    means = [posterior.xs[c].m for c in range(n_components)]
    Ss = [posterior.xs[c].S for c in range(n_components)]

    sub_means = [[means[c][0]] for c in range(n_components)]
    sub_cov = np.asarray([Ss[c] for c in range(n_components)])
    pdf = dd.mixture.MoG(a=posterior.a, ms=sub_means, Ss=sub_cov)
    post_pdf = pdf.eval(theta[:, np.newaxis], log=False)
    
    if idx == nrounds-1: 
        plt.plot(theta, post_pdf, label='round {}, cl'.format(idx + 1))
    elif idx > 4: 
        plt.plot(theta, post_pdf, label='round {}, cl'.format(idx + 1), alpha=.6, linestyle='-')
    else:
        plt.plot(theta, post_pdf, label='round {}'.format(idx + 1), alpha=.6, linestyle='--')
    plt.xlabel(param_label)
    
plt.axvline(x=true_params[0], label=r'observed {}'.format(param_label), linestyle='-', color='k', alpha=.5)
plt.legend()
plt.title('Posteriors over rounds');

plt.tight_layout()
filename = time_str + '_{}_posteriors'.format(inference_param_name)
if save_figure and os.path.exists(path_to_save_folder): 
    destination = os.path.join(path_to_save_folder, filename + fileformat)
    plt.savefig(destination, dpi=dpi)
    print('saved file in {}'.format(destination));

## Plotting the summary stats 

In [None]:
plt.figure(figsize=(15, 10))
p = trn_data[0][0].flatten()
ss = trn_data[0][1]
titles = ['Fano factor', 'mean rate', 'kurtosis', 'positive pairwise corr', 
          '0 lag auto/cross corr', '10 lag auto/cross corr', '20 lag auto/cross corr']
ylabels = ['ff', 'rate', 'kurt', 'corr prop',
                                    'corr0',
                                    'corr10',
                                    'corr20']

labels = ['ffE', 'ffI', 'rateE', 'rateI', 'kurtE', 'kurtI', 'corrE', 'corrI',
                                    'EE', 'EI', 'II',
                                    'EE', 'EI', 'IE', 'II',
                                    'EE', 'EI', 'IE', 'II']
plot_idx = [1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7]

for idx, s in enumerate(ss.T): 
    plt.subplot(3, 3, plot_idx[idx])
    plt.plot(p, s, 'o', label=labels[idx], alpha=.6)
    plt.title(titles[plot_idx[idx] - 1])
    plt.ylabel(ylabels[plot_idx[idx] - 1])
    if idx > 7: 
        plt.xlabel(param_label)
    plt.legend()
    plt.axvline(x=true_params[0])
plt.tight_layout()

filename = time_str + '_{}_summary_stats'.format(inference_param_name)
if save_figure and os.path.exists(path_to_save_folder): 
    destination = os.path.join(path_to_save_folder, filename + fileformat)
    plt.savefig(destination, dpi=200)
    print('saved file in {}'.format(destination));

In [None]:
posterior.mean

## Posterior predictive checking 

Generate samples from the posterior and simulate them. The resulting data should be near the observed data. 

In [None]:
# make sure to change this accordingly!!! 
ree = 2.5
# set the seed by hand or use the one from the simulation if available. 
seed = 3
param_names = [inference_param_name]
m = BalancedNetwork(inference_params=param_names, dim=1, first_port=8100,
                    verbose=True, n_servers=3, duration=3., parallel=True,
                    estimate_time=False, calculate_stats=True, seed=seed)
s = Identity(seed=seed)
# generate observed stats from true params using the same seed 
stats_obs = s.calc_all(m.gen([true_params]))

In [None]:
# generate a few samples and simulate 
n_samples = 5
params = []
# append the mean
params.append(posterior.mean)
for i in range(n_samples): 
    params.append(posterior.gen())
params

In [None]:
# simulate 
data = m.gen(params)
m.stop_server()

In [None]:
# calculate summary stats
stats = np.array(s.calc_all(data)).squeeze()
stats_normed = ((stats - stats_obs) / stats_obs).squeeze()

In [None]:
# additionally generate 5 samples using the true params but different seeds: 
param_names = [inference_param_name]
m = BalancedNetwork(inference_params=param_names, dim=1, first_port=8100,
                    verbose=True, n_servers=3, duration=3., parallel=True,
                    estimate_time=False, calculate_stats=True, seed=None)
s = Identity(seed=None)
# simulate and calc stats
stats_obs_var = s.calc_all(m.gen(5 * [true_params]))

In [None]:
stats_obs_var = np.array(stats_obs_var).squeeze()
# normalize by initial stats_obs 
stats_var_normed = ((stats_obs_var - stats_obs) / stats_obs).squeeze()

In [None]:
plt.figure(figsize=(15, 5))
titles = ['mean'] + n_samples * ['sample']
file_string = inference_param_name

for i in range(n_samples + 1): 
    if i==0: 
        plt.title('Summary stats of posterior mean and 5 samples, normalized by observed stats')
        plt.plot(stats_normed[i, ], 'o-', label='posterior', color='C0'.format(i), lw=3., ms=8.)
    else: 
        plt.plot(stats_normed[i, ], 'o-', color='C0'.format(i), alpha=0.5, label='_no_legend_')
    plt.legend()
    plt.xticks(np.arange(19), [])
plt.grid()
plt.ylim([-3., 3.])

# plot the variability with different seeds 
for i in range(5): 
    plt.plot(stats_var_normed[i, :], '*-', alpha=.7, color='C1', 
             label='true params, different seeds' if not i else '_no_legend_')
    
plt.legend()

    
plt.xticks(np.arange(19), ['ffE', 'ffI', 'rateE', 'rateI', 'kurtE', 'kurtI', 'corrE', 'corrI',
                                    'EE', 'EI', 'II',
                                    '10 EE', '10 EI', '10 IE', '10 II',
                                    '20 EE', '20 EI', '20 IE', '20 II'], rotation='vertical')

plt.tight_layout();

if save_figure and os.path.exists(path_to_save_folder): 
    addon = ''
    filename = time_str + '_{}_predictiveChecks_{}_r{}_ntrain{}_'.format(inference_method, file_string, nrounds, ntrain) + addon + '.pdf'
    plt.savefig(os.path.join(path_to_save_folder, filename))