In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl
%matplotlib inline
import delfi.distribution as dd
import numpy as np
import pickle
import time
import scipy.stats as st
import os 
from lfimodels.balancednetwork.BalancedNetworkSimulator import BalancedNetwork
from lfimodels.balancednetwork.BalancedNetworkStats import BalancedNetworkStats

mpl_params = {'legend.fontsize': 12,
                      'axes.titlesize': 20,
                      'axes.labelsize': 17,
                      'xtick.labelsize': 12,
                      'ytick.labelsize': 12,
             'figure.figsize' : (15, 5)}

mpl.rcParams.update(mpl_params)

In [None]:
path_to_save_folder = '../figures/'
save_figure = True

inference_method = 'snpe'
fileformat = '.png'
dpi = 300
filenames = ['15162524974308438_snpe_cJie_r5_n1000_rcl5_seed1', 
             '15163027842632177_snpe_cJie_r5_n1000_rcl5_seed1',
             '15162316221014335_snpe_cJie_r5_n1000_rcl5_seed2', 
            '15162664409937313_snpe_cJie_r5_n1000_rcl5_seed3', 
            '15162902757021925_snpe_cJie_r5_n1000_rcl5_seed4']

titles = [r'inference on $w_{ie}$, seed 1',
          r'inference on $w_{ie}$, seed 1',
          r'inference on $w_{ie}$, seed 2',
          r'inference on $w_{ie}$, seed 3', 
          r'inference on $w_{ie}$, seed 4']
"""
titles = [r'2 rounds á 2500 samples, CL on', 
         '2 rounds á 2500 samples, CL off', 
         '5 rounds á 1000 samples, CL after round 3', 
         '10 rounds á 500 samples, CL after round 5']
"""

param_names = [r'$w_{ie}$'] * len(filenames)
#param_names = [r'$w_{ee}$', '$w_{ei}$', r'$w_{ie}$', r'$w_{ii}$']

## Plot single figure with posterior over rounds 

In [None]:
plt.figure(figsize=(15, 10))
round_cl = 5

for idx, filename in enumerate(filenames): 
    time_str = filename[:filename.find('_')]
    fullname = os.path.join('../results', filename, filename + '.p')
    param_name = param_names[idx]

    # load data 
    with open(fullname, 'rb') as handle:
        result_dict = pickle.load(handle)

#    print(result_dict.keys())
    # unpack values 
    try: 
        true_params, stats_obs, nrounds, ntrain, seed, posterior, out, trn_data, prior, posteriors, svi = result_dict.values()
        theta = np.linspace(prior.lower[0], prior.upper[0], 1000)
    except ValueError:
        try: 
            true_params, stats_obs, nrounds, ntrain, posterior, out, trn_data, prior = result_dict.values()            
            theta = np.linspace(prior.lower[0], prior.upper[0], 1000)
        except ValueError: 
            true_params, stats_obs, nrounds, ntrain, posterior, out, trn_data= result_dict.values()            
            theta = np.linspace(0, 5, 1000)

    nrounds = len(out)

    
    plt.subplot(5, 1, idx + 1)
    
    theta = np.linspace(0.99 * true_params[0], 1.01 * true_params[0], 1000)
    #plt.ylim([-20, 370])
    plt.title(titles[idx])
            
    for idx2, posterior in enumerate(posteriors): 
        n_components = len(posterior.a)
        means = [posterior.xs[c].m for c in range(n_components)]
        Ss = [posterior.xs[c].S for c in range(n_components)]

        sub_means = [[means[c][0]] for c in range(n_components)]
        sub_cov = np.asarray([Ss[c] for c in range(n_components)])
        pdf = dd.mixture.MoG(a=posterior.a, ms=sub_means, Ss=sub_cov)
        post_pdf = pdf.eval(theta[:, np.newaxis], log=False)

        # final posterior 
        if idx2 == nrounds - 1: 
            plt.plot(theta, post_pdf, lw=3., label='round {}, final'.format(idx2 + 1))
        # posteriors with continual 
        elif idx2 > round_cl - 1: 
            plt.plot(theta, post_pdf, label='round {}, cl'.format(idx2 + 1), alpha=.7)
        # other posteriors 
        else: 
            plt.plot(theta, post_pdf, label='round {}'.format(idx2 + 1), alpha=.7)
        
            
    plt.xlabel(param_name)

    plt.ylabel(r'$\hat{p}( \theta | x=x_{o})$')
    
    plt.axvline(x=true_params[0], label=r'observed {}'.format(param_name), linestyle='--', color='k')
    if idx == 0: 
        plt.legend()
    
if len(filenames) > 1: 
    filename = time_str + '_combined_results_wie_seeds'
    
plt.tight_layout()

if save_figure and os.path.exists(path_to_save_folder): 
    plt.savefig(os.path.join(path_to_save_folder, filename + fileformat), dpi=dpi);

## Plotting the summary stats 

In [None]:
plt.figure(figsize=(15, 10))
p = trn_data[0][0].flatten()
ss = trn_data[0][1]
titles = ['Fano factor', 'mean rate', 'kurtosis', 'positive pairwise corr', 
          '0 lag auto/cross corr', '10 lag auto/cross corr', '20 lag auto/cross corr']
ylabels = ['ff', 'rate', 'kurt', 'corr prop',
                                    'corr0',
                                    'corr10',
                                    'corr20']

labels = ['ffE', 'ffI', 'rateE', 'rateI', 'kurtE', 'kurtI', 'corrE', 'corrI',
                                    'EE', 'EI', 'II',
                                    'EE', 'EI', 'IE', 'II',
                                    'EE', 'EI', 'IE', 'II']
plot_idx = [1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7]

for idx, s in enumerate(ss.T): 
    plt.subplot(3, 3, plot_idx[idx])
    plt.plot(p, s, 'o', label=labels[idx], alpha=.6)
    plt.title(titles[plot_idx[idx] - 1])
    plt.ylabel(ylabels[plot_idx[idx] - 1])
    if idx > 7: 
        plt.xlabel(param_name)
    plt.legend()
    plt.axvline(x=true_params[0])
plt.tight_layout()

filename = time_str + '_{}_summary_stats_r{}_ntrain{}'.format(inference_method, nrounds, ntrain)
if save_figure and os.path.exists(path_to_save_folder): 
    plt.savefig(os.path.join(path_to_save_folder, filename + fileformat), dpi=200);

In [None]:
posterior.mean