# Hierarchical Ar(1) on a Grid Test with compositional score matching

In this notebook, we will test the compositional score matching on a hierarchical problem defined on a grid.
- The observations are on grid with `n_grid` x `n_grid` points.
- The global parameters are the same for all grid points with hyper-priors:
$$ \alpha \sim \mathcal{N}(0, 1) \quad
  \mu_\beta \sim \mathcal{N}(0, 1) \quad
  \log\text{std}_\beta \sim \mathcal{N}(0, 1);$$

- The local parameters are different for each grid point
$$ \eta_{i,j}^\text{raw} \sim \mathcal{N}(0, I), \qquad \eta_{i,j} = 2\operatorname{sigmoid}(\beta + \sigma\cdot\eta_{i,j}^\text{raw})-1$$

-  In each grid point, we have a time series of `T` observations.
$$ y_{i,j} \sim \mathcal{N}(\alpha + \eta_{i,j}y_{i,j-1}, 0.1 I), y_{i,0} \sim \mathcal{N}(0, 0.1 I)$$
- We observe $T=5$ time points for each grid point. We can also amortize over the time dimension.

In [None]:
import os

import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import median_abs_deviation as mad
import torch

os.environ['KERAS_BACKEND'] = 'torch'
from bayesflow import diagnostics

from torch.utils.data import DataLoader

from diffusion_model import HierarchicalScoreModel, SDE, ShallowSet, euler_maruyama_sampling, adaptive_sampling, \
    train_score_model
from problems import AR1GridProblem, AR1GridPrior
from problems import visualize_simulation_output

In [None]:
torch_device = torch.device("mps")

In [None]:
prior = AR1GridPrior()

# test the simulator
sim_test = prior.sample(1, n_local_samples=16, get_grid=True)['data'][0]
visualize_simulation_output(np.array(sim_test))

In [None]:
batch_size = 128
number_of_obs = 1
max_number_of_obs = number_of_obs if isinstance(number_of_obs, int) else max(number_of_obs)
current_sde = SDE(
    kernel_type='variance_preserving',
    noise_schedule='cosine'
)

dataset = AR1GridProblem(
    n_data=10000,
    prior=prior,
    sde=current_sde,
    online_learning=False,
    number_of_obs=number_of_obs,
    amortize_time=False,
    as_set=True
)

dataset_valid = AR1GridProblem(
    n_data=1000,
    prior=prior,
    sde=current_sde,
    number_of_obs=number_of_obs,
    as_set=True,
)

# Create dataloader
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
dataloader_valid = DataLoader(dataset_valid, batch_size=batch_size, shuffle=False)

for test in dataloader:
    print(test[4].shape)
    break

In [None]:
# Define diffusion model
global_summary_dim = 5
global_summary_net = ShallowSet(dim_input=5, dim_output=global_summary_dim, dim_hidden=128)

score_model = HierarchicalScoreModel(
    input_dim_theta_global=prior.n_params_global,
    input_dim_theta_local=prior.n_params_local,
    input_dim_x_global=global_summary_dim,
    input_dim_x_local=global_summary_dim,
    global_summary_net=global_summary_net if isinstance(number_of_obs, list) else None,
    hidden_dim=256,
    n_blocks=5,
    dropout_rate=0.1,
    max_number_of_obs=max_number_of_obs,
    prediction_type='v',
    sde=current_sde,
    weighting_type='likelihood_weighting',
    prior=prior,
    name_prefix=f'ar1_{max_number_of_obs}',
)

In [None]:
if not os.path.exists(f"models/ar1/{score_model.name}.pt"):
    # train model
    loss_history = train_score_model(score_model, dataloader,
                                     dataloader_valid=dataloader_valid, hierarchical=True,
                                     epochs=1000, device=torch_device)
    score_model.eval()
    torch.save(score_model.state_dict(), f"models/ar1/{score_model.name}.pt")

    # plot loss history
    plt.figure(figsize=(16, 4), tight_layout=True)
    plt.plot(loss_history[:, 0], label='Training', color="#132a70", lw=2.0, alpha=0.9)
    plt.plot(loss_history[:, 1], label='Validation', linestyle="--", marker="o", color='black')
    plt.grid(alpha=0.5)
    plt.xlabel('Training epoch #')
    plt.ylabel('Value')
    plt.legend()
    plt.savefig(f'models/ar1/{score_model.name}_loss_training.png')
else:
    score_model.load_state_dict(torch.load(f"models/ar1/{score_model.name}.pt",
                                           map_location=torch_device, weights_only=True))
    score_model.eval()

# Validation

In [None]:
n_grid = [4, 16, 128][1]
print(f'Grid size: {n_grid}x{n_grid}')
prior_dict = prior.sample(batch_size=100, n_local_samples=n_grid*n_grid)

valid_prior_global, valid_prior_local, valid_data = prior_dict['global_params'], prior_dict['local_params'], prior_dict['data']
n_post_samples = 300
global_param_names = prior.global_param_names
local_param_names = prior.get_local_param_names(n_grid*n_grid)

score_model.current_number_of_obs = 1
print(valid_data.shape, score_model.current_number_of_obs)

In [None]:
visualize_simulation_output(valid_data[0].reshape(5, n_grid, n_grid).numpy())

In [None]:
t1_value = 1/np.sqrt(2*n_grid*n_grid)
t0_value = 0.95
print(t1_value, t0_value)

sampling_arg = {
    'size': 2,
    'damping_factor': lambda t: t0_value * torch.exp(-np.log(t0_value / t1_value) * t),
}
score_model.sde.s_shift_cosine = 0.03

posterior_global_samples_valid = adaptive_sampling(score_model, valid_data,
                                                   n_post_samples=n_post_samples,
                                                   sampling_arg=sampling_arg,
                                                   device=torch_device,
                                                   #diffusion_steps=300,
                                                   #max_evals=1000,
                                                   verbose=True)

In [None]:
fig = diagnostics.recovery(posterior_global_samples_valid, np.array(valid_prior_global), variable_names=global_param_names,
                           figsize=(6, 2.5))
print('RMSE:', diagnostics.metrics.root_mean_squared_error(posterior_global_samples_valid, np.array(valid_prior_global))['values'])
fig.savefig(f'plots/ar1/recovery_global_n_grid_{n_grid}.pdf')

fig = diagnostics.calibration_ecdf(posterior_global_samples_valid, np.array(valid_prior_global),
                                   difference=True, variable_names=global_param_names, figsize=(6, 2.5), stacked=True)
leg = plt.legend()
leg.set_visible(False)
print('ECDF:', diagnostics.calibration_error(posterior_global_samples_valid, np.array(valid_prior_global))['values'])
fig.savefig(f'plots/ar1/ecdf_global_n_grid_{n_grid}.pdf')

In [None]:
score_model.sde.s_shift_cosine = 0
score_model.current_number_of_obs = 1
first_n_samples = valid_data.shape[1]
posterior_local_samples_valid = euler_maruyama_sampling(score_model, valid_data[:, :first_n_samples],
                                                        conditions=posterior_global_samples_valid,
                                                        n_post_samples=n_post_samples,
                                                        diffusion_steps=100, device=torch_device,
                                                        verbose=True)

posterior_local_samples_valid = score_model.prior.transform_local_params(beta=posterior_global_samples_valid[..., 1:2], eta_raw=posterior_local_samples_valid[..., 0])

In [None]:
diagnostics.recovery(posterior_local_samples_valid.reshape(valid_data.shape[0], n_post_samples, -1),
                          np.array(valid_prior_local)[:, :first_n_samples],
                          variable_names=local_param_names[:first_n_samples]);

In [None]:
global_rmse = diagnostics.metrics.root_mean_squared_error(posterior_global_samples_valid, np.array(valid_prior_global),
                                                          aggregation=np.median)['values'].mean().round(2)
global_rmse_mad = diagnostics.metrics.root_mean_squared_error(posterior_global_samples_valid, np.array(valid_prior_global),
                                                              aggregation=mad)['values'].mean().round(2)
print('Global RMSE:', global_rmse, global_rmse_mad)

posterior_local_samples_reshaped = posterior_local_samples_valid.reshape(posterior_global_samples_valid.shape[0], n_post_samples, -1)
valid_prior_local_ = np.array(valid_prior_local)[:, :first_n_samples]

global_rmse = diagnostics.metrics.root_mean_squared_error(posterior_local_samples_reshaped, valid_prior_local_,
                                                          aggregation=np.median)['values'].mean().round(2)
global_rmse_mad = diagnostics.metrics.root_mean_squared_error(posterior_local_samples_reshaped, valid_prior_local_,
                                                              aggregation=mad)['values'].mean().round(2)
print('Local RMSE:', global_rmse, global_rmse_mad)

global_rmse = diagnostics.posterior_contraction(posterior_global_samples_valid, np.array(valid_prior_global),
                                                aggregation=np.median)['values'].mean().round(2)
global_rmse_mad = diagnostics.posterior_contraction(posterior_global_samples_valid, np.array(valid_prior_global),
                                                    aggregation=mad)['values'].mean().round(2)
print('Global Contraction:', global_rmse, global_rmse_mad)

global_rmse = diagnostics.posterior_contraction(posterior_local_samples_reshaped, valid_prior_local_,
                                                aggregation=np.median)['values'].mean().round(2)
global_rmse_mad = diagnostics.posterior_contraction(posterior_local_samples_reshaped, valid_prior_local_,
                                                    aggregation=mad)['values'].mean().round(2)
print('Local Contraction:', global_rmse, global_rmse_mad)

# Inference-Time Hyperparameter Optimization

In [None]:
import optuna

def decay(t, d0, d1, alpha, beta):
    shaped = 1 - (1 - t.pow(alpha)).pow(beta)
    return d0 + (d1 - d0) * shaped

def objective(trial):
    t0_value = trial.suggest_float('t0_value', 1e-3, 1)
    t1_value = trial.suggest_float('t1_value', 1e-5, 1e-3)
    s_shift_cosine = trial.suggest_float('s_shift_cosine', 0, 2)
    alpha = trial.suggest_float('alpha', 0.3, 2)
    beta = trial.suggest_float('beta', 0.3, 2)

    sampling_arg = {
        'size': 2,
        'damping_factor': lambda t: decay(t, d0=t0_value, d1=t1_value, alpha=alpha, beta=beta),
    }
    score_model.sde.s_shift_cosine = s_shift_cosine

    test_global_samples = adaptive_sampling(score_model, valid_data,
                                            n_post_samples=n_post_samples,
                                            sampling_arg=sampling_arg,
                                            max_evals=1000,
                                            device=torch_device, verbose=False)
    if np.isnan(test_global_samples).any():
        return np.inf

    rmse = diagnostics.metrics.root_mean_squared_error(test_global_samples, np.array(valid_prior_global))['values'].mean()
    cerror = diagnostics.calibration_error(test_global_samples, np.array(valid_prior_global))['values'].mean()
    return rmse + cerror

study = optuna.create_study()
study.optimize(objective, n_trials=20)
study.best_params

In [None]:
# plot the optimal decay function
t = torch.linspace(0, 1, 100)
dt = decay(t, d0=study.best_params['t0_value'], d1=study.best_params['t1_value'],
           alpha=study.best_params['alpha'], beta=study.best_params['beta'])
plt.plot(t, dt, label='Decay function')
plt.legend()
plt.xlabel('t')
plt.ylabel('Damping factor')
plt.show()

print(study.best_params)

In [None]:
sampling_arg = {
    'size': 2,
    'damping_factor': lambda t: decay(t,
                                      d0=study.best_params['t0_value'],
                                      d1=study.best_params['t1_value'],
                                      alpha=study.best_params['alpha'],
                                      beta=study.best_params['beta']),
}
score_model.sde.s_shift_cosine = study.best_params['s_shift_cosine']

posterior_global_samples_valid = adaptive_sampling(score_model, valid_data,
                                                   n_post_samples=n_post_samples,
                                                   sampling_arg=sampling_arg,
                                                   device=torch_device,
                                                   verbose=True)

In [None]:
fig = diagnostics.recovery(posterior_global_samples_valid, np.array(valid_prior_global), variable_names=global_param_names,
                           figsize=(6, 2.5))
print('RMSE:', diagnostics.metrics.root_mean_squared_error(posterior_global_samples_valid, np.array(valid_prior_global))['values'])
#fig.savefig(f'plots/ar1/recovery_global_n_grid_{n_grid}.pdf')

fig = diagnostics.calibration_ecdf(posterior_global_samples_valid, np.array(valid_prior_global),
                                   difference=True, variable_names=global_param_names, figsize=(6, 2.5), stacked=True)
leg = plt.legend()
leg.set_visible(False)
print('ECDF:', diagnostics.calibration_error(posterior_global_samples_valid, np.array(valid_prior_global))['values'])
#fig.savefig(f'plots/ar1/ecdf_global_n_grid_{n_grid}.pdf')