# Hierarchical Ar(1) on a Grid Test with compositional score matching

In this notebook, we will test the compositional score matching on a hierarchical problem defined on a grid.
- The observations are on grid with `n_grid` x `n_grid` points.
- The global parameters are the same for all grid points with hyper-priors:
$$ \alpha \sim \mathcal{N}(0, 1) \quad
  \mu_\beta \sim \mathcal{N}(0, 1) \quad
  \log\text{std}_\beta \sim \mathcal{N}(0, 1);$$

- The local parameters are different for each grid point
$$ \eta_{i,j}^\text{raw} \sim \mathcal{N}(0, I), \qquad \eta_{i,j} = 2\operatorname{sigmoid}(\beta + \sigma\cdot\eta_{i,j}^\text{raw})-1$$

-  In each grid point, we have a time series of `T` observations.
$$ y_{i,j} \sim \mathcal{N}(\alpha + \eta_{i,j}y_{i,j-1}, 0.1 I), y_{i,0} \sim \mathcal{N}(0, 0.1 I)$$
- We observe $T=5$ time points for each grid point. We can also amortize over the time dimension.

In [None]:
import os

import matplotlib.pyplot as plt
import numpy as np
import torch

os.environ['KERAS_BACKEND'] = 'torch'
from bayesflow import diagnostics

from torch.utils.data import DataLoader

from diffusion_model import HierarchicalScoreModel, SDE, ShallowSet, euler_maruyama_sampling, adaptive_sampling, \
    train_score_model, probability_ode_solving
from problems.ar1_grid import AR1GridProblem, Prior
from problems import visualize_simulation_output

In [None]:
torch_device = torch.device("mps")

In [None]:
prior = Prior()

# test the simulator
sim_test = prior.sample(1, n_local_samples=16, get_grid=True)['data'][0]
visualize_simulation_output(np.array(sim_test))

In [None]:
batch_size = 128
number_of_obs = 1  # [1, 4, 8, 16, 64, 128]  # or a list
max_number_of_obs = number_of_obs if isinstance(number_of_obs, int) else max(number_of_obs)
current_sde = SDE(
    kernel_type='variance_preserving',
    noise_schedule='cosine'
)

dataset = AR1GridProblem(
    n_data=10000,
    prior=prior,
    sde=current_sde,
    online_learning=True,
    number_of_obs=number_of_obs,
    amortize_time=False,
    as_set=True
)

dataset_valid = AR1GridProblem(
    n_data=1000,
    prior=prior,
    sde=current_sde,
    number_of_obs=number_of_obs,
    as_set=True,
)

# Create dataloader
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
dataloader_valid = DataLoader(dataset_valid, batch_size=batch_size, shuffle=False)

for test in dataloader:
    print(test[4].shape)
    break

In [None]:
# Define diffusion model
global_summary_dim = 5
global_summary_net = ShallowSet(dim_input=5, dim_output=global_summary_dim, dim_hidden=128)

score_model = HierarchicalScoreModel(
    input_dim_theta_global=prior.n_params_global,
    input_dim_theta_local=prior.n_params_local,
    input_dim_x_global=global_summary_dim,
    input_dim_x_local=global_summary_dim,
    global_summary_net=global_summary_net if isinstance(number_of_obs, list) else None,
    hidden_dim=256,
    n_blocks=5,
    dropout_rate=0.1,
    max_number_of_obs=max_number_of_obs,
    prediction_type='v',
    sde=current_sde,
    weighting_type='likelihood_weighting',
    prior=prior,
    name_prefix=f'ar1_5_{max_number_of_obs}',
)

# make dir for plots
if not os.path.exists(f"plots/{score_model.name}"):
    os.makedirs(f"plots/{score_model.name}")

In [None]:
if not os.path.exists(f"models/{score_model.name}.pt"):
    # train model
    loss_history = train_score_model(score_model, dataloader, dataloader_valid=dataloader_valid, hierarchical=True,
                                                  epochs=3000, device=torch_device)
    score_model.eval()
    torch.save(score_model.state_dict(), f"models/{score_model.name}.pt")

    # plot loss history
    plt.figure(figsize=(16, 4), tight_layout=True)
    plt.plot(loss_history[:, 0], label='Training', color="#132a70", lw=2.0, alpha=0.9)
    plt.plot(loss_history[:, 1], label='Validation', linestyle="--", marker="o", color='black')
    plt.grid(alpha=0.5)
    plt.xlabel('Training epoch #')
    plt.ylabel('Value')
    plt.legend()
    plt.savefig(f'plots/{score_model.name}/loss_training.png')
else:
    score_model.load_state_dict(torch.load(f"models/{score_model.name}.pt", map_location=torch_device, weights_only=True))
    score_model.eval()

# Validation

In [None]:
n_grid = 128
prior_dict = prior.sample(batch_size=100, n_local_samples=n_grid*n_grid)

valid_prior_global, valid_prior_local, valid_data = prior_dict['global_params'], prior_dict['local_params'], prior_dict['data']
n_post_samples = 300
global_param_names = prior.global_param_names
local_param_names = prior.get_local_param_names(n_grid*n_grid)

score_model.current_number_of_obs = 1
#valid_data = valid_data.reshape(100, n_grid*n_grid, 5, 1)
print(valid_data.shape, score_model.current_number_of_obs)

In [None]:
visualize_simulation_output(prior.normalize_data(valid_data[0]).reshape(5, n_grid, n_grid).numpy())

In [None]:
visualize_simulation_output(valid_data[0].reshape(5, n_grid, n_grid).numpy())

In [None]:
t1_value = 0.044240666535491316
t0_value = 1

sampling_arg = {
    'size': 2,
    #'damping_factor': lambda t: (1-torch.ones_like(t)) * 1e-5 + 1e-3,
    'damping_factor': lambda t: t0_value * torch.exp(-np.log(t0_value / t1_value) * t),
    'MC-dropout': False
}
score_model.sde.s_shift_cosine = 0#3.71213313557092-2

In [None]:
posterior_global_samples_valid = euler_maruyama_sampling(score_model, valid_data,
                                                         n_post_samples=n_post_samples,
                                                         sampling_arg=sampling_arg,
                                                         diffusion_steps=300, device=torch_device,
                                                         verbose=True)

In [None]:
fig = diagnostics.recovery(posterior_global_samples_valid, np.array(valid_prior_global), variable_names=global_param_names,
                           figsize=(6, 2.5))
print('RMSE:', diagnostics.root_mean_squared_error(posterior_global_samples_valid, np.array(valid_prior_global))['values'])
#fig.savefig(f'plots/{score_model.name}/recovery_n_grid_{n_grid}.pdf')

fig = diagnostics.calibration_ecdf(posterior_global_samples_valid, np.array(valid_prior_global),
                                   difference=True, variable_names=global_param_names, figsize=(6, 2.5), stacked=True)
leg = plt.legend()
leg.set_visible(False)
print('ECDF:', diagnostics.calibration_error(posterior_global_samples_valid, np.array(valid_prior_global))['values'])
#fig.savefig(f'plots/{score_model.name}/ecdf_n_grid_{n_grid}.pdf')

In [None]:
score_model.sde.s_shift_cosine = 0
score_model.current_number_of_obs = 1
posterior_local_samples_valid = euler_maruyama_sampling(score_model, valid_data[:, :12],
                                                        conditions=posterior_global_samples_valid,
                                                        n_post_samples=n_post_samples,
                                                        diffusion_steps=100, device=torch_device,
                                                        verbose=True)

posterior_local_samples_valid = score_model.prior.transform_local_params(beta=posterior_global_samples_valid[..., 1:2], eta_raw=posterior_local_samples_valid[..., 0])

In [None]:
diagnostics.recovery(posterior_local_samples_valid.reshape(valid_data.shape[0], n_post_samples, -1),
                          np.array(valid_prior_local)[:, :12],
                          variable_names=local_param_names[:12]);

In [None]:
print('RMSE:', diagnostics.root_mean_squared_error(posterior_global_samples_valid, np.array(valid_prior_global))['values'].mean().round(2), diagnostics.root_mean_squared_error(posterior_global_samples_valid, np.array(valid_prior_global))['values'].std().round(2))
print('RMSE Local:', diagnostics.root_mean_squared_error(posterior_local_samples_valid.reshape(posterior_global_samples_valid.shape[0], n_post_samples, -1), np.array(valid_prior_local)[:, :12])['values'].mean().round(2),  diagnostics.root_mean_squared_error(posterior_local_samples_valid.reshape(posterior_global_samples_valid.shape[0], n_post_samples, -1), np.array(valid_prior_local)[:, :12])['values'].std().round(2))

print('Contraction:', diagnostics.posterior_contraction(posterior_global_samples_valid, np.array(valid_prior_global))['values'].mean().round(2), diagnostics.posterior_contraction(posterior_global_samples_valid, np.array(valid_prior_global))['values'].std().round(2))
print('Contraction Local:', diagnostics.posterior_contraction(posterior_local_samples_valid.reshape(posterior_global_samples_valid.shape[0], n_post_samples, -1), np.array(valid_prior_local)[:, :12])['values'].mean().round(2), diagnostics.posterior_contraction(posterior_local_samples_valid.reshape(posterior_global_samples_valid.shape[0], n_post_samples, -1), np.array(valid_prior_local)[:, :12])['values'].std().round(2))

# Compare to STAN

First, you need to run the notebook `ar(1) STAN.ipynb` to generate the STAN posterior samples.

In [None]:
N = [4*4, 32*32, 128*128][2]
if N > 32*32:
    test_data = valid_data
    true_global = np.array(valid_prior_global)
    true_local = np.array(valid_prior_local)
else:
    global_posterior_stan = np.load(f'problems/ar1/global_posterior_{N}.npy')
    local_posterior_stan = np.load(f'problems/ar1/local_posterior_{N}.npy')
    true_global = np.load(f'problems/ar1/true_global_{N}.npy')
    true_local = np.load(f'problems/ar1/true_local_{N}.npy')

    n_grid_stan = int(np.sqrt(true_local.shape[1]))
    test_data = []
    for g, l in zip(true_global, true_local):
        sim_dict = {'alpha': g[0],
                    'eta': l}
        td = prior.simulator(sim_dict)['observable']
        test_data.append(td.reshape(1, n_grid_stan*n_grid_stan, 5))
    test_data = np.concatenate(test_data)

    n_obs = n_grid_stan*n_grid_stan
    batch_size = test_data.shape[0]
    n_post_samples = 300

n_grid_stan = int(np.sqrt(true_local.shape[1]))
print(n_grid_stan*n_grid_stan, test_data.shape)

In [None]:
import optuna

def decay(t, d0, d1, alpha, beta):
    shaped = 1 - (1 - t.pow(alpha)).pow(beta)
    return d0 + (d1 - d0) * shaped

def exponential_decay(t, d0, d1):
    return d0 * torch.exp(-np.log(d0 / d1) * t)

def linear_decay(t, d0, d1):
    start = torch.as_tensor(d0, dtype=t.dtype, device=t.device)
    end = torch.as_tensor(d1, dtype=t.dtype, device=t.device)
    return torch.lerp(input=start, end=end, weight=t)

def cosine_decay(t, d0, d1):
    return d1 + 0.5 * (d0 - d1) * (1 + torch.cos(torch.pi * t))


def objective(trial):
    #t0_value = trial.suggest_float('t0_value', 1e-3, 1)
    #t1_value = trial.suggest_float('t1_value', 1e-5, 1e-1)
    t0_value = trial.suggest_float('t0_value', 1e-3, 1e-2)
    t1_value = trial.suggest_float('t1_value', 1e-5, 1e-3)
    s_shift_cosine = trial.suggest_float('s_shift_cosine', 0, 4)
    alpha = trial.suggest_float('alpha', 0.3, 2)
    beta = trial.suggest_float('beta', 0.3, 2)

    sampling_arg = {
        'size': 2,
        #'damping_factor': lambda t: exp_decay(t, d0=t0_value, d1=t1_value),
        'damping_factor': lambda t: decay(t, d0=t0_value, d1=t1_value, alpha=alpha, beta=beta),
    }
    score_model.sde.s_shift_cosine = s_shift_cosine

    test_global_samples = euler_maruyama_sampling(score_model, test_data,
                                            n_post_samples=n_post_samples,
                                            sampling_arg=sampling_arg,
                                            diffusion_steps=300,
                                            device=torch_device, verbose=False, return_time=True)
    if test_global_samples.ndim == 0:
        # time was returned instead of samples because of an error
        return 10 + float(test_global_samples) * 10  # penalize the error, but with an informative value
    #if np.isnan(test_global_samples).any():
    #    return np.inf

    rmse = diagnostics.root_mean_squared_error(test_global_samples, true_global)['values'].mean()
    cerror = diagnostics.calibration_error(test_global_samples, true_global)['values'].mean()
    return cerror + rmse

study = optuna.create_study()
study.optimize(objective, n_trials=100)
study.best_params

In [None]:
# no daming factor for N=16

In [None]:
# plot the optimal decay function
t = torch.linspace(0, 1, 100)
dt = decay(t, d0=study.best_params['t0_value'], d1=study.best_params['t1_value'],
           alpha=study.best_params['alpha'], beta=study.best_params['beta'])
dt_exp = exponential_decay(t, d0=study.best_params['t0_value'], d1=study.best_params['t1_value'])
dt_linear = linear_decay(t, d0=study.best_params['t0_value'], d1=study.best_params['t1_value'])
dt_cosine = cosine_decay(t, d0=study.best_params['t0_value'], d1=study.best_params['t1_value'])

plt.plot(t, dt, label='Decay function')
plt.plot(t, dt_exp, label='Exponential decay function')
plt.plot(t, dt_linear, label='Linear decay function')
plt.plot(t, dt_cosine, label='Cosine decay function')
plt.legend()
plt.xlabel('t')
plt.ylabel('Damping factor')
plt.show()

print(study.best_params)

In [None]:
#t0_value, t1_value = 0.1, 0.01
sampling_arg = {
    'size': 2,
    'damping_factor': lambda t: decay(t, d0=study.best_params['t0_value'], d1=study.best_params['t1_value'],
           alpha=study.best_params['alpha'], beta=study.best_params['beta']),
    'MC-dropout': False
}
global_param_names = prior.global_param_names
local_param_names = prior.get_local_param_names(n_grid_stan*n_grid_stan)
param_names_stan = ['STAN '+ p for p in global_param_names]
score_model.sde.s_shift_cosine = study.best_params['s_shift_cosine']

print(sampling_arg)

In [None]:
posterior_global_samples_test = euler_maruyama_sampling(score_model, test_data,
                                                   n_post_samples=n_post_samples,
                                                   sampling_arg=sampling_arg,
                                                   diffusion_steps=300,
                                                   device=torch_device, verbose=True)

In [None]:
fig = diagnostics.recovery(posterior_global_samples_test, true_global, variable_names=global_param_names)
#fig.savefig(f'plots/{score_model.name}/recovery_global_ours.png')
print('RMSE:', diagnostics.root_mean_squared_error(posterior_global_samples_test, true_global)['values'].mean())
fig = diagnostics.recovery(posterior_global_samples_test, np.median(global_posterior_stan, axis=1),
                     variable_names=global_param_names, xlabel='STAN Median Estimate')
#fig.savefig(f'plots/{score_model.name}/recovery_global_ours_vs_STAN.png')
fig = diagnostics.recovery(global_posterior_stan, true_global, variable_names=param_names_stan)
#fig.savefig(f'plots/{score_model.name}/recovery_global_STAN.png')
print('RMSE STAN:', diagnostics.root_mean_squared_error(global_posterior_stan, true_global)['values'].mean())

In [None]:
fig = diagnostics.calibration_ecdf(posterior_global_samples_test, true_global, difference=True,
                             variable_names=global_param_names)
#fig.savefig(f'plots/{score_model.name}/ecdf_global_ours.png')
print('ECDF:', diagnostics.calibration_error(posterior_global_samples_test, true_global)['values'].mean())

fig = diagnostics.calibration_ecdf(global_posterior_stan, true_global, difference=True, variable_names=param_names_stan)
#fig.savefig(f'plots/{score_model.name}/ecdf_global_STAN.png')
print('ECDF STAN:', diagnostics.calibration_error(global_posterior_stan, true_global)['values'].mean())

In [None]:
score_model.sde.s_shift_cosine = 0
score_model.current_number_of_obs = 1
posterior_local_samples_test = euler_maruyama_sampling(score_model, test_data, #[:, :12],
                                                       n_post_samples=n_post_samples,
                                                       conditions=posterior_global_samples_test,
                                                       diffusion_steps=100,
                                                       device=torch_device, verbose=True)

posterior_local_samples_test = score_model.prior.transform_local_params(beta=posterior_global_samples_test[..., 1:2], eta_raw=posterior_local_samples_test[..., 0])

In [None]:
print('RMSE:', diagnostics.root_mean_squared_error(posterior_global_samples_test, np.array(true_global))['values'].mean().round(2), diagnostics.root_mean_squared_error(posterior_global_samples_test, np.array(true_global))['values'].std().round(2))
print('RMSE Local:', diagnostics.root_mean_squared_error(posterior_local_samples_test.reshape(posterior_global_samples_test.shape[0], n_post_samples, -1), np.array(true_local))['values'].mean().round(2), diagnostics.root_mean_squared_error(posterior_local_samples_test.reshape(posterior_global_samples_test.shape[0], n_post_samples, -1), np.array(true_local))['values'].std().round(2))

print('Contraction:', diagnostics.posterior_contraction(posterior_global_samples_test, np.array(true_global))['values'].mean().round(2), diagnostics.posterior_contraction(posterior_global_samples_test, np.array(true_global))['values'].std().round(2))
print('Contraction Local:', diagnostics.posterior_contraction(posterior_local_samples_test.reshape(posterior_global_samples_test.shape[0], n_post_samples, -1), np.array(true_local))['values'].mean().round(2), diagnostics.posterior_contraction(posterior_local_samples_test.reshape(posterior_global_samples_test.shape[0], n_post_samples, -1), np.array(true_local))['values'].std().round(2))

In [None]:
diagnostics.recovery(posterior_local_samples_test.reshape(test_data.shape[0], n_post_samples, -1)[:, :, :12],
                     true_local[:, :12],
                     variable_names=local_param_names[:12])
diagnostics.recovery(local_posterior_stan[:, :, :12], true_local[:, :12],
                     variable_names=local_param_names[:12]);