# Simple Test for Hierarchical ABI with compositional score matching

In [1]:
import numpy as np
import math
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

In [2]:
# eight schools problem
J = 8
y = np.array([28, 8, -3, 7, -1, 1, 18, 12])[:, np.newaxis]  # our groups, one observation per group
sigma = np.array([15, 10, 16, 11, 9, 11, 10, 18])  # assumed to be known
n_obs_per_group = 2  # sigma is known, so part of the observation

In [3]:
# simulator example
class Prior:
    def __init__(self):
        self.mu_mean = 0
        self.mu_std = 10
        self.log_tau_mean = 5
        self.log_tau_std = 1

    def __call__(self, batch_size):
        return self.sample(batch_size)

    def sample(self, batch_size):
        mu = np.random.normal(loc=self.mu_mean, scale=self.mu_std, size=(batch_size,1))
        log_tau = np.random.normal(loc=self.log_tau_mean, scale=self.log_tau_std, size=(batch_size,1))
        theta_j = np.random.normal(loc=mu, scale=np.exp(log_tau), size=(batch_size, J))
        return dict(mu=mu, log_tau=log_tau, theta_j=theta_j)

    def log_score_global(self, theta):
        """ Computes the global score for a single parameter set (mu, log_tau). """
        mu, log_tau = theta  # Assuming theta is a 1D tensor of shape (2,)
        # Gradient w.r.t mu
        grad_logp_mu = -(mu - self.mu_mean) / self.mu_std**2
        # Gradient w.r.t log_tau
        grad_logp_tau = -(log_tau - self.log_tau_mean) / self.log_tau_std**2

        return torch.tensor([grad_logp_mu, grad_logp_tau], dtype=torch.float32)

    def score_global_batch(self, theta_batch):
        """ Computes the global score for a batch of parameters. """
        return torch.stack([self.log_score_global(theta) for theta in theta_batch])

    @staticmethod
    def log_score_local(theta):
        """ Computes the local score for a single sample theta_j. """
        mu, log_tau, theta_j = theta[0], theta[1], theta[2:]
        # Gradient w.r.t theta_j
        grad_logp_theta_j = -(theta_j - mu) / np.exp(log_tau)**2
        return torch.tensor([grad_logp_theta_j], dtype=torch.float32)

    def score_local_batch(self, theta_batch):
        """ Computes the local score for a batch of samples. """
        return torch.stack([self.log_score_local(theta) for theta in theta_batch])


def simulator(params, school_i=None):
    batch_size = params['mu'].shape[0]
    y_j = np.random.normal(loc=params['theta_j'], scale=sigma, size=(batch_size, J))
    if school_i is None:
        return dict(observable=y_j, sigma=np.tile(sigma, (batch_size, 1)))
    return dict(observable=y_j[:, school_i][:, np.newaxis], sigma=(np.ones(batch_size)*sigma[school_i])[:, np.newaxis])

prior = Prior()
n_params = 2

In [4]:
prior(2)

{'mu': array([[ 17.73374141],
        [-11.69914719]]),
 'log_tau': array([[6.14189388],
        [3.32845105]]),
 'theta_j': array([[-343.78364777, -240.35043144, -147.70232476,  -88.72958342,
         -476.47217915, -310.55506431, -217.53907339,  695.7704519 ],
        [ -11.89013417,    3.62698016,  -15.28757231,  -38.68051892,
          -27.38451374,   18.97765665,  -33.15087166,  -14.43514689]])}

In [5]:
simulator(prior(2), school_i=None)

{'observable': array([[ -923.87378975,  -413.6135652 ,  -555.35685577,  1330.28983981,
          1016.50080995,  -406.26270439, -1115.9590014 ,   928.0417485 ],
        [   89.37077683,    36.74712177,  -214.62680435,   104.55501516,
            57.245894  ,   -18.87897402,   -13.36876791,  -231.88165267]]),
 'sigma': array([[15, 10, 16, 11,  9, 11, 10, 18],
        [15, 10, 16, 11,  9, 11, 10, 18]])}

In [6]:
# Define the Score Model based on F-NPSE
class ResidualBlock(nn.Module):
    def __init__(self, in_features, out_features):
        super(ResidualBlock, self).__init__()
        self.linear = nn.Linear(in_features, out_features)
        self.norm = nn.LayerNorm(out_features)
        self.activation = nn.ReLU()
        self.proj = nn.Linear(in_features, out_features) if in_features != out_features else nn.Identity()

    def forward(self, x):
        identity = self.proj(x)
        out = self.linear(x)
        out = self.norm(out)
        out = self.activation(out)
        return out + identity

class ScoreModel(nn.Module):
    def __init__(self, input_dim_theta, hidden_dim_theta,
                 input_dim_x, hidden_dim_x,
                 hidden_dim_emb):
        super(ScoreModel, self).__init__()
        self.net_theta = nn.Sequential(
            ResidualBlock(input_dim_theta, hidden_dim_theta),
            ResidualBlock(hidden_dim_theta, hidden_dim_theta),
            ResidualBlock(hidden_dim_theta, input_dim_theta)
        )

        self.net_x = nn.Sequential(
            ResidualBlock(input_dim_x, hidden_dim_x),
            ResidualBlock(hidden_dim_x, hidden_dim_x),
            ResidualBlock(hidden_dim_x, input_dim_theta)
        )

        self.net_emb = nn.Sequential(
            ResidualBlock(input_dim_theta*2+1, hidden_dim_emb),
            ResidualBlock(hidden_dim_emb, hidden_dim_emb),
            ResidualBlock(hidden_dim_emb, input_dim_theta)
        )

    def forward(self, theta, t, x):
        theta_emb = self.net_theta(theta)
        x_emb = self.net_x(x)
        t_emb = t # todo: add positional encoding
        return self.net_emb(torch.cat([theta_emb, x_emb, t_emb], axis=-1))

In [7]:
def generate_synthetic_data(n_samples, schools_joint=False):
    thetas = []
    xs = []
    for i in range(n_samples):
        batch_params = prior(1)
        if schools_joint:
            sim_batch = simulator(batch_params)
            theta = torch.tensor(np.concatenate([batch_params['mu'], batch_params['log_tau']], axis=-1), dtype=torch.float32)
            x = torch.tensor(np.stack((sim_batch['observable'], sim_batch['sigma']), axis=-1), dtype=torch.float32)
        else:
            sim_batch = simulator(batch_params, school_i=i % J)
            theta = torch.tensor(np.concatenate([batch_params['mu'], batch_params['log_tau']], axis=-1), dtype=torch.float32)
            x = torch.tensor(np.concatenate((sim_batch['observable'], sim_batch['sigma']), axis=-1), dtype=torch.float32)

        thetas.append(theta)
        xs.append(x)
    thetas = torch.concatenate(thetas)
    xs = torch.concatenate(xs)
    return thetas, xs


# Gaussian kernel for log pdf and sampling
def gaussian_kernel_score(theta, theta_prime, gamma):
    return -(theta - torch.sqrt(gamma) * theta_prime) / (1 - gamma)


def gaussian_kernel_sample(theta_prime, gamma):
    noise = torch.randn_like(theta_prime) * torch.sqrt(1 - gamma)
    return torch.sqrt(gamma) * theta_prime + noise

# Generate diffusion time and step size
def generate_diffusion_time(T):
    time = np.linspace(0, 1, T+1)[1:-1]  # exclude 0 and T
    gamma = time[::-1].copy()  # first index is close to 1

    alpha_t = np.concatenate(([gamma[0]], gamma[1:] / gamma[:-1]))
    delta_t = 0.3 * (1-alpha_t) / np.sqrt(alpha_t)  # weighting for loss, and later used as step size in Langevin dynamics
    return time, gamma, delta_t

In [None]:
# Weighted MSE loss
def weighted_mse_loss(pred, target, weights):
    return torch.sum(weights * (pred - target)**2) / torch.sum(weights)

# Training loop for Score Model
def train_score_model(model, dataloader, T, epochs=100, lr=1e-3):
    optimizer = optim.Adam(model.parameters(), lr=lr)

    diffusion_time, gamma, delta_t = generate_diffusion_time(T)
    criterion = nn.MSELoss()

    diffusion_time = torch.tensor(diffusion_time, dtype=torch.float32)
    gamma = torch.tensor(gamma, dtype=torch.float32)
    for epoch in range(epochs):
        total_loss = 0.0
        for theta_prime_batch, x_batch in dataloader:
            loss = 0.0
            optimizer.zero_grad()
            # for each sample in the batch, calculate the loss for each diffusion time
            for g, t in zip(gamma, diffusion_time):
                # sample from the Gaussian kernel
                theta_batch = gaussian_kernel_sample(theta_prime_batch, g)
                # calculate the score for the sampled theta
                score_pred = model(theta=theta_batch, t=torch.ones((theta_batch.shape[0], 1))*t, x=x_batch)
                # calculate the reference score for the sampled theta
                score_ref = gaussian_kernel_score(theta_batch, theta_prime_batch, g)
                # calculate the loss
                loss += criterion(score_pred, score_ref)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(dataloader):.4f}")

# Hyperparameters
n_samples = 1000
batch_size = 128
T = 100

score_model = ScoreModel(
    input_dim_theta=n_params, hidden_dim_theta=16,
    input_dim_x=n_obs_per_group, hidden_dim_x=16,
    hidden_dim_emb=16
)

# Create model and dataset
thetas, xs = generate_synthetic_data(n_samples)
dataset = TensorDataset(thetas, xs)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Train model
train_score_model(score_model, dataloader, T=T, epochs=100, lr=1e-3)

Epoch 10/100, Loss: 514.8203
Epoch 20/100, Loss: 501.5399
Epoch 30/100, Loss: 500.4289
Epoch 40/100, Loss: 492.0512
Epoch 50/100, Loss: 485.9502


In [None]:
# Annealed Langevin Dynamics for Sampling
def langevin_sampling(model, x_obs, n_post_samples, steps=5):
    # Initialize parameters
    n_obs = x_obs.shape[0]
    theta = torch.randn(n_post_samples, n_params) / torch.sqrt(torch.tensor(n_obs, dtype=torch.float32))

    # Generate diffusion time parameters
    diffusion_time, gamma, delta_t = generate_diffusion_time(T)

    # Ensure x_obs is a PyTorch tensor
    if not isinstance(x_obs, torch.Tensor):
        x_obs = torch.tensor(x_obs, dtype=torch.float32)

    # Expand x_obs to match the number of posterior samples
    x_exp = x_obs.unsqueeze(0).expand(n_post_samples, -1, -1)  # Shape: (n_post_samples, n_obs, d)
    x_expanded = x_exp.reshape(-1, x_obs.shape[-1])

    # Reverse iterate over diffusion times and step sizes
    for step_size, t in zip(delta_t[::-1], diffusion_time[::-1]):
        # Create tensor for current time step
        t_tensor = torch.full((n_post_samples, 1), t, dtype=torch.float32)
        t_exp = t_tensor.unsqueeze(1).expand(-1, n_obs, -1).reshape(-1, 1)

        for _ in range(steps):
            # Sample Gaussian noise
            eps = torch.randn_like(theta)

            # Compute prior score
            prior_score = prior.score_batch(theta)

            # Compute model scores
            theta_exp = theta.unsqueeze(1).expand(-1, n_obs, -1).reshape(-1, n_params)
            model_scores = model(theta_exp, t=t_exp, x=x_expanded)
            model_scores = model_scores.reshape(n_post_samples, n_obs, -1).sum(dim=1)

            # Compute updated scores and perform Langevin step
            scores = (1 - n_obs) * (1 - t) * prior_score + model_scores
            theta = theta + (step_size / 2) * scores + np.sqrt(step_size) * eps

    return theta.detach().numpy()

# Validation

In [None]:
from bayesflow import diagnostics

In [None]:
valid_prior, valid_data = generate_synthetic_data(10, schools_joint=True)

In [None]:
posterior_samples_valid = np.array([langevin_sampling(score_model, vd, n_post_samples=100) for vd in valid_data])

In [None]:
diagnostics.plot_recovery(posterior_samples_valid, np.array(valid_prior), param_names=['mu', 'log_tau']);

In [None]:
diagnostics.plot_sbc_ecdf(posterior_samples_valid, np.array(valid_prior), difference=True, param_names=['mu', 'log_tau']);

# Apply on Data

In [None]:
# Generate posterior samples
test_data = torch.tensor(np.concatenate((y, sigma[:, np.newaxis]), axis=-1), dtype=torch.float32)
theta_samples = langevin_sampling(score_model, test_data, n_post_samples=100)
print("Sampled posterior parameters:", theta_samples)