# FLI with compositional score matching


In [1]:
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from scipy.special import expit

os.environ['KERAS_BACKEND'] = 'torch'
from bayesflow import diagnostics

from torch.utils.data import DataLoader

from diffusion_model import HierarchicalScoreModel, SDE, weighting_function, euler_maruyama_sampling, adaptive_sampling, \
    generate_diffusion_time, count_parameters, train_score_model
from diffusion_model.helper_networks import LSTM
from problems.fli import FLIProblem, FLI_Prior, generate_synthetic_data
from problems import plot_shrinkage, visualize_simulation_output

In [2]:
torch_device = torch.device("cuda")

In [3]:
prior = FLI_Prior()

In [4]:
batch_size = 128
number_of_obs = [1, 2, 4, 5, 10]

current_sde = SDE(
    kernel_type=['variance_preserving', 'sub_variance_preserving'][0],
    noise_schedule=['linear', 'cosine', 'flow_matching'][1]
)

dataset = FLIProblem(
    n_data=10000,
    prior=prior,
    sde=current_sde,
    online_learning=True,
    number_of_obs=number_of_obs,
)

dataset_valid = FLIProblem(
    n_data=1000,
    prior=prior,
    sde=current_sde,
    number_of_obs=number_of_obs
)

# Create dataloader
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
dataloader_valid = DataLoader(dataset_valid, batch_size=batch_size, shuffle=False)

Kernel type: variance_preserving, noise schedule: cosine
t_min: 0.00035210439818911254, t_max: 0.999647855758667
alpha, sigma: (tensor(1.0000), tensor(0.0006)) (tensor(0.0006), tensor(1.0000))
Moving prior to device: cpu


In [7]:
# Define diffusion model
summary_net = LSTM(input_size=1, hidden_dim=256)

score_model = HierarchicalScoreModel(
    input_dim_theta_global=prior.n_params_global,
    input_dim_theta_local=prior.n_params_local,
    input_dim_x=256,
    summary_net=summary_net,
    hidden_dim=512,
    n_blocks=5,
    max_number_of_obs=number_of_obs if isinstance(number_of_obs, int) else max(number_of_obs),
    prediction_type=['score', 'e', 'x', 'v'][3],
    sde=current_sde,
    weighting_type=[None, 'likelihood_weighting', 'flow_matching', 'sigmoid'][1],
    prior=prior,
    name_prefix='FLI_'
)
print(score_model.name)
count_parameters(score_model)

# make dir for plots
if not os.path.exists(f"plots/{score_model.name}"):
    os.makedirs(f"plots/{score_model.name}")

+--------------------------------------+------------+
|               Modules                | Parameters |
+--------------------------------------+------------+
|   blocks.res_blocks.0.dense.weight   |   266240   |
|    blocks.res_blocks.0.dense.bias    |    512     |
| blocks.res_blocks.0.projector.weight |   266240   |
|   blocks.res_blocks.1.dense.weight   |   262144   |
|    blocks.res_blocks.1.dense.bias    |    512     |
|   blocks.res_blocks.2.dense.weight   |   262144   |
|    blocks.res_blocks.2.dense.bias    |    512     |
|   blocks.res_blocks.3.dense.weight   |   262144   |
|    blocks.res_blocks.3.dense.bias    |    512     |
|   blocks.res_blocks.4.dense.weight   |   262144   |
|    blocks.res_blocks.4.dense.bias    |    512     |
|    final_projection_linear.weight    |    3072    |
|     final_projection_linear.bias     |     6      |
+--------------------------------------+------------+
Total Trainable Params: 1586694
FLI_global_score_model_v_variance_preserving_cosin

In [None]:
# train model
loss_history = train_score_model(score_model, dataloader, dataloader_valid=dataloader_valid, hierarchical=True,
                                              epochs=2000, device=torch_device)
score_model.eval()
torch.save(score_model.state_dict(), f"models/{score_model.name}.pt")

# plot loss history
plt.figure(figsize=(16, 4), tight_layout=True)
plt.plot(loss_history[:, 0], label='Training', color="#132a70", lw=2.0, alpha=0.9)
plt.plot(loss_history[:, 1], label='Validation', linestyle="--", marker="o", color='black')
plt.grid(alpha=0.5)
plt.xlabel('Training epoch #')
plt.ylabel('Value')
plt.legend()
plt.savefig(f'plots/{score_model.name}/loss_training.png')

Training v-model for 2000 epochs with learning rate 0.0005 and likelihood_weighting weighting.
Epoch 24/2000, Loss: 14.2251, Valid Loss: 14.9868

In [None]:
score_model.load_state_dict(torch.load(f"models/{score_model.name}.pt", weights_only=True))
score_model.eval();

In [None]:
# check the error prediction: is it close to the noise?
loss_list_target = {}
loss_list_score = {}
loss_list_error_w_global = {}
loss_list_error_w_local = {}
loss_list_error = {}

with torch.no_grad():
    # Generate diffusion time and step size
    diffusion_time = generate_diffusion_time(size=100, device=torch_device)
    for t in diffusion_time:
        loss_list_target[t.item()] = 0
        loss_list_score[t.item()] = 0
        loss_list_error_w_global[t.item()] = 0
        loss_list_error_w_local[t.item()] = 0
        loss_list_error[t.item()] = 0

        for theta_global_batch, _, theta_local_batch, _, x_batch in dataloader_valid:
            theta_global_batch = theta_global_batch.to(torch_device)
            theta_local_batch = theta_local_batch.to(torch_device)
            x_batch = x_batch.to(torch_device)

            # sample from the Gaussian kernel, just learn the noise
            epsilon_global = torch.randn_like(theta_global_batch, dtype=torch.float32, device=torch_device)
            epsilon_local = torch.randn_like(theta_local_batch, dtype=torch.float32, device=torch_device)

            # perturb the theta batch
            t_tensor = torch.full((theta_global_batch.shape[0], 1), t,
                                  dtype=torch.float32, device=torch_device)
            # perturb the theta batch
            snr = score_model.sde.get_snr(t=t_tensor)
            alpha, sigma = score_model.sde.kernel(log_snr=snr)
            z_global = alpha * theta_global_batch + sigma * epsilon_global
            if score_model.max_number_of_obs > 1:
                # global params are not factorized to the same level as local params
                alpha_local = alpha.unsqueeze(1)
                sigma_local = sigma.unsqueeze(1)
            else:
                alpha_local = alpha
                sigma_local = sigma
            z_local = alpha_local * theta_local_batch + sigma_local * epsilon_local

            # predict from perturbed theta
            pred_epsilon_global, pred_epsilon_local = score_model(theta_global=z_global, theta_local=z_local,
                                       time=t_tensor, x=x_batch, pred_score=False)
            pred_score_global, pred_score_local = score_model(theta_global=z_global, theta_local=z_local,
                                     time=t_tensor, x=x_batch, pred_score=True)
            true_score_global = score_model.sde.grad_log_kernel(x=z_global,
                                                                x0=theta_global_batch, t=t_tensor)
            if score_model.max_number_of_obs == 1:
                true_score_local = score_model.sde.grad_log_kernel(x=z_local,
                                                                   x0=theta_local_batch,
                                                                   t=t_tensor)
            else:
                true_score_local = []
                for i in range(score_model.max_number_of_obs):
                    score_local = score_model.sde.grad_log_kernel(x=z_local[:, i],
                                                                   x0=theta_local_batch[:, i],
                                                                   t=t_tensor)
                    true_score_local.append(score_local.unsqueeze(1))
                true_score_local = torch.concatenate(true_score_local, dim=1)

            if score_model.prediction_type == 'score':
                target_global = -epsilon_global / sigma
                pred_target_global = -pred_epsilon_global / sigma
                target_local = -epsilon_local / sigma_local
                pred_target_local = -pred_epsilon_local / sigma_local
            elif score_model.prediction_type == 'e':
                target_global = epsilon_global
                pred_target_global = pred_epsilon_global
                target_local = epsilon_local
                pred_target_local = pred_epsilon_local
            elif score_model.prediction_type == 'v':
                target_global = alpha*epsilon_global - sigma * theta_global_batch
                pred_target_global = alpha*pred_epsilon_global - sigma * theta_global_batch
                target_local = alpha_local*epsilon_local - sigma_local * theta_local_batch
                pred_target_local = alpha_local*pred_epsilon_local - sigma_local * theta_local_batch
            elif score_model.prediction_type == 'x':
                target_global = theta_global_batch
                pred_target_global = (z_global - pred_epsilon_global * sigma) / alpha
                target_local = theta_local_batch
                pred_target_local = (z_local - pred_epsilon_local * sigma_local) / alpha_local
            else:
                raise ValueError("Invalid prediction type.")

            # calculate the loss (sum over the last dimension, mean over the batch)
            loss_global = torch.mean(torch.sum(torch.square(pred_target_global - target_global), dim=-1))
            loss_local = torch.mean(torch.sum(torch.square(pred_target_local - target_local), dim=-1))
            loss = loss_global + loss_local
            loss_list_target[t.item()] += loss.item()

            # calculate the error of the true score
            loss_global = torch.mean(torch.sum(torch.square(pred_score_global - true_score_global), dim=-1))
            loss_local = torch.mean(torch.sum(torch.square(pred_score_local - true_score_local), dim=-1))
            loss = loss_global + loss_local
            loss_list_score[t.item()] += loss.item()

            # calculate the weighted loss
            w = weighting_function(t_tensor, sde=score_model.sde,
                                   weighting_type=score_model.weighting_type, prediction_type=score_model.prediction_type)
            loss_global = torch.mean(w * torch.sum(torch.square(pred_epsilon_global - epsilon_global), dim=-1))
            loss_local = torch.mean(w * torch.sum(torch.square(pred_epsilon_local - epsilon_local), dim=-1))
            #loss = loss_global + loss_local
            loss_list_error_w_global[t.item()] += loss_global.item()
            loss_list_error_w_local[t.item()] += loss_local.item()

            # check if the weighting function is correct
            loss_global = torch.mean(torch.sum(torch.square(pred_epsilon_global - epsilon_global), dim=-1))
            loss_local = torch.mean(torch.sum(torch.square(pred_epsilon_local - epsilon_local), dim=-1))
            loss = loss_global + loss_local
            loss_list_error[t.item()] += loss.item()

In [None]:
df_target = pd.DataFrame(loss_list_target.items(), columns=['Time', 'Loss'])
df_score = pd.DataFrame(loss_list_score.items(), columns=['Time', 'Loss'])
df_error_w_local = pd.DataFrame(loss_list_error_w_local.items(), columns=['Time', 'Loss'])
df_error_w_global = pd.DataFrame(loss_list_error_w_global.items(), columns=['Time', 'Loss'])
df_error = pd.DataFrame(loss_list_error.items(), columns=['Time', 'Loss'])

# compute snr
snr = score_model.sde.get_snr(diffusion_time)
#upper_bound_loss = (np.sqrt(2) + 1) / (std.numpy()**2)

fig, ax = plt.subplots(ncols=4, sharex=True, figsize=(16, 3), tight_layout=True)
ax[0].plot(df_target['Time'], np.log(df_target['Loss']), label=f'Unscaled {score_model.prediction_type} Loss')
ax[1].plot(df_score['Time'], np.log(df_score['Loss']), label='Score Loss')
#ax[1].plot(df_score['Time'], df_score['Loss'] / upper_bound_loss, label='Score Loss')
ax[1].plot(diffusion_time, snr, label='log snr', alpha=0.5)
ax[2].plot(df_error_w_local['Time'], np.log(df_error_w_local['Loss']), label='Local Weighted Loss')
ax[2].plot(df_error_w_global['Time'], np.log(df_error_w_global['Loss']), label='Global Weighted Loss')
ax[2].plot(df_error_w_local['Time'], np.log(df_error_w_local['Loss']+df_error_w_global['Loss']), label='Weighted Loss (as in Optimization)')
ax[3].plot(df_error['Time'], np.log(df_error['Loss']), label='Loss on Error')
for a in ax:
    a.set_xlabel('Diffusion Time')
    a.set_ylabel('Log Loss')
    a.legend()
plt.savefig(f'plots/{score_model.name}/losses_diffusion_time.png')
plt.show()

plt.figure(figsize=(6, 3), tight_layout=True)
plt.plot(diffusion_time.cpu(),
         weighting_function(diffusion_time, sde=score_model.sde, weighting_type=score_model.weighting_type,
                            prediction_type=score_model.prediction_type).cpu(),
         label='weighting')
plt.xlabel('Diffusion Time')
plt.ylabel('Weight')
plt.legend()
plt.show()

# Validation

In [None]:
n_grid = 4
valid_prior_global, valid_prior_local, valid_data = generate_synthetic_data(prior=prior, n_data=100, n_local_samples=n_grid*n_grid,
                                                                            as_grid=True,
                                                                            random_seed=0)
n_post_samples = 20
global_param_names = prior.global_param_names
local_param_names = prior.get_local_param_names(n_grid*n_grid)
#score_model.current_number_of_obs = 4  # we can choose here, how many observations are passed together through the score

In [None]:
mini_batch_size = 10
t1_value = 0.1 #/( (n_grid*n_grid) //score_model.current_number_of_obs)
t0_value = 1
mini_batch_arg = {
    'size': mini_batch_size,
    #'damping_factor': lambda t: t1_value + (t0_value - t1_value) * 0.5 * (1 + torch.cos(torch.pi * t)),
    #'damping_factor': lambda t: t0_value + (t1_value - t0_value) * score_model.sde.kernel(log_snr=score_model.sde.get_snr(t))[1],
    #'damping_factor': lambda t: 0.1, #t1_value,
    'damping_factor': lambda t: t0_value * torch.exp(-np.log(t0_value / t1_value) * 2*t),
    'damping_factor_prior': lambda t: t0_value * torch.exp(-np.log(t0_value / t1_value) * 2*t),
    #'damping_factor': lambda t: t0_value + (t1_value - t0_value) * torch.sigmoid(20*(t-0.3))
}
#plt.plot(torch.linspace(0, 1, 100), mini_batch_arg['damping_factor'](torch.linspace(0, 1, 100)))
#plt.show()

t0_value, t1_value

In [None]:
score_model.sde.s_shift_cosine = 4
posterior_global_samples_valid = adaptive_sampling(score_model, valid_data, n_post_samples,
                                                   mini_batch_arg=mini_batch_arg,
                                                   run_sampling_in_parallel=False,
                                                   device=torch_device, verbose=True)

In [None]:
fig = diagnostics.recovery(posterior_global_samples_valid, np.array(valid_prior_global), variable_names=global_param_names)
fig.savefig(f'plots/{score_model.name}/recovery_global_adaptive_sampler.png')

fig = diagnostics.calibration_ecdf(posterior_global_samples_valid, np.array(valid_prior_global),
                          difference=True, variable_names=global_param_names)
fig.savefig(f'plots/{score_model.name}/ecdf_global_adaptive_sampler.png')

In [None]:
conditions_global = (np.median(posterior_global_samples_valid, axis=0), posterior_global_samples_valid)[1]
score_model.sde.s_shift_cosine = 0
posterior_local_samples_valid = euler_maruyama_sampling(score_model, valid_data,
                                                        n_post_samples=n_post_samples, conditions=conditions_global,
                                                        diffusion_steps=50, device=torch_device, verbose=True)

In [None]:
diagnostics.recovery(posterior_local_samples_valid.reshape(valid_data.shape[0], n_post_samples, -1),
                          np.array(valid_prior_local).reshape(valid_data.shape[0], -1),
                          variable_names=local_param_names);

In [None]:
plot_shrinkage(posterior_global_samples_valid[:12], posterior_local_samples_valid[:12], min_max=(-10, 10))

In [None]:
valid_id = 0
print('Global Estimates')
print('mu:', np.median(posterior_global_samples_valid[valid_id, :, 0]), np.std(posterior_global_samples_valid[valid_id, :, 0]))
print('log sigma:', np.median(posterior_global_samples_valid[valid_id, :, 1]), np.std(posterior_global_samples_valid[valid_id, :, 1]))
print('True')
print('mu:', valid_prior_global[valid_id][0].item())
print('log sigma:', valid_prior_global[valid_id][1].item())

In [None]:
ps = posterior_local_samples_valid[valid_id].reshape(n_post_samples, n_grid, n_grid, 3).copy()
true = valid_prior_local[valid_id].numpy().copy()
ps[:, :, :, 0] = np.exp(ps[:, :, :, 0])
true[:, :, 0] = np.exp(true[:, :, 0])
ps[:, :, :, 1] = ps[:, :, :, 0] + np.exp(ps[:, :, :, 1])
true[:, :, 1] = true[:, :, 0] + np.exp(true[:, :, 1])
ps[:, :, :, 2] = expit(ps[:, :, :, 2])
true[:, :, 2] = expit(true[:, :, 2])
transf_local_param_names = [r'$\tau_1^L$', r'$\tau_2^L$', r'$A^L$']

med = np.median(ps, axis=0)
std = np.std(ps, axis=0)
error = (med-true)**2
visualize_simulation_output(med, title_prefix=['Posterior Median ' + p for p in transf_local_param_names], cmap='turbo')
visualize_simulation_output(true, title_prefix=['True ' + p for p in transf_local_param_names], cmap='turbo')

visualize_simulation_output(std, title_prefix=['Posterior Std ' + p for p in transf_local_param_names])
visualize_simulation_output(error, title_prefix=['Error ' + p for p in transf_local_param_names])

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(12, 4), tight_layout=True)
for i in range(3):
    ax[i].errorbar(x=true[:, :, i].flatten(), y=med[:, :, i].flatten(), yerr=1.96*std[:, :, i].flatten(), fmt='o')
    #ax[i].plot([np.min(true[:, :, i]), np.max(true[:, :, i])], [np.min(true[:, :, i]), np.max(true[:, :, i])], 'k--')
    ax[i].axhline(np.median(posterior_global_samples_valid[valid_id, :, i], axis=0), color='red', linestyle='--',
                label='Global posterior mean', alpha=0.75)
    ax[i].set_ylabel('Prediction')
    ax[i].set_xlabel('True')
    ax[i].legend()
plt.show()