# Hierarchical Ar(1) on a Grid Test with STAN

In this notebook, we will test the compositional score matching on a hierarchical problem defined on a grid.
- The observations are on grid with `n_grid` x `n_grid` points.
- The global parameters are the same for all grid points with hyper-priors:
$$ \alpha \sim \mathcal{N}(0, 1) \quad
  \mu_\beta \sim \mathcal{N}(0, 1) \quad
  \log\text{std}_\beta \sim \mathcal{N}(0, 1);$$

- The local parameters are different for each grid point
$$ \eta_{i,j}^\text{raw} \sim \mathcal{N}(0, I), \qquad \eta_{i,j} = 2\operatorname{sigmoid}(\beta + \sigma\cdot\eta_{i,j}^\text{raw})-1$$

-  In each grid point, we have a time series of `T` observations.
$$ y_{i,j} \sim \mathcal{N}(\alpha + \eta_{i,j}y_{i,j-1}, 0.1 I), y_{i,0} \sim \mathcal{N}(0, 0.1 I)$$
- We observe $T=5$ time points for each grid point. We can also amortize over the time dimension.

In [None]:
import os
os.environ['KERAS_BACKEND'] = 'torch'

import time
import numpy as np
from scipy.stats import median_abs_deviation as mad

from bayesflow import diagnostics
from joblib import Parallel, delayed

from experiments.problems import get_stan_posterior
from experiments.problems import AR1GridPrior
from experiments.problems import visualize_simulation_output

n_procs = int(os.environ.get('SLURM_CPUS_PER_TASK', 10))
n_datasets = 100
n_samples = 1000

In [None]:
N = [4*4, 16*16, 128*128][0]

prior = AR1GridPrior()

global_param_names = prior.global_param_names
local_param_names = prior.get_local_param_names(N)

In [None]:
results = []
for _ in range(n_datasets):
    test = prior.sample(1, n_local_samples=N)
    results.append(test['data'][0])
results = np.stack(results, axis=0)

visualize_simulation_output(np.mean(results, axis=0))

In [None]:
@delayed
def wrapper_fun(sample_i):
    sample = prior.sample(1, n_local_samples=N)
    true_global = sample['global_params'].flatten().cpu().numpy()
    true_local = sample['local_params'].flatten().cpu().numpy()
    sim_test = sample['data'][0].cpu().numpy()
    global_posterior, local_posterior = get_stan_posterior(sim_test, sigma_noise=prior.simulator.sigma_noise)
    return global_posterior, local_posterior, true_global, true_local

In [None]:
np.random.seed(42)
if os.path.exists(f'metrics/ar1/local_posterior_{N}.npy'):
    global_posterior = np.load(f'metrics/ar1/global_posterior_{N}.npy')
    local_posterior = np.load(f'metrics/ar1/local_posterior_{N}.npy')
    true_global = np.load(f'metrics/ar1/true_global_{N}.npy')
    true_local = np.load(f'metrics/ar1/true_local_{N}.npy')
else:
    start = time.time()
    out_parallel = Parallel(n_jobs=n_procs // 4, verbose=1)(wrapper_fun(i) for i in range(n_datasets))
    end = time.time()
    print(end - start)

    # make numpy arrays
    global_posterior = np.stack([out[0] for out in out_parallel], axis=0)
    local_posterior = np.stack([out[1] for out in out_parallel], axis=0).transpose(0, 2, 1)
    true_global = np.stack([out[2] for out in out_parallel], axis=0)
    true_local = np.stack([out[3] for out in out_parallel], axis=0)

    np.save(f'metrics/ar1/global_posterior_{N}.npy', global_posterior)
    np.save(f'metrics/ar1/local_posterior_{N}.npy', local_posterior)
    np.save(f'metrics/ar1/true_global_{N}.npy', true_global)
    np.save(f'metrics/ar1/true_local_{N}.npy', true_local)

# Plotting

In [None]:
fig = diagnostics.recovery(global_posterior, true_global, variable_names=global_param_names)
fig.savefig(f'plots/ar1/recovery_global_STAN_n_grid_{N}.png')

fig = diagnostics.calibration_ecdf(global_posterior, true_global, difference=True, variable_names=global_param_names)
fig.savefig(f'plots/ar1/ecdf_global_STAN_n_grid_{N}.png')

In [None]:
global_rmse = diagnostics.metrics.root_mean_squared_error(global_posterior, np.array(true_global),
                                                          aggregation=np.median)['values'].mean().round(2)
global_rmse_mad = diagnostics.metrics.root_mean_squared_error(global_posterior, np.array(true_global),
                                                              aggregation=mad)['values'].mean().round(2)
print('Global RMSE:', global_rmse, global_rmse_mad)

global_rmse = diagnostics.metrics.root_mean_squared_error(local_posterior, np.array(true_local),
                                                          aggregation=np.median)['values'].mean().round(2)
global_rmse_mad = diagnostics.metrics.root_mean_squared_error(local_posterior, np.array(true_local),
                                                              aggregation=mad)['values'].mean().round(2)
print('Local RMSE:', global_rmse, global_rmse_mad)

global_rmse = diagnostics.posterior_contraction(global_posterior, np.array(true_global),
                                                aggregation=np.median)['values'].mean().round(2)
global_rmse_mad = diagnostics.posterior_contraction(global_posterior, np.array(true_global),
                                                    aggregation=mad)['values'].mean().round(2)
print('Global Contraction:', global_rmse, global_rmse_mad)

global_rmse = diagnostics.posterior_contraction(local_posterior, np.array(true_local),
                                                aggregation=np.median)['values'].mean().round(2)
global_rmse_mad = diagnostics.posterior_contraction(local_posterior, np.array(true_local),
                                                    aggregation=mad)['values'].mean().round(2)
print('Local Contraction:', global_rmse, global_rmse_mad)

In [None]:
fig = diagnostics.recovery(np.transpose(local_posterior, (0, 2, 1)).reshape(100*N, -1, 1),
                           true_local.reshape(-1, 1), variable_names=r'$\eta$')
fig.savefig(f'plots/ar1/recovery_local_STAN_n_grid_{N}.png')