# Gaussian on a Grid Test for Hierarchical ABI with compositional score matching

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from bayesflow import diagnostics
from torch.utils.data import DataLoader, TensorDataset

from helper_functions import generate_diffusion_time
from diffusion_sampling import euler_maruyama_sampling, adaptive_sampling, probability_ode_solving, langevin_sampling
from diffusion_model import HierarchicalScoreModel, SDE, weighting_function

from gaussian_test_simulator import Prior, Simulator, visualize_simulation_output, generate_synthetic_data

In [None]:
torch_device = torch.device("cpu")

In [None]:
prior = Prior()
simulator_test = Simulator()

# test the simulator
sim_test = simulator_test(prior.sample_full(1))['observable']
visualize_simulation_output(sim_test)

In [None]:
# plot the different snrs
t = torch.linspace(0, 1, 1000)
bins = 100

sub_vp = ['', 'sub_'][0]

fig, ax = plt.subplots(2, 6, sharex='col', sharey='col', figsize=(12, 6), tight_layout=True)
for a, sub_vp in zip(ax, ['', 'sub_']):
    snr = SDE(sub_vp+'variance_preserving', 'linear').get_snr(t)
    h1 = a[0].plot(t, snr, label='linear', alpha=0.75)
    a[1].hist(snr, bins=bins, density=True, alpha=0.75)
    alpha, sigma = SDE(sub_vp+'variance_preserving', 'linear').kernel(t)
    a[2].plot(t, alpha, label='linear', alpha=0.75)
    a[3].plot(t, sigma, label='linear', alpha=0.75)
    a[4].plot(t, alpha**2+sigma**2, label='linear', alpha=0.75)
    a[5].plot(t, alpha+sigma, label='linear', alpha=0.75)

    snr = SDE(sub_vp+'variance_preserving', 'cosine').get_snr(t)
    h2 = a[0].plot(t, snr, label='cosine', alpha=0.75)
    a[1].hist(snr, bins=bins, density=True, alpha=0.75)
    alpha, sigma = SDE(sub_vp+'variance_preserving', 'cosine').kernel(t)
    a[2].plot(t, alpha, label='cosine', alpha=0.75)
    a[3].plot(t, sigma, label='cosine', alpha=0.75)
    a[4].plot(t, alpha**2+sigma**2, label='cosine', alpha=0.75)
    a[5].plot(t, alpha+sigma, label='cosine', alpha=0.75)

    snr = SDE(sub_vp+'variance_preserving', 'cosine', s_shift_cosine=0.8).get_snr(t)
    h3 = a[0].plot(t, snr, label='shifted_cosine', alpha=0.75)
    a[1].hist(snr, bins=bins, density=True, alpha=0.75)
    alpha, sigma = SDE(sub_vp+'variance_preserving', 'cosine', s_shift_cosine=0.8).kernel(t)
    a[2].plot(t, alpha, label='shifted_cosine', alpha=0.75)
    a[3].plot(t, sigma, label='shifted_cosine', alpha=0.75)
    a[4].plot(t, alpha**2+sigma**2, label='shifted_cosine', alpha=0.75)
    a[5].plot(t, alpha+sigma, label='shifted_cosine', alpha=0.75)

    if sub_vp == 'sub_':
        snr = SDE(sub_vp+'variance_preserving', 'flow_matching').get_snr(t)
        h4 = a[0].plot(t, snr, label='flow_matching', alpha=0.75)
        a[1].hist(snr, bins=bins, density=True, alpha=0.75)
        alpha, sigma = SDE(sub_vp+'variance_preserving', 'flow_matching').kernel(t)
        a[2].plot(t, alpha, label='flow_matching', alpha=0.75)
        a[3].plot(t, sigma, label='flow_matching', alpha=0.75)
        a[4].plot(t, alpha**2+sigma**2, label='flow_matching', alpha=0.75)
        a[5].plot(t, alpha+sigma, label='flow_matching', alpha=0.75)

    # dotted line at 0
    a[0].axhline(0, color='black', linestyle='--', alpha=0.5)

    for axis in a:
        axis.set_xlabel('t')
    a[1].set_xlabel('snr')
    a[0].set_ylabel(f'snr\n{sub_vp}variance_preserving')
    a[0].set_title(f'Signal-to-noise ratio')
    a[1].set_title(f'Signal-to-noise ratio')
    a[2].set_title(f'alpha(t)')
    a[3].set_title(f'sigma(t)')
    a[4].set_title(f'alpha(t)^2 + sigma(t)^2')
    a[5].set_title(f'alpha(t) + sigma(t)')
fig.legend(handles=h1+h2+h3+h4, loc='lower center', ncol=4, bbox_to_anchor=(0.5, -0.05))
plt.savefig('noise schedules.png', bbox_inches='tight')
plt.show()

In [None]:
t = torch.linspace(0, 1, 1000)
x_0 = torch.tensor([0.01])

fig, ax = plt.subplots(2, 2, sharex='col', sharey='col', figsize=(5, 4), tight_layout=True)
for a, sub_vp in zip(ax, ['', 'sub_']):
    f, g = SDE(sub_vp+'variance_preserving', 'linear').get_f_g(t, x=x_0)
    h1 = a[0].plot(t, f, label='linear', alpha=0.75)
    a[1].plot(t, g, label='linear', alpha=0.75)

    f, g = SDE(sub_vp+'variance_preserving', 'cosine').get_f_g(t, x=x_0)
    h2 = a[0].plot(t, f, label='cosine', alpha=0.75)
    a[1].plot(t, g, label='cosine', alpha=0.75)

    #f, g = SDE(sub_vp+'variance_preserving', 'cosine', s_shift_cosine=0.8).get_f_g(t, x=x_0)
    #h3 = a[0].plot(t, f, label='shifted_cosine', alpha=0.75)
    #a[1].plot(t, g, label='shifted_cosine', alpha=0.75)

    if sub_vp == 'sub_':
        f, g = SDE(sub_vp+'variance_preserving', 'flow_matching').get_f_g(t, x=x_0)
        h4 = a[0].plot(t, f, label='flow_matching', alpha=0.75)
        a[1].plot(t, g, label='flow_matching', alpha=0.75)

    for axis in a:
        axis.set_xlabel('t')
    a[0].set_ylabel(f'snr\n{sub_vp}variance_preserving')
    a[0].set_title(f'f(z={round(x_0.item(), 2)},t)')
    a[1].set_title(f'g(t)')
fig.legend(handles=h1+h2+h4, loc='lower center', ncol=4, bbox_to_anchor=(0.5, -0.05))
#plt.savefig('noise schedules.png', bbox_inches='tight')
plt.show()

In [None]:
# plot the kernel
t = generate_diffusion_time(1000)
sde_test = SDE('variance_preserving', 'cosine')
snr = sde_test.get_snr(t)

fig, ax = plt.subplots(2, 5, sharex='row', sharey='row', figsize=(12, 6), tight_layout=True)
for a, wt in zip(ax.T, [None, 'likelihood_weighting', 'flow_matching', 'sigmoid', 'min-snr']):
    w = weighting_function(t, sde_test, weighting_type=wt)
    a[0].plot(t, w / max(w))
    a[0].set_xlabel(r'$t$')
    a[0].set_ylabel('Normalized weight')
    a[0].set_title(wt)

    a[1].plot(snr, w / max(w))
    a[1].set_xlabel(r'$\lambda$')
    a[1].set_ylabel('Normalized weight')
    a[0].set_title(wt)
#plt.savefig(f'weighting_functions_{sde_test.noise_schedule}.png', bbox_inches='tight')
plt.show()

In [None]:
def compute_score_loss(theta_batch, x_batch, model):
    # Generate diffusion time and step size
    diffusion_time = generate_diffusion_time(size=theta_batch.shape[0], return_batch=True, device=theta_batch.device)

    # sample from the Gaussian kernel, just learn the noise
    epsilon = torch.randn_like(theta_batch, dtype=theta_batch.dtype, device=theta_batch.device)

    # perturb the theta batch
    alpha, sigma = model.sde.kernel(t=diffusion_time)
    z = alpha * theta_batch + sigma * epsilon
    # predict from perturbed theta
    pred = model(theta=z, time=diffusion_time, x=x_batch, pred_score=False)  # if prediction_type is 'score', this is still the score

    if model.prediction_type == 'score':
        target = model.sde.grad_log_kernel(x=z, x0=theta_batch, t=diffusion_time)
    else:
        target = epsilon

    effective_weight = weighting_function(diffusion_time, sde=model.sde, weighting_type=model.weighting_type,
                                          prediction_type=model.prediction_type)
    # calculate the loss (sum over the last dimension, mean over the batch)
    loss = torch.mean(effective_weight * torch.sum(torch.square(pred - target), dim=-1))
    return loss


# Training loop for Score Model
def train_score_model(model, dataloader, dataloader_valid=None, epochs=100, lr=1e-4, device=None):
    print(f"Training {model.prediction_type}-model for {epochs} epochs with learning rate {lr} and {model.sde.kernel_type},"
          f" {model.sde.noise_schedule} schedule and {model.weighting_type} weighting.")
    score_model.to(torch_device)
    optimizer = optim.AdamW(model.parameters(), lr=lr)

    # Add Cosine Annealing Scheduler
    #scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-6)

    # Training loop
    loss_history = np.zeros((epochs, 2))
    for epoch in range(epochs):
        model.train()
        total_loss = []
        # for each sample in the batch, calculate the loss for a random diffusion time
        for theta_global_batch, theta_local_batch, x_batch in dataloader:
            # initialize the gradients
            optimizer.zero_grad()
            theta_batch = torch.concat([theta_global_batch, theta_local_batch], dim=-1)
            theta_batch = theta_batch.to(device)
            x_batch = x_batch.to(device)
            # calculate the loss
            loss = compute_score_loss(theta_batch=theta_batch, x_batch=x_batch, model=model)
            loss.backward()
            # gradient clipping
            nn.utils.clip_grad_norm_(model.parameters(), 3.0)
            optimizer.step()
            total_loss.append(loss.item())
        #scheduler.step()

        # validate the model
        model.eval()
        valid_loss = []
        if dataloader_valid is not None:
            with torch.no_grad():
                for theta_global_batch, theta_local_batch, x_batch in dataloader_valid:
                    theta_batch = torch.concat([theta_global_batch, theta_local_batch], dim=-1)
                    theta_batch = theta_batch.to(device)
                    x_batch = x_batch.to(device)
                    loss = compute_score_loss(theta_batch=theta_batch, x_batch=x_batch, model=model)
                    valid_loss.append(loss.item())

        loss_history[epoch] = [np.mean(total_loss), np.mean(valid_loss)]
        print_str = f"Epoch {epoch+1}/{epochs}, Loss: {np.mean(total_loss):.4f}, "\
                    f"Valid Loss: {np.mean(valid_loss):.4f}"
        print(print_str, end='\r')
        # Update the checkpoint after each epoch of training.
        #torch.save(model.state_dict(), 'ckpt.pth')
    return loss_history

In [None]:
# Hyperparameters
n_data = 10000
batch_size = 128

# Create model and dataset
thetas_global, thetas_local, xs = generate_synthetic_data(prior, n_data=n_data, normalize=True)

# Create dataloader
dataset = TensorDataset(thetas_global, thetas_local, xs)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# create validation data
valid_prior_global, valid_prior_local, valid_data = generate_synthetic_data(prior, n_data=batch_size*2, normalize=True)
dataset_valid = TensorDataset(valid_prior_global, valid_prior_local, valid_data)
dataloader_valid = DataLoader(dataset_valid, batch_size=batch_size)

In [None]:
# Define model
current_sde = SDE(
    kernel_type=['variance_preserving', 'sub_variance_preserving'][0],
    noise_schedule=['linear', 'cosine', 'flow_matching'][0]
)

score_model = HierarchicalScoreModel(
    input_dim_theta_global=prior.n_params_global,
    input_dim_theta_local=prior.n_params_local,
    input_dim_x=1,
    hidden_dim=64,
    n_blocks=3,
    prediction_type=['score', 'e', 'x', 'v'][3],
    sde=current_sde,
    time_embed_dim=16,
    use_film=True,
    weighting_type=[None, 'likelihood_weighting', 'flow_matching', 'sigmoid'][1],
    prior=prior
)

In [None]:
# train model
loss_history = train_score_model(score_model, dataloader, dataloader_valid=dataloader_valid,
                                 epochs=500, lr=1e-4, device=torch_device)
score_model.eval();

In [None]:
torch.save(score_model.state_dict(), f"score_model_{score_model.sde.noise_schedule}.pt")

In [None]:
# plot loss history
plt.figure(figsize=(6, 3), tight_layout=True)
plt.plot(loss_history[:, 0], label='Mean Train')
plt.plot(loss_history[:, 1], label='Mean Valid')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.savefig(f'loss_plots/loss_training_{score_model.sde.noise_schedule}.png')
plt.show()

In [None]:
# check the error prediction: is it close to the noise?
loss_list_target = {}
loss_list_score = {}
loss_list_error_w = {}
loss_list_error = {}

pred_list = []
score_list = []
with torch.no_grad():
    # Generate diffusion time and step size
    diffusion_time = generate_diffusion_time(size=100, device=torch_device)
    for t in diffusion_time:
        for theta_global_batch, theta_local_batch, x_batch in dataloader_valid:
            theta_batch = torch.cat([theta_global_batch, theta_local_batch], dim=-1)
            theta_batch = theta_batch.to(torch_device)
            x_batch = x_batch.to(torch_device)

            # sample from the Gaussian kernel, just learn the noise
            epsilon = torch.randn_like(theta_batch, dtype=torch.float32, device=torch_device)

            # perturb the theta batch
            t_tensor = torch.full((theta_batch.shape[0], 1), t, dtype=torch.float32, device=torch_device)
            # perturb the theta batch
            alpha, sigma = score_model.sde.kernel(t=t_tensor)
            z = alpha * theta_batch + sigma * epsilon
            snr = torch.log(torch.square(alpha)) - torch.log(torch.square(sigma))
            # predict from perturbed theta
            pred_epsilon = score_model(theta=z, time=t_tensor, x=x_batch, pred_score=False)
            pred_score = score_model(theta=z, time=t_tensor, x=x_batch, pred_score=True)
            true_score = score_model.sde.grad_log_kernel(x=z, x0=theta_batch, t=t_tensor)
            pred_list.append(torch.mean(pred_score))
            score_list.append(torch.mean(true_score))

            if score_model.prediction_type == 'score':
                target = score_model.sde.grad_log_kernel(x=z, x0=theta_batch, t=t_tensor)
                pred_target = pred_epsilon  # is still the score
                epsilon = target  # we do not need to calculate the error
            elif score_model.prediction_type == 'e':
                target = epsilon
                pred_target = pred_epsilon
            elif score_model.prediction_type == 'v':
                target = alpha*epsilon - sigma * theta_batch
                pred_target = alpha*pred_epsilon - sigma * theta_batch
            elif score_model.prediction_type == 'x':
                target = theta_batch
                pred_target = (z - pred_epsilon * sigma) / alpha
            else:
                raise ValueError("Invalid prediction type.")

            # calculate the loss (sum over the last dimension, mean over the batch)
            loss = torch.mean(torch.sum(torch.square(pred_target - target), dim=-1))
            loss_list_target[t.item()] = loss.item()

            # calculate the error of the true score
            loss = torch.mean(torch.sum(torch.square(pred_score - true_score), dim=-1))
            loss_list_score[t.item()] = loss.item()

            # calculate the weighted loss
            w = weighting_function(t_tensor, sde=score_model.sde,
                                   weighting_type=score_model.weighting_type, prediction_type=score_model.prediction_type)
            loss = torch.mean(w * torch.sum(torch.square(pred_epsilon - epsilon), dim=-1))
            loss_list_error_w[t.item()] = loss.item()

            # check if the weighting function is correct
            loss = torch.mean(torch.sum(torch.square(pred_epsilon - epsilon), dim=-1))
            loss_list_error[t.item()] = loss.item()

In [None]:
df_target = pd.DataFrame(loss_list_error.items(), columns=['Time', 'Loss'])
df_score = pd.DataFrame(loss_list_score.items(), columns=['Time', 'Loss'])
df_error_w = pd.DataFrame(loss_list_error_w.items(), columns=['Time', 'Loss'])
df_error = pd.DataFrame(loss_list_error.items(), columns=['Time', 'Loss'])

# compute snr
m, std = score_model.sde.kernel(diffusion_time)
snr = torch.log(torch.square(m)) - torch.log(torch.square(std))
upper_bound_loss = (np.sqrt(2) + 1) / (std.numpy()**2)

fig, ax = plt.subplots(ncols=4, sharex=True, figsize=(16, 3), tight_layout=True)
ax[0].plot(df_target['Time'], np.log(df_target['Loss']), label=f'Unscaled {score_model.prediction_type} Loss')
ax[1].plot(df_score['Time'], np.log(df_score['Loss']), label='Score Loss')
#ax[1].plot(df_score['Time'], df_score['Loss'] / upper_bound_loss, label='Score Loss')
ax[1].plot(diffusion_time, snr, label='log snr', alpha=0.5)
ax[2].plot(df_error_w['Time'], np.log(df_error_w['Loss']), label='Weighted Loss (as in Optimization)')
ax[3].plot(df_error['Time'], np.log(df_error['Loss']), label='Loss on Error')
for a in ax:
    a.set_xlabel('Diffusion Time')
    a.set_ylabel('Log Loss')
    a.legend()
plt.savefig(f'loss_plots/losses_diffusion_time_{score_model.sde.noise_schedule}.png')
plt.show()

plt.figure(figsize=(6, 3), tight_layout=True)
plt.plot(diffusion_time.cpu(),
         weighting_function(diffusion_time, sde=score_model.sde, weighting_type=score_model.weighting_type,
                            prediction_type=score_model.prediction_type).cpu(),
         label='weighting')
plt.xlabel('Diffusion Time')
plt.ylabel('Weight')
plt.legend()
plt.show()

# Validation

In [None]:
n_grid = 8
valid_prior_global, valid_prior_local, valid_data = generate_synthetic_data(prior, n_data=10, grid_size=n_grid, full_grid=True,
                                                                            normalize=False)
n_post_samples = 10

In [None]:
posterior_global_samples_valid = np.array([langevin_sampling(score_model, vd, n_post_samples=n_post_samples,
                                                            diffusion_steps=500, langevin_steps=5, device=torch_device)
                                        for vd in valid_data])

In [None]:
fig = diagnostics.plot_recovery(posterior_global_samples_valid, np.array(valid_prior_global), param_names=[r'$\mu$', r'$\log \tau$'])
fig.savefig(f'loss_plots/recovery_global_{score_model.prediction_type}_langevin_sampler.png')

In [None]:
posterior_global_samples_valid = np.array([euler_maruyama_sampling(score_model, vd, n_post_samples=n_post_samples,
                                                                   diffusion_steps=800, device=torch_device)
                                        for vd in valid_data])

In [None]:
fig = diagnostics.plot_recovery(posterior_global_samples_valid, np.array(valid_prior_global), param_names=[r'$\mu$', r'$\log \tau$'])
fig.savefig(f'loss_plots/recovery_global_{score_model.prediction_type}_euler_sampler.png')

In [None]:
posterior_global_samples_valid = np.array([euler_maruyama_sampling(score_model, vd, n_post_samples=n_post_samples,
                                                                   pareto_smooth_sum_dict={'tail_fraction': 0.1, 'alpha': 0.6},
                                                                   diffusion_steps=400, device=torch_device)
                                        for vd in valid_data])

In [None]:
fig = diagnostics.plot_recovery(posterior_global_samples_valid, np.array(valid_prior_global), param_names=[r'$\mu$', r'$\log \tau$'])
fig.savefig(f'loss_plots/recovery_global_{score_model.prediction_type}_euler_pareto_sampler.png')

In [None]:
posterior_global_samples_valid = np.array([adaptive_sampling(score_model, vd, n_post_samples,
                                                             e_rel=0.1, max_steps=2000, device=torch_device)
                                        for vd in valid_data])

In [None]:
fig = diagnostics.plot_recovery(posterior_global_samples_valid, np.array(valid_prior_global), param_names=[r'$\mu$', r'$\log \tau$'])
fig.savefig(f'loss_plots/recovery_global_{score_model.prediction_type}_adaptive_sampler.png')

In [None]:
posterior_global_samples_valid = np.zeros((len(valid_data), n_post_samples, 2))
for i, vd in enumerate(valid_data):
    # solve for every sample individually, much slower, and still most of the samples were similar
    #for j in range(n_post_samples):
    #    posterior_global_samples_valid[i, j] = probability_ode_solving(score_model, vd, n_post_samples=1,
    #                                                                    device=torch_device)
    # solve for all samples at once
    posterior_global_samples_valid[i] = probability_ode_solving(score_model, vd, n_post_samples=n_post_samples,
                                                                device=torch_device)

In [None]:
fig = diagnostics.plot_recovery(posterior_global_samples_valid, np.array(valid_prior_global), param_names=[r'$\mu$', r'$\log \tau$'])
fig.savefig(f'loss_plots/recovery_global_{score_model.prediction_type}_ode.png')

In [None]:
diagnostics.plot_sbc_ecdf(posterior_global_samples_valid, np.array(valid_prior_global),
                          difference=True, param_names=[r'$\mu$', r'$\log \tau$']);

In [None]:
#conditions_global = np.median(posterior_global_samples_valid, axis=0)
posterior_local_samples_valid = np.array([euler_maruyama_sampling(score_model, vd,
                                                                  n_post_samples=n_post_samples, conditions=c,
                                                                  diffusion_steps=300, device=torch_device)
                                        for vd, c in zip(valid_data, posterior_global_samples_valid)])

In [None]:
diagnostics.plot_recovery(posterior_local_samples_valid.reshape(valid_data.shape[0], n_post_samples, -1),
                          np.array(valid_prior_local).reshape(valid_data.shape[0], -1),
                          param_names=['$\\theta_{'+str(i)+'}$' for i in range(n_grid**2)]);

In [None]:
#conditions_global = np.median(posterior_global_samples_valid, axis=1)
posterior_local_samples_valid = np.array([probability_ode_solving(score_model, vd, n_post_samples=n_post_samples,
                                                                  conditions=c, device=torch_device)
                                        for vd, c in zip(valid_data, posterior_global_samples_valid)])

In [None]:
diagnostics.plot_recovery(posterior_local_samples_valid.reshape(valid_data.shape[0], n_post_samples, -1),
                          np.array(valid_prior_local).reshape(valid_data.shape[0], -1),
                          param_names=['$\\theta_{'+str(i)+'}$' for i in range(n_grid**2)]);

In [None]:
valid_id = 0
print('Data')
visualize_simulation_output(valid_data[valid_id])
print('Global Estimates')
print('mu:', np.median(posterior_global_samples_valid[valid_id, :, 0]), np.std(posterior_global_samples_valid[valid_id, :, 0]))
print('log tau:', np.median(posterior_global_samples_valid[valid_id, :, 1]), np.std(posterior_global_samples_valid[valid_id, :, 1]))
print('True')
print('mu:', valid_prior_global[valid_id][0].item())
print('log tau:', valid_prior_global[valid_id][1].item())

In [None]:
med = np.median(posterior_local_samples_valid[valid_id].reshape(n_post_samples, n_grid, n_grid), axis=0)
std = np.std(posterior_local_samples_valid[valid_id].reshape(n_post_samples, n_grid, n_grid), axis=0)
error = (med-valid_prior_local[valid_id].numpy())**2
visualize_simulation_output(np.stack((med, valid_prior_local[valid_id], )),
                            title_prefix=['Posterior Median', 'True'])

visualize_simulation_output(np.stack((std, error)), title_prefix=['Uncertainty', 'Error'], same_scale=False)

# Visualize the Score

In [None]:
valid_id = np.random.randint(0, len(valid_data))

diffusion_time = generate_diffusion_time(size=10, device=torch_device)
x_valid = valid_data[valid_id].to(torch_device)
theta_global = prior.normalize_theta(valid_prior_global[valid_id], global_params=True).cpu().numpy()  # we normalize as the score is normalized space
print(valid_id, 'theta global', theta_global)

In [None]:
test_sample = adaptive_sampling(score_model, x_valid, conditions=None, n_post_samples=1,
                                e_rel=0.1, max_steps=2500, t_end=diffusion_time[0], random_seed=0, device=torch_device)
test_sample = prior.normalize_theta(torch.tensor(test_sample), global_params=True).cpu().numpy()
print(test_sample)

In [None]:
posterior_sample_path = np.array([adaptive_sampling(score_model, x_valid, conditions=None, n_post_samples=1,
                                                                  e_rel=0.1, max_steps=2500, t_end=t, random_seed=0, device=torch_device)
                                  for t in diffusion_time[:-1]])
# we normalize as the score is normalized space
posterior_sample_path = prior.normalize_theta(torch.tensor(posterior_sample_path), global_params=True).cpu().numpy()

# posterior_sample_path2 = np.array([euler_maruyama_sampling(score_model, x_valid, diffusion_steps=40000,
#                                                            conditions=None, n_post_samples=1, t_end=t, random_seed=0, device=torch_device)
#                                   for t in diffusion_time[:-1]])
# we normalize as the score is normalized space
#posterior_sample_path2 = prior.normalize_theta(torch.tensor(posterior_sample_path2), global_params=True).cpu().numpy()

posterior_sample_path3 = np.array([probability_ode_solving(score_model, x_valid,
                                                           conditions=None, n_post_samples=1, t_end=t, random_seed=0, device=torch_device)
                                  for t in diffusion_time[:-1]])
# we normalize as the score is normalized space
posterior_sample_path3 = prior.normalize_theta(torch.tensor(posterior_sample_path3), global_params=True).cpu().numpy()

print('theta global', theta_global, posterior_sample_path[0].flatten())

In [None]:
# Define grid boundaries and resolution for your 2D space.
x_min, x_max, y_min, y_max = -1.5, 1.5, -1.5, 1.5
grid_res = 10  # Number of points per dimension

# Create a meshgrid of points
x_vals = np.linspace(x_min, x_max, grid_res)
y_vals = np.linspace(y_min, y_max, grid_res)
xx, yy = np.meshgrid(x_vals, y_vals)
# Stack into (N,2) where N = grid_res*grid_res
grid_points = np.vstack([xx.ravel(), yy.ravel()]).T

# Convert grid to a torch tensor and move to device
grid_tensor = torch.tensor(grid_points, dtype=torch.float32, device=torch_device)
x_valid_e = x_valid.reshape(10, -1)
x_valid_ext = x_valid_e.unsqueeze(0).repeat(grid_tensor.shape[0], 1, 1).to(torch_device)

# Dictionary to hold score outputs for each time
scores = {}

# Evaluate the score model for each time value
for t in diffusion_time:
    # Create a tensor of time values for each grid point
    t_tensor = torch.full((grid_tensor.shape[0], 1), t.item(), dtype=torch.float32, device=torch_device)
    epsilon = torch.randn_like(grid_tensor, dtype=torch.float32, device=torch_device)

    # perturb theta
    alpha, sigma = score_model.sde.kernel(t=t_tensor)
    z = grid_tensor #alpha * grid_tensor + sigma * epsilon

    # Evaluate the score model
    with torch.no_grad():
        score = (1 - n_grid*n_grid) * (1 - t) / 1 * prior.score_global_batch(z)
        for i in range(x_valid_ext.shape[2]):
            score += score_model.forward_global(theta_global=z, time=t_tensor, x=x_valid_ext[:, :, i].unsqueeze(-1),
                                                pred_score=True, clip_x=False)
    scores[t.item()] = score.cpu().numpy()

In [None]:
# Plot the vector field (score) for each time step using subplots
nrows = 2
fig, axes = plt.subplots(nrows, len(diffusion_time) // nrows, sharex=True, sharey=True,
                         figsize=(15, 3*nrows), tight_layout=True)
axes = axes.flatten()

for i, (t_val, score_val) in enumerate(sorted(scores.items(), reverse=True)):
    # Reshape score components back to (grid_res, grid_res) for quiver plotting
    U = score_val[:, 0].reshape(grid_res, grid_res)
    V = score_val[:, 1].reshape(grid_res, grid_res)  # negative since we are plotting the reverse score

    ax = axes[i]
    h0, = ax.plot(0, 0, 'o', color='black', label='Latent Prior')
    h1, = ax.plot(theta_global[0], theta_global[1], 'ro', label='True Parameter')

    j = len(diffusion_time)-1-i
    if i != 0:
        #h2, = ax.plot(posterior_sample_path[j, 0, 0], posterior_sample_path[j, 0, 1], 'o', label='Posterior Path Sampling')
        h2, = ax.plot(posterior_sample_path[j:, 0, 0], posterior_sample_path[j:, 0, 1], 'o-', label='Posterior Path Sampling', alpha=0.5)
        #h3, = ax.plot(posterior_sample_path2[j, 0, 0], posterior_sample_path2[j, 0, 1], 'o', label='Posterior Path Euler')
        ##h3, = ax.plot(posterior_sample_path2[j:, 0, 0], posterior_sample_path2[j:, 0, 1], 'o-', label='Posterior Path Euler', alpha=0.5)
        #h4, = ax.plot(posterior_sample_path3[j, 0, 0], posterior_sample_path3[j, 0, 1], 'o', label='Posterior Path ODE')
        h4, = ax.plot(posterior_sample_path3[j:, 0, 0], posterior_sample_path3[j:, 0, 1], 'o-', label='Posterior Path ODE', alpha=0.5)
    #ax.plot(posterior_sample_path3[:, 0, 0], posterior_sample_path3[:, 0, 1], 'o-', label='Posterior Path ODE')
    ax.quiver(xx, yy, U, V, color='blue', angles='xy', scale_units='xy', scale=10*n_grid*n_grid, alpha=.75)
    ax.set_title(f"Diffusion t = {t_val:.3f}")
    ax.set_xlim(x_min-0.5, x_max+0.5)
    ax.set_ylim(y_min-0.5, y_max+0.5)
    ax.set_aspect('equal')
    ax.set_xlabel("$x$")
    ax.set_ylabel("$y$")
    #ax.legend()
    #print(posterior_sample_path[i, 0, 0], posterior_sample_path[i, 0, 1])
fig.legend(handles=[h0, h1, h2, h4], loc='lower center', ncols=5, bbox_to_anchor=(0.5, -0.05))
plt.savefig(f'loss_plots/score_field_{score_model.sde.noise_schedule}.png')
plt.show()
print('theta global', theta_global, posterior_sample_path[0].flatten())