In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl
%matplotlib inline
import delfi.distribution as dd
import numpy as np
import pickle
import time
import scipy.stats as st
import os 
from lfimodels.balancednetwork.BalancedNetworkSimulator import BalancedNetwork
from lfimodels.balancednetwork.BalancedNetworkStats import BalancedNetworkStats, Identity
import matplotlib.colors as colors

mpl_params = {'legend.fontsize': 14,
                      'axes.titlesize': 20,
                      'axes.labelsize': 17,
                      'xtick.labelsize': 12,
                      'ytick.labelsize': 12,
             'figure.figsize' : (15, 5)}

mpl.rcParams.update(mpl_params)

In [None]:
save_figure = True

fileformat = '.png'
dpi = 300

# set name to find the folder 
simulation_name = '1513042700676306_bruteforce_n6561'
path_to_save_folder = os.path.join('results', simulation_name)

In [None]:
time_str = simulation_name[:simulation_name.find('_')]
fullname = os.path.join(path_to_save_folder, simulation_name + '.p')


# load data 
with open(fullname, 'rb') as handle:
    result_dict = pickle.load(handle)

In [None]:
true_params, stats, data, params = result_dict.values()

In [None]:
# simulate forward and calculate true stats 
m = BalancedNetwork(inference_params=['wxy'], n_servers=1, duration=3., first_port=8010,
                    calculate_stats=True, dim=4)
params_list = [true_params]
true_stats = m.gen(params_list)[0][0]['data']

In [None]:
pmat = np.array(params).reshape(9, 9, 9, 9, 4)
smat = np.array(stats).reshape(9, 9, 9, 9, 19)
# calculate mean squared error between simulated and true summary stats
mse = ((smat - true_stats)**2).mean(axis=4)

In [None]:
n_steps = 9

In [None]:
wee, wei, wie, wii = true_params
idx_wee, idx_wei, idx_wie, idx_wii = 3, 5, 1, 5

In [None]:
opt_idx = np.unravel_index(np.argmin(mse), mse.shape)
opt_stats = smat[opt_idx]
truly_opt_stats = smat[idx_wee, idx_wei, idx_wie, idx_wii]
opt_params = pmat[opt_idx]
print(opt_params, opt_idx)

## Looking at all 6 combinations of 2D inference problems and there error landscapes: 

In [None]:
mses_2D = np.zeros((9, 9, 6))
xlabels = ['wei', 'wie', 'wii', 'wie', 'wii', 'wii']
ylabels = ['wee', 'wee', 'wee', 'wei', 'wei', 'wie']
mses_2D[:, :, 0] = mse[:, :, opt_idx[2], opt_idx[3]]
mses_2D[:, :, 1] = mse[:, opt_idx[1], :, opt_idx[3]]
mses_2D[:, :, 2] = mse[:, opt_idx[1], opt_idx[2], :]
mses_2D[:, :, 3] = mse[opt_idx[0], :, :, opt_idx[3]]
mses_2D[:, :, 4] = mse[opt_idx[0], :, opt_idx[2], :]
mses_2D[:, :, 5] = mse[opt_idx[0], opt_idx[1], :, :]

# common color norm 
norm = colors.LogNorm(vmin=np.min(mse), vmax=np.max(mse))


plt.figure(figsize=(15, 10))
for idx in range(6): 
    plt.subplot(2, 3, idx + 1)
    plt.imshow(mses_2D[:, :, idx], norm = norm)
    plt.ylabel(ylabels[idx])
    plt.xlabel(xlabels[idx])
    plt.colorbar()
    
#plt.suptitle('MSE landscapes of all six combinations of 2D inference problems.'); 
plt.tight_layout()
plt.savefig('figures/2d_mse_landscapes.png', dpi=300)

## There is a clear minimum - at a cliff

We see that there is a clear minimun in the brute force search in the 2D space for each combination of weight pairs. . However, it is often right at a cliff in the mse landscape indicating that for slightly different parameter combinations the network produced largely different summary statistics. This could make the inference problem very hard. The flattest landscape around the minimum seems to occur for the combination of $w_{ei}$ and $w_{ii}$. So it might be a good start to try the inference on this combination of weights. 

## Trying to visualize 3 dimensions at once

In [None]:
from mpl_toolkits.mplot3d import Axes3D

X = np.linspace(0.009, 0.049, n_steps)
Y = np.linspace(0.02, 0.06, n_steps)
Z = np.linspace(0.009, 0.049, n_steps)

X, Y, Z = np.meshgrid(X, Y, Z)

In [None]:
fig = plt.figure()
ax = fig.gca(projection='3d')
surf = ax.scatter(X, Y, Z, c=mse[:, :, :, 7], cmap='jet')

In [None]:
np.min(mse[:, :, 2, :])

In [None]:
jiis = np.linspace(0.032, 0.072, n_steps)
# define a common norm 
norm = colors.LogNorm(vmin=np.min(mse[:, :, 2, :]), vmax=np.max(mse[:, :, 2, :]))

plt.figure(figsize=(15, 15))
for idx, w in enumerate(jiis): 
    plt.subplot(3, 3, idx + 1)
    plt.imshow(mse[:, :, 2, idx], norm=norm)
    plt.title('w={}'.format(round(jiis[idx], 3)))
    plt.colorbar();

In [None]:
se_opt = (opt_stats - true_stats)**2
se_truly_opt = (truly_opt_stats - true_stats)**2

In [None]:
plt.plot(true_stats, 'o-', label='opt')
plt.plot(opt_stats, 'o-', label='true')
plt.plot(truly_opt_stats, 'o-', label='truly_opt')
plt.legend()
plt.title('Summary statistics of observed data, the MSE optimum and the theoretical optimum');

In [None]:
plt.bar(np.arange(19), se_opt, label='opt', alpha=.5)
plt.bar(np.arange(19), se_truly_opt, label='theoretical opt', alpha=.5)
plt.legend()
plt.ylabel('Squared error')
plt.title('SE of the actual minimum vs. that of the true parameter idx');

In theory the squared error between the observed stats and the entry in the stats matrix at the true parameters should be zero because they are based on the same simulation parameter. However, due to randomness in the simulation this is not the case. The small differences in the indices of the optimum are therefore explainable by the noise in the system. 

## Look at the optimum along the $w_{ie}$ dimension
This is the inference problem that seems to be very difficult in 1D already. 

In [None]:
plt.ylabel('MSE')
plt.xlabel(r'$w_{ie}$')
plt.semilogy(mse[idx_wee, idx_wei, :, idx_wii], 'o-');

It shows a clear minimum at the true value of the parameter. However, one can already see the strong MSE at index 0. During inference the prior reached down to 0.008 covering a whole range of values for which the network goes crazy. This might result in the strong uncertainty in the resulting posteriors. 