# Hierarchical Ar(1) on a Grid Test with STAN

In this notebook, we will test the compositional score matching on a hierarchical problem defined on a grid.
- The observations are on grid with `n_grid` x `n_grid` points.
- The global parameters are the same for all grid points with hyper-priors:
$$ \alpha \sim \mathcal{N}(0, 1) \quad
  \mu_\beta \sim \mathcal{N}(0, 1) \quad
  \log\text{std}_\beta \sim \mathcal{N}(0, 1);$$

- The local parameters are different for each grid point
$$ \eta_{i,j}^\text{raw} \sim \mathcal{N}(0, I), \qquad \eta_{i,j} = 2\operatorname{sigmoid}(\beta + \sigma\cdot\eta_{i,j}^\text{raw})-1$$

-  In each grid point, we have a time series of `T` observations.
$$ y_{i,j} \sim \mathcal{N}(\alpha + \eta_{i,j}y_{i,j-1}, 0.1 I), y_{i,0} \sim \mathcal{N}(0, 0.1 I)$$
- We observe $T=5$ time points for each grid point. We can also amortize over the time dimension.

In [None]:
import os
import time

import matplotlib.pyplot as plt
import numpy as np

os.environ['KERAS_BACKEND'] = 'torch'
from bayesflow import diagnostics

from experiments.problems.ar1_grid_stan import get_stan_posterior
from experiments.problems.ar1_grid import Prior
from experiments.problems import visualize_simulation_output

n_procs = 10 #int(os.environ.get('SLURM_CPUS_PER_TASK', 1))

In [None]:
N = [4*4, 32*32, 128*128][2]

prior = Prior()

global_param_names = prior.global_param_names
local_param_names = prior.get_local_param_names(N)

In [None]:
results = []
for _ in range(100):
    test = prior.sample(1, n_local_samples=N)
    results.append(test['data'][0])
results = np.stack(results, axis=0)

visualize_simulation_output(np.mean(results, axis=0))

In [None]:
from joblib import Parallel, delayed

In [None]:
@delayed
def wrapper_fun(sample_i):
    sample = prior.sample(1, n_local_samples=N)
    true_global = sample['global_params'].flatten().cpu().numpy()
    true_local = sample['local_params'].flatten().cpu().numpy()
    sim_test = sample['data'][0].cpu().numpy()
    global_posterior, local_posterior = get_stan_posterior(sim_test, sigma_noise=prior.simulator.sigma_noise)
    return global_posterior, local_posterior, true_global, true_local

In [None]:
np.random.seed(42)
if os.path.exists(f'problems/ar1/local_posterior_{N}.npy'):
    global_posterior = np.load(f'problems/ar1/global_posterior_{N}.npy')
    local_posterior = np.load(f'problems/ar1/local_posterior_{N}.npy')
    true_global = np.load(f'problems/ar1/true_global_{N}.npy')
    true_local = np.load(f'problems/ar1/true_local_{N}.npy')

else:
    start = time.time()
    out_parallel = Parallel(n_jobs=n_procs // 4, verbose=1)(wrapper_fun(i) for i in range(100))
    end = time.time()
    print(end - start)

    # make numpy arrays
    global_posterior = np.stack([out[0] for out in out_parallel], axis=0)
    local_posterior = np.stack([out[1] for out in out_parallel], axis=0).transpose(0, 2, 1)
    true_global = np.stack([out[2] for out in out_parallel], axis=0)
    true_local = np.stack([out[3] for out in out_parallel], axis=0)

    np.save(f'problems/ar1/global_posterior_{N}.npy', global_posterior)
    np.save(f'problems/ar1/local_posterior_{N}.npy', local_posterior)
    np.save(f'problems/ar1/true_global_{N}.npy', true_global)
    np.save(f'problems/ar1/true_local_{N}.npy', true_local)

In [None]:
# np.random.seed(42)
#
# global_posteriors = []
# local_posteriors = []
#
# true_global = []
# true_local = []
#
# start = time.time()
# for i in range(100):
#     print(i)
#     sample = prior.sample(1, n_local_samples=N)
#     true_global.append(sample['global_params'].flatten().cpu().numpy())
#     true_local.append(sample['local_params'].flatten().cpu().numpy())
#     sim_test = sample['data'][0].cpu().numpy()
#
#     global_posterior, local_posterior = get_stan_posterior(sim_test, sigma_noise=prior.simulator.sigma_noise)
#     global_posteriors.append(global_posterior)
#     local_posteriors.append(local_posterior)
# end = time.time()
# print(end - start)
#
# # make numpy arrays
# global_posterior = np.stack(global_posteriors, axis=0)
# local_posterior = np.stack(local_posteriors, axis=0).transpose(0, 2, 1)
# true_global = np.stack(true_global, axis=0)
# true_local = np.stack(true_local, axis=0)
#
# np.save(f'problems/ar1/global_posterior_{N}.npy', global_posterior)
# np.save(f'problems/ar1/local_posterior_{N}.npy', local_posterior)
# np.save(f'problems/ar1/true_global_{N}.npy', true_global)
# np.save(f'problems/ar1/true_local_{N}.npy', true_local)

In [None]:
# timing (100 datasets)
# 4x4: 41s / 100 = 0.41s or 20s in parallel (20*2/100=0.4s)
# 32x32: 3600s / 100 = 36s
# 128x128: 743.6min in parallel (16 parallel jobs) -> 198h

data_n = [4*4, 32*32]
time_n = np.array([0.4, 36]) / 60.

plt.plot(data_n, time_n)
plt.xlabel('Number of data points')
plt.ylabel('Execution time (min)')
plt.yscale('log')
plt.xscale('log')
plt.show()

# Plotting

In [None]:
fig = diagnostics.recovery(global_posterior, true_global, variable_names=global_param_names)
#fig.savefig(f'plots/recovery_global_STAN_n_grid{N}.png')

fig = diagnostics.calibration_ecdf(global_posterior, true_global, difference=True, variable_names=global_param_names)
#fig.savefig(f'plots/ecdf_global_STAN_n_grid{N}.png')

In [None]:
print('RMSE:', diagnostics.root_mean_squared_error(global_posterior, np.array(true_global))['values'].mean().round(2), diagnostics.root_mean_squared_error(global_posterior, np.array(true_global))['values'].std().round(2))
print('RMSE Local:', diagnostics.root_mean_squared_error(local_posterior, np.array(true_local))['values'].mean().round(2), diagnostics.root_mean_squared_error(local_posterior, np.array(true_local))['values'].std().round(2))

print('Contraction:', diagnostics.posterior_contraction(global_posterior, np.array(true_global))['values'].mean().round(2), diagnostics.posterior_contraction(global_posterior, np.array(true_global))['values'].std().round(2))
print('Contraction Local:', diagnostics.posterior_contraction(local_posterior, np.array(true_local))['values'].mean().round(2), diagnostics.posterior_contraction(local_posterior, np.array(true_local))['values'].std().round(2))

In [None]:
diagnostics.recovery(local_posterior[:, :, :10], true_local[:, :10], variable_names=local_param_names[:10]);

In [None]:
diagnostics.recovery(np.transpose(local_posterior, (0, 2, 1)).reshape(100*N, -1, 1), true_local.reshape(-1, 1), variable_names=r'$\eta$');