# Gaussian on a Grid Test with compositional score matching

In this notebook, we will test the compositional score matching on a hierarchical problem defined on a grid.
- The observations are on grid with `n_grid` x `n_grid` points.
- The global parameters are the same for all grid points with hyper-priors:
$$ \mu \sim \mathcal{N}(0, 3^2),\qquad \log\sigma \sim \mathcal{N}(0, 1^2)$$

- The local parameters are different for each grid point
$$ \theta_{i,j} \sim \mathcal{N}(\mu, \sigma^2)$$

-  In each grid point, we have a Brownian motion with drift:
$$ dx_t = \theta \cdot dt + \sqrt{dt} \cdot dW_t$$
- We observe $T=10$ time points for each grid point over a time horizon of `max_time=1`. We can also amortize over the time dimension.

In [None]:
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn

os.environ['KERAS_BACKEND'] = 'torch'
from bayesflow import diagnostics

from torch.utils.data import DataLoader

from diffusion_model import HierarchicalScoreModel, SDE, euler_maruyama_sampling, adaptive_sampling, \
    probability_ode_solving, langevin_sampling, \
    generate_diffusion_time, count_parameters, train_score_model
from diffusion_model.helper_networks import LSTM, GaussianFourierProjection
from problems.gaussian_grid import GaussianGridProblem, Prior, Simulator, visualize_simulation_output, plot_shrinkage, \
    generate_synthetic_data

In [None]:
torch_device = torch.device("cuda")

In [None]:
prior = Prior(n_time_points=10)
simulator_test = Simulator(n_time_points=10)

# test the simulator
sim_test = simulator_test(prior.sample_full(1, n_grid=8))['observable']
visualize_simulation_output(sim_test)

In [None]:
batch_size = 128
number_of_obs = 1 #[64]

current_sde = SDE(
    kernel_type=['variance_preserving', 'sub_variance_preserving'][0],
    noise_schedule=['linear', 'cosine', 'flow_matching'][1]
)

dataset = GaussianGridProblem(
    n_data=100000,
    prior=prior,
    sde=current_sde,
    online_learning=False,
    number_of_obs=number_of_obs,
)

dataset_valid = GaussianGridProblem(
    n_data=1000,
    prior=prior,
    sde=current_sde,
    number_of_obs=number_of_obs,
)

# Create dataloader
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
dataloader_valid = DataLoader(dataset_valid, batch_size=batch_size, shuffle=False)

for test in dataloader:
    print(test[0].shape, test[1].shape, test[2].shape, test[3].shape, test[4].shape)
    break

In [None]:
# Define diffusion model
summary_dim = 8
#summary_net = LSTM(input_size=1, hidden_dim=summary_dim)

time_embedding_local = nn.Sequential(
    GaussianFourierProjection(8),
    nn.Linear(8, 8),
    nn.Mish()
)
time_embedding_global = nn.Sequential(
    GaussianFourierProjection(8),
    nn.Linear(8, 8),
    nn.Mish()
)

score_model = HierarchicalScoreModel(
    input_dim_theta_global=prior.n_params_global,
    input_dim_theta_local=prior.n_params_local,
    input_dim_x=10,
    #summary_net=summary_net,
    time_embedding_local=time_embedding_local,
    time_embedding_global=time_embedding_global,
    hidden_dim=256,
    n_blocks=5,
    max_number_of_obs=number_of_obs if isinstance(number_of_obs, int) else max(number_of_obs),
    prediction_type=['score', 'e', 'x', 'v'][3],
    sde=current_sde,
    weighting_type=[None, 'likelihood_weighting', 'flow_matching', 'sigmoid'][1],
    prior=prior,
    name_prefix='grid_'
)

print(score_model.name)
count_parameters(score_model)

# make dir for plots
if not os.path.exists(f"plots/{score_model.name}"):
    os.makedirs(f"plots/{score_model.name}")

In [None]:
# train model
loss_history = train_score_model(score_model, dataloader, dataloader_valid=dataloader_valid, hierarchical=True,
                                              epochs=500, device=torch_device, lr=1e-4)
score_model.eval()
torch.save(score_model.state_dict(), f"models/{score_model.name}.pt")

# plot loss history
plt.figure(figsize=(16, 4), tight_layout=True)
plt.plot(loss_history[:, 0], label='Training', color="#132a70", lw=2.0, alpha=0.9)
plt.plot(loss_history[:, 1], label='Validation', linestyle="--", marker="o", color='black')
plt.grid(alpha=0.5)
plt.xlabel('Training epoch #')
plt.ylabel('Value')
plt.legend()
plt.savefig(f'plots/{score_model.name}/loss_training.png')

In [None]:
score_model.load_state_dict(torch.load(f"models/{score_model.name}.pt", weights_only=True))
score_model.eval();

# Validation

In [None]:
n_grid = 8
valid_prior_global, valid_prior_local, valid_data = generate_synthetic_data(prior, n_samples=20, grid_size=n_grid,
                                                                            n_time_points=10,
                                                                            normalize=False, random_seed=0)
n_post_samples = 100
param_names = [r'$\mu$', r'$\log \sigma$']
#score_model.current_number_of_obs = 4  # we can choose here, how many observations are passed together through the score

In [None]:
visualize_simulation_output(valid_data)

In [None]:
posterior_global_samples_valid = langevin_sampling(score_model, valid_data, n_post_samples=n_post_samples,
                                                   diffusion_steps=300, langevin_steps=5, step_size_factor=0.05,
                                                   device=torch_device, verbose=True)

In [None]:
fig = diagnostics.recovery(posterior_global_samples_valid, np.array(valid_prior_global), variable_names=param_names)
fig.savefig(f'plots/{score_model.name}/recovery_global_langevin_sampler.png')

fig = diagnostics.calibration_ecdf(posterior_global_samples_valid, np.array(valid_prior_global),
                          difference=True, variable_names=param_names)
fig.savefig(f'plots/{score_model.name}/ecdf_global_langevin_sampler.png')

In [None]:
posterior_global_samples_valid = euler_maruyama_sampling(score_model, valid_data, n_post_samples=n_post_samples,
                                                        diffusion_steps=1000, device=torch_device, verbose=True)

In [None]:
fig = diagnostics.recovery(posterior_global_samples_valid, np.array(valid_prior_global), variable_names=param_names)
fig.savefig(f'plots/{score_model.name}/recovery_global_euler_sampler.png')

fig = diagnostics.calibration_ecdf(posterior_global_samples_valid, np.array(valid_prior_global),
                          difference=True, variable_names=param_names)
fig.savefig(f'plots/{score_model.name}/ecdf_global_euler_sampler.png')

In [None]:
mini_batch_size = 10
t1_value = mini_batch_size /( (n_grid*n_grid) //score_model.current_number_of_obs)
t0_value = 1
mini_batch_arg = {
    'size': mini_batch_size,
    #'damping_factor': lambda t: t1_value + (t0_value - t1_value) * 0.5 * (1 + torch.cos(torch.pi * t)),
    #'damping_factor': lambda t: t0_value + (t1_value - t0_value) * score_model.sde.kernel(log_snr=score_model.sde.get_snr(t))[1],
    'damping_factor': lambda t: 0.1, #t1_value,
    #'damping_factor': lambda t: t0_value * torch.exp(-np.log(t0_value / t1_value) * 2*t),
    #'damping_factor': lambda t: t0_value + (t1_value - t0_value) * torch.sigmoid(20*(t-0.3))
}
#plt.plot(torch.linspace(0, 1, 100), mini_batch_arg['damping_factor'](torch.linspace(0, 1, 100)))
#plt.show()

t0_value, t1_value

In [None]:
posterior_global_samples_valid = euler_maruyama_sampling(score_model, valid_data, n_post_samples=n_post_samples,
                                                         mini_batch_arg=mini_batch_arg,
                                                         diffusion_steps=200, device=torch_device, verbose=True)

In [None]:
fig = diagnostics.recovery(posterior_global_samples_valid, np.array(valid_prior_global), variable_names=param_names)
fig.savefig(f'plots/{score_model.name}/recovery_global_euler_sub_sampler.png')

fig = diagnostics.calibration_ecdf(posterior_global_samples_valid, np.array(valid_prior_global),
                          difference=True, variable_names=param_names)
fig.savefig(f'plots/{score_model.name}/ecdf_global_euler_sub_sampler.png')

In [None]:
posterior_global_samples_valid = adaptive_sampling(score_model, valid_data, n_post_samples,
                                                   run_sampling_in_parallel=False,
                                                   device=torch_device, verbose=True)

In [None]:
fig = diagnostics.recovery(posterior_global_samples_valid, np.array(valid_prior_global), variable_names=param_names)
fig.savefig(f'plots/{score_model.name}/recovery_global_adaptive_sampler.png')

fig = diagnostics.calibration_ecdf(posterior_global_samples_valid, np.array(valid_prior_global),
                          difference=True, variable_names=param_names)
fig.savefig(f'plots/{score_model.name}/ecdf_global_adaptive_sampler.png')

In [None]:
posterior_global_samples_valid = probability_ode_solving(score_model, valid_data, n_post_samples=n_post_samples,
                                                         run_sampling_in_parallel=False,
                                                         device=torch_device, verbose=True)

In [None]:
fig = diagnostics.recovery(posterior_global_samples_valid, np.array(valid_prior_global), variable_names=param_names)
fig.savefig(f'plots/{score_model.name}/recovery_global_ode.png')

fig = diagnostics.calibration_ecdf(posterior_global_samples_valid, np.array(valid_prior_global),
                          difference=True, variable_names=param_names)
fig.savefig(f'plots/{score_model.name}/ecdf_global_ode.png')

In [None]:
conditions_global = (np.median(posterior_global_samples_valid, axis=0), posterior_global_samples_valid)[1]
posterior_local_samples_valid = euler_maruyama_sampling(score_model, valid_data,
                                                        n_post_samples=n_post_samples, conditions=conditions_global,
                                                        diffusion_steps=50, device=torch_device, verbose=True)

In [None]:
diagnostics.recovery(posterior_local_samples_valid.reshape(valid_data.shape[0], n_post_samples, -1),
                          np.array(valid_prior_local).reshape(valid_data.shape[0], -1),
                          variable_names=['$\\theta_{'+str(i)+'}$' for i in range(n_grid**2)]);

In [None]:
#conditions_global = np.median(posterior_global_samples_valid, axis=1)
posterior_local_samples_valid = probability_ode_solving(score_model, valid_data, n_post_samples=n_post_samples,
                                                        run_sampling_in_parallel=False,
                                                        conditions=posterior_global_samples_valid, device=torch_device, verbose=True)

In [None]:
diagnostics.recovery(posterior_local_samples_valid.reshape(valid_data.shape[0], n_post_samples, -1),
                          np.array(valid_prior_local).reshape(valid_data.shape[0], -1),
                          variable_names=['$\\theta_{'+str(i)+'}$' for i in range(n_grid**2)]);

In [None]:
plot_shrinkage(posterior_global_samples_valid, posterior_local_samples_valid)

In [None]:
valid_id = 2
print('Data')
visualize_simulation_output(valid_data[valid_id])
print('Global Estimates')
print('mu:', np.median(posterior_global_samples_valid[valid_id, :, 0]), np.std(posterior_global_samples_valid[valid_id, :, 0]))
print('log sigma:', np.median(posterior_global_samples_valid[valid_id, :, 1]), np.std(posterior_global_samples_valid[valid_id, :, 1]))
print('True')
print('mu:', valid_prior_global[valid_id][0].item())
print('log sigma:', valid_prior_global[valid_id][1].item())

In [None]:
med = np.median(posterior_local_samples_valid[valid_id].reshape(n_post_samples, n_grid, n_grid), axis=0)
std = np.std(posterior_local_samples_valid[valid_id].reshape(n_post_samples, n_grid, n_grid), axis=0)
error = (med-valid_prior_local[valid_id].numpy())**2
visualize_simulation_output(np.stack((med, valid_prior_local[valid_id], )),
                            title_prefix=['Posterior Median', 'True'])

visualize_simulation_output(np.stack((std, error)), title_prefix=['Uncertainty', 'Error'], same_scale=False)


plt.figure(figsize=(4, 4), tight_layout=True)
plt.errorbar(x=valid_prior_local[valid_id].flatten(), y=med.flatten(), yerr=1.96*std.flatten(), fmt='o')
plt.plot([np.min(med), np.max(med)], [np.min(med), np.max(med)], 'k--')
plt.axhline(np.median(posterior_global_samples_valid[valid_id, :], axis=0)[0], color='red', linestyle='--',
            label='Global posterior mean', alpha=0.75)
plt.ylabel('Prediction')
plt.xlabel('True')
plt.legend()
plt.show()

# Compare to STAN

In [None]:
global_posterior_stan = np.load('problems/grid/global_posterior.npy')#[:, -100:]
local_posterior_stan = np.load('problems/grid/local_posterior.npy')#[:, -100:]
true_global = np.load('problems/grid/true_global.npy')
true_local = np.load('problems/grid/true_local.npy')

n_grid_stan = int(np.sqrt(true_local.shape[1]))

test_data = []
for g, l in zip(true_global, true_local):
    sim_dict = {
        'mu': g[0].reshape(1,1),
        'log_sigma': g[1].reshape(1,1),
        'theta': l.reshape(1, n_grid_stan, n_grid_stan)
    }
    td = simulator_test(sim_dict)['observable']
    test_data.append(td)
test_data = np.concatenate(test_data)
test_data.shape

In [None]:
t0_value = 1
t1_value = 0.01
mini_batch_arg = {
    #'size': 2,
    #'damping_factor': lambda t: t0_value * torch.exp(-np.log(t0_value / t1_value) * 2*t)
}
#score_model.sde.s_shift_cosine = 6
n_post_samples = 25
param_names = [r'$\mu$', r'$\log \sigma$']
local_param_names = [r'$\theta_{'+str(i)+'}$' for i in range(n_grid_stan**2)]
param_names_stan = ['STAN '+ p for p in param_names]
score_model.max_number_of_obs

In [None]:
#score_model.sde.s_shift_cosine = 0
posterior_global_samples_test = adaptive_sampling(score_model, test_data, n_post_samples,
                                                  mini_batch_arg=mini_batch_arg,
                                                  run_sampling_in_parallel=False,
                                                  device=torch_device, verbose=True)

In [None]:
#score_model.sde.s_shift_cosine = 0
posterior_local_samples_test = euler_maruyama_sampling(score_model, test_data, n_post_samples,
                                                 conditions=posterior_global_samples_test,
                                                 #run_sampling_in_parallel=False,
                                                 diffusion_steps=50,
                                                 device=torch_device, verbose=True)

In [None]:
diagnostics.recovery(posterior_global_samples_test, true_global, variable_names=param_names)
#diagnostics.recovery(posterior_global_samples_test, np.median(global_posterior_stan, axis=1),
#                     variable_names=param_names, xlabel='STAN Median Estimate')
diagnostics.recovery(global_posterior_stan, true_global, variable_names=param_names_stan);

In [None]:
diagnostics.calibration_ecdf(posterior_global_samples_test, true_global, difference=True, variable_names=param_names)
diagnostics.calibration_ecdf(global_posterior_stan, true_global, difference=True, variable_names=param_names_stan);

In [None]:
global_var = np.exp(np.median(posterior_global_samples_test, axis=1)[:, 1])[:, np.newaxis] ** 2
shrinkage = 1-np.var(np.median(posterior_local_samples_test, axis=1), axis=1)/global_var

global_var_stan = np.exp(np.median(global_posterior_stan, axis=1)[:, 1])**2
shrinkage_stan = 1-np.var(np.median(local_posterior_stan, axis=1), axis=1)/global_var_stan

true_var = np.exp(true_global)[:, 1]**2
shrinkage_true = 1-np.var(true_local, axis=1)/true_var

s_order = np.argsort(shrinkage_true)
shrinkage = shrinkage.flatten()[s_order]
shrinkage_stan = shrinkage_stan.flatten()[s_order]
shrinkage_true = shrinkage_true[s_order]

min_s = -10
shrinkage[shrinkage < min_s] = min_s
shrinkage_stan[shrinkage_stan < min_s] = min_s
shrinkage_true[shrinkage_true < min_s] = min_s

plt.title('Shrinkage')
plt.plot(shrinkage, label='score', alpha=0.75)
plt.plot(shrinkage_stan, label='STAN', alpha=0.75)
plt.plot(shrinkage_true, label='true', alpha=0.75)
plt.legend()
plt.show()

print('Correlation shrinkage score and STAN:', np.corrcoef(shrinkage, shrinkage_stan)[0, 1])
print('Correlation shrinkage true and score:', np.corrcoef(shrinkage_true, shrinkage)[0, 1])
print('Correlation shrinkage true and STAN:', np.corrcoef(shrinkage_true, shrinkage_stan)[0, 1])

print(f"Score shrinkage < STAN shrinkage: {(shrinkage < shrinkage_stan).sum() / shrinkage.shape[0]*100}%")

In [None]:
diagnostics.recovery(posterior_local_samples_test.reshape(test_data.shape[0], n_post_samples, -1), true_local,
                     variable_names=local_param_names)
diagnostics.recovery(local_posterior_stan, true_local, variable_names=local_param_names);

In [None]:
diagnostics.recovery(posterior_local_samples_test.reshape(test_data.shape[0], n_post_samples, -1),
                     np.median(local_posterior_stan, axis=1), ylabel='Score Based Estimates', xlabel='STAN Median Estimate');

In [None]:
plot_shrinkage(posterior_global_samples_test[:10], posterior_local_samples_test[:, :, :, np.newaxis][:10], min_max=(-10,10))

In [None]:
plot_shrinkage(global_posterior_stan[:10], local_posterior_stan[:, :, :, np.newaxis][:10], min_max=(-10,10))

# Visualize the Score

In [None]:
n_grid = 8
valid_prior_global, valid_prior_local, valid_data = generate_synthetic_data(prior, n_samples=10, grid_size=n_grid,
                                                                            normalize=False, random_seed=0)

valid_id = 3

diffusion_time = generate_diffusion_time(size=10, device=torch_device)
x_valid = valid_data[valid_id].to(torch_device)
x_valid_norm = score_model.prior.normalize_data(x_valid)
theta_global = score_model.prior.normalize_theta(valid_prior_global[valid_id], global_params=True).cpu().numpy()  # we normalize as the score is normalized space
print(valid_id, 'theta global', theta_global)

In [None]:
test_sample = adaptive_sampling(score_model, x_valid[None], conditions=None, n_post_samples=1, #e_abs=0.00078,
                                t_end=0, random_seed=0, device=torch_device)
test_sample = score_model.prior.normalize_theta(torch.tensor(test_sample), global_params=True).cpu().numpy()
print(test_sample)

In [None]:
posterior_sample_path = np.array([adaptive_sampling(score_model, x_valid[None],
                                                    conditions=None, n_post_samples=1, t_end=t,
                                                    random_seed=0, device=torch_device)
                                  for t in diffusion_time[:-1]])
# we normalize as the score is normalized space
posterior_sample_path = score_model.prior.normalize_theta(torch.tensor(posterior_sample_path), global_params=True).cpu().numpy()

posterior_sample_path2 = np.array([probability_ode_solving(score_model, x_valid[None],
                                                           conditions=None, n_post_samples=1, t_end=t, random_seed=0, device=torch_device)
                                  for t in diffusion_time[:-1]])
# we normalize as the score is normalized space
posterior_sample_path2 = score_model.prior.normalize_theta(torch.tensor(posterior_sample_path2), global_params=True).cpu().numpy()

print('theta global', theta_global, posterior_sample_path[0].flatten())

In [None]:
# Define grid boundaries and resolution for your 2D space.
x_min, x_max, y_min, y_max = -1.5, 1.5, -1.5, 1.5
grid_res = 10  # Number of points per dimension

# Create a meshgrid of points
x_vals = np.linspace(x_min, x_max, grid_res)
y_vals = np.linspace(y_min, y_max, grid_res)
xx, yy = np.meshgrid(x_vals, y_vals)
# Stack into (N,2) where N = grid_res*grid_res
grid_points = np.vstack([xx.ravel(), yy.ravel()]).T

# Convert grid to a torch tensor and move to device
grid_tensor = torch.tensor(grid_points, dtype=torch.float32, device=torch_device)
x_valid_norm_e = x_valid_norm.reshape(10, -1).to(torch_device)
x_valid_norm_ext = x_valid_norm_e.unsqueeze(0).repeat(grid_tensor.shape[0], 1, 1)

# Dictionary to hold score outputs for each time
scores = {}
scores_smoothed = {}
drifts = {}
drifts_smoothed = {}

with torch.no_grad():
    # Evaluate the score model for each time value
    for t in diffusion_time:
        # Create a tensor of time values for each grid point
        t_tensor = torch.full((grid_tensor.shape[0], 1), t.item(), dtype=torch.float32, device=torch_device)
        epsilon = torch.randn_like(grid_tensor, dtype=torch.float32, device=torch_device)

        # perturb theta
        snr = score_model.sde.get_snr(t=t_tensor)
        alpha, sigma = score_model.sde.kernel(log_snr=snr)
        z = grid_tensor #alpha * grid_tensor + sigma * epsilon

        # Evaluate the score model
        score_indv = torch.zeros((x_valid_norm_ext.shape[2], grid_tensor.shape[0], 2))

        prior_scores = (1 - t) * score_model.prior.score_global_batch(z)
        prior_scores_indv = prior_scores.unsqueeze(0)
        for i in range(x_valid_norm_ext.shape[2]):
            score_indv[i] = score_model.forward_global(theta_global=z, time=t_tensor, x=x_valid_norm_ext[:, :, i].unsqueeze(-1),
                                                       pred_score=True, clip_x=False)
        score_indv = score_indv - prior_scores_indv

        score = score_indv.sum(axis=0)
        score = score + prior_scores
        scores[t.item()] = score.cpu().numpy()

        # score_pareto = torch.zeros_like(score)
        # for i in range(score_indv.shape[1]):
        #     score_pareto[i] = pareto_smooth_sum(score_indv[:, i].unsqueeze(0),
        #                                            tail_fraction=0.3)[0]  # expects dim to be the batch
        #
        # score_pareto = score_pareto + prior_scores
        # scores_smoothed[t.item()] = score_pareto.cpu().numpy()

        f, g = score_model.sde.get_f_g(x=z, t=t_tensor)
        drift = f - 0.5 * torch.square(g) * scores[t.item()]
        drifts[t.item()] = drift.cpu().numpy()

        # drift_pareto = f - 0.5 * torch.square(g) * scores_smoothed[t.item()]
        # drifts_smoothed[t.item()] = drift_pareto.cpu().numpy()

In [None]:
# Plot the vector field (score) for each time step using subplots
nrows = 2
fig, axes = plt.subplots(nrows, len(diffusion_time) // nrows, sharex=True, sharey=True,
                         figsize=(15, 3*nrows), tight_layout=True)
axes = axes.flatten()

for i, (t_val, score_val) in enumerate(sorted(scores.items(), reverse=True)):
    # Reshape score components back to (grid_res, grid_res) for quiver plotting
    U = score_val[:, 0].reshape(grid_res, grid_res)
    V = score_val[:, 1].reshape(grid_res, grid_res)  # negative since we are plotting the reverse score

    ax = axes[i]

    h0, = ax.plot(0, 0, 'o', color='black', label='Latent Prior')
    h1, = ax.plot(theta_global[0], theta_global[1], 'ro', label='True Parameter')

    if i != 0:
        j = i
        h2, = ax.plot(posterior_sample_path[:-j, 0, 0, 0], posterior_sample_path[:-j, 0, 0, 1], 'o-', label='Posterior Path Sampling', alpha=0.5)
        h4, = ax.plot(posterior_sample_path2[:-j, 0, 0, 0], posterior_sample_path2[:-j, 0, 0, 1], 'o-', label='Posterior Path ODE', alpha=0.5)
    h5 = ax.quiver(xx, yy, U, V, color='blue', angles='xy', scale_units='xy', scale=5*n_grid*n_grid, alpha=.75, label='Score')
    ax.set_title(f"Diffusion t = {t_val:.3f}")
    ax.set_xlim(x_min-0.5, x_max+0.5)
    ax.set_ylim(y_min-0.5, y_max+0.5)
    ax.set_aspect('equal')
    ax.set_xlabel("$x$")
    ax.set_ylabel("$y$")
    #ax.legend()
    #print(posterior_sample_path[i, 0, 0], posterior_sample_path[i, 0, 1])
fig.legend(handles=[h5, h0, h1, h2, h4], loc='lower center', ncols=5, bbox_to_anchor=(0.5, -0.05))
plt.savefig(f'plots/{score_model.name}/score.png', bbox_inches='tight')
plt.show()
print('theta global', theta_global, posterior_sample_path[0].flatten())

In [None]:
# Plot the vector field (score) for each time step using subplots
nrows = 2
fig, axes = plt.subplots(nrows, len(diffusion_time) // nrows, sharex=True, sharey=True,
                         figsize=(15, 3*nrows), tight_layout=True)
axes = axes.flatten()

for i, (t_val, score_val) in enumerate(sorted(drifts.items(), reverse=True)):
    # Reshape score components back to (grid_res, grid_res) for quiver plotting
    U = score_val[:, 0].reshape(grid_res, grid_res)
    V = score_val[:, 1].reshape(grid_res, grid_res)

    ax = axes[i]
    h0, = ax.plot(0, 0, 'o', color='black', label='Latent Prior')
    h1, = ax.plot(theta_global[0], theta_global[1], 'ro', label='True Parameter')

    j = len(diffusion_time)-i-1
    h2, = ax.plot(posterior_sample_path[j:, 0, 0, 0], posterior_sample_path[j:, 0, 0, 1], 'o-', label='Posterior Path Sampling', alpha=0.5)
    h4, = ax.plot(posterior_sample_path2[j:, 0, 0, 0], posterior_sample_path2[j:, 0, 0, 1], 'o-', label='Posterior Path ODE', alpha=0.5)
    h5 = ax.quiver(xx, yy, U, V, color='blue', angles='xy', scale_units='xy', scale=5*n_grid*n_grid, alpha=.75,
                   label='Probability Flow')
    ax.set_title(f"Diffusion t = {t_val:.3f}")
    ax.set_xlim(x_min-0.5, x_max+0.5)
    ax.set_ylim(y_min-0.5, y_max+0.5)
    ax.set_aspect('equal')
    ax.set_xlabel("$x$")
    ax.set_ylabel("$y$")
    #ax.legend()
    #print(posterior_sample_path[i, 0, 0], posterior_sample_path[i, 0, 1])
fig.legend(handles=[h5, h0, h1, h2, h4], loc='lower center', ncols=5, bbox_to_anchor=(0.5, -0.05))
plt.savefig(f'plots/{score_model.name}/drift.png', bbox_inches='tight')
plt.show()
print('theta global', theta_global, posterior_sample_path[0].flatten())