# Simple Test for Hierarchical ABI with compositional score matching

In [None]:
import numpy as np
import math
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

In [None]:
# eight schools problem
J = 8
y = np.array([28, 8, -3, 7, -1, 1, 18, 12])[:, np.newaxis]  # our groups, one observation per group
sigma = np.array([15, 10, 16, 11, 9, 11, 10, 18])  # assumed to be known
n_obs_per_group = 2  # sigma is known, so part of the observation

In [None]:
# simulator example
class Prior:
    def __call__(self, batch_size):
        return self.sample(batch_size)

    @staticmethod
    def sample(batch_size):
        log_mu = np.random.normal(loc=0, scale=1, size=(batch_size,1))
        log_tau = np.random.normal(loc=0, scale=1, size=(batch_size,1))
        return dict(log_mu=log_mu, log_tau=log_tau)

    @staticmethod
    def log_score(theta):
        log_mu = theta['log_mu']
        log_tau = theta['log_tau']
        grad_logp_mu = -log_mu
        grad_logp_tau = -log_tau
        return np.array([grad_logp_mu, grad_logp_tau])

    def score_batch(self, theta_batch):
        return np.concatenate([[self.log_score(dict(log_mu=theta[0], log_tau=theta[1]))]
                               for theta in theta_batch], axis=0)

def simulator(params, school_i=None):
    batch_size = params['log_mu'].shape[0]
    theta_j = np.random.normal(loc=np.exp(params['log_mu']), scale=np.exp(params['log_tau']), size=(batch_size, J))
    y_j = np.random.normal(loc=theta_j, scale=sigma, size=(batch_size, J))
    if school_i is None:
        return dict(observable=y_j, sigma=np.tile(sigma, (batch_size, 1)))
    return dict(observable=y_j[:, school_i][:, np.newaxis], sigma=(np.ones(batch_size)*sigma[school_i])[:, np.newaxis])

prior = Prior()
n_params = len(prior(2).keys())

In [None]:
prior(2)

In [None]:
simulator(prior(2), school_i=None)

In [None]:
# Define the Score Model based on F-NPSE
class ResidualBlock(nn.Module):
    def __init__(self, in_features, out_features):
        super(ResidualBlock, self).__init__()
        self.linear = nn.Linear(in_features, out_features)
        self.norm = nn.LayerNorm(out_features)
        self.activation = nn.ReLU()
        self.proj = nn.Linear(in_features, out_features) if in_features != out_features else nn.Identity()

    def forward(self, x):
        identity = self.proj(x)
        out = self.linear(x)
        out = self.norm(out)
        out = self.activation(out)
        return out + identity

class ScoreModel(nn.Module):
    def __init__(self, input_dim_theta, hidden_dim_theta,
                 input_dim_x, hidden_dim_x,
                 hidden_dim_emb):
        super(ScoreModel, self).__init__()
        self.net_theta = nn.Sequential(
            ResidualBlock(input_dim_theta, hidden_dim_theta),
            ResidualBlock(hidden_dim_theta, hidden_dim_theta),
            ResidualBlock(hidden_dim_theta, input_dim_theta)
        )

        self.net_x = nn.Sequential(
            ResidualBlock(input_dim_x, hidden_dim_x),
            ResidualBlock(hidden_dim_x, hidden_dim_x),
            ResidualBlock(hidden_dim_x, input_dim_theta)
        )

        self.net_emb = nn.Sequential(
            ResidualBlock(input_dim_theta*2+1, hidden_dim_emb),
            ResidualBlock(hidden_dim_emb, hidden_dim_emb),
            ResidualBlock(hidden_dim_emb, input_dim_theta)
        )

    def forward(self, theta, t, x):
        theta_emb = self.net_theta(theta)
        x_emb = self.net_x(x)
        t_emb = t # todo: add positional encoding
        return self.net_emb(torch.cat([theta_emb, x_emb, t_emb], axis=-1))

In [None]:
def generate_synthetic_data(n_samples, schools_joint=False):
    thetas = []
    xs = []
    for i in range(n_samples):
        batch_params = prior(1)
        if schools_joint:
            sim_batch = simulator(batch_params)
            theta = torch.tensor(np.concatenate([batch_params['log_mu'], batch_params['log_tau']], axis=-1), dtype=torch.float32)
            x = torch.tensor(np.stack((sim_batch['observable'], sim_batch['sigma']), axis=-1), dtype=torch.float32)
        else:
            sim_batch = simulator(batch_params, school_i=i % J)
            theta = torch.tensor(np.concatenate([batch_params['log_mu'], batch_params['log_tau']], axis=-1), dtype=torch.float32)
            x = torch.tensor(np.concatenate((sim_batch['observable'], sim_batch['sigma']), axis=-1), dtype=torch.float32)

        thetas.append(theta)
        xs.append(x)
    thetas = torch.concatenate(thetas)
    xs = torch.concatenate(xs)
    return thetas, xs


# Gaussian kernel for log pdf and sampling
def gaussian_kernel_log_pdf(theta, theta_prime, gamma):
    d = theta.size(1)
    mean = torch.sqrt(gamma) * theta_prime
    diff = theta - mean
    exponent = -0.5 * diff**2 / (1 - gamma)
    norm = (2 * math.pi * (1 - gamma)) ** (d / 2)
    return exponent - torch.log(norm)

def gaussian_kernel_sample(theta_prime, gamma):
    noise = torch.randn_like(theta_prime) * torch.sqrt(1 - gamma)
    return torch.sqrt(gamma) * theta_prime + noise

# Generate diffusion time and step size
def generate_diffusion_time(T):
    time = np.linspace(0, 1, T+1)[1:-1]  # exclude 0 and T
    gamma = time[::-1].copy()  # first index is close to 1

    alpha_t = np.concatenate(([gamma[0]], gamma[1:] / gamma[:-1]))
    delta_t = 0.3 * (1-alpha_t) / np.sqrt(alpha_t)  # weighting for loss, and later used as step size in Langevin dynamics
    return time, gamma, delta_t

In [None]:
# Weighted MSE loss
def weighted_mse_loss(pred, target, weights):
    return torch.sum(weights * (pred - target)**2) / torch.sum(weights)

# Training loop for Score Model
def train_score_model(model, dataloader, T, epochs=100, lr=1e-3):
    optimizer = optim.Adam(model.parameters(), lr=lr)

    diffusion_time, gamma, delta_t = generate_diffusion_time(T)
    criterion = nn.MSELoss()

    diffusion_time = torch.tensor(diffusion_time, dtype=torch.float32)
    gamma = torch.tensor(gamma, dtype=torch.float32)
    for epoch in range(epochs):
        total_loss = 0.0
        for theta_prime_batch, x_batch in dataloader:
            loss = 0.0
            # for each sample in the batch, calculate the loss for each diffusion time
            for g, t in zip(gamma, diffusion_time):
                optimizer.zero_grad()
                # sample from the Gaussian kernel
                theta_batch = gaussian_kernel_sample(theta_prime_batch, g)
                # calculate the score for the sampled theta
                score_pred = model(theta=theta_batch, t=torch.ones((theta_batch.shape[0], 1))*t, x=x_batch)
                # calculate the reference score for the sampled theta
                score_ref = gaussian_kernel_log_pdf(theta_batch, theta_prime_batch, t)
                # calculate the loss
                loss += criterion(score_pred, score_ref)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(dataloader):.4f}")

# Hyperparameters
n_samples = 1000
batch_size = 128
T = 100

score_model = ScoreModel(
    input_dim_theta=n_params, hidden_dim_theta=16,
    input_dim_x=n_obs_per_group, hidden_dim_x=16,
    hidden_dim_emb=16
)

# Create model and dataset
thetas, xs = generate_synthetic_data(n_samples)
dataset = TensorDataset(thetas, xs)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Train model
train_score_model(score_model, dataloader, T=T, epochs=100, lr=1e-3)

In [None]:
# Annealed Langevin Dynamics for Sampling
def langevin_sampling(model, x_obs, n_post_samples, steps=5):
    n = x_obs.shape[0]
    theta = np.random.normal(loc=0, scale=np.ones(n_params)/np.sqrt(n), size=(n_post_samples, n_params))
    theta_torch = torch.tensor(theta, dtype=torch.float32)

    # Generate diffusion time and step size
    diffusion_time, gamma, delta_t = generate_diffusion_time(T)
    for step_size, t in zip(delta_t[::-1], diffusion_time[::-1]):  # reverse order
        for step in range(steps):
            # sample noise
            eps = torch.tensor(np.random.normal(loc=0, scale=np.ones(n_params), size=(n_post_samples, n_params)),
                               dtype=torch.float32)
            t_tensor = torch.ones((n_post_samples, 1)) * t
            # calculate the prior score
            scores = torch.tensor((1-n)*(1-t) * prior.score_batch(theta), dtype=torch.float32)
            for x in x_obs:
                # concat x as often as n_post_samples
                x_tensor = torch.ones((n_post_samples, 1)) * x
                # calculate the model score for each observation at the current theta and diffusion time
                scores += model(theta_torch, t=t_tensor, x=x_tensor)
            # update theta using Langevin step
            theta_torch = theta_torch + step_size/2 * scores + np.sqrt(step_size) * eps
    return theta_torch.detach().numpy()

In [None]:
# Generate posterior samples
test_data = torch.tensor(np.concatenate((y, sigma[:, np.newaxis]), axis=-1), dtype=torch.float32)
theta_samples = langevin_sampling(score_model, test_data, n_post_samples=5)
print("Sampled posterior parameters:", theta_samples)

In [None]:
test_data.shape

In [None]:
valid_prior, valid_data = generate_synthetic_data(10, schools_joint=True)
posterior_samples_valid = np.array([langevin_sampling(score_model, vd, n_post_samples=5) for vd in valid_data])

In [None]:
posterior_samples_valid

In [None]:
from bayesflow import diagnostics

In [None]:
diagnostics.plot_recovery(posterior_samples_valid, np.array(valid_prior), param_names=['mu', 'log_tau']);

In [None]:
diagnostics.plot_sbc_ecdf(posterior_samples_valid, np.array(valid_prior), difference=True, param_names=['mu', 'log_tau']);