In [1]:
import sbibm
from sbi.inference import SNRE_A
import torch
from torch import nn
from torch import Tensor, nn, ones
from sbibm.tasks.bernoulli_glm.task import BernoulliGLM

In [2]:
def simulator(theta):
    return theta + torch.randn(theta.shape)

In [3]:
task = sbibm.get_task("two_moons")
prior = task.get_prior_dist()

thetas = prior.sample((1000000,))
xs = simulator(thetas)
observation = torch.ones(1,2)*0
kernel = lambda x: torch.exp(-6*torch.sum((x-observation)**2,-1))

In [4]:
def correction_factor(theta, N=1000):
    xs = simulator(theta.repeat(N,1)).reshape(N,-1, 2)
    return kernel(xs).mean(0)

In [5]:
from sbi import utils as utils
def _loss(self, theta, x, num_atoms=2):
    """
    Returns the binary cross-entropy loss for the trained classifier.
    The classifier takes as input a $(\theta,x)$ pair. It is trained to predict 1
    if the pair was sampled from the joint $p(\theta,x)$, and to predict 0 if the
    pair was sampled from the marginals $p(\theta)p(x)$.
    """

    assert theta.shape[0] == x.shape[0], "Batch sizes for theta and x must match."
    batch_size = theta.shape[0]

    logits = self._classifier_logits(theta, x, num_atoms)
    likelihood = torch.sigmoid(logits).squeeze()

    # Alternating pairs where there is one sampled from the joint and one
    # sampled from the marginals. The first element is sampled from the
    # joint p(theta, x) and is labelled 1. The second element is sampled
    # from the marginals p(theta)p(x) and is labelled 0. And so on.
    labels = ones(2 * batch_size, device=self._device)  # two atoms
    labels[1::2] = 0.0
    weights = kernel(utils.repeat_rows(x, num_atoms))

    # Binary cross entropy to learn the likelihood (AALR-specific)
    return nn.BCELoss(weight=weights)(likelihood, labels)

In [6]:
inf = SNRE_A(prior)

In [7]:
inf._loss = lambda theta,x, *args: _loss(inf, theta,x,*args)

In [8]:
inf.append_simulations(thetas,xs)

<sbi.inference.snre.snre_a.SNRE_A at 0x7f2e4d9d9610>

In [None]:
inf.train()

Training neural network. Epochs trained:  13

In [None]:
posterior = inf.build_posterior(sample_with="mcmc")
posterior.set_default_x(observation)

In [None]:
import numpy as np
X = np.linspace(-1, 1, 500)
Y = np.linspace(-1, 1, 500)
X, Y = np.meshgrid(X, Y)

pos = np.empty(X.shape + (2,))
pos[:, :, 0] = X
pos[:, :, 1] = Y

Z = posterior.log_prob(torch.tensor(pos).reshape(-1,2).float()).reshape(500,500).exp()

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.contourf(X, Y, Z, zdir='z', cmap="viridis", levels=20, rasterized=False)

In [None]:
# corrected
Z = Z * correction_factor(torch.tensor(pos.reshape(-1,2)), N=10000).reshape(500,500)

In [None]:
# The correction factor value
plt.contourf(X, Y, Z, zdir='z', cmap="viridis", levels=20, rasterized=False)

In [None]:
Z = torch.distributions.MultivariateNormal(observation, torch.eye(2)).log_prob(torch.tensor(pos.reshape(-1,2)).float()).exp().reshape(500,500)

In [None]:
# closed form posterior
plt.contourf(X, Y, Z, zdir='z', cmap="viridis", levels=20, rasterized=False)