# Flat Gaussian with compositional score matching


In this notebook, we will use the compositional score matching to learn the posterior of a flat Gaussian model.
The problem is defined as follows:
- The prior is a Gaussian distribution with mean 0 and standard deviation 0.1.
- The simulator/likelihood is a Gaussian distribution with mean 0 and standard deviation 0.1.
- We have an analytical solution for the posterior.
- We set the dimension of the problem to $D=10$.

In [None]:
import itertools
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn

os.environ['KERAS_BACKEND'] = 'torch'
from bayesflow import diagnostics
from torch.utils.data import DataLoader

from diffusion_model import ScoreModel, SDE, train_score_model, \
    adaptive_sampling, probability_ode_solving, langevin_sampling, euler_maruyama_sampling
from diffusion_model.helper_networks import GaussianFourierProjection, ShallowSet
from problems.gaussian_flat import GaussianProblem, Prior, Simulator, visualize_simulation_output, \
    generate_synthetic_data, sample_posterior

In [None]:
torch_device = torch.device("mps")

In [None]:
prior = Prior()
simulator_test = Simulator()

# test the simulator
prior_test = prior.sample(2)
sim_test = simulator_test(prior_test, n_obs=1000)
visualize_simulation_output(sim_test['observable'])

In [None]:
batch_size = 128
number_of_obs = 1 #[1, 100]  # multiple obs means we amortize over number of conditions
max_number_of_obs = max(number_of_obs) if isinstance(number_of_obs, list) else number_of_obs

current_sde = SDE(
    kernel_type=['variance_preserving', 'sub_variance_preserving'][0],
    noise_schedule=['linear', 'cosine', 'flow_matching', 'edm-training', 'edm-sampling'][1]
)

dataset = GaussianProblem(
    n_data=10000,
    prior=prior,
    sde=current_sde,
    online_learning=True,
    number_of_obs=number_of_obs
)
dataset_valid = GaussianProblem(
    n_data=1000,
    prior=prior,
    sde=current_sde,
    number_of_obs=number_of_obs
)

# Create dataloader
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
dataloader_valid = DataLoader(dataset_valid, batch_size=batch_size, shuffle=False)

for test in dataloader:
    print(test[0].shape)
    print(test[1].shape)
    print(test[2].shape)
    break

In [None]:
# Define diffusion model
time_embedding = nn.Sequential(
    GaussianFourierProjection(8),
    nn.Linear(8, 8),
    nn.Mish()
)
summary_dim = 10
summary_net = ShallowSet(dim_input=10, dim_output=summary_dim, dim_hidden=8) if isinstance(number_of_obs, list) else None

score_model = ScoreModel(
    input_dim_theta=prior.n_params_global,
    input_dim_x=summary_dim,
    summary_net=summary_net if isinstance(number_of_obs, list) else None,
    #time_embedding=time_embedding,
    hidden_dim=256,
    n_blocks=5,
    max_number_of_obs=max_number_of_obs,
    prediction_type=['score', 'e', 'x', 'v', 'F'][3],
    sde=current_sde,
    weighting_type=[None, 'likelihood_weighting', 'flow_matching', 'sigmoid', 'edm'][1],
    prior=prior,
    name_prefix=f'gaussian_flat0_{max_number_of_obs}'
)

# make dir for plots
if not os.path.exists(f"plots/{score_model.name}"):
    os.makedirs(f"plots/{score_model.name}")

In [None]:
if not os.path.exists(f"models/{score_model.name}.pt"):
    # train model
    loss_history = train_score_model(score_model, dataloader, dataloader_valid=dataloader_valid,
                                                  epochs=1000, device=torch_device)
    score_model.eval()
    torch.save(score_model.state_dict(), f"models/{score_model.name}.pt")

    # plot loss history
    plt.figure(figsize=(16, 4), tight_layout=True)
    plt.plot(loss_history[:, 0], label='Training', color="#132a70", lw=2.0, alpha=0.9)
    plt.plot(loss_history[:, 1], label='Validation', linestyle="--", marker="o", color='black')
    plt.grid(alpha=0.5)
    plt.xlabel('Training epoch #')
    plt.ylabel('Value')
    plt.legend()
    plt.savefig(f'plots/{score_model.name}/loss_training.png')
else:
    score_model.load_state_dict(torch.load(f"models/{score_model.name}.pt", map_location=torch_device, weights_only=True))
    score_model.eval();

# Validation

In [None]:
data_size = 100  # number of observations
obs_n_time_steps = 0
valid_prior_global, valid_data = generate_synthetic_data(prior, n_samples=100, data_size=data_size,
                                                         normalize=False, random_seed=0)
param_names = ['$D_{' + str(i+1) + '}$' for i in range(prior.D)]
n_post_samples = 100
score_model.current_number_of_obs = 1
score_model.sde.s_shift_cosine = 0

In [None]:
sample_posterior_single = lambda vd: sample_posterior(
    vd,
    prior_sigma=prior.scale,
    sigma=prior.simulator.scale,
    n_samples=n_post_samples
)
posterior_global_samples_true = np.array([sample_posterior_single(vd) for vd in valid_data])

In [None]:
diagnostics.recovery(posterior_global_samples_true, np.array(valid_prior_global), variable_names=param_names)
diagnostics.calibration_ecdf(posterior_global_samples_true, np.array(valid_prior_global),
                             difference=True, variable_names=param_names);

In [None]:
mini_batch_size = 1
t1_value = 0.8 #mini_batch_size /( data_size //score_model.current_number_of_obs)
t0_value = 1
sampling_arg = {
    'size': mini_batch_size,
    #'damping_factor': lambda t: t0_value * torch.exp(-np.log(t0_value / t1_value) * 2*t),
    #'damping_factor_prior': 1
}
#plt.plot(torch.linspace(0, 1, 100), sampling_arg['damping_factor'](torch.linspace(0, 1, 100)))
#plt.show()

t0_value, t1_value

In [None]:
posterior_global_samples_valid = langevin_sampling(score_model, valid_data,
                                                   n_post_samples=n_post_samples,
                                                   sampling_arg=sampling_arg,
                                                   diffusion_steps=2000, langevin_steps=10, step_size_factor=0.1,
                                                   device=torch_device, verbose=True)

In [None]:
fig = diagnostics.recovery(posterior_global_samples_valid, np.array(valid_prior_global), variable_names=param_names)
#fig.savefig(f'plots/{score_model.name}/recovery_global_langevin_sampler{score_model.current_number_of_obs}.png')

fig = diagnostics.calibration_ecdf(posterior_global_samples_valid, np.array(valid_prior_global),
                          difference=True, variable_names=param_names)
#fig.savefig(f'plots/{score_model.name}/ecdf_global_langevin_sampler{score_model.current_number_of_obs}.png')

In [None]:
# posterior_global_samples_valid = sde_sampling(score_model, valid_data, obs_n_time_steps=obs_n_time_steps,
#                                               n_post_samples=n_post_samples, diffusion_steps=300,
#                                               method=['euler', 'milstein_grad_free', 'srk1w1'][1],
#                                            device=torch_device, verbose=True)

posterior_global_samples_valid = euler_maruyama_sampling(score_model, valid_data,
                                              n_post_samples=n_post_samples, diffusion_steps=10000,
                                              sampling_arg=sampling_arg,
                                              device=torch_device, verbose=True)

In [None]:
fig = diagnostics.recovery(posterior_global_samples_valid, np.array(valid_prior_global), variable_names=param_names)
#fig.savefig(f'plots/{score_model.name}/recovery_global_euler_sampler{score_model.current_number_of_obs}.png')

fig = diagnostics.calibration_ecdf(posterior_global_samples_valid, np.array(valid_prior_global),
                          difference=True, variable_names=param_names)
#fig.savefig(f'plots/{score_model.name}/ecdf_global_euler_sampler{score_model.current_number_of_obs}.png')

In [None]:
mini_batch_size = 10
t1_value = 0.01 #mini_batch_size /( data_size //score_model.current_number_of_obs)
t0_value = 1
sampling_arg = {
    #'size': mini_batch_size,
    'damping_factor': lambda t: t0_value * torch.exp(-np.log(t0_value / t1_value) * 2*t)
}
#plt.plot(torch.linspace(0, 1, 100), sampling_arg['damping_factor'](torch.linspace(0, 1, 100)))
#plt.show()

t0_value, t1_value

In [None]:
posterior_global_samples_valid = adaptive_sampling(score_model, valid_data,
                                                   n_post_samples=n_post_samples,
                                                   sampling_arg=sampling_arg,
                                                   run_sampling_in_parallel=False,
                                                   device=torch_device, verbose=True)

In [None]:
fig = diagnostics.recovery(posterior_global_samples_valid, np.array(valid_prior_global), variable_names=param_names)
#fig.savefig(f'plots/{score_model.name}/recovery_global_adaptive_sampler{score_model.current_number_of_obs}.png')

fig = diagnostics.calibration_ecdf(posterior_global_samples_valid, np.array(valid_prior_global),
                                difference=True, variable_names=param_names)
#fig.savefig(f'plots/{score_model.name}/ecdf_global_adaptive_sampler{score_model.current_number_of_obs}.png')

fig = diagnostics.z_score_contraction(posterior_global_samples_valid, np.array(valid_prior_global),
                                            variable_names=param_names)
#fig.savefig(f'plots/{score_model.name}/z_score_global_adaptive_sampler{score_model.current_number_of_obs}.png')

diagnostics.calibration_error(posterior_global_samples_valid, np.array(valid_prior_global))['values'].mean()

In [None]:
posterior_global_samples_valid = probability_ode_solving(score_model, valid_data,
                                                         n_post_samples=n_post_samples,
                                                         run_sampling_in_parallel=True,
                                                         device=torch_device, verbose=True)

In [None]:
fig = diagnostics.recovery(posterior_global_samples_valid, np.array(valid_prior_global), variable_names=param_names)
#fig.savefig(f'plots/{score_model.name}/recovery_global_ode{score_model.current_number_of_obs}.png')

fig = diagnostics.calibration_ecdf(posterior_global_samples_valid, np.array(valid_prior_global),
                          difference=True, variable_names=param_names)
#fig.savefig(f'plots/{score_model.name}/ecdf_global_ode{score_model.current_number_of_obs}.png')