# Gaussian on a Grid Test for Hierarchical ABI with compositional score matching

In [None]:
from functools import partial

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from bayesflow import diagnostics
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

from gaussian_test_simulator import Prior, Simulator, visualize_simulation_output, generate_synthetic_data

In [None]:
torch_device = torch.device("cpu")

In [None]:
prior = Prior(torch_device)
simulator_test = Simulator()

# test the simulator
sim_test = simulator_test(prior.sample_full(1))['observable']
visualize_simulation_output(sim_test)

In [None]:
class GaussianFourierProjection(nn.Module):
  """Gaussian random features for encoding time steps."""
  def __init__(self, embed_dim, scale=30.):
    super().__init__()
    # Randomly sample weights during initialization. These weights are fixed
    # during optimization and are not trainable.
    self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)

  def forward(self, x):
    x_proj = x * self.W * 2 * np.pi
    return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)


# Define the Residual Block
class ConditionalResidualBlock(nn.Module):
    def __init__(self, in_dim, out_dim, cond_dim, dropout_rate, use_spectral_norm):
        """
        Args:
            in_dim (int): Dimensionality of the input hidden state.
            out_dim (int): Desired dimensionality of the output.
            cond_dim (int): Dimensionality of the conditioning vector.
            dropout_rate (float): Dropout rate.
            use_spectral_norm (bool): Whether to apply spectral normalization.
        """
        super(ConditionalResidualBlock, self).__init__()

        # First linear layer: from [in_dim; cond_dim] -> out_dim
        self.fc1 = nn.Linear(in_dim + cond_dim, out_dim)
        self.norm1 = nn.LayerNorm(out_dim)

        # Second linear layer: from [out_dim; cond_dim] -> out_dim
        self.fc2 = nn.Linear(out_dim + cond_dim, out_dim)
        self.norm2 = nn.LayerNorm(out_dim)

        self.activation = nn.SiLU()  # SiLU is equivalent to swish
        self.dropout = nn.Dropout(dropout_rate)

        # If input and output dims differ, project h for the skip connection.
        if in_dim != out_dim:
            self.skip = nn.Linear(in_dim, out_dim)
        else:
            self.skip = None

        # Apply spectral normalization if specified.
        if use_spectral_norm:
            self.fc1 = nn.utils.parametrizations.spectral_norm(self.fc1)
            self.fc2 = nn.utils.parametrizations.spectral_norm(self.fc2)
            if self.skip is not None:
                self.skip = nn.utils.parametrizations.spectral_norm(self.skip)

    def forward(self, h, cond):
        # h: [batch_size, in_dim]
        # cond: [batch_size, cond_dim]

        # First transformation with conditioning
        x = torch.cat([h, cond], dim=-1)  # [batch_size, in_dim + cond_dim]
        out = self.fc1(x)
        out = self.norm1(out)
        out = self.activation(out)
        out = self.dropout(out)

        # Second transformation with conditioning injected again
        out = self.fc2(torch.cat([out, cond], dim=-1))
        out = self.norm2(out)

        # Apply skip connection: if dims differ, project h first.
        skip = self.skip(h) if self.skip is not None else h

        return self.activation(out + skip)

class FiLMResidualBlock(nn.Module):
    def __init__(self, in_dim, out_dim, cond_dim, dropout_rate, use_spectral_norm):
        super().__init__()
        self.fc = nn.Linear(in_dim, out_dim)
        self.film_gamma = nn.Linear(cond_dim, out_dim)
        self.film_beta = nn.Linear(cond_dim, out_dim)
        self.activation = nn.SiLU()
        self.dropout = nn.Dropout(dropout_rate)
        if in_dim != out_dim:
            self.skip = nn.Linear(in_dim, out_dim)
        else:
            self.skip = nn.Identity()
        self.norm = nn.LayerNorm(out_dim)

        # Apply spectral normalization if specified.
        if use_spectral_norm:
            self.fc = nn.utils.parametrizations.spectral_norm(self.fc)
            if in_dim != out_dim:
                self.skip = nn.utils.parametrizations.spectral_norm(self.skip)

    def forward(self, h, cond):
        # h: [batch, in_dim], cond: [batch, cond_dim]
        x = self.fc(h)
        # Compute modulation parameters
        gamma = self.film_gamma(cond)
        beta = self.film_beta(cond)
        # Apply FiLM modulation
        x = gamma * x + beta
        x = self.activation(x)
        x = self.dropout(x)
        x = self.norm(x)
        return self.activation(x + self.skip(h))


class ScoreModel(nn.Module):
    """
        Neural network model that computes score estimates.

        Args:
            input_dim_theta (int): Input dimension for theta.
            input_dim_x (int): Input dimension for x.
            input_dim_condition (int): Input dimension for the condition. Can be 0 for global score.
            hidden_dim (int): Hidden dimension for theta network.
            n_blocks (int): Number of residual blocks.
            marginal_prob_std_fn (callable): Function to compute the inverse marginal probability standard deviation.
            prediction_type (str): Type of prediction to perform. Can be 'score', 'e', 'x', or 'v'.
            time_embed_dim (int): Dimension of time embedding.
            dropout_rate (float): Dropout rate.
            use_film (bool): Whether to use FiLM-residual blocks.
            use_spectral_norm (bool): Whether to use spectral normalization.
    """
    def __init__(self,
                 input_dim_theta, input_dim_x, input_dim_condition,
                 hidden_dim, n_blocks, marginal_prob_std_fn, prediction_type,
                 time_embed_dim, use_film, dropout_rate, use_spectral_norm):
        super(ScoreModel, self).__init__()
        self.marginal_prob_std_fn = marginal_prob_std_fn
        if prediction_type not in ['score', 'e', 'x', 'v']:
            raise ValueError("Invalid prediction type. Must be one of 'score', 'e', 'x', or 'v'.")
        self.prediction_type = prediction_type

        # Gaussian random feature embedding layer for time
        self.embed = nn.Sequential(
            GaussianFourierProjection(embed_dim=time_embed_dim),
            nn.Linear(time_embed_dim, time_embed_dim),
            nn.SiLU()
        )

        # Define the dimension of the conditioning vector
        cond_dim = input_dim_x + input_dim_condition + time_embed_dim
        self.cond_dim = cond_dim

        # Project the concatenation of theta and the condition into hidden_dim
        self.input_layer = nn.Linear(input_dim_theta, hidden_dim)

        # Create a sequence of conditional residual blocks
        if not use_film:
            self.blocks = nn.ModuleList([
                ConditionalResidualBlock(in_dim=hidden_dim, out_dim=hidden_dim if b < n_blocks - 1 else input_dim_theta,
                                         cond_dim=cond_dim, dropout_rate=dropout_rate, use_spectral_norm=use_spectral_norm)
                for b in range(n_blocks)
            ])
        else:
            # Create a series of FiLM-residual blocks
            self.blocks = nn.ModuleList([
                FiLMResidualBlock(in_dim=hidden_dim, out_dim=hidden_dim if b < n_blocks - 1 else input_dim_theta,
                                  cond_dim=cond_dim, dropout_rate=dropout_rate, use_spectral_norm=use_spectral_norm)
                for b in range(n_blocks)
            ])

        # Final layer to get back to the theta dimension
        self.final_linear = nn.Linear(input_dim_theta, input_dim_theta)
        if use_spectral_norm:
            # initialize weights close zero since we want to predict the noise (otherwise in conflict with the spectral norm)
            nn.init.xavier_uniform_(self.final_linear.weight, gain=1e-3)
        else:
             # initialize weights to zero since we want to predict the noise
            nn.init.zeros_(self.final_linear.weight)
        nn.init.zeros_(self.final_linear.bias)

        # Apply spectral normalization
        if use_spectral_norm:
            self.input_layer = nn.utils.parametrizations.spectral_norm(self.input_layer)
            self.final_linear = nn.utils.parametrizations.spectral_norm(self.final_linear)

    def forward(self, theta, t, x, conditions, pred_score):
        """
        Forward pass of the ScoreModel.

        Args:
            theta (torch.Tensor): Input theta tensor of shape (batch_size, input_dim_theta).
            t (torch.Tensor): Input time tensor of shape (batch_size, 1).
            x (torch.Tensor): Input x tensor of shape (batch_size, input_dim_x).
            conditions (torch.Tensor or None): Input condition tensor of shape (batch_size, input_dim_condition). Can be None.
            pred_score (bool): Whether to predict the score (True) or the whatever is specified in prediction_type (False).

        Returns:
            torch.Tensor: Output of the network (dependent on prediction_type) or the score of shape (batch_size, input_dim_theta).
        """
        # Compute a time embedding (shape: [batch, time_embed_dim])
        t_emb = self.embed(t)

        # Form the conditioning vector. If conditions is None, only x and time are used.
        if conditions is not None:
            cond = torch.cat([x, conditions, t_emb], dim=-1)
        else:
            cond = torch.cat([x, t_emb], dim=-1)

        # Save the skip connection (this bypasses the blocks)
        skip = theta.clone()

        # initial input
        h = self.input_layer(theta)

        # Pass through each block, injecting the same cond at each layer
        for block in self.blocks:
            h = block(h, cond)

        # Add the skip connection from theta (or from the input projection)
        h = h + skip

        # Final linear layer to get back to the theta dimension
        theta_emb = self.final_linear(h)

        if not pred_score:
            # just return the prediction
            return theta_emb
        # return the score, depends on the prediction type
        if self.prediction_type == 'score':
            # prediction is the score
            return theta_emb

        # todo: check if this is correct
        alpha, sigma = self.marginal_prob_std_fn(t)
        if self.prediction_type == 'x':
            # convert prediction into error
            error = -(theta_emb * alpha - theta) / sigma
        elif self.prediction_type == 'v':
            # convert prediction into error
            x = alpha * theta - sigma * theta_emb
            error = -(x * alpha - theta) / sigma
        else:
            # prediction is the error
            error = theta_emb

        # divide by the std to predict the score
        return -error / sigma


class HierarchicalScoreModel(nn.Module):
    """
        Neural network model that computes score estimates for a hierarchical model.

        Args:
            input_dim_theta_global (int): Input dimension for global theta.
            input_dim_theta_local (int): Input dimension for local theta.
            input_dim_x (int): Input dimension for x.
            hidden_dim (int): Hidden dimension for theta network.
            n_blocks (int): Number of residual blocks.
            marginal_prob_std_fn (callable): Function to compute the inverse marginal probability standard deviation.
            prediction_type (str): Type of prediction to perform. Can be 'score', 'e', 'x', or 'v'.
            time_embed_dim (int, optional): Dimension of time embedding. Default is 16.
            use_film (bool, optional): Whether to use FiLM-residual blocks. Default is False.
            dropout_rate (float, optional): Dropout rate. Default is 0.1.
            use_spectral_norm (bool, optional): Whether to use spectral normalization. Default is False.
    """
    def __init__(self,
                 input_dim_theta_global, input_dim_theta_local, input_dim_x,
                 hidden_dim, n_blocks, marginal_prob_std_fn, prediction_type,
                 time_embed_dim=16, use_film=False, dropout_rate=0.1, use_spectral_norm=False):
        super(HierarchicalScoreModel, self).__init__()
        self.prediction_type = prediction_type
        self.summary_net = nn.GRU(
            input_size=input_dim_x,
            hidden_size=hidden_dim,
            num_layers=1,
            batch_first=True
        )
        self.n_params_global = input_dim_theta_global
        self.global_model = ScoreModel(
            input_dim_theta=input_dim_theta_global,
            input_dim_x=hidden_dim,
            input_dim_condition=0,
            hidden_dim=hidden_dim,
            n_blocks=n_blocks,
            marginal_prob_std_fn=marginal_prob_std_fn,
            prediction_type=prediction_type,
            time_embed_dim=time_embed_dim,
            use_film=use_film,
            dropout_rate=dropout_rate,
            use_spectral_norm=use_spectral_norm
        )
        self.n_params_local = input_dim_theta_local
        self.local_model = ScoreModel(
            input_dim_theta=input_dim_theta_local,
            input_dim_x=hidden_dim,
            input_dim_condition=input_dim_theta_global,
            hidden_dim=hidden_dim,
            n_blocks=n_blocks,
            marginal_prob_std_fn=marginal_prob_std_fn,
            prediction_type=prediction_type,
            time_embed_dim=time_embed_dim,
            use_film=use_film,
            dropout_rate=dropout_rate,
            use_spectral_norm=use_spectral_norm
        )

    def forward(self, theta, t, x, pred_score=False):  # __call__ method for the model
        """Forward pass through the global and local model. This usually only used, during training."""
        theta_global, theta_local = torch.split(theta, [self.n_params_global, self.n_params_local], dim=-1)
        _, x_emb = self.summary_net(x)
        x_emb = x_emb[0]  # only one layer, not bidirectional
        global_out = self.global_model.forward(theta=theta_global, t=t, x=x_emb, conditions=None, pred_score=pred_score)
        local_out = self.local_model.forward(theta=theta_local, t=t, x=x_emb, conditions=theta_global, pred_score=pred_score)
        return torch.cat([global_out, local_out], dim=-1)

    def forward_local(self, theta_local, theta_global, t, x, pred_score=True):
        """Forward pass through the local model. Usually we want the score, not the predicting task from training."""
        _, x_emb = self.summary_net(x)
        x_emb = x_emb[0]  # only one layer, not bidirectional
        local_out = self.local_model.forward(theta=theta_local, t=t, x=x_emb, conditions=theta_global, pred_score=pred_score)
        return local_out

    def forward_global(self, theta_global, t, x, pred_score=True):
        """Forward pass through the global model. Usually we want the score, not the predicting task from training."""
        _, x_emb = self.summary_net(x)
        x_emb = x_emb[0]  # only one layer, not bidirectional
        global_out = self.global_model.forward(theta=theta_global, t=t, x=x_emb, conditions=None, pred_score=pred_score)
        return global_out

In [None]:
BETA_MIN = 0.1
BETA_MAX = 20.0  # todo: check if this is a good value, 20 before

def beta(t):
    """beta(t) = beta_\text{min} + t*(beta_\text{max} - beta_\text{min})"""
    return BETA_MIN + t * (BETA_MAX - BETA_MIN)

def variance_preserving_kernel(t):
    """
    Computes the variance-preserving kernel p(x_t | x_0) for the diffusion process.
    Assuming beta_t is an arithmetic sequence as in DDPM models, we get the following:
        beta(t) = beta_\text{min} + t*(beta_\text{max} - beta_\text{min})

    Args:
        t (torch.Tensor): The time at which to evaluate the kernel in [0,1]. Should be not too close to 0.

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: The mean and standard deviation of the kernel at time t.
    """
    # mean is x(0) exp(-1/2 \int_0^t beta(s) ds)
    log_integral = -0.5 * (t**2) * (BETA_MAX - BETA_MIN) - 0.5 * t * BETA_MIN
    # var is 1 - exp(- \int_0^t beta(s) ds)
    std = torch.sqrt(1 - torch.exp(log_integral))
    mean = torch.exp(0.5 * log_integral)  # * x_0 is the mean
    return mean, std

def weighting_function(t):
    """for clarity, we define the weighting function as the inverse of the expectation of the log score"""
    #return variance_preserving_kernel(1+1e-3-t)**2
    #return torch.ones_like(t)
    return beta(t)  # likelihood weighting, since beta(t) = g(t)^2
    #return variance_preserving_kernel(t)**2
    #return variance_preserving_kernel(t)**4

def log_grad_kernel(x, x0, t):
    # Compute log_mean, m and std as in your kernel
    log_mean = -0.25 * (t**2) * (BETA_MAX - BETA_MIN) - 0.5 * t * BETA_MIN
    m = x0 * torch.exp(log_mean)
    std = torch.sqrt(1 - torch.exp(2. * log_mean))
    var = std ** 2  # var = 1 - exp(2 * log_mean)

    # Gradient of log p(x|x0) wrt x
    grad_log_p = -(x - m) / var
    return grad_log_p

In [None]:
def generate_weighted_time_points(num_points, num_grid=10000, device=None):
    """
    Generate time points between 0 and 1 with spacing proportional to weighting_function(t).

    The idea is to compute the cumulative density function (CDF)
    from the weighting function via numerical integration and then
    invert the CDF at equally spaced values.

    Args:
        num_points (int): Number of time points to generate.
        num_grid (int, optional): Number of points in the fine grid for numerical integration.
        device (torch.device, optional): Device to perform computation on. Defaults to CPU.

    Returns:
        torch.Tensor: Tensor of shape (num_points,) containing the generated time points.
    """
    # Create a fine grid over [0, 1]
    t_grid = torch.linspace(0, 1, steps=num_grid, dtype=torch.float32, device=device)

    # Evaluate the weighting function on this grid
    weights = weighting_function(t_grid)

    # Ensure weights are non-negative
    if (weights < 0).any():
        raise ValueError("weighting_function returned negative values.")

    # Use the trapezoidal rule to compute the cumulative integral (CDF)
    dt = t_grid[1] - t_grid[0]
    cumulative = torch.cumsum(weights, dim=0) * dt

    # Normalize so that the final value is 1 (i.e. form a proper CDF)
    cumulative = cumulative / cumulative[-1]

    # Create equally spaced values in [0,1] which serve as target CDF values
    target_cdf = torch.linspace(0, 1, steps=num_points, dtype=torch.float32, device=device)

    # Invert the CDF: for each target value, find the corresponding t
    t_points = torch.tensor(np.interp(target_cdf.cpu().numpy(),
                                          cumulative.cpu().numpy(),
                                          t_grid.cpu().numpy()), dtype=torch.float32, device=device)
    return t_points

In [None]:
def generate_diffusion_time(size, epsilon=5e-3, return_batch=False, weighted_time=False, device=None):
    """
    Generates equally spaced diffusion time values in [epsilon,1].
    The time is generated uniformly in [epsilon, 1] if return_batch is True.
    """
    if not return_batch and not weighted_time:
        time = torch.linspace(epsilon, 1, steps=size, dtype=torch.float32, device=device)
        return time
    if weighted_time:
        # make time step size proportional to the weighting function
        time = generate_weighted_time_points(num_points=size, device=device)
        if return_batch:
            return time.unsqueeze(1)
        return time

    #time = torch.rand(size, dtype=torch.float32, device=device) * (1 - epsilon) + epsilon
    beta_dist = torch.distributions.Beta(1, 3)
    samples = beta_dist.sample((size,))
    time = epsilon + (1 - epsilon) * samples
    # low discrepancy sequence
    # t_i = \mod (u_0 + i/k, 1)
    #u0 = torch.rand(1, dtype=torch.float32, device=device)
    #i = torch.arange(0, size, dtype=torch.float32, device=device)  # i as a tensor of indices
    #time = ((u0 + i / size) % 1) * (1 - epsilon) + epsilon
    #time, _ = time.sort()

    # Add a new dimension so that each tensor has shape (size, 1)
    return time.unsqueeze(1)

In [None]:
plt.hist(generate_diffusion_time(10000, return_batch=True), bins=50)

In [None]:
# plot the kernel
t = generate_diffusion_time(100)
beta_t = beta(t)
plt.plot(t, torch.sqrt(beta_t), label='beta_t')
plt.legend()
plt.xlabel('Time')
plt.ylabel('Value')
plt.title('Noise Level')
plt.show()

# plot the kernel
m, std = variance_preserving_kernel(t)
plt.plot(t, m, label='mean')
plt.plot(t, std**2, label='variance')
plt.legend()
plt.xlabel('Time')
plt.ylabel('Value')
plt.title('Variance Preserving Kernel')
plt.show()

plt.plot(t, weighting_function(t), label='weighting')
plt.legend()
plt.xlabel('Time')
plt.ylabel('Value')
plt.title('Weighting Function')
plt.show()

plt.plot(t, 1/std, label='score scaling')
plt.legend()
plt.xlabel('Time')
plt.ylabel('Value')
plt.title('score scaling')
plt.show()

In [None]:
def compute_score_loss(theta_batch, x_batch, model):
    # Generate diffusion time and step size
    diffusion_time = generate_diffusion_time(size=theta_batch.shape[0], return_batch=True, device=theta_batch.device)

    # sample from the Gaussian kernel, just learn the noise
    epsilon = torch.randn_like(theta_batch, dtype=theta_batch.dtype, device=theta_batch.device)

    # perturb the theta batch
    alpha, sigma = variance_preserving_kernel(t=diffusion_time)
    z = alpha * theta_batch + sigma * epsilon
    # predict from perturbed theta
    pred = model(theta=z, t=diffusion_time, x=x_batch)
    snr = torch.log(torch.square(alpha)) - torch.log(torch.square(sigma))

    if model.prediction_type == 'score':
        target = log_grad_kernel(x=z, x0=theta_batch, t=diffusion_time)
        pred_type_weight = 1
    elif model.prediction_type == 'e':
        # divide by the std to learn the score
        pred_type_weight = 1 / torch.square(sigma)
        target = epsilon
    elif model.prediction_type == 'x':
        target = theta_batch
        # divide by the std to learn the score
        pred_type_weight = torch.exp(-snr) / torch.square(sigma)
    elif model.prediction_type == 'v':
        target = alpha * epsilon - sigma * theta_batch
        # divide by the std to learn the score
        pred_type_weight = torch.square(alpha) * torch.square(torch.exp(-snr) - 1) / torch.square(sigma)
    else:
        raise ValueError("Invalid prediction type. Must be one of 'score', 'e', 'x', or 'v'.")

    effective_weight = pred_type_weight * weighting_function(diffusion_time)
    # calculate the loss (sum over the last dimension, mean over the batch)
    loss = torch.mean(effective_weight * torch.sum(torch.square(pred - target), dim=-1))
    return loss


# Training loop for Score Model
def train_score_model(model, dataloader, dataloader_valid=None, epochs=100, lr=1e-3, device=None):
    score_model.to(torch_device)
    optimizer = optim.AdamW(model.parameters(), lr=lr)

    # Add Cosine Annealing Scheduler
    #scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-6)

    # Training loop
    loss_history = np.zeros((epochs, 2))
    for epoch in range(epochs):
        model.train()
        total_loss = []
        # for each sample in the batch, calculate the loss for a random diffusion time
        for theta_global_batch, theta_local_batch, x_batch in dataloader:
            # initialize the gradients
            optimizer.zero_grad()
            theta_batch = torch.concat([theta_global_batch, theta_local_batch], dim=-1)
            theta_batch = theta_batch.to(device)
            # calculate the loss
            loss = compute_score_loss(theta_batch=theta_batch, x_batch=x_batch, model=model)
            loss.backward()
            # gradient clipping
            nn.utils.clip_grad_norm_(model.parameters(), 3.0)
            optimizer.step()
            total_loss.append(loss.item())
        #scheduler.step()

        # validate the model
        model.eval()
        valid_loss = []
        if dataloader_valid is not None:
            for theta_global_batch, theta_local_batch, x_batch in dataloader_valid:
                with torch.no_grad():
                    theta_batch = torch.concat([theta_global_batch, theta_local_batch], dim=-1)
                    theta_batch = theta_batch.to(device)
                    loss = compute_score_loss(theta_batch=theta_batch, x_batch=x_batch, model=model)
                    valid_loss.append(loss.item())

        loss_history[epoch] = [np.mean(total_loss), np.mean(valid_loss)]
        print_str = f"Epoch {epoch+1}/{epochs}, Loss: {np.mean(total_loss):.4f}, "\
                    f"Valid Loss: {np.mean(valid_loss):.4f}"
        print(print_str, end='\r')
        # Update the checkpoint after each epoch of training.
        #torch.save(model.state_dict(), 'ckpt.pth')
    return loss_history

In [None]:
# Hyperparameters
n_data = 25000
batch_size = 128

# Create model and dataset
thetas_global, thetas_local, xs = generate_synthetic_data(prior, n_data=n_data, normalize=True)

# Create dataloader
dataset = TensorDataset(thetas_global, thetas_local, xs)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# create validation data
valid_prior_global, valid_prior_local, valid_data = generate_synthetic_data(prior, n_data=batch_size*2, normalize=True)
dataset_valid = TensorDataset(valid_prior_global, valid_prior_local, valid_data)
dataloader_valid = DataLoader(dataset_valid, batch_size=batch_size)

In [None]:
# Define model
score_model = HierarchicalScoreModel(
    input_dim_theta_global=prior.n_params_global,
    input_dim_theta_local=prior.n_params_local,
    input_dim_x=1,
    hidden_dim=64,
    n_blocks=3,
    time_embed_dim=16,
    use_film=True,
    marginal_prob_std_fn=partial(variance_preserving_kernel),
    use_spectral_norm=False
)

In [None]:
# train model
loss_history = train_score_model(score_model, dataloader, dataloader_valid=dataloader_valid,
                                 epochs=1000, lr=1e-4, device=torch_device)
score_model.eval();

In [None]:
# plot loss history
plt.figure(figsize=(6, 3), tight_layout=True)
plt.plot(loss_history[:, 0], label='Mean Train')
plt.plot(loss_history[:, 1], label='Mean Valid')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
# check the error prediction: is it close to the noise?
loss_list_error = {}
loss_list_score = {}
loss_list_w = {}
loss_list_w2 = {}
with torch.no_grad():
    # Generate diffusion time and step size
    diffusion_time = generate_diffusion_time(size=50, device=torch_device)
    for t in diffusion_time:
        for theta_global_batch, theta_local_batch, x_batch in dataloader_valid:
            theta_batch = torch.cat([theta_global_batch, theta_local_batch], dim=-1)
            theta_batch = theta_batch.to(torch_device)

            # sample from the Gaussian kernel, just learn the noise
            epsilon = torch.randn_like(theta_batch, dtype=torch.float32, device=torch_device)

            # perturb the theta batch
            t_tensor = torch.full((theta_batch.shape[0], 1), t, dtype=torch.float32, device=torch_device)
            # perturb the theta batch
            alpha, sigma = variance_preserving_kernel(t=t_tensor)
            z = alpha * theta_batch + sigma * epsilon
            snr = torch.log(torch.square(alpha)) - torch.log(torch.square(sigma))
            # predict from perturbed theta
            pred = score_model(theta=z, t=t_tensor, x=x_batch, pred_score=False)
            pred_score = score_model(theta=z, t=t_tensor, x=x_batch, pred_score=True)
            true_score = log_grad_kernel(x=z, x0=theta_batch, t=t_tensor)

            if score_model.prediction_type == 'score':
                target = log_grad_kernel(x=z, x0=theta_batch, t=t_tensor)
                pred_type_weight = 1
            elif score_model.prediction_type == 'e':
                # divide by the std to learn the score
                pred_type_weight = 1 / torch.square(sigma)
                target = epsilon
            elif score_model.prediction_type == 'x':
                target = theta_batch
                # divide by the std to learn the score
                pred_type_weight = torch.exp(-snr) / torch.square(sigma)
            elif score_model.prediction_type == 'v':
                target = alpha * epsilon - sigma * theta_batch
                # divide by the std to learn the score
                pred_type_weight = torch.square(alpha) * torch.square(torch.exp(-snr) - 1) / torch.square(sigma)
            else:
                raise ValueError("Invalid prediction type. Must be one of 'score', 'e', 'x', or 'v'.")

            # calculate the loss (sum over the last dimension, mean over the batch)
            loss = torch.mean(torch.sum(torch.square(pred - target), dim=-1))
            loss_list_error[t.item()] = loss.item()

            # calculate the error of the true score
            loss = torch.mean(torch.sum(torch.square(pred_score - true_score), dim=-1))
            loss_list_score[t.item()] = loss.item()

            # calculate the weighted loss
            loss = torch.mean(weighting_function(t) * pred_type_weight * torch.sum(torch.square(pred - target), dim=-1))
            loss_list_w[t.item()] = loss.item()

            # check if the weighting function is correct
            loss2 = torch.mean(weighting_function(t) * torch.sum(torch.square(pred_score - true_score), dim=-1))
            loss_list_w2[t.item()] = loss2.item()

In [None]:
df_error = pd.DataFrame(loss_list_error.items(), columns=['Time', 'Loss'])
df_score = pd.DataFrame(loss_list_score.items(), columns=['Time', 'Loss'])
df_score_w = pd.DataFrame(loss_list_w.items(), columns=['Time', 'Loss'])
df_score_w2 = pd.DataFrame(loss_list_w2.items(), columns=['Time', 'Loss'])

fig, ax = plt.subplots(ncols=3, sharex=True, figsize=(12, 3), tight_layout=True)
ax[0].plot(df_error['Time'], df_error['Loss'], label=f'Unscaled {score_model.prediction_type} Loss')
ax[1].plot(df_score['Time'], df_score['Loss'], label='Score Loss')
ax[2].plot(df_score_w['Time'], df_score_w['Loss'], label='Weighted Loss (as in Optimization)')
ax[2].plot(df_score_w2['Time'], df_score_w2['Loss'], label='Weighted Loss on Scores')  # should be the same as the loss in optimization
for a in ax:
    a.set_xlabel('Diffusion Time')
    a.set_ylabel('Loss')
    a.legend()
ax[-1].set_ylabel('Weighted Loss')
plt.show()

plt.figure(figsize=(6, 3), tight_layout=True)
plt.plot(diffusion_time, weighting_function(diffusion_time), label='assumed weighting')
#plt.plot(diffusion_time, weighting_function(diffusion_time) / (variance_preserving_kernel(diffusion_time)**2),
#         label='effective weighting', alpha=0.5)
plt.xlabel('Diffusion Time')
plt.ylabel('Weight')
plt.legend()
plt.show()

# Sample from the Score Model

In [None]:
def euler_maruyama_step(x, score, t, dt):
    """
    Perform one Euler-Maruyama update step for the SDE

    Parameters:
        x: Current state (PyTorch tensor)
        score: Score network output (PyTorch tensor)
        t: Current time (scalar tensor or float)
        dt: Time step size (float)

    Returns:
        z_next: Updated state after time dt.
    """
    # Compute g(t)^2
    g_t2 = beta(t)
    # Compute f(x,t)
    f_x_t = -0.5 * g_t2 * x
    # Compute drift and diffusion
    drift = f_x_t - g_t2 * score
    diffusion = torch.sqrt(g_t2)

    # Sample Gaussian noise (same shape as x), with variance dt
    noise = torch.randn_like(x, dtype=x.dtype, device=x.device) * torch.sqrt(dt)

    # Eulerâ€“Maruyama update step
    x_mean = x + drift * dt
    x_next = x_mean + diffusion * noise
    return x_next, x_mean

In [None]:
def eval_compositional_score(model, theta, diffusion_time, x_expanded, n_post_samples, n_obs, conditions_exp):
    # Create tensors for current time step
    t_exp = diffusion_time.unsqueeze(1).expand(-1, n_obs, -1).reshape(-1, 1)
    # Compute model scores
    if conditions_exp is None: # compositional global scores
        theta_exp = theta.unsqueeze(1).expand(-1, n_obs, -1).reshape(-1, prior.n_params_global)
        model_indv_scores = model.forward_global(theta_global=theta_exp, t=t_exp, x=x_expanded)
        # Sum over observations
        model_sum_scores = model_indv_scores.reshape(n_post_samples, n_obs, -1).sum(dim=1)

        # Compute prior score
        prior_score = prior.score_global_batch(theta)
        model_scores = (1 - n_obs) * (1 - diffusion_time) / 1 * prior_score + model_sum_scores
    else:  # not compositional local scores
        theta_exp = theta.reshape(-1, prior.n_params_local)
        model_scores = model.forward_local(theta_local=theta_exp, t=t_exp, x=x_expanded, theta_global=conditions_exp)
        model_scores = model_scores.reshape(n_post_samples, n_obs, -1)
    return model_scores

In [None]:
# Euler-Maruyama Sampling from Song et al. (2021)
def euler_maruyama_sampling(model, x_obs, n_post_samples, conditions=None, diffusion_steps=100, device=None):
    with torch.no_grad():
        model.to(device)
        model.eval()

        # Ensure x_obs is a PyTorch tensor
        if not isinstance(x_obs, torch.Tensor):
            x_obs = torch.tensor(x_obs, dtype=torch.float32, device=device)

        x_obs_norm = prior.normalize_data(x_obs)
        x_obs_norm = x_obs_norm.reshape(x_obs_norm.shape[0], -1)
        n_obs = x_obs_norm.shape[-1]
        n_time_steps = x_obs_norm.shape[0]
        x_obs_norm = x_obs_norm.T[:, :, None]

        # Initialize parameters
        if conditions is None:  # global
            n_params = prior.n_params_global
            theta = torch.randn(n_post_samples, prior.n_params_global, dtype=torch.float32, device=device) / torch.sqrt(torch.tensor(n_obs, dtype=torch.float32, device=device))
            conditions_exp = None
        else:
            # Ensure conditions is a PyTorch tensor
            if not isinstance(conditions, torch.Tensor):
                conditions = torch.tensor(conditions, dtype=torch.float32, device=device)
            conditions_norm = prior.normalize_theta(conditions, global_params=True)

            n_params = prior.n_params_local*n_obs
            theta = torch.randn(n_post_samples, n_obs, prior.n_params_local, dtype=torch.float32, device=device)
            conditions_exp = conditions_norm.unsqueeze(0).expand(n_post_samples, n_obs, -1).reshape(-1, prior.n_params_global)

        # Generate diffusion time parameters
        diffusion_time = generate_diffusion_time(size=diffusion_steps+1, device=device, weighted_time=True) # todo: weighed times

        # Expand x_obs_norm to match the number of posterior samples
        x_exp = x_obs_norm.unsqueeze(0).expand(n_post_samples, n_obs, n_time_steps, -1)  # Shape: (n_post_samples, n_obs, n_time_steps, d)
        x_expanded = x_exp.reshape(n_post_samples*n_obs, n_time_steps, -1)

        # Reverse iterate over diffusion times and step sizes
        for t in tqdm(reversed(range(1, diffusion_steps+1)), total=diffusion_steps):
            t_tensor = torch.full((n_post_samples, 1), diffusion_time[t], dtype=torch.float32, device=device)
            scores = eval_compositional_score(model=model, theta=theta, diffusion_time=t_tensor, x_expanded=x_expanded,
                                              n_obs=n_obs, n_post_samples=n_post_samples,
                                              conditions_exp=conditions_exp)
            # Make Euler-Maruyama step
            if t == 1:
                # do not include noise in the last step
                _, theta = euler_maruyama_step(theta, score=scores, t=diffusion_time[t],
                                              dt=diffusion_time[t] - diffusion_time[t-1])
            else:
                theta, _ = euler_maruyama_step(theta, score=scores, t=diffusion_time[t],
                                               dt=diffusion_time[t] - diffusion_time[t-1])
            # clip theta to avoid numerical issues
            #theta = torch.clamp(theta, -5, 5)
            if torch.isnan(theta).any():
                print("NaNs in theta")
                break
        # correct for normalization
        if conditions is None:
            theta = prior.denormalize_theta(theta, global_params=True)
            # convert to numpy
            theta = theta.detach().numpy().reshape(n_post_samples, prior.n_params_global)
        else:
            theta = prior.denormalize_theta(theta, global_params=False)
            # convert to numpy
            theta = theta.detach().numpy().reshape(n_post_samples, n_obs, prior.n_params_local)
    return theta

In [None]:
# Probability ODE from Song et al. (2021)
from scipy.integrate import solve_ivp
def probability_ode_solving(model, x_obs, n_post_samples, conditions=None, device=None):
    with torch.no_grad():
        model.to(device)
        model.eval()

        # Ensure x_obs is a PyTorch tensor
        if not isinstance(x_obs, torch.Tensor):
            x_obs = torch.tensor(x_obs, dtype=torch.float32, device=device)

        x_obs_norm = prior.normalize_data(x_obs)
        x_obs_norm = x_obs_norm.reshape(x_obs_norm.shape[0], -1)
        n_obs = x_obs_norm.shape[-1]
        n_time_steps = x_obs_norm.shape[0]
        x_obs_norm = x_obs_norm.T[:, :, None]


        # Initialize parameters
        if conditions is None:  # global
            theta = torch.randn(n_post_samples, prior.n_params_global, dtype=torch.float32, device=device) / torch.sqrt(torch.tensor(n_obs, dtype=torch.float32, device=device))
            conditions_exp = None
        else:
            # Ensure conditions is a PyTorch tensor
            if not isinstance(conditions, torch.Tensor):
                conditions = torch.tensor(conditions, dtype=torch.float32, device=device)
            conditions_norm = prior.normalize_theta(conditions, global_params=True)

            theta = torch.randn(n_post_samples, n_obs, prior.n_params_local, dtype=torch.float32, device=device)
            conditions_exp = conditions_norm.unsqueeze(0).expand(n_post_samples, n_obs, -1).reshape(-1, prior.n_params_global)

        # Expand x_obs_norm to match the number of posterior samples
        x_exp = x_obs_norm.unsqueeze(0).expand(n_post_samples, n_obs, n_time_steps, -1)  # Shape: (n_post_samples, n_obs, n_time_steps, d)
        x_expanded = x_exp.reshape(n_post_samples*n_obs, n_time_steps, -1)

        # Reverse iterate over diffusion times and step sizes

        def probability_ode(t, x):
            t_tensor = torch.full((n_post_samples, 1), t, dtype=torch.float32, device=device)

            if conditions is None:
                x_torch = torch.tensor(x.reshape(n_post_samples, prior.n_params_global), dtype=torch.float32, device=device)
            else:
                x_torch = torch.tensor(x.reshape(n_post_samples, n_obs, prior.n_params_local), dtype=torch.float32, device=device)
            scores = eval_compositional_score(model=model, theta=x_torch, diffusion_time=t_tensor, x_expanded=x_expanded,
                                              n_obs=n_obs, n_post_samples=n_post_samples,
                                              conditions_exp=conditions_exp)
            if conditions is None:
                t_exp = t_tensor
            else:
                t_exp = t_tensor.unsqueeze(1).expand(-1, n_obs, -1)
            # Compute g(t)^2
            g_t2 = beta(t_exp)
            # Compute f(x,t)
            f_x_t = -0.5 * g_t2 * x_torch
            # Compute drift and diffusion
            drift = f_x_t - 0.5*g_t2 * scores
            if conditions is None:
                return drift.detach().numpy().reshape(n_post_samples * prior.n_params_global)
            return drift.detach().numpy().reshape(n_post_samples * n_obs * prior.n_params_local)

        if conditions is None:
            x_0 = theta.detach().numpy().reshape(n_post_samples * prior.n_params_global)
        else:
            x_0 = theta.detach().numpy().reshape(n_post_samples * n_obs * prior.n_params_local)
        sol = solve_ivp(probability_ode, t_span=[1, 5e-3], y0=x_0, method='RK45', t_eval=[5e-3])
        print('ODE solved:', sol.success)
        if not sol.success:
            print(sol.message)
        if conditions is None:
            theta = torch.tensor(sol.y[:, -1].reshape(n_post_samples, prior.n_params_global), dtype=torch.float32, device=device)
        else:
            theta = torch.tensor(sol.y[:, -1].reshape(n_post_samples, n_obs, prior.n_params_local), dtype=torch.float32, device=device)

        # correct for normalization
        if conditions is None:
            theta = prior.denormalize_theta(theta, global_params=True)
            # convert to numpy
            theta = theta.detach().numpy().reshape(n_post_samples, prior.n_params_global)
        else:
            theta = prior.denormalize_theta(theta, global_params=False)
            # convert to numpy
            theta = theta.detach().numpy().reshape(n_post_samples, n_obs, prior.n_params_local)
    return theta

# Validation

In [None]:
n_grid = 8
valid_prior_global, valid_prior_local, valid_data = generate_synthetic_data(prior, n_data=10, grid_size=n_grid, full_grid=True,
                                                                            normalize=False)
n_post_samples = 2

In [None]:
posterior_global_samples_valid = np.array([euler_maruyama_sampling(score_model, vd, n_post_samples=n_post_samples,
                                                                   diffusion_steps=1000,
                                                          device=torch_device)
                                        for vd in valid_data])

In [None]:
posterior_global_samples_valid = np.array([probability_ode_solving(score_model, vd, n_post_samples=n_post_samples, device=torch_device)
                                        for vd in valid_data])

In [None]:
diagnostics.plot_recovery(posterior_global_samples_valid, np.array(valid_prior_global), param_names=[r'$\mu$', r'$\log \tau$']);

In [None]:
diagnostics.plot_sbc_ecdf(posterior_global_samples_valid, np.array(valid_prior_global),
                          difference=True, param_names=[r'$\mu$', r'$\log \tau$']);

In [None]:
conditions_global = np.median(posterior_global_samples_valid, axis=1)
posterior_local_samples_valid = np.array([probability_ode_solving(score_model, vd, n_post_samples=n_post_samples,
                                                                  conditions=c,
                                                                  device=torch_device)
                                        for vd, c in zip(valid_data, conditions_global)])

In [None]:
diagnostics.plot_recovery(posterior_local_samples_valid.reshape(valid_data.shape[0], n_post_samples, -1),
                          np.array(valid_prior_local).reshape(valid_data.shape[0], -1),
                          param_names=['$\\theta_{'+str(i)+'}$' for i in range(n_grid**2)]);

In [None]:
posterior_local_samples_valid.reshape(valid_data.shape[0], n_post_samples, -1).shape, np.array(valid_prior_local).reshape(valid_data.shape[0], -1).shape

In [None]:
np.array(valid_prior_local).reshape(valid_data.shape[0], -1)[0].shape

In [None]:
conditions_global

In [None]:
valid_id = 0
print('Data')
visualize_simulation_output(valid_data[valid_id])
print('Global Estimates')
print('mu:', np.median(posterior_global_samples_valid[valid_id, :, 0]), np.std(posterior_global_samples_valid[valid_id, :, 0]))
print('log tau:', np.median(posterior_global_samples_valid[valid_id, :, 1]), np.std(posterior_global_samples_valid[valid_id, :, 1]))
print('True')
print('mu:', valid_prior_global[valid_id][0].item())
print('log tau:', valid_prior_global[valid_id][1].item())

In [None]:
med = np.median(posterior_local_samples_valid[valid_id].reshape(n_post_samples, n_grid, n_grid), axis=0)
std = np.std(posterior_local_samples_valid[valid_id].reshape(n_post_samples, n_grid, n_grid), axis=0)
error = (med-valid_prior_local[valid_id].numpy())**2
visualize_simulation_output(np.stack((med, valid_prior_local[valid_id], )),
                            title_prefix=['Posterior Median', 'True'])

visualize_simulation_output(np.stack((std, error)), title_prefix=['Uncertainty', 'Error'], same_scale=False)