# Simple Test for Hierarchical ABI with compositional score matching

In [1]:
import math

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

In [2]:
mps_device = torch.device("cpu")
mps_device

device(type='cpu')

In [3]:
# eight schools problem
J = 8
y = np.array([28, 8, -3, 7, -1, 1, 18, 12])[:, np.newaxis]  # our groups, one observation per group
sigma = np.array([15, 10, 16, 11, 9, 11, 10, 18])  # assumed to be known
n_obs_per_group = 2  # sigma is known, so part of the observation

In [4]:
# simulator example
def simulator(params, school_i=None):
    batch_size = params['theta_j'].shape[0]
    y_j = np.random.normal(loc=params['theta_j'], scale=sigma, size=(batch_size, J))
    if school_i is None:
        return dict(observable=y_j, sigma=np.tile(sigma, (batch_size, 1)))
    return dict(observable=y_j[:, school_i][:, np.newaxis], sigma=(np.ones(batch_size)*sigma[school_i])[:, np.newaxis])


class Prior:
    def __init__(self):
        self.mu_mean = 0
        self.mu_std = 10
        self.log_tau_mean = 5
        self.log_tau_std = 1

        np.random.seed(0)
        test_prior = self.sample(1000)
        test = simulator(test_prior, school_i=None)
        self.x_mean = torch.tensor(np.array([np.mean(test['observable']), np.mean(test['sigma'])]),
                                   dtype=torch.float32, device=mps_device)
        self.x_std = torch.tensor(np.array([np.std(test['observable']), np.std(test['sigma'])]),
                                  dtype=torch.float32, device=mps_device)
        self.prior_mean = torch.tensor(np.array([np.mean(test_prior['mu']), np.mean(test_prior['log_tau'])]),
                                       dtype=torch.float32, device=mps_device)
        self.prior_std = torch.tensor(np.array([np.std(test_prior['mu']), np.std(test_prior['log_tau'])]),
                                      dtype=torch.float32, device=mps_device)

    def __call__(self, batch_size):
        return self.sample(batch_size)

    def sample(self, batch_size):
        mu = np.random.normal(loc=self.mu_mean, scale=self.mu_std, size=(batch_size,1))
        log_tau = np.random.normal(loc=self.log_tau_mean, scale=self.log_tau_std, size=(batch_size,1))
        theta_j = np.random.normal(loc=mu, scale=np.exp(log_tau), size=(batch_size, J))
        return dict(mu=mu, log_tau=log_tau, theta_j=theta_j)

    def score_global_batch(self, theta_batch_norm):
        """ Computes the global score for a batch of parameters without explicit looping. """
        theta_batch = theta_batch_norm * self.prior_std + self.prior_mean
        mu, log_tau = theta_batch[..., 0], theta_batch[..., 1]
        grad_logp_mu = -(mu - self.mu_mean) / (self.mu_std**2)
        grad_logp_tau = -(log_tau - self.log_tau_mean) / (self.log_tau_std**2)
        score = torch.stack([grad_logp_mu, grad_logp_tau], dim=-1)
        return score / self.prior_std

    def score_local_batch(self, theta_batch_norm):
        """ Computes the local score for a batch of samples. """
        theta_batch = theta_batch_norm * self.prior_std + self.prior_mean  # todo: should differ between global and local prior
        mu, log_tau, theta_j = theta_batch[..., 0], theta_batch[..., 1], theta_batch[..., 2:]
        # Gradient w.r.t theta_j
        grad_logp_theta_j = -(theta_j - mu) / (torch.exp(log_tau)**2)
        # correct the score for the normalization
        score = grad_logp_theta_j / self.prior_std
        return score

prior = Prior()
n_params = 2

In [5]:
prior(2)

{'mu': array([[0.0601421 ],
        [7.52771209]]),
 'log_tau': array([[5.23244057],
        [5.59150044]]),
 'theta_j': array([[-191.14834832,   27.55194558,  186.33058514,  -21.5385896 ,
         -412.61555756,  -68.55850401,  155.48726917,  141.2585979 ],
        [-317.10900255,   78.89788105,  369.67340495,  527.58649296,
          -44.04420209, -147.9635381 ,  304.84245453, -146.04505225]])}

In [6]:
def positional_encoding(t, d_model, max_t=10000.0):
    """
    Computes the sinusoidal positional encoding for a given time t.

    Args:
        t (torch.Tensor): The input time tensor of shape (batch_size, 1).
        d_model (int): The dimensionality of the embedding.

    Returns:
        torch.Tensor: The positional encoding of shape (batch_size, d_model).
    """
    half_dim = d_model // 2
    div_term = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=t.device) *
                         -(math.log(max_t) / (half_dim - 1)))
    t_proj = t * div_term
    pos_enc = torch.cat([torch.sin(t_proj), torch.cos(t_proj)], dim=-1)
    return pos_enc

# Define the Score Model based on F-NPSE
class ResidualBlock(nn.Module):
    def __init__(self, in_features, out_features):
        super(ResidualBlock, self).__init__()
        self.linear = nn.Linear(in_features, out_features)
        self.norm = nn.LayerNorm(out_features)
        self.activation = nn.ReLU()
        self.proj = nn.Linear(in_features, out_features) if in_features != out_features else nn.Identity()

    def forward(self, x):
        identity = self.proj(x)
        out = self.linear(x)
        out = self.norm(out)
        out = self.activation(out)
        return out + identity

class ScoreModel(nn.Module):
    """
        Neural network model that computes score estimates.

        Args:
            input_dim_theta (int): Input dimension for theta.
            hidden_dim_theta (int): Hidden dimension for theta network.
            input_dim_x (int): Input dimension for x.
            hidden_dim_x (int): Hidden dimension for x network.
            hidden_dim_emb (int): Hidden dimension for the embedding network.
            time_embed_dim (int, optional): Dimension of time embedding. Defaults to 4.
    """
    def __init__(self, input_dim_theta, hidden_dim_theta,
                 input_dim_x, hidden_dim_x,
                 hidden_dim_emb, time_embed_dim=4):
        super(ScoreModel, self).__init__()
        self.net_theta = nn.Sequential(
            ResidualBlock(input_dim_theta, hidden_dim_theta),
            ResidualBlock(hidden_dim_theta, hidden_dim_theta),
            ResidualBlock(hidden_dim_theta, input_dim_theta)
        )

        self.net_x = nn.Sequential(
            ResidualBlock(input_dim_x, hidden_dim_x),
            ResidualBlock(hidden_dim_x, hidden_dim_x),
            ResidualBlock(hidden_dim_x, input_dim_theta)
        )

        self.time_embed_dim = time_embed_dim
        self.net_emb = nn.Sequential(
            ResidualBlock(input_dim_theta*2 + time_embed_dim, hidden_dim_emb),
            ResidualBlock(hidden_dim_emb, hidden_dim_emb),
            ResidualBlock(hidden_dim_emb, input_dim_theta)
        )

    def forward(self, theta, t, x):
        """
        Forward pass of the ScoreModel.

        Args:
            theta (torch.Tensor): Input theta tensor of shape (batch_size, input_dim_theta).
            t (torch.Tensor): Input time tensor of shape (batch_size, 1).
            x (torch.Tensor): Input x tensor of shape (batch_size, input_dim_x).

        Returns:
            torch.Tensor: Output of the score model.
        """
        theta_emb = self.net_theta(theta)
        x_emb = self.net_x(x)
        t_emb = positional_encoding(t, self.time_embed_dim)
        return self.net_emb(torch.cat([theta_emb, x_emb, t_emb], dim=-1))

In [7]:
def generate_synthetic_data(n_samples, schools_joint=False, device=None):
    thetas = []
    xs = []
    for i in range(n_samples):
        batch_params = prior(1)
        if schools_joint:
            sim_batch = simulator(batch_params)
            theta = torch.tensor(np.concatenate([batch_params['mu'], batch_params['log_tau']], axis=-1),
                                 dtype=torch.float32, device=device)
            x = torch.tensor(np.stack((sim_batch['observable'], sim_batch['sigma']), axis=-1),
                             dtype=torch.float32, device=device)
        else:
            sim_batch = simulator(batch_params, school_i=i % J)
            theta = torch.tensor(np.concatenate([batch_params['mu'], batch_params['log_tau']], axis=-1),
                                 dtype=torch.float32, device=device)
            x = torch.tensor(np.concatenate((sim_batch['observable'], sim_batch['sigma']), axis=-1),
                             dtype=torch.float32, device=device)

        thetas.append(theta)
        xs.append(x)
    thetas = torch.concatenate(thetas)
    xs = torch.concatenate(xs)
    return thetas, xs


# Gaussian kernel for log pdf and sampling
def gaussian_kernel_score(theta, theta_prime, gamma):
    return -(theta - torch.sqrt(gamma) * theta_prime) / (1 - gamma)


def gaussian_kernel_sample(theta_prime, gamma, device=None):
    noise = torch.randn_like(theta_prime, dtype=torch.float32, device=device) * torch.sqrt(1 - gamma)
    return torch.sqrt(gamma) * theta_prime + noise

# Generate diffusion time and step size
def generate_diffusion_time(max_t, steps, device=None):
    time = torch.linspace(0, max_t, steps+1, dtype=torch.float32, device=device)[1:-1]  # exclude T
    gamma = torch.linspace(1, 0, steps+1, dtype=torch.float32, device=device)[1:-1]  # first index is close to 1, noise of the kernel

    alpha_t = torch.cat((gamma[0:1], gamma[1:] / gamma[:-1]), dim=0)
    delta_t = 0.3 * (1-alpha_t) / torch.sqrt(alpha_t)  # weighting for loss, and later used as step size in Langevin dynamics
    return time, gamma, delta_t

In [8]:
 # Loss function for weighted MSE
def weighted_mse_loss(inputs, targets, weights):
    return torch.sum(weights * (inputs - targets) ** 2)


def compute_score_loss(x_batch, theta_prime_batch, model, diffusion_time, gamma, delta_t, device=None):
    # sample a diffusion time for each sample in the batch
    t_index = torch.randint(0, len(diffusion_time), size=(x_batch.shape[0],), device=device)
    t = diffusion_time[t_index]
    g = gamma[t_index]
    w = delta_t[t_index]
    # sample from the Gaussian kernel
    theta_batch = gaussian_kernel_sample(theta_prime_batch, g, device=device)
    # calculate the score for the sampled theta
    score_pred = model(theta=theta_batch, t=t, x=x_batch)
    # calculate the reference score for the sampled theta
    score_ref = gaussian_kernel_score(theta_batch, theta_prime_batch, g)
    # calculate the loss
    loss = weighted_mse_loss(score_pred, score_ref, weights=w)
    return loss


# Training loop for Score Model
def train_score_model(model, dataloader, dataloader_valid=None, T=400, epochs=100, lr=1e-3, steps_diffusion_time=100, device=None):
    optimizer = optim.AdamW(model.parameters(), lr=lr)

    # Add Cosine Annealing Scheduler
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-6)

    # Generate diffusion time and step size
    diffusion_time, gamma, delta_t = generate_diffusion_time(max_t=T, steps=steps_diffusion_time, device=device)

    # Add a new dimension so that each tensor has shape (steps, 1)
    diffusion_time = diffusion_time.unsqueeze(1)
    gamma = gamma.unsqueeze(1)
    delta_t = delta_t.unsqueeze(1)

    # Training loop
    loss_history = np.zeros((epochs, 2))
    for epoch in range(epochs):
        total_loss = 0.0
        # for each sample in the batch, calculate the loss for a random diffusion time
        for theta_prime_batch, x_batch in dataloader:
            # initialize the gradients
            optimizer.zero_grad()
            # calculate the loss
            loss = compute_score_loss(x_batch, theta_prime_batch, model, diffusion_time, gamma, delta_t, device=device)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        scheduler.step()

        # validate the model
        valid_loss = 0.0
        if dataloader_valid is not None:
            for theta_prime_batch, x_batch in dataloader_valid:
                with torch.no_grad():
                    loss = compute_score_loss(x_batch, theta_prime_batch, model, diffusion_time, gamma, delta_t, device=device)
                    valid_loss += loss.item()

        loss_history[epoch] = [total_loss/len(dataloader), valid_loss/len(dataloader_valid)]
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(dataloader):.4f}, "
                   f"Valid Loss: {valid_loss/len(dataloader_valid):.4f}")
    return loss_history

In [None]:
# Hyperparameters
n_samples = 250000
batch_size = 256
T = 400
steps_time = T

score_model = ScoreModel(
    input_dim_theta=n_params, hidden_dim_theta=64,
    input_dim_x=n_obs_per_group, hidden_dim_x=64,
    hidden_dim_emb=64
)
score_model.to(mps_device)

# Create model and dataset
thetas, xs = generate_synthetic_data(n_samples, device=mps_device)
# Normalize data
thetas = (thetas - prior.prior_mean) / prior.prior_std
xs = (xs - prior.x_mean) / prior.x_std
# Create dataloader
dataset = TensorDataset(thetas, xs)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# create validation data
valid_prior, valid_data = generate_synthetic_data(1000, device=mps_device)
valid_data = (valid_data - prior.x_mean) / prior.x_std
valid_prior = (valid_prior - prior.prior_mean) / prior.prior_std
dataset_valid = TensorDataset(valid_prior, valid_data)
dataloader_valid = DataLoader(dataset_valid, batch_size=batch_size, shuffle=True)

# Train model
loss_history = train_score_model(score_model, dataloader, dataloader_valid=dataloader_valid,
                  T=T, epochs=200, lr=1e-3, steps_diffusion_time=steps_time, device=mps_device)

In [None]:
# plot loss history
plt.plot(loss_history[:, 0], label='Train')
plt.plot(loss_history[:, 1], label='Valid')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
# Annealed Langevin Dynamics for Sampling
def langevin_sampling(model, x_obs, n_post_samples, steps=5, device=None):
    x_obs_norm = (x_obs - prior.x_mean) / prior.x_std  # assumes x_obs is not standardized

    # Initialize parameters
    n_obs = x_obs_norm.shape[0]
    theta = torch.randn(n_post_samples, n_params, device=device) / torch.sqrt(torch.tensor(n_obs, dtype=torch.float32, device=device))

    # Generate diffusion time parameters
    diffusion_time, gamma, delta_t = generate_diffusion_time(max_t=T, steps=steps_time)

    # Ensure x_obs_norm is a PyTorch tensor
    if not isinstance(x_obs_norm, torch.Tensor):
        x_obs_norm = torch.tensor(x_obs_norm, dtype=torch.float32, device=device)

    # Expand x_obs_norm to match the number of posterior samples
    x_exp = x_obs_norm.unsqueeze(0).expand(n_post_samples, -1, -1)  # Shape: (n_post_samples, n_obs, d)
    x_expanded = x_exp.reshape(-1, x_obs_norm.shape[-1])

    # Reverse iterate over diffusion times and step sizes
    for step_size, t in zip(delta_t.flip(0), diffusion_time.flip(0)):
        # Create tensor for current time step
        t_tensor = torch.full((n_post_samples, 1), t, dtype=torch.float32, device=device)
        t_exp = t_tensor.unsqueeze(1).expand(-1, n_obs, -1).reshape(-1, 1)

        for _ in range(steps):
            # Sample Gaussian noise
            eps = torch.randn_like(theta, dtype=torch.float32, device=device)

            # Compute prior score
            prior_score = prior.score_global_batch(theta)

            # Compute model scores
            theta_exp = theta.unsqueeze(1).expand(-1, n_obs, -1).reshape(-1, n_params)
            model_scores = model(theta_exp, t=t_exp, x=x_expanded)
            model_scores = model_scores.reshape(n_post_samples, n_obs, -1).sum(dim=1)

            # Compute updated scores and perform Langevin step
            scores = (1 - n_obs) * (T - t) / T * prior_score + model_scores
            theta = theta + (step_size / 2) * scores + torch.sqrt(step_size) * eps
     # correct for normalization
    theta = theta * prior.prior_std + prior.prior_mean
    # convert to numpy
    theta = theta.detach().numpy()
    return theta

# Validation

In [None]:
from bayesflow import diagnostics

In [None]:
valid_prior, valid_data = generate_synthetic_data(10, schools_joint=True, device=mps_device)

In [None]:
posterior_samples_valid = np.array([langevin_sampling(score_model, vd, n_post_samples=100, device=mps_device)
                                    for vd in valid_data])

In [None]:
diagnostics.plot_recovery(posterior_samples_valid, np.array(valid_prior), param_names=[r'$\mu$', r'$\log \tau$']);

In [None]:
diagnostics.plot_sbc_ecdf(posterior_samples_valid, np.array(valid_prior),
                          difference=True, param_names=[r'$\mu$', r'$\log \tau$']);

# Apply on Data

In [None]:
# Generate posterior samples
test_data = torch.tensor(np.concatenate((y, sigma[:, np.newaxis]), axis=-1), dtype=torch.float32, device=mps_device)
posterior_samples = langevin_sampling(score_model, test_data, n_post_samples=valid_prior.shape[0])
print("Sampled posterior parameters:", posterior_samples)

In [None]:
diagnostics.plot_posterior_2d(posterior_samples, prior_draws=valid_prior, param_names=[r'$\mu$', r'$\log \tau$']);

In [None]:
posterior_samples.mean(axis=0)

## Stan inference Results
$\mu = 5.836806$

$\log\tau = 2.450053$

$\theta =
[ 0.64940756,  0.09001582, -0.23279844,  0.04471902, -0.33542507, -0.2041105,  0.53249937,  0.14456798]$

https://github.com/blei-lab/edward/blob/master/notebooks/eight_schools.ipynb