## Plotting inference results for multidimensional posteriors 

In order to use this notebook you need data files stored in a folder `lfi-experiments/balancednetwork/data`. 

For saving the figures you should create a folder `lfi-experiments/balancednetwork/figures`

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib as mpl
import numpy as np
import pickle
import time
import scipy.stats as st
import os 
from lfimodels.balancednetwork.BalancedNetworkSimulator import BalancedNetwork
from lfimodels.balancednetwork.BalancedNetworkStats import BalancedNetworkStats, Identity
import delfi.distribution as dd

mpl_params = {'legend.fontsize': 14,
                      'axes.titlesize': 20,
                      'axes.labelsize': 17,
                      'xtick.labelsize': 12,
                      'ytick.labelsize': 12,
             'figure.figsize' : (15, 5)}

mpl.rcParams.update(mpl_params)

In [None]:
from delfi.utils.viz import probs2contours, plot_pdf

## Load data and extract posterior object

data file should be in `lfi-experiments/balancednetwork/data`

In [None]:
inference_method = 'snpe'
filename = '15162969944576585_snpe_cJieii_r3_n8000_rcl3'
file_string = 'Jieii'
weight_labels = [r'$J^{IE}$', r'$J^{II}$']
dpi = 300

save_figure = True
path_to_save_folder = '../results/' + filename

time_str = filename[:filename.find('_')]
filename.find('_')
fullname = os.path.join('../results/', filename, filename + '.p')

In [None]:
# load data 
assert os.path.exists(fullname), 'path not found: {}. data file should be in lfi-experiments/balancednetwork/data'.format(fullname)
with open(fullname, 'rb') as handle:
    result_dict = pickle.load(handle)

In [None]:
result_dict.keys()

In [None]:
# unpack values 
true_params, stats_obs, nrounds, ntrain, seed, posterior, out, trn_data, prior, posterior_list = result_dict.values()
dim_params = len(true_params)
assert dim_params > 1, 'this notebook is for inference on more than 1 parameter.'

In [None]:
# extract the posterior 
n_components = len(posterior.a)
means = [posterior.xs[c].m for c in range(n_components)]
Ss = [posterior.xs[c].S for c in range(n_components)]

# mixing coefs 
mixing_coefs = posterior.a

# construct posterior
post = dd.mixture.MoG(posterior.a, ms=means, Ss=Ss)

In [None]:
print(posterior.mean)
print(true_params)
print(Ss)

## Compare to true parameter 

We have generated the observed data ourselves so we do have the true parameter. The mean of the posterior should be close to it when evaluated for $x=x_{obs}$

In [None]:
log_flag = False
use_custom_ranges = False

# mass isolines to plot 
levels = [0.68]

# get grid of sampling points
lims = [-10., 10.]
resolution = (lims[1] - lims[0]) / 1000
n_steps = int((lims[1] - lims[0])/ resolution)
theta = np.linspace(lims[0], lims[1], n_steps)    
x, y = np.meshgrid(theta, theta)
    
# arrange samples in rows 
v = np.vstack((x.flatten(), y.flatten())).T

In [None]:
plt.figure(figsize=(15, 10))
plot_idx = 1

for i in range(dim_params): 
    for j in range(dim_params): 
        
        if i == j: 
            plt.subplot(dim_params, dim_params, plot_idx)
                        
            if not use_custom_ranges:
                # define limits from the corresponding prior 
                lims = [prior.lower[j], prior.upper[j]]
                theta = np.linspace(lims[0], lims[1], n_steps)
                
            plt.plot(theta, post.eval(x=theta.T, log=log_flag, ii=[i]), label='posterior')
            plt.axvline(x=true_params[i], color='C1', label='true ' + weight_labels[i])
            plt.legend(prop=dict(size=14))
            plt.xlabel(weight_labels[i])        
        elif i < j:                     
            
            if not use_custom_ranges: 
                # define limits from the corresponding prior
                lims_i = [prior.lower[i], prior.upper[i]]
                lims_j = [prior.lower[j], prior.upper[j]]
                
                x, y = np.meshgrid(np.linspace(lims_i[0], lims_i[1], n_steps), 
                                   np.linspace(lims_j[0], lims_j[1], n_steps))
                v = np.vstack((x.flatten(), y.flatten())).T
            else: 
                lims_i = lims 
                lims_j = lims
            
            z = post.eval(x=v, log=log_flag, ii=[i, j]).reshape(x.shape)
            
            dm = ((lims_i[1] - lims_i[0]) / 1000) * ((lims_j[1] - lims_j[0]) / 1000)
            print('mass: ', z.sum() * dm)
                
            cl = probs2contours(z.flatten(), levels=levels).reshape(x.shape)
            
            x_mask = np.logical_and(x >= lims_i[0], x <= lims_i[1])
            y_mask = np.logical_and(y >= lims_j[0], y <= lims_j[1])
            z_mask = np.logical_and(x_mask, y_mask)
            
            size = int(np.sqrt(z_mask.sum()))
            z_new = z.flatten()[z_mask.flatten()].reshape(size, size)
            x_new = x.flatten()[z_mask.flatten()].reshape(size, size)
            y_new = y.flatten()[z_mask.flatten()].reshape(size, size)
            cl_new = cl.flatten()[z_mask.flatten()].reshape(size, size)
            
            plt.subplot(dim_params, dim_params, plot_idx)
            plt.contourf(y_new, x_new, z_new)
            plt.contour(y_new, x_new, cl_new, levels)
#            plt.colorbar() 
            plt.plot([true_params[j]], [true_params[i]], 'o', color='C1')
        plot_idx += 1
plt.tight_layout()

if save_figure and os.path.exists(path_to_save_folder): 
    addon = ''
    filename = time_str + '_{}_posteriormatrix_{}_r{}_ntrain{}_e'.format(inference_method, 
                                                                         file_string, 
                                                                         nrounds, 
                                                                         ntrain) + addon + '.png'
    plt.tight_layout()
    plt.savefig(os.path.join(path_to_save_folder, filename), dpi=dpi)

## Plot summary statistics 

In [None]:
titles_short = ['ffE', 'ffI', 'rateE', 'rateI', 'kurtE', 'kurtI', 'corrE', 'corrI',
                                    'EE', 'EI', 'II',
                                    'EE', 'EI', 'IE', 'II',
                                    'EE', 'EI', 'IE', 'II']
titles = ['Fano factor E', 'Fano factor I', 'mean rate E', 'mean rate I', 
          'kurtosis E', 'kurtosis I', 'positive pairwise corr E', 'positive pairwise corr I',
          '0 lag auto corr EE', '0 lag auto corr II', '0 lag cross corr EI',  
          '10 lag auto corr EE', '10 lag cross corr EI', '10 lag cross corr IE', '10 lag auto corr II',
          '20 lag auto corr EE', '20 lag cross corr EI', '20 lag cross corr IE', '20 lag auto corr II']

## Plot stat values resulting from the initial prior sampling

In [None]:
plt.figure(figsize=(18, 10))

for stat_idx in range(19):
    
    params, stats, wtf = trn_data[0]
    
    plt.subplot(4, 5, stat_idx + 1)
    plt.scatter(x=params[:, 0], y=params[:, 1], c=stats[:, stat_idx], cmap='viridis')    
    plt.title(titles[stat_idx])
    
    cb = plt.colorbar(fraction=0.2, shrink=1.1, pad=0.1, aspect=10, orientation='vertical')
    
    #  labels 
    if stat_idx in [0, 5, 10, 15]: 
        plt.ylabel(weight_labels[1])
    if stat_idx > 14: 
        plt.xlabel(weight_labels[0])
    
plt.tight_layout()
if save_figure and os.path.exists(path_to_save_folder): 
    filename = time_str + '_{}_statsprior_{}_r{}_ntrain{}_'.format(inference_method, 
                                                                   file_string, 
                                                                   nrounds, ntrain) + '.png'
    plt.savefig(os.path.join(path_to_save_folder, filename), dpi=dpi)

### Plot summary stats over rounds 

In [None]:
n_stats = 4
n_rounds = 3
step = int(nrounds / n_rounds)
rounds = np.arange(0, nrounds, step)
n_rounds = rounds.size

plt.figure(figsize=(18, 10))
for i in range(n_stats):
       
    # iterate over rounds
    for j_idx, j in enumerate(rounds):
        # extract round data 
        r_data = trn_data[j]
        # get arrays 
        r_params, r_stats, r_wtf = r_data

        # define colorbar norm for stats 
        stats = r_stats[:, i]
        norm = mpl.colors.Normalize(vmin=stats.min(), vmax=stats.max())
        
        # add labels 
        plt.subplot(n_stats, n_rounds, (i * n_rounds) + j_idx + 1)
        if j == 0: 
            plt.ylabel(weight_labels[1])
        if i == 0: 
            plt.title('round {}'.format(j))
            
        # plot
        plt.scatter(x=r_params[:, 0], y=r_params[:, 1], c=stats, cmap='viridis', norm=norm)
        
        if j_idx == (n_rounds - 1): 
            # add colorbar for every stats
            cb = plt.colorbar(fraction=0.2, shrink=1.2, pad=0.1, aspect=10, orientation='vertical')
            cb.set_label(titles[i], fontsize=15, rotation='vertical')

        if i == (n_stats - 1): 
            plt.xlabel(weight_labels[0])
    
plt.tight_layout();
if save_figure and os.path.exists(path_to_save_folder): 
    filename = time_str + '_{}_statsOverRounds_{}_r{}_ntrain{}_'.format(inference_method, 
                                                                        file_string, 
                                                                        nrounds, 
                                                                        ntrain) + addon + '.png'
    plt.savefig(os.path.join(path_to_save_folder, filename), dpi=dpi)

## Posterior predictive checking: 

Sample from the estimated posterior and simulate with the sampled parameters.

In [None]:
seed = 2
param_names = ['wie', 'wii']
m = BalancedNetwork(inference_params=param_names, dim=2, first_port=8100,
                    verbose=True, n_servers=3, duration=3., parallel=True,
                    estimate_time=False, calculate_stats=True, seed=seed)
s = Identity(seed=seed)
# generate observed stats from true params using the same seed 
stats_obs = s.calc_all(m.gen([true_params]))

In [None]:
# generate a few samples and simulate 
n_samples = 5
params = []
# append the mean
params.append(post.mean)
for i in range(n_samples): 
    params.append(post.gen())
params

In [None]:
# simulate 
data = m.gen(params)
m.stop_server()

In [None]:
# calculate summary stats
stats = np.array(s.calc_all(data)).squeeze()
stats_normed = ((stats - stats_obs) / stats_obs).squeeze()

In [None]:
# additionally generate 5 samples using the true params but different seeds: 
m = BalancedNetwork(inference_params=param_names, dim=2, first_port=8100,
                    verbose=True, n_servers=3, duration=3., parallel=True,
                    estimate_time=False, calculate_stats=True, seed=None)
s = Identity(seed=None)
# simulate and calc stats
stats_obs_var = s.calc_all(m.gen(5 * [true_params]))

In [None]:
stats_obs_var = np.array(stats_obs_var).squeeze()
# normalize by initial stats_obs 
stats_var_normed = ((stats_obs_var - stats_obs) / stats_obs).squeeze()

In [None]:
plt.figure(figsize=(15, 5))
titles = ['mean'] + n_samples * ['sample']

for i in range(n_samples + 1): 
    if i==0: 
        plt.title('Summary stats of posterior mean and 5 samples, normalized by observed stats')
        plt.plot(stats_normed[i, ], 'o-', label='posterior', color='C0'.format(i), lw=3., ms=8.)
    else: 
        plt.plot(stats_normed[i, ], 'o-', color='C0'.format(i), alpha=0.5, label='_no_legend_')
    plt.legend()
    plt.xticks(np.arange(19), [])
plt.grid()
plt.ylim([-3., 3.])

# plot the variability with different seeds 
for i in range(5): 
    plt.plot(stats_var_normed[i, :], '*-', alpha=.7, color='C1', 
             label='true params, different seeds' if not i else '_no_legend_')
    
plt.legend()

    
plt.xticks(np.arange(19), ['ffE', 'ffI', 'rateE', 'rateI', 'kurtE', 'kurtI', 'corrE', 'corrI',
                                    'EE', 'EI', 'II',
                                    '10 EE', '10 EI', '10 IE', '10 II',
                                    '20 EE', '20 EI', '20 IE', '20 II'], rotation='vertical')

plt.tight_layout();

if save_figure and os.path.exists(path_to_save_folder): 
    addon = ''
    filename = time_str + '_{}_predictiveChecks_{}_r{}_ntrain{}_'.format(inference_method, 
                                                                         file_string, 
                                                                         nrounds, 
                                                                         ntrain) + addon + '.png'
    plt.savefig(os.path.join(path_to_save_folder, filename), dpi=dpi)