## Plotting inference results for multidimensional posteriors 

In order to use this notebook you need data files stored in a folder `lfi-experiments/balancednetwork/results/filename`. 

Figures will be saved in that folder as well. 

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib as mpl
import numpy as np
import pickle
import time
import scipy.stats as st
import os 
from lfimodels.balancednetwork.BalancedNetworkSimulator import BalancedNetwork
from lfimodels.balancednetwork.BalancedNetworkStats import BalancedNetworkStats
import delfi.distribution as dd

mpl_params = {'legend.fontsize': 14,
                      'axes.titlesize': 20,
                      'axes.labelsize': 17,
                      'xtick.labelsize': 12,
                      'ytick.labelsize': 12,
             'figure.figsize' : (15, 5)}

mpl.rcParams.update(mpl_params)

In [None]:
from delfi.utils.viz import probs2contours, plot_pdf

## Load data and extract posterior object

data file should be in `lfi-experiments/balancednetwork/data`

In [None]:
inference_method = 'snpe'
filenames = ['15147527039953778_snpe_Jeeei_r10_n4000_rcl5', 
             '15149090519689238_snpe_Jeeie_r10_n4000_rcl5', 
             '15136776630205421_snpe_Jeeii_r5_n3000_rcl3', 
             '15141421200117958_snpe_Jeiie_r10_n4000_rcl5', 
             '15137875951938703_snpe_Jeiii_r10_n2000_rcl5', 
             '1513918570360665_snpe_Jieii_r10_n2000_rcl5']

n_files = len(filenames)

weight_labels = [r'$J^{EE}$', r'$J^{EI}$', r'$J^{IE}$', r'$J^{II}$']

save_figure = True
path_to_save_folder = '../results/'

time_str = time.strftime('%Y%m%d%H%M')

In [None]:
# set params for plotting 
log_flag = False
use_custom_ranges = False

# mass isolines to plot 
levels = [0.68]

# get grid of sampling points
lims = [0., .1]
resolution = (lims[1] - lims[0]) / 1000
n_steps = int((lims[1] - lims[0])/ resolution)
theta = np.linspace(lims[0], lims[1], n_steps)    
x, y = np.meshgrid(theta, theta)
    
# arrange samples in rows 
v = np.vstack((x.flatten(), y.flatten())).T

In [None]:
def plot_1D_posterior(prior, param_idx, lims, n_steps, post, true_params, weight_label): 
    # define limits from the corresponding prior 
    theta = np.linspace(lims[0], lims[1], n_steps)

    plt.plot(theta, post.eval(x=theta.T, log=log_flag, ii=[param_idx]), label='posterior')
    plt.axvline(x=true_params[param_idx], color='C1', label='true ' + weight_label)
    plt.legend()
    

def plot_2D_posterior(lims_i, lims_j, x, y, v, post, true_params, plot_idx, levels, subplotsize=4): 
    z = post.eval(x=v, log=log_flag, ii=[0, 1]).reshape(x.shape)
    
    dm = ((lims_i[1] - lims_i[0]) / 1000) * ((lims_j[1] - lims_j[0]) / 1000)
    print('mass: ', z.sum() * dm)
    
    cl = probs2contours(z.flatten(), levels=levels).reshape(x.shape)
    
    x_mask = np.logical_and(x >= lims_i[0], x <= lims_i[1])
    y_mask = np.logical_and(y >= lims_j[0], y <= lims_j[1])
    z_mask = np.logical_and(x_mask, y_mask)

    size = int(np.sqrt(z_mask.sum()))
    z_new = z.flatten()[z_mask.flatten()].reshape(size, size)
    x_new = x.flatten()[z_mask.flatten()].reshape(size, size)
    y_new = y.flatten()[z_mask.flatten()].reshape(size, size)
    cl_new = cl.flatten()[z_mask.flatten()].reshape(size, size)

    plt.subplot(subplotsize, subplotsize, plot_idx)
    plt.contourf(y_new, x_new, z_new)
    plt.contour(y_new, x_new, cl_new, levels, linewidths=1, colors=['C1'])
    plt.plot([true_params[1]], [true_params[0]], 'x', color='C1', markersize=7)    

## Big loop over files 

In [None]:
plt.figure(figsize=(15, 10))
plot_idx = 1

# we assume the order of files to be: EEEI, EEIE, EEII, EIIE, EIII, IEII
#weight_label_indices = [0, ]
diag_idx = 0
diag_param_idx = [0, 0, 0, 1]
diag_file_idx = [0, 3, 4, 5]
# plot indices for off diagonals 
plot_idxs = [2, 3, 4, 7, 8, 12]

for file_idx, filename in enumerate(filenames): 
    fullname = os.path.join('../results/', filename, filename + '.p')
    
    # load data 
    assert os.path.exists(fullname), 'path not found: {}.'.format(fullname)
    with open(fullname, 'rb') as handle:
        result_dict = pickle.load(handle)
    
    # unpack values 
    true_params, stats_obs, nrounds, ntrain, posterior, out, trn_data, prior, posterior_list = result_dict.values()
    dim_params = len(true_params)
    assert dim_params > 1, 'this notebook is for inference on more than 1 parameter.'
    
    # extract the posterior 
    n_components = len(posterior.a)
    means = [posterior.xs[c].m for c in range(n_components)]
    Ss = [posterior.xs[c].S for c in range(n_components)]

    # mixing coefs 
    mixing_coefs = posterior.a

    # construct posterior
    post = dd.mixture.MoG(posterior.a, ms=means, Ss=Ss)        
    
    # first plot the diagonal with the 1D posteriors of individual weights 
    if file_idx in diag_file_idx:  
        plt.subplot(4, 4, diag_idx * 5 + 1)
        
        # define limits from the corresponding prior, it is always the first param idx 
        param_idx = diag_param_idx[diag_idx]
        lims = [prior.lower[param_idx] - 0.05, prior.upper[param_idx] + .05]
        
        plot_1D_posterior(prior, param_idx, lims, n_steps, post, true_params, weight_labels[diag_idx])
        diag_idx += 1
    
    # then plot the off-diagonal 2D posteriors 
    # define limits from the corresponding prior
    lims_i = [prior.lower[0], prior.upper[0]]
    lims_j = [prior.lower[1], prior.upper[1]]

    x, y = np.meshgrid(np.linspace(lims_i[0], lims_i[1], n_steps), 
                       np.linspace(lims_j[0], lims_j[1], n_steps))
    v = np.vstack((x.flatten(), y.flatten())).T
    
    plot_2D_posterior(lims_i, lims_j, x, y, v, post, true_params, plot_idxs[file_idx], levels, subplotsize=4)
    
plt.tight_layout()

if save_figure and os.path.exists(path_to_save_folder): 
    addon = ''
    filename = time_str + '_2Dweightinference_summary_{}'.format(inference_method) + addon + '.pdf'
    plt.tight_layout()
    plt.savefig(os.path.join(path_to_save_folder, filename))