# Gaussian on a Grid Test for Hierarchical ABI with compositional score matching

In [None]:
import math
from functools import partial

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from bayesflow import diagnostics
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

In [None]:
torch_device = torch.device("cpu")

In [None]:
class Simulator:
    def __init__(self, n_grid=8):
        self.n_grid = n_grid
        self.max_time = 100
        self.n_time_points = 10  # number of observation points to return
        self.dt = 0.1            # simulation time step
        self.n_time_steps = int(self.max_time / self.dt)  # number of simulation steps

    def __call__(self, params):
        """
        Simulate Brownian motion with drift.

        The SDE is:
            dx(t) = mu * dt + tau * sqrt(dt) * dW(t)
        starting from 0.

        The simulation runs for self.n_time_steps steps (with step dt)
        and then returns self.n_time_points evenly spaced observations
        between 0 and self.max_time.

        The parameter dict 'params' must contain:
            - 'mu': drift coefficient
            - 'log_tau': log of the diffusion coefficient

        These parameters can be provided as:
            - A scalar (for a single grid element),
            - A 2D array of shape (batch_size, 1) or (batch_size, n_grid)
              for batch simulations over a one-dimensional grid,
            - A 3D array of shape (batch_size, n_grid, n_grid)
              for batch simulations over a two-dimensional grid.
        """
        # Convert parameters to numpy arrays.
        theta = np.array(params['theta'])

        # Determine simulation mode and grid shape.
        if theta.ndim in (0,1):
            # Scalar: simulate a single grid element.
            grid_shape = (1,1)
            theta = np.full(grid_shape, theta)
        elif theta.ndim == 2:
            # 2D array: shape (batch_size, d) where d==1.
            if theta.shape[1] == 1:
                grid_shape = (1,1)
            else:
                raise ValueError("For 2D 'theta', the second dimension must be 1.")
        elif theta.ndim == 3:
            # 3D array: shape (batch_size, n_grid, n_grid)
            if theta.shape[1] != self.n_grid or theta.shape[2] != self.n_grid:
                raise ValueError("For 3D 'theta', the second and third dimensions must equal n_grid.")
            grid_shape = (self.n_grid, self.n_grid)
        else:
            raise ValueError("Parameter 'theta' must be provided as a scalar, 2D array, or 3D array.")
        batch_size = theta.shape[0]

        # Simulate the full trajectory.
        # The noise will have shape: (batch_size, n_time_steps, *grid_shape)
        noise_shape = (batch_size, self.n_time_steps) + grid_shape
        noise = np.random.normal(loc=0, scale=1, size=noise_shape)

        # Expand mu and tau to include a time axis.
        if theta.ndim in (1, 2):
            # mu and tau have shape (batch_size, grid) in the 2D case
            # For a scalar, we already set them to shape (1,)
            # Expand to (batch_size, 1, grid)
            if batch_size == 1:
                # Ensure shape is (1, 1, grid)
                theta_expanded = theta[np.newaxis, np.newaxis, :]
            else:
                theta_expanded = theta[:, np.newaxis, np.newaxis, :]
        else:
            # For 3D parameters, mu and tau have shape (batch_size, n_grid, n_grid)
            # Expand to (batch_size, 1, n_grid, n_grid)
            theta_expanded = theta[:, np.newaxis, :, :]

        # Compute increments:
        #   increment = mu * dt + tau * sqrt(dt) * noise
        increments = theta_expanded * self.dt + 1 * np.sqrt(self.dt) * noise

        # Initial condition: zeros with shape (batch_size, 1, *grid_shape)
        x0 = np.zeros((batch_size, 1) + grid_shape)
        # Full trajectory: shape (batch_size, n_time_steps+1, *grid_shape)
        traj_full = np.concatenate([x0, np.cumsum(increments, axis=1)], axis=1)

        # Sample self.n_time_points evenly spaced indices from the full trajectory.
        # These indices span from 0 to self.n_time_steps.
        indices = np.linspace(self.n_time_points, self.max_time, self.n_time_points, dtype=int)
        traj_sampled = traj_full[:, indices, ...]  # shape: (batch_size, n_time_points, *grid_shape)

        if theta.ndim == 2:  # just one grid element
            traj_sampled = traj_sampled.reshape(batch_size, self.n_time_points, 1)
        return dict(observable=traj_sampled)

class Prior:
    def __init__(self):
        self.mu_mean = 0
        self.mu_std = 3
        self.log_tau_mean = 0
        self.log_tau_std = 1

        np.random.seed(0)
        test_prior = self.sample_single(1000)
        self.simulator = Simulator()
        test = self.simulator(test_prior,)
        self.x_mean = torch.tensor([np.mean(test['observable'])], dtype=torch.float32, device=torch_device)
        self.x_std = torch.tensor([np.std(test['observable'])], dtype=torch.float32, device=torch_device)
        self.prior_global_mean = torch.tensor(np.array([np.mean(test_prior['mu']), np.mean(test_prior['log_tau'])]),
                                              dtype=torch.float32, device=torch_device)
        self.prior_global_std = torch.tensor(np.array([np.std(test_prior['mu']), np.std(test_prior['log_tau'])]),
                                             dtype=torch.float32, device=torch_device)
        self.prior_local_mean = torch.tensor(np.array([np.mean(test_prior['theta'])]),
                                             dtype=torch.float32, device=torch_device)
        self.prior_local_std = torch.tensor(np.array([np.std(test_prior['theta'])]),
                                            dtype=torch.float32, device=torch_device)

    def __call__(self, batch_size):
        return self.sample_single(batch_size)

    def sample_single(self, batch_size):
        mu = np.random.normal(loc=self.mu_mean, scale=self.mu_std, size=(batch_size,1))
        log_tau = np.random.normal(loc=self.log_tau_mean, scale=self.log_tau_std, size=(batch_size,1))
        theta = np.random.normal(loc=mu, scale=np.exp(log_tau), size=(batch_size, 1))
        return dict(mu=mu, log_tau=log_tau, theta=theta)

    def sample_full(self, batch_size):
        mu = np.random.normal(loc=self.mu_mean, scale=self.mu_std, size=(batch_size, 1))
        log_tau = np.random.normal(loc=self.log_tau_mean, scale=self.log_tau_std, size=(batch_size, 1))
        theta = np.random.normal(loc=mu[:, np.newaxis], scale=np.exp(log_tau)[:, np.newaxis],
                                 size=(batch_size, self.simulator.n_grid, self.simulator.n_grid))
        return dict(mu=mu, log_tau=log_tau, theta=theta)

    def score_global_batch(self, theta_batch_norm, condition_norm=None):
        """ Computes the global score for a batch of parameters."""
        theta_batch = theta_batch_norm * self.prior_global_std + self.prior_global_mean
        mu, log_tau = theta_batch[..., 0], theta_batch[..., 1]
        grad_logp_mu = -(mu - self.mu_mean) / (self.mu_std**2)
        grad_logp_tau = -(log_tau - self.log_tau_mean) / (self.log_tau_std**2)
        # correct the score for the normalization
        score = torch.stack([grad_logp_mu, grad_logp_tau], dim=-1)
        return score / self.prior_global_std

    def score_local_batch(self, theta_batch_norm, condition_norm):
        """ Computes the local score for a batch of samples. """
        theta = theta_batch_norm * self.prior_local_std + self.prior_local_mean
        condition = condition_norm * self.prior_global_std + self.prior_global_mean
        mu, log_tau = condition[..., 0], condition[..., 1]
        # Gradient w.r.t theta conditioned on mu and log_tau
        grad_logp_theta = -(theta - mu) / torch.exp(log_tau*2)
        # correct the score for the normalization
        score = grad_logp_theta / self.prior_local_std
        return score

prior = Prior()
n_params_global = 2
n_params_local = 1

In [None]:
prior(2)

In [None]:
def generate_synthetic_data(n_samples, n_grid=8, full_grid=False, device=None):
    if full_grid:
        batch_params = prior.sample_full(n_samples)
    else:
        batch_params = prior.sample_single(n_samples)
    simulator = Simulator(n_grid=n_grid)
    sim_batch = simulator(batch_params)

    param_global = torch.tensor(np.concatenate([batch_params['mu'], batch_params['log_tau']], axis=1),
                                dtype=torch.float32, device=device)
    param_local = torch.tensor(batch_params['theta'], dtype=torch.float32, device=device)
    data = torch.tensor(sim_batch['observable'], dtype=torch.float32, device=device)
    return param_global, param_local, data

In [None]:
def visualize_simulation_output(sim_output, title_prefix="Time", cmap="viridis"):
    """
    Visualize the full simulation trajectory on a grid of subplots.

    Parameters:
        sim_output (np.ndarray): Simulation trajectory output.
            For a single simulation, it can be either:
              - 2D: shape (n_time_points, grid_size) for a 1D grid, or
              - 3D: shape (n_time_points, n_grid, n_grid) for a 2D grid.
            For batched simulations, the shape is:
              - 3D: (batch_size, n_time_points, grid_size) or
              - 4D: (batch_size, n_time_points, n_grid, n_grid).
            In such cases, only the first simulation (i.e. first batch element) is visualized.
        title_prefix (str, list): Prefix for subplot titles.
        cmap (str): Colormap for imshow when visualizing 2D grid outputs.
    """
    # If a batch dimension is present, select the first simulation.
    if sim_output.ndim == 4:
        # (batch_size, n_time_points, n_grid, n_grid)
        sim_output = sim_output[0]

    # Determine number of time points.
    n_time_points = sim_output.shape[0]

    # Automatically choose grid layout (approximate square).
    n_cols = n_time_points
    n_rows = 1

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(3 * n_cols, 3 * n_rows))
    # Flatten axes array in case it's 2D.
    axes = np.array(axes).reshape(-1)

    for i in range(n_time_points):
        ax = axes[i]
        # Check if the grid is 1D or 2D.
        # 2D grid: shape (n_time_points, n_grid, n_grid)
        im = ax.imshow(sim_output[i], cmap=cmap, vmin=sim_output.min(), vmax=sim_output.max())
        if isinstance(title_prefix, list):
            ax.set_title(title_prefix[i])
        else:
            ax.set_title(f"{title_prefix} {i}")
        fig.colorbar(im, ax=ax)

    # Hide any unused subplots.
    for j in range(n_time_points, len(axes)):
        axes[j].axis('off')

    plt.tight_layout()
    plt.show()
    return

test = prior.sample_full(1)
simulator_test = Simulator()
sim_test = simulator_test(test)['observable']
visualize_simulation_output(sim_test)

In [None]:
def positional_encoding(t, d_model, max_t=1000.0):
    """
    Computes the sinusoidal positional encoding for a given time t.

    Args:
        t (torch.Tensor): The input time tensor of shape (batch_size, 1).
        d_model (int): The dimensionality of the embedding.

    Returns:
        torch.Tensor: The positional encoding of shape (batch_size, d_model).
    """
    half_dim = d_model // 2
    div_term = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=t.device) *
                         -(math.log(max_t) / (half_dim - 1)))
    t_proj = t * div_term
    pos_enc = torch.cat([torch.sin(t_proj), torch.cos(t_proj)], dim=-1)
    return pos_enc

class ConditionalResidualBlock(nn.Module):
    def __init__(self, hidden_dim, cond_dim, dropout=0.1):
        super(ConditionalResidualBlock, self).__init__()
        # First linear layer that takes [hidden state; conditioning]
        self.fc1 = nn.Linear(hidden_dim + cond_dim, hidden_dim)
        self.norm1 = nn.LayerNorm(hidden_dim)
        # Second linear layer also takes [hidden state; conditioning]
        self.fc2 = nn.Linear(hidden_dim + cond_dim, hidden_dim)
        self.norm2 = nn.LayerNorm(hidden_dim)
        self.activation = nn.SiLU()  # same as swish
        self.dropout = nn.Dropout(dropout)

        # Apply spectral normalization
        self.fc1 = nn.utils.parametrizations.spectral_norm(self.fc1)
        self.fc2 = nn.utils.parametrizations.spectral_norm(self.fc2)

    @staticmethod
    def swish(x):
        return x * torch.sigmoid(x)

    def forward(self, h, cond):
        # Concatenate the hidden state with the conditioning vector
        x = torch.cat([h, cond], dim=-1)
        out = self.fc1(x)
        out = self.norm1(out)
        out = self.activation(out)
        out = self.dropout(out)
        # Inject conditioning again before the second transformation
        out = self.fc2(torch.cat([out, cond], dim=-1))
        out = self.norm2(out)
        # Add the original hidden state (skip connection) and apply activation
        return self.activation(out + h)

class ScoreModel(nn.Module):
    """
        Neural network model that computes score estimates.

        Args:
            input_dim_theta (int): Input dimension for theta.
            input_dim_x (int): Input dimension for x.
            input_dim_condition (int): Input dimension for the condition. Can be 0 for global score.
            hidden_dim (int): Hidden dimension for theta network.
            time_embed_dim (int, optional): Dimension of time embedding. Defaults to 4.
    """
    def __init__(self,
                 input_dim_theta, input_dim_x, input_dim_condition,
                 hidden_dim,
                 time_embed_dim=16):
        super(ScoreModel, self).__init__()
        self.time_embed_dim = time_embed_dim

        # Define the dimension of the conditioning vector
        cond_dim = input_dim_x + input_dim_condition + time_embed_dim
        self.cond_dim = cond_dim

        # Project the concatenation of theta and the condition into hidden_dim
        self.input_layer = nn.Linear(input_dim_theta + cond_dim, hidden_dim)

        # Create a sequence of conditional residual blocks
        self.block1 = ConditionalResidualBlock(hidden_dim, cond_dim)
        self.block2 = ConditionalResidualBlock(hidden_dim, cond_dim)
        self.block3 = ConditionalResidualBlock(hidden_dim, cond_dim)

        # Create a sequence of residual blocks
        #self.block1 = ResidualBlock(input_dim_theta + cond_dim, hidden_dim)
        #self.block2 = ResidualBlock(hidden_dim, hidden_dim)
        #self.block3 = ResidualBlock(hidden_dim, hidden_dim)

        # Final layer to get back to the theta dimension
        self.final_linear = nn.Linear(hidden_dim, input_dim_theta)

        # Apply spectral normalization
        self.final_linear = nn.utils.parametrizations.spectral_norm(self.final_linear)

    def forward(self, theta, t, x, conditions=None):
        """
        Forward pass of the ScoreModel.

        Args:
            theta (torch.Tensor): Input theta tensor of shape (batch_size, input_dim_theta).
            t (torch.Tensor): Input time tensor of shape (batch_size, 1).
            x (torch.Tensor): Input x tensor of shape (batch_size, input_dim_x).
            conditions (torch.Tensor, optional): Input condition tensor of shape (batch_size, input_dim_condition).
                Defaults to None.

        Returns:
            torch.Tensor: Output of the score model.
        """
        # Compute a time embedding (shape: [batch, time_embed_dim])
        t_emb = positional_encoding(t, self.time_embed_dim)

        # Form the conditioning vector. If conditions is None, only x and time are used.
        if conditions is not None:
            cond = torch.cat([x, conditions, t_emb], dim=-1)
        else:
            cond = torch.cat([x, t_emb], dim=-1)

        # Concatenate theta with the conditioning vector as the initial input
        h = torch.cat([theta, cond], dim=-1)
        h = self.input_layer(h)

        # Pass through each residual block, injecting the same cond at each layer
        h = self.block1(h, cond)
        h = self.block2(h, cond)
        h = self.block3(h, cond)

        theta_emb = self.final_linear(h)
        return theta_emb


class HierarchicalScoreModel(nn.Module):
    def __init__(self,
                 input_dim_theta_local, input_dim_theta_global, input_dim_x,
                 hidden_dim,
                 time_embed_dim=16):
        super(HierarchicalScoreModel, self).__init__()
        self.summary_net = nn.GRU(
            input_size=input_dim_x,
            hidden_size=hidden_dim,
            num_layers=1,
            batch_first=True
        )
        self.n_params_global = input_dim_theta_global
        self.global_model = ScoreModel(
            input_dim_theta=input_dim_theta_global,
            input_dim_x=hidden_dim,
            input_dim_condition=0,
            hidden_dim=hidden_dim,
            time_embed_dim=time_embed_dim
        )
        self.n_params_local = input_dim_theta_local
        self.local_model = ScoreModel(
            input_dim_theta=input_dim_theta_local,
            input_dim_x=hidden_dim,
            input_dim_condition=input_dim_theta_global,
            hidden_dim=hidden_dim,
            time_embed_dim=time_embed_dim
        )

    def forward(self, theta, t, x):
        theta_global, theta_local = torch.split(theta, [self.n_params_global, self.n_params_local], dim=-1)
        _, x_emb = self.summary_net(x)
        x_emb = x_emb[0]  # only one layer, not bidirectional
        global_out = self.global_model.forward(theta=theta_global, t=t, x=x_emb, conditions=None)
        local_out = self.local_model.forward(theta=theta_local, t=t, x=x_emb, conditions=theta_global)
        return torch.cat([global_out, local_out], dim=-1)

    def forward_local(self, theta_local, theta_global, t, x):
        _, x_emb = self.summary_net(x)
        x_emb = x_emb[0]  # only one layer, not bidirectional
        local_out = self.local_model.forward(theta=theta_local, t=t, x=x_emb, conditions=theta_global)
        return local_out

    def forward_global(self, theta_global, t, x):
        _, x_emb = self.summary_net(x)
        x_emb = x_emb[0]  # only one layer, not bidirectional
        global_out = self.global_model.forward(theta=theta_global, t=t, x=x_emb, conditions=None)
        return global_out

In [None]:
def cosine_schedule_diffusion_time(t, max_t, s=0):
    return torch.cos(((t/max_t + s) / (1 + s)) * (np.pi / 2)) ** 2


def generate_diffusion_time_old(max_t, steps, random=False, device=None):
    if random:
        time = torch.rand(steps+1, dtype=torch.float32, device=device) * max_t
        time = torch.sort(time)[0]
    else:
        time = torch.linspace(0, max_t, steps+1, dtype=torch.float32, device=device)

    # gamma called alpha_t in paper
    f_0 = 1#cosine_schedule_diffusion_time(torch.tensor(0, dtype=torch.float32, device=device), max_t)
    gamma = cosine_schedule_diffusion_time(time, max_t)
    beta_t = 1 - torch.cat((gamma[0:1], gamma[1:] / gamma[:-1]), dim=0)
    # clip to avoid numerical instability
    beta_t = torch.clamp(beta_t, max=0.999)
    return time, gamma, beta_t

In [None]:
time_old, gamma, beta_t = generate_diffusion_time_old(10, 400)

plt.plot(time_old, torch.log(gamma/(1 - gamma)), label='snr')
plt.xlabel('Time')
plt.legend()
plt.show()

plt.plot(time_old, beta_t, label='weighting')
plt.xlabel('Time')
plt.legend()
plt.show()

plt.plot(time_old, torch.sqrt(1 - gamma), label='noise')
plt.xlabel('Time')
plt.legend()
plt.show()

plt.plot(time_old, beta_t/(1-gamma), label='actual weights')
plt.xlabel('Time')
plt.legend()
plt.show()

In [None]:
time_old

In [None]:
S_SHIFT_COSINE = 0.#05
LAMBDA_0 = 10
LAMBDA_1 = -LAMBDA_0

def cosine_schedule_signal_to_noise_inv(snr):
    return (2/torch.pi)*torch.arctan(torch.exp(-snr/2 - S_SHIFT_COSINE))

T_0 = cosine_schedule_signal_to_noise_inv(torch.tensor(LAMBDA_0, dtype=torch.float32, device=torch_device))
T_1 = cosine_schedule_signal_to_noise_inv(torch.tensor(LAMBDA_1, dtype=torch.float32, device=torch_device))

def cosine_schedule_signal_to_noise(t, max_t):
    """
    Cosine schedule for the log signal-to-noise ratio.
    t is assumed to be in the interval [0, max_t].
    """
    # Map t in [0, max_t] to the interval [T_0, T_1]
    t_truncated = T_0 + (T_1 - T_0) * (t / max_t)
    return -2 * torch.log(torch.tan(torch.pi * t_truncated / 2)) + 2 * S_SHIFT_COSINE

def sech(x):
    """Compute the hyperbolic secant of x."""
    return 1 / torch.cosh(x)

def compute_continuous_weights(t, epsilon=1e-3):
    """
    Compute continuous weight schedule for diffusion models.

    Args:
        t: Time values in range [0, 1]
        epsilon: Small value to avoid division by zero

    Returns:
        weights: Continuous weight values matching the discrete implementation
    """
    # Compute signal-to-noise ratio
    snr = -2 * torch.log(torch.tan(torch.pi * t / 2))

    # Compute gamma (sqrt of sigmoid of SNR)
    gamma = torch.sqrt(torch.sigmoid(snr))

    # Compute time shift for ratio (equivalent to looking at t+dt)
    dt = epsilon
    t_next = torch.clamp(t + dt, 0, 1)

    # Compute gamma for shifted time
    snr_next = -2 * torch.log(torch.tan(torch.pi * t_next / 2))
    gamma_next = torch.sqrt(torch.sigmoid(snr_next))

    # Compute ratio (equivalent to gamma[1:] / gamma[:-1])
    ratio = gamma_next / (gamma + epsilon)

    # Special handling for t=0 (equivalent to gamma[0:1])
    ratio = torch.where(t < epsilon, gamma, ratio)

    # Compute weights as 1 - ratio
    weights = 1 - ratio

    return weights

def cosine_schedule_signal_to_noise_density(snr):
    """Density of the log signal-to-noise ratio."""
    p = sech(snr / 2 - S_SHIFT_COSINE) / (2 * torch.pi * (T_1 - T_0))
    # Truncate the density outside the [LAMBDA_1, LAMBDA_0] interval
    p[snr > LAMBDA_0] = 0
    p[snr < LAMBDA_1] = 0
    return p

def weighting_function(snr, device=None):
    #return torch.ones_like(snr, dtype=torch.float32, device=device)
    #return torch.tanh(snr)
    return torch.sigmoid(-snr + 2)

def generate_diffusion_time(max_t, size, return_batch=False, device=None):
    """
    Generates diffusion time values along with their corresponding
    log signal-to-noise ratio, weighting function, and density.

    The time is generated uniformly in [0, max_t] and then mapped to [T_0, T_1].
    """
    if not return_batch:
        time = torch.linspace(0, max_t, steps=size, dtype=torch.float32, device=device)
    else:
        # t_i = \mod (u_0 + i/k, 1)
        u0 = torch.rand(1, dtype=torch.float32, device=device)
        i = torch.arange(0, size, dtype=torch.float32, device=device)  # i as a tensor of indices
        time = ((u0 + i / size) % 1) * (T_1*max_t - T_0) + T_0

    snr = cosine_schedule_signal_to_noise(time, max_t)
    #sigma2_noise = torch.sigmoid(-snr)
    #weight_snr = cosine_schedule_signal_to_noise(T_1*max_t-time, max_t)
    #weight_snr = torch.ones_like(snr, dtype=torch.float32, device=device) #weighting_function(snr, device=device)
    snr_density = cosine_schedule_signal_to_noise_density(snr) #/ cosine_schedule_signal_to_noise_density(torch.tensor(LAMBDA_0))#* 800
    #snr_density = torch.sigmoid(-snr)
    weight_snr = weighting_function(snr, device=device)
    #weight_snr = weight_snr / sigma2_noise

    #gamma = torch.sqrt(torch.sigmoid(snr))
    #weight_snr = 1 - torch.cat((gamma[0:1], gamma[1:] / gamma[:-1]), dim=0)  # beta
    #weight_snr = compute_continuous_weights(time/ (max_t*T_1) + T_0)

    #alpha_sampling_noise = torch.sqrt(torch.sigmoid(snr))
    #alpha_t = torch.cat((alpha_sampling_noise[0:1], alpha_sampling_noise[1:] / alpha_sampling_noise[:-1]), dim=0)
    #weight_snr = 1 - alpha_t

    if return_batch:
        # Add a new dimension so that each tensor has shape (size, 1)
        return time.unsqueeze(1), snr.unsqueeze(1), weight_snr.unsqueeze(1), snr_density.unsqueeze(1)
    else:
        return time, snr, weight_snr, snr_density

In [None]:
T_0, T_1

In [None]:
snr = torch.tensor(np.linspace(LAMBDA_1, LAMBDA_0, 100))
snr_density = cosine_schedule_signal_to_noise_density(snr)

plt.plot(snr, snr_density, label='Signal-To-Noise Ratio')
#plt.plot(time / 400, weight/p_noise, label='weight')
#plt.plot(time / 400, -1 / torch.sqrt(1 - gamma) * delta_t, label='sampling step')
plt.xlabel('SNR')
plt.ylabel('Density')
plt.legend()
plt.show()

In [None]:
time, snr, weight, snr_density = generate_diffusion_time(10, 400)

plt.plot(time, snr, label='log snr')
plt.plot(time_old[:-1], torch.log(gamma/(1 - gamma))[:-1], label='log snr old', alpha=0.5)
plt.xlabel('Time')
plt.legend()
plt.show()

plt.plot(time, weight, label='weighting function')
plt.plot(time_old, beta_t, label='weighting function old', alpha=0.5)
plt.xlabel('Time')
plt.legend()
plt.show()

plt.plot(time, snr_density, label='snr density')
plt.xlabel('Time')
plt.ylabel('Density')
plt.legend()
plt.show()

plt.plot(time, torch.sqrt(torch.sigmoid(snr)), label='mean noise')
plt.plot(time_old, torch.sqrt(gamma), label='mean noise old', alpha=0.5)
plt.xlabel('Time')
plt.legend()
plt.show()

plt.plot(time, torch.sqrt(torch.sigmoid(-snr)), label='noise')
plt.plot(time_old, torch.sqrt(1 - gamma), label='noise old', alpha=0.5)
plt.xlabel('Time')
plt.legend()
plt.show()

plt.plot(time, -1/torch.sqrt(torch.sigmoid(-snr)) , label='scaling')
plt.plot(time_old, -1 / torch.sqrt(1 - gamma), label='scaling old', alpha=0.5)
plt.xlabel('Time')
plt.legend()
plt.show()

plt.plot(time, weight/snr_density, label='actual weights')
plt.plot(time_old, beta_t/(1-gamma), label='actual weights old', alpha=0.5)
plt.xlabel('Time')
plt.legend()
plt.show()

In [None]:
 # Loss function for weighted MSE
def weighted_mse_loss(inputs, targets, weights):
    return torch.mean(weights * (inputs - targets) ** 2)


def compute_score_loss(x_batch, theta_prime_batch, model, diffusion_time, alpha_noise, sigma_noise, weights_noise, device=None):
    # sample from the Gaussian kernel, just learn the noise
    epsilon = torch.randn_like(theta_prime_batch, dtype=torch.float32, device=device)
    theta_batch = alpha_noise * theta_prime_batch + sigma_noise * epsilon
    # calculate the score for the sampled theta
    score_pred = model(theta=theta_batch, t=diffusion_time, x=x_batch)
    # calculate the loss
    loss = 0.5 * weighted_mse_loss(score_pred, epsilon, weights=weights_noise)
    return loss


# Training loop for Score Model
def train_score_model(model, dataloader, dataloader_valid=None,
                      T=400, epochs=100, lr=1e-3, device=None):
    optimizer = optim.AdamW(model.parameters(), lr=lr)

    # Add Cosine Annealing Scheduler
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-6)

    # Training loop
    loss_history = np.zeros((epochs, 2))
    for epoch in range(epochs):
        model.train()
        total_loss = []
        # for each sample in the batch, calculate the loss for a random diffusion time
        for theta_global_prime_batch, theta_local_prime_batch, x_batch in dataloader:
            # Generate diffusion time and step size
            diffusion_time, snr, weights, snr_density = generate_diffusion_time(max_t=T, size=x_batch.shape[0],
                                                                             return_batch=True, device=device)
            alpha_noise = torch.sqrt(torch.sigmoid(snr))
            sigma_noise = torch.sqrt(torch.sigmoid(-snr))
            weights_noise = weights / snr_density

            # initialize the gradients
            optimizer.zero_grad()
            theta_prime_batch = torch.concat([theta_global_prime_batch, theta_local_prime_batch], dim=-1)
            # calculate the loss
            loss = compute_score_loss(x_batch=x_batch, theta_prime_batch=theta_prime_batch,
                                      model=model, diffusion_time=diffusion_time,
                                      alpha_noise=alpha_noise, sigma_noise=sigma_noise, weights_noise=weights_noise,
                                      device=device)
            loss.backward()
            # gradient clipping
            nn.utils.clip_grad_norm_(model.parameters(), 3.0)
            optimizer.step()
            total_loss.append(loss.item())

        scheduler.step()

        # validate the model
        model.eval()
        valid_loss = []
        if dataloader_valid is not None:
            for theta_global_prime_batch, theta_local_prime_batch, x_batch in dataloader_valid:
                # Generate diffusion time and step size
                diffusion_time, snr, weights, snr_density = generate_diffusion_time(max_t=T, size=x_batch.shape[0],
                                                                                    return_batch=True, device=device)
                alpha_noise = torch.sqrt(torch.sigmoid(snr))
                sigma_noise = torch.sqrt(torch.sigmoid(-snr))
                weights_noise = weights / snr_density

                with torch.no_grad():
                    theta_prime_batch = torch.concat([theta_global_prime_batch, theta_local_prime_batch], dim=-1)
                    loss = compute_score_loss(x_batch, theta_prime_batch=theta_prime_batch,
                                              model=model, diffusion_time=diffusion_time,
                                              alpha_noise=alpha_noise, sigma_noise=sigma_noise, weights_noise=weights_noise,
                                              device=device)
                    valid_loss.append(loss.item())

        loss_history[epoch] = [np.median(total_loss), np.median(valid_loss)]
        print(f"Epoch {epoch+1}/{epochs}, Loss: {np.median(total_loss):.4f}, "
              f"Valid Loss: {np.median(valid_loss):.4f}", end='\r')
    return loss_history

In [None]:
# Hyperparameters
n_samples = 10000
batch_size = 128
T = 400

score_model = HierarchicalScoreModel(
    input_dim_theta_global=n_params_global,
    input_dim_theta_local=n_params_local,
    input_dim_x=1,
    hidden_dim=64
)
score_model.to(torch_device)

# Create model and dataset
thetas_global, thetas_local, xs = generate_synthetic_data(n_samples, device=torch_device)
# Normalize data
thetas_global = (thetas_global - prior.prior_global_mean) / prior.prior_global_std
thetas_local = (thetas_local - prior.prior_local_mean) / prior.prior_local_std
xs = (xs - prior.x_mean) / prior.x_std

# Create dataloader
dataset = TensorDataset(thetas_global, thetas_local, xs)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# create validation data
valid_prior_global, valid_prior_local, valid_data = generate_synthetic_data(1000, device=torch_device)
valid_data = (valid_data - prior.x_mean) / prior.x_std
valid_prior_global = (valid_prior_global - prior.prior_global_mean) / prior.prior_global_std
valid_prior_local = (valid_prior_local - prior.prior_local_mean) / prior.prior_local_std
dataset_valid = TensorDataset(valid_prior_global, valid_prior_local, valid_data)
dataloader_valid = DataLoader(dataset_valid, batch_size=batch_size)

In [None]:
# Train model
loss_history = train_score_model(score_model, dataloader, dataloader_valid=dataloader_valid,
                                 T=T, epochs=100, lr=1e-3, device=torch_device)

In [None]:
score_model.global_model.final_linear.weight

In [None]:
# plot loss history
plt.figure(figsize=(6, 3), tight_layout=True)
plt.plot(loss_history[:, 0], label='Train')
plt.plot(loss_history[:, 1], label='Valid')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
diffusion_time, noise, weights, p_noise = generate_diffusion_time(max_t=T, size=100)
alpha_noise = torch.sqrt(torch.sigmoid(noise))  # sometimes called \sqrt(1-\beta_t)
sigma_noise = torch.sqrt(torch.sigmoid(-noise))  # sometimes \sqrt(\beta_t)
plt.plot(diffusion_time/T, alpha_noise, label='alpha')
#plt.plot(1-alpha_noise**2, label='alpha')
plt.plot(diffusion_time/T, sigma_noise, label='sigma')
plt.legend()
plt.show()

In [None]:
# DDPM Sampling from Ho. et al. (2020)
def ddpm_sampling(model, x_obs, n_post_samples, conditions=None, diffusion_steps=100, device=None):
    x_obs_norm = (x_obs - prior.x_mean) / prior.x_std  # assumes x_obs is not standardized
    x_obs_norm = x_obs_norm.reshape(x_obs_norm.shape[0], -1)
    n_obs = x_obs_norm.shape[-1]
    n_time_steps = x_obs_norm.shape[0]
    x_obs_norm = x_obs_norm.T[:, :, np.newaxis]

    # Ensure x_obs_norm is a PyTorch tensor
    if not isinstance(x_obs_norm, torch.Tensor):
        x_obs_norm = torch.tensor(x_obs_norm, dtype=torch.float32, device=device)

    # Initialize parameters
    if conditions is None:  # global
        n_params = n_params_global
        theta = torch.randn(n_post_samples, n_params_global, dtype=torch.float32, device=device) / torch.sqrt(torch.tensor(n_obs, dtype=torch.float32, device=device))
        conditions_exp = None
    else:
        # Ensure conditions is a PyTorch tensor
        if not isinstance(conditions, torch.Tensor):
            conditions = torch.tensor(conditions, dtype=torch.float32, device=device)

        n_params = n_params_local*n_obs
        theta = torch.randn(n_post_samples, n_obs, n_params_local, dtype=torch.float32, device=device)
        conditions = (conditions - prior.prior_global_mean) / prior.prior_global_std
        conditions_exp = conditions.unsqueeze(0).expand(n_post_samples, n_obs, -1).reshape(-1, n_params_global)

    # Generate diffusion time parameters
    diffusion_time, snr, _, _ = generate_diffusion_time(max_t=T, size=diffusion_steps, device=device)
    alpha_sampling_noise = torch.sigmoid(snr)  # sometimes called 1-\beta_t
    sigma_sampling_noise = torch.sqrt(torch.sigmoid(-snr))  # sometimes \sqrt(\beta_t)

    # Expand x_obs_norm to match the number of posterior samples
    x_exp = x_obs_norm.unsqueeze(0).expand(n_post_samples, n_obs, n_time_steps, -1)  # Shape: (n_post_samples, n_obs, n_time_steps, d)
    x_expanded = x_exp.reshape(n_post_samples*n_obs, n_time_steps, -1)

    # Reverse iterate over diffusion times and step sizes
    with torch.no_grad():
        for t in tqdm(reversed(range(diffusion_steps)), total=diffusion_steps):
            # Create tensor for current time step
            t_tensor = torch.full((n_post_samples, 1), diffusion_time[t], dtype=torch.float32, device=device)
            t_exp = t_tensor.unsqueeze(1).expand(-1, n_obs, -1).reshape(-1, 1)

            # Compute model scores
            if conditions is None:
                theta_exp = theta.unsqueeze(1).expand(-1, n_obs, -1).reshape(-1, n_params_global)
                model_scores_noise = model.forward_global(theta_global=theta_exp, t=t_exp, x=x_expanded)
                # Sum over observations
                model_scores_noise = model_scores_noise.reshape(n_post_samples, n_obs, -1).sum(dim=1)
            else:
                theta_exp = theta.reshape(-1, n_params_local)
                model_scores_noise = model.forward_local(theta_local=theta_exp, t=t_exp, x=x_expanded, theta_global=conditions_exp)
                model_scores_noise = model_scores_noise.reshape(n_post_samples, n_obs, -1)

            # scaling since we learned the scaled score network (noise)
            model_scores = -model_scores_noise  / sigma_sampling_noise[t]

            # Compute updated scores
            if conditions is None:
                # Compute prior score
                prior_score = prior.score_global_batch(theta)
                w_scores = (1 - n_obs) * (T - diffusion_time[t]) / T * prior_score + model_scores
            else:
                w_scores = model_scores

            # these are be the marginals to get from one time point to the next one p(x_{t-1}|x_t), not p(x_0|x_t)
            if t == 0:
                alpha_t = torch.tensor(1, dtype=torch.float32, device=device)
            else:
                alpha_t = alpha_sampling_noise[t] / alpha_sampling_noise[t-1]
            beta_t = 1 - alpha_t
            sigma_t = torch.sqrt(beta_t)

            # Sample Gaussian noise
            eps = torch.randn_like(theta, dtype=torch.float32, device=device) # set to 0 for the last step
            # Make update
            theta = 1 / torch.sqrt(alpha_t) * (theta - (1-alpha_t) * w_scores) + sigma_t * eps
            if torch.isnan(theta).any():
                print("NaNs in theta")
                break
    # correct for normalization
    if conditions is None:
        theta = theta * prior.prior_global_std + prior.prior_global_mean
    else:
        theta = theta * prior.prior_local_std + prior.prior_local_mean
    # convert to numpy
    theta = theta.detach().numpy().reshape(n_post_samples, n_params)
    return theta

In [None]:
def euler_maruyama_step(z, score, t, dt):
    """
    Perform one Euler-Maruyama update step for the SDE
      dz = [f(z,t)- g(t)^2*s(z,t)] dt + g(t) dW_t,
    with
      f(z,t) = -0.5 * (d/dt log(1+e^{-lambda(t)})) * z,
      g(t)^2 = d/dt log(1+e^{-lambda(t)}),
    and
      lambda(t) = -2 log(tan(pi t/2)) + 2s.

    Parameters:
        z: Current state (PyTorch tensor)
        score: Score network output (PyTorch tensor)
        t: Current time (scalar tensor or float)
        dt: Time step size (float)

    Returns:
        z_next: Updated state after time dt.
    """
    # Compute lambda(t) = -2 log(tan(pi*t/2)) + 2s.
    lambda_t = cosine_schedule_signal_to_noise(t, max_t=T)
    if torch.isnan(lambda_t).any():
        print("NaNs in lambda_t")

    # Compute g_t2 = 2 pi / sin(pi t) * sigmoid(-lambda(t))
    g_t2 = (2 * torch.pi / torch.sin(torch.pi * t)) * torch.sigmoid(-lambda_t)
    if torch.isnan(g_t2).any():
        print("NaNs in g_t2")

    # Drift: f(z,t) = -1/2 * g_t2 * z.
    f_z = -0.5 * g_t2 * z
    drift = f_z - 0.5 * g_t2 * score

    # Diffusion: g(t) = sqrt(g_t2)
    diffusion = torch.sqrt(g_t2)

    # Sample Gaussian noise (same shape as z)
    noise = torch.randn_like(z, dtype=z.dtype, device=z.device)

    # Eulerâ€“Maruyama update:
    if torch.isnan(drift * dt).any():
        print("NaNs in drift")
    if torch.isnan(diffusion  * torch.sqrt(dt) * noise).any():
        print("NaNs in diffusion")
    z_next = z + drift * dt + diffusion * torch.sqrt(dt) * noise
    return z_next

In [None]:
# Euler-Maruyama Sampling from Song et al. (2021)
def euler_maruyama_sampling(model, x_obs, n_post_samples, conditions=None, diffusion_steps=100, device=None):
    x_obs_norm = (x_obs - prior.x_mean) / prior.x_std  # assumes x_obs is not standardized
    x_obs_norm = x_obs_norm.reshape(x_obs_norm.shape[0], -1)
    n_obs = x_obs_norm.shape[-1]
    n_time_steps = x_obs_norm.shape[0]
    x_obs_norm = x_obs_norm.T[:, :, np.newaxis]

    # Ensure x_obs_norm is a PyTorch tensor
    if not isinstance(x_obs_norm, torch.Tensor):
        x_obs_norm = torch.tensor(x_obs_norm, dtype=torch.float32, device=device)

    # Initialize parameters
    if conditions is None:  # global
        n_params = n_params_global
        theta = torch.randn(n_post_samples, n_params_global, dtype=torch.float32, device=device) / torch.sqrt(torch.tensor(n_obs, dtype=torch.float32, device=device))
        conditions_exp = None
    else:
        # Ensure conditions is a PyTorch tensor
        if not isinstance(conditions, torch.Tensor):
            conditions = torch.tensor(conditions, dtype=torch.float32, device=device)

        n_params = n_params_local*n_obs
        theta = torch.randn(n_post_samples, n_obs, n_params_local, dtype=torch.float32, device=device)
        conditions = (conditions - prior.prior_global_mean) / prior.prior_global_std
        conditions_exp = conditions.unsqueeze(0).expand(n_post_samples, n_obs, -1).reshape(-1, n_params_global)

    # Generate diffusion time parameters
    diffusion_time, snr, _, _ = generate_diffusion_time(max_t=T, size=diffusion_steps, device=device)
    sigma_sampling_noise = torch.sqrt(torch.sigmoid(-snr))  # sometimes \sqrt(\beta_t)

    # Expand x_obs_norm to match the number of posterior samples
    x_exp = x_obs_norm.unsqueeze(0).expand(n_post_samples, n_obs, n_time_steps, -1)  # Shape: (n_post_samples, n_obs, n_time_steps, d)
    x_expanded = x_exp.reshape(n_post_samples*n_obs, n_time_steps, -1)

    # Reverse iterate over diffusion times and step sizes
    with torch.no_grad():
        for t in tqdm(reversed(range(diffusion_steps)), total=diffusion_steps):
            # Create tensor for current time step
            t_tensor = torch.full((n_post_samples, 1), diffusion_time[t], dtype=torch.float32, device=device)
            t_exp = t_tensor.unsqueeze(1).expand(-1, n_obs, -1).reshape(-1, 1)

            # Compute model scores
            if conditions is None:
                theta_exp = theta.unsqueeze(1).expand(-1, n_obs, -1).reshape(-1, n_params_global)
                model_scores_noise = model.forward_global(theta_global=theta_exp, t=t_exp, x=x_expanded)
                # Sum over observations
                model_scores_noise = model_scores_noise.reshape(n_post_samples, n_obs, -1).sum(dim=1)
            else:
                theta_exp = theta.reshape(-1, n_params_local)
                model_scores_noise = model.forward_local(theta_local=theta_exp, t=t_exp, x=x_expanded, theta_global=conditions_exp)
                model_scores_noise = model_scores_noise.reshape(n_post_samples, n_obs, -1)

            # scaling since we learned the scaled score network (noise)
            model_scores = -model_scores_noise / sigma_sampling_noise[t]
            print('Scores', model_scores)

            # Compute updated scores
            if conditions is None:
                # Compute prior score
                prior_score = prior.score_global_batch(theta)
                w_scores = (1 - n_obs) * (T - diffusion_time[t]) / T * prior_score + model_scores
            else:
                w_scores = model_scores

            # Make Euler-Maruyama step
            #print(diffusion_time[t] - diffusion_time[t-1], theta, w_scores)
            if t == 0:
                pass
            else:
                theta = euler_maruyama_step(theta, score=w_scores, t=t_tensor, dt=diffusion_time[t] - diffusion_time[t-1])
            if torch.isnan(theta).any():
                print("NaNs in theta")
                break
    # correct for normalization
    if conditions is None:
        theta = theta * prior.prior_global_std + prior.prior_global_mean
    else:
        theta = theta * prior.prior_local_std + prior.prior_local_mean
    # convert to numpy
    theta = theta.detach().numpy().reshape(n_post_samples, n_params)
    return theta

In [None]:
# Annealed Langevin Dynamics for Sampling
def langevin_sampling(model, x_obs, n_post_samples, conditions=None, steps=1, diffusion_steps=400, device=None):
    x_obs_norm = (x_obs - prior.x_mean) / prior.x_std  # assumes x_obs is not standardized
    x_obs_norm = x_obs_norm.reshape(x_obs_norm.shape[0], -1)
    n_obs = x_obs_norm.shape[-1]
    n_time_steps = x_obs_norm.shape[0]
    x_obs_norm = x_obs_norm.T[:, :, np.newaxis]

    # Ensure x_obs_norm is a PyTorch tensor
    if not isinstance(x_obs_norm, torch.Tensor):
        x_obs_norm = torch.tensor(x_obs_norm, dtype=torch.float32, device=device)

    # Initialize parameters
    if conditions is None:  # global
        n_params = n_params_global
        theta = torch.randn(n_post_samples, n_params_global, dtype=torch.float32, device=device) / torch.sqrt(torch.tensor(n_obs, dtype=torch.float32, device=device))
        conditions_exp = None
    else:
        # Ensure conditions is a PyTorch tensor
        if not isinstance(conditions, torch.Tensor):
            conditions = torch.tensor(conditions, dtype=torch.float32, device=device)

        n_params = n_params_local*n_obs
        theta = torch.randn(n_post_samples, n_obs, n_params_local, dtype=torch.float32, device=device)
        conditions = (conditions - prior.prior_global_mean) / prior.prior_global_std
        conditions_exp = conditions.unsqueeze(0).expand(n_post_samples, n_obs, -1).reshape(-1, n_params_global)

    # Generate diffusion time parameters
    diffusion_time, snr, weight_snr, _ = generate_diffusion_time(max_t=T, size=diffusion_steps, device=device)
    scaling = torch.sqrt(torch.sigmoid(-snr))  # sometimes \sqrt(\beta_t)
    delta_t = weight_snr / torch.sigmoid(-snr)  # w = beta_t/(1-gamma)

    #diffusion_time, gamma, delta_t = generate_diffusion_time(max_t=T, steps=steps_time, device=device)

    # Expand x_obs_norm to match the number of posterior samples
    x_exp = x_obs_norm.unsqueeze(0).expand(n_post_samples, n_obs, n_time_steps, -1)  # Shape: (n_post_samples, n_obs, n_time_steps, d)
    x_expanded = x_exp.reshape(n_post_samples*n_obs, n_time_steps, -1)

    # Reverse iterate over diffusion times and step sizes
    for step_size, t, scale in tqdm(zip(delta_t.flip(0), diffusion_time.flip(0), scaling.flip(0)), total=T):
        # Create tensor for current time step
        t_tensor = torch.full((n_post_samples, 1), t, dtype=torch.float32, device=device)
        t_exp = t_tensor.unsqueeze(1).expand(-1, n_obs, -1).reshape(-1, 1)

        for _ in range(steps):
            # Sample Gaussian noise
            eps = torch.randn_like(theta, dtype=torch.float32, device=device)

            if conditions is None:
                # Compute prior score
                prior_score = prior.score_global_batch(theta)
            else:
                prior_score = 0

            # Compute model scores
            if conditions is None:
                theta_exp = theta.unsqueeze(1).expand(-1, n_obs, -1).reshape(-1, n_params_global)
                model_scores = model.forward_global(theta_global=theta_exp, t=t_exp, x=x_expanded)
                # Sum over observations
                model_scores = model_scores.reshape(n_post_samples, n_obs, -1).sum(dim=1)
            else:
                theta_exp = theta.reshape(-1, n_params_local)
                model_scores = model.forward_local(theta_local=theta_exp, t=t_exp, x=x_expanded, theta_global=conditions_exp)
                model_scores = model_scores.reshape(n_post_samples, n_obs, -1)

            # Compute updated scores and perform Langevin step
            # scaling since we learned the scaled score network (noise)
            scores = (1 - n_obs) * (T - t) / T * prior_score + scale * model_scores
            theta = theta + (step_size / 2) * scores + torch.sqrt(step_size) * eps
        if torch.isnan(theta).any():
            print("NaNs in theta")
            break
    # correct for normalization
    if conditions is None:
        theta = theta * prior.prior_global_std + prior.prior_global_mean
    else:
        theta = theta * prior.prior_local_std + prior.prior_local_mean
    # convert to numpy
    theta = theta.detach().numpy().reshape(n_post_samples, n_params)
    return theta

# Validation

In [None]:
n_grid = 8
valid_prior_global, valid_prior_local, valid_data = generate_synthetic_data(2, n_grid=n_grid, full_grid=True, device=torch_device)
n_post_samples = 20

In [None]:
posterior_global_samples_valid = np.array([ddpm_sampling(score_model, vd, n_post_samples=n_post_samples,
                                                                   diffusion_steps=T,
                                                          device=torch_device)
                                        for vd in valid_data])

In [None]:
posterior_global_samples_valid

In [None]:
diagnostics.plot_recovery(posterior_global_samples_valid, np.array(valid_prior_global), param_names=[r'$\mu$', r'$\log \tau$']);

In [None]:
diagnostics.plot_sbc_ecdf(posterior_global_samples_valid, np.array(valid_prior_global),
                          difference=True, param_names=[r'$\mu$', r'$\log \tau$']);

In [None]:
conditions_global = valid_prior_global #np.median(posterior_global_samples_valid, axis=0)
posterior_local_samples_valid = np.array([euler_maruyama_sampling(score_model, vd, n_post_samples=n_post_samples, conditions=c,
                                                             diffusion_steps=10,
                                                          device=torch_device)
                                        for vd, c in zip(valid_data, conditions_global)])

In [None]:
diagnostics.plot_recovery(posterior_local_samples_valid.reshape(valid_data.shape[0], n_post_samples, -1),
                          np.array(valid_prior_local).reshape(valid_data.shape[0], -1),
                          param_names=['$\\theta_{'+str(i)+'}$' for i in range(n_grid**2)]);

In [None]:
diagnostics.plot_sbc_ecdf(posterior_local_samples_valid.reshape(valid_data.shape[0], n_post_samples, -1),
                          np.array(valid_prior_local).reshape(valid_data.shape[0], -1),
                          difference=True, stacked=True,
                          param_names=['$\\theta_{'+str(i)+'}$' for i in range(n_grid**2)]);

In [None]:
valid_id = 1
print('Data')
visualize_simulation_output(valid_data[valid_id])
print('Global Estimates')
print('mu:', np.median(posterior_global_samples_valid[valid_id, :, 0]), np.std(posterior_global_samples_valid[valid_id, :, 0]))
print('log tau:', np.median(posterior_global_samples_valid[valid_id, :, 1]), np.std(posterior_global_samples_valid[valid_id, :, 1]))
print('True')
print('mu:', valid_prior_global[valid_id][0].item())
print('log tau:', valid_prior_global[valid_id][1].item())

In [None]:
med = np.median(posterior_local_samples_valid[valid_id].reshape(n_post_samples, n_grid, n_grid), axis=0)
std = np.std(posterior_local_samples_valid[valid_id].reshape(n_post_samples, n_grid, n_grid), axis=0)
cat = np.stack((med, std, valid_prior_local[valid_id]))
visualize_simulation_output(cat, title_prefix=['Posterior Median\n', 'Posterior Std\nUncertainty', 'True\n'])