In [2]:
        from typing import List

import torch
from torch import tensor, Tensor, exp 
from torch.distributions import Distribution


def ground_truth_mmd(
    x: Tensor,
    dists_y: List[Distribution],
    y_limits: Tensor,
    y_res: int = 100,
    scale: float = 0.01,
):
    term1 = sample_based_mmd_marginal(x, x, scale=scale)
    term2 = sample_integral_mixed_mmd_marginal(
        x=x, y_dist=dists_y, y_limits=y_limits, y_res=y_res, scale=scale
    )
    term3 = integral_based_mmd_marginal(
        x_dist=dists_y,
        y_dist=dists_y,
        x_limits=y_limits,
        y_limits=y_limits,
        x_res=y_res,
        y_res=y_res,
        scale=scale,
    )
    return term1 + term3 - 2 * term2


def sample_based_mmd(x, y, scale: float = 0.01):
    term1 = sample_based_mmd_marginal(x, x, scale=scale)
    term2 = sample_based_mmd_marginal(x, y, scale=scale)
    term3 = sample_based_mmd_marginal(y, y, scale=scale)
    return term1 + term3 - 2 * term2


def sample_based_mmd_marginal(x, y, scale: float = 0.01):
    """Assumes diagonal likelihood and sums over each dimension. Sum turns into
    product because exp(sum) = prod(exp)
    """
    dim = x.shape[1]
    term = tensor(
        [
            sample_based_mmd_term(x[:, d : d + 1], y[:, d : d + 1], scale=scale)
            for d in range(dim)
        ]
    ).prod()
    return term


def sample_based_mmd_term(x, y, scale: float = 0.01):
    num_x = x.shape[0]
    num_y = y.shape[0]
    xo1 = x.repeat((num_y, 1))
    xo2 = y.repeat_interleave((num_x), dim=0)
    distances = exp(-scale * ((xo1 - xo2) ** 2).sum(dim=1))
    average_dist = distances.mean(dim=0)
    return average_dist


def integral_based_mmd_marginal(
    x_dist: List[Distribution],
    y_dist: List[Distribution],
    x_limits: Tensor,
    y_limits: Tensor,
    x_res: int = 100,
    y_res: int = 100,
    scale: float = 0.01,
):
    """Assumes diagonal likelihood and sums over each dimension. Sum turns into
    product because exp(sum) = prod(exp)
    """
    dim = len(x_dist)
    term = tensor(
        [
            integral_mmd_term(
                x_dist[d],
                y_dist[d],
                x_limits[d],
                y_limits[d],
                x_res=x_res,
                y_res=y_res,
                scale=scale,
            )
            for d in range(dim)
        ]
    ).prod()
    return term


def integral_mmd_term(
    x_dist: Distribution,
    y_dist: Distribution,
    x_limits: Tensor,
    y_limits: Tensor,
    x_res: int = 100,
    y_res: int = 100,
    scale: float = 0.01,
):
    x_range = torch.linspace(x_limits[0].item(), x_limits[1].item(), x_res).unsqueeze(1)
    y_range = torch.linspace(y_limits[0].item(), y_limits[1].item(), y_res).unsqueeze(1)
    x_repeat = x_range.repeat((y_res, 1))
    y_repeat = y_range.repeat_interleave((x_res), dim=0)
    probs_x = x_dist.log_prob(x_repeat).exp()
    probs_y = y_dist.log_prob(y_repeat).exp()
    distances = exp(-scale * ((x_repeat - y_repeat) ** 2).sum(dim=1))
    dx = (x_limits[1].item() - x_limits[0].item()) / (x_res - 1)
    dy = (y_limits[1].item() - y_limits[0].item()) / (y_res - 1)
    integral = (probs_x * probs_y * distances).sum() * dx * dy
    return integral


def sample_integral_mixed_mmd_marginal(
    x,
    y_dist: List[Distribution],
    y_limits: Tensor,
    y_res: int = 100,
    scale: float = 0.01,
):
    """Assumes diagonal likelihood and sums over each dimension. Sum turns into
    product because exp(sum) = prod(exp)
    """
    dim = len(y_dist)
    term = tensor(
        [
            sample_integral_mixed_mmd_term(
                x[:, d : d + 1], y_dist[d], y_limits[d], y_res=y_res, scale=scale
            )
            for d in range(dim)
        ]
    ).prod()
    return term


def sample_integral_mixed_mmd_term(
    x, y_dist: Distribution, y_limits: Tensor, y_res: int = 100, scale: float = 0.01
):
    y_range = torch.linspace(y_limits[0].item(), y_limits[1].item(), y_res).unsqueeze(1)
    y_repeat = y_range.repeat((x.shape[0], 1))
    probs_y = y_dist.log_prob(y_repeat).exp()
    probs_y = torch.reshape(probs_y, (y_res, x.shape[0]))
    y_reshape = torch.reshape(y_repeat, (y_res, x.shape[0], 1))
    distances = exp(-scale * ((x - y_reshape) ** 2).sum(dim=2))
    dy = (y_limits[1].item() - y_limits[0].item()) / (y_res - 1)
    integrals = (distances * probs_y).sum(dim=0) * dy
    monte_carlo_integral = torch.mean(integrals)
    return monte_carlo_integral

In [None]:
from typing import Optional

import torch
from torch import tensor, ones, eye, Tensor
from torch.distributions import MultivariateNormal, Distribution
from sbi.utils import BoxUniform


class GaussianMixture:
    def __init__(
        self,
        x_o: Optional[Tensor] = None,
        num_trials: int = 5,
        beta: float = 1.0,
        dim: int = 2,
        seed: int = 0,
        limits: Tensor = tensor([[-14, 14], [-14, 14]]),
        resolution: int = 250,
        mmd_length_scale: float = 0.01,
    ):
        """Suggested beta: [2.0, 10.0, 50.0]"""
        # Set seed.
        _ = torch.manual_seed(seed)
        self.limits = limits
        self.resolution = resolution
        self.prior = BoxUniform(-10 * ones(dim), 10 * ones(dim))
        self.x_o = x_o
        # Ensure that shape is [5, 2], not [1, 5, 2].
        if (self.x_o != None) and (len(self.x_o.shape) == 3):
            raise ValueError("Gaussian mixture can not deal with batched observations.")
        self.num_trials = num_trials
        self.beta = beta
        self.mmd_length_scale = mmd_length_scale

    def simulate(self, theta: Tensor) -> Tensor:
        """Simulator."""
        samples1 = torch.randn((self.num_trials, *theta.shape)) + theta
        samples2 = 0.1 * torch.randn((self.num_trials, *theta.shape)) + theta
        all_samples = torch.zeros(*samples1.shape)

        bern = torch.bernoulli(0.5 * ones((self.num_trials, theta.shape[0]))).bool()

        all_samples[bern] = samples1[bern]
        all_samples[~bern] = samples2[~bern]
        all_samples = torch.permute(all_samples, (1, 0, 2))
        return all_samples

    def simulate_misspecified(self, theta: Tensor) -> Tensor:
        """Simulator."""
        # For misspecified x, push it out of the prior bounds.
        samples1 = torch.randn((self.num_trials, *theta.shape)) + theta
        samples2 = 0.5 * torch.randn((self.num_trials, *theta.shape)) + torch.sign(theta)*12.5
        all_samples = torch.zeros(*samples1.shape)

        bern = torch.bernoulli(0.5 * ones((self.num_trials, theta.shape[0]))).bool()

        all_samples[bern] = samples1[bern]
        all_samples[~bern] = samples2[~bern]
        all_samples = torch.permute(all_samples, (1, 0, 2))
        assert ((all_samples[:,:,0]>self.limits[0,0]) & (all_samples[:,:,0]<self.limits[0,1]) & (all_samples[:,:,1]>self.limits[1,0]) & (all_samples[:,:,1]<self.limits[1,1])).all()
        return all_samples
        # samples = 0.5 * torch.randn((self.num_trials, *theta.shape)) + theta
        # samples = torch.permute(samples, (1, 0, 2))
        # return samples

        

    def build_marginal_dist(self, predicted_mean):
        class MixtureDist(Distribution):
            def __init__(self, predicted_mean):
                super().__init__()
                self.dist1 = MultivariateNormal(tensor([predicted_mean]), eye(1))
                self.dist2 = MultivariateNormal(tensor([predicted_mean]), 0.01 * eye(1))

            def log_prob(self, x):
                prob1 = self.dist1.log_prob(x).exp()
                prob2 = self.dist1.log_prob(x).exp()
                return (0.5 * prob1 + 0.5 * prob2).log()

        marginals = [MixtureDist(p) for p in predicted_mean[0]]
        return marginals

    def distance_fn(self, theta):
        """Computes E_{x|t}[(x - x_o)^2]."""
        assert self.x_o is not None, "x_o not set."
        if theta.ndim == 1:
            theta = theta.unsqueeze(0)

        marginals = self.build_marginal_dist(theta)
        mmd_x = ground_truth_mmd(
            x=self.x_o,
            dists_y=marginals,
            y_limits=self.limits,
            y_res=self.resolution,
            scale=self.mmd_length_scale,
        )
        return mmd_x

    def potential(self, theta):
        """Potential for GBI ground truth posterior."""
        if theta.ndim == 1:
            theta = theta.unsqueeze(0)

        potentials = []
        for t in theta:
            term1 = -self.beta * self.distance_fn(t)
            potentials.append(term1 + self.prior.log_prob(t))
        return torch.stack(potentials)
    

simulator = GaussianMixture()
prior = simulator.prior
theta = prior.sample((1000,))
x = simulator.simulate(theta).numpy()
simulator.simulate_misspecified(theta)

AssertionError: 