Here we demonstrate how pyro can be used to do inference on complex
probabilistic models. We will examine a problem of locating signal sources in
2 dimensions.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
import pyro
import pyro.distributions as dist
import pyro.optim as optim
import seaborn
import torch

from contextlib import ExitStack
from eig import elbo_learn
from pyro.contrib.util import iter_plates_to_shape


We have two signal sources and the aim is to infer their locations
$$\theta = {x_1, y_1, x_2, y_2}$$.
Our data is noisy samples of signal strength taken at two-dimensional
coordinates $\xi$. The individual signal strengths follow an inverse square law
and the total intensity at $\xi$ is the superposition of individual signals:

$$\mu(\theta, \xi) = b + \frac{1}{m + ||\theta_1 - \xi||} + \frac{1}{m + ||\theta_2 - \xi||},$$

We can plot the signal map for a given instantiation of $\theta$:

In [None]:
b, m = 1e-1, 1e-4

def get_signal(xi, s1, s2):
    d1 = np.square(xi-s1).sum(axis=-1)
    d2 = np.square(xi-s2).sum(axis=-1)
    return np.log(b + 1 / (m + d1) + 1 / (m + d2))

x = np.arange(-4,4,0.01)
y = np.arange(-4,4,0.01)
xy = np.transpose([np.tile(x, len(y)), np.repeat(y, len(x))])
s1 = np.array([-0.2963,  2.6764])
s2 = np.array([-0.1408, -0.8441])
log_mu = get_signal(xy, s1, s2)
z = log_mu.reshape(800,800)

plt.figure(figsize=(16, 12))
cs = plt.contourf(x,y, z, 64, cmap=cm.Blues)
plt.grid(True); plt.yticks(fontsize=20); plt.xticks(fontsize=20);
plt.show()

where `b` is the background signal and `m` the maximum signal, respectively.
Our prior belief over the location of the sources is a standard normal:

$\theta_k \sim \mathcal{N}(0, \mathcal{I})$.

Signal measurements have Gaussian noise $\sigma$, and we consider the
log-strength for convenience:

$\log y | \theta, \xi \sim \mathcal{N} (\log \mu(\theta, \xi), \sigma)$

Lets create our model and guide

In [None]:
def make_source_model(theta_mu, theta_sig, observation_sd, alpha=1,
                      observation_label="y", b=1e-1, m=1e-4):
    def source_model(design):
        batch_shape = design.shape[:-2]
        with ExitStack() as stack:
            for plate in iter_plates_to_shape(batch_shape):
                stack.enter_context(plate)
            theta_shape = batch_shape + theta_mu.shape[-2:]
            theta = pyro.sample(
                "theta",
                dist.Normal(
                    theta_mu.expand(theta_shape),
                    theta_sig.expand(theta_shape)
                ).to_event(2)
            )
            distance = torch.square(
                design.unsqueeze(-2) - theta.unsqueeze(-3)
            ).sum(dim=-1)
            ratio = alpha / (m + distance)
            mu = b + ratio.sum(dim=-1)
            emission_dist = dist.Normal(
                torch.log(mu), observation_sd
            ).to_event(1)
            y = pyro.sample(observation_label, emission_dist)
            return y

    return source_model

def elboguide(design, dim=1):
    theta_mu = pyro.param("theta_mu", torch.zeros(dim, 1, 2, 2))
    theta_sig = pyro.param("theta_sig", torch.ones(dim, 1, 2, 2),
                           constraint=torch.distributions.constraints.positive)
    batch_shape = design.shape[:-2]
    with ExitStack() as stack:
        for plate in iter_plates_to_shape(batch_shape):
            stack.enter_context(plate)
        theta_shape = batch_shape + theta_mu.shape[-2:]
        pyro.sample("theta", dist.Normal(
            theta_mu.expand(theta_shape),
            theta_sig.expand(theta_shape)).to_event(2)
        )

And lets generate our training data

In [None]:
n_srcs = 2
n_dims = 2
obs_sd = 0.5
theta_mu = torch.zeros(1, 1, n_srcs, n_dims)
theta_sig = torch.ones(1, 1, n_srcs, n_dims)
true_theta = torch.distributions.Normal(theta_mu, theta_sig).sample()
true_model = pyro.condition(
    make_source_model(theta_mu, theta_sig, obs_sd),
    {"theta": true_theta}
)
s1 = true_theta[...,0].squeeze().cpu().numpy()
s2 = true_theta[...,1].squeeze().cpu().numpy()
print(f"True \u03B8:\n\u03B8_1 = {s1}\n\u03B8_2 = {s2}")

xi_x = np.arange(-4,4,0.4)
xi_y = np.arange(-4,4,0.4)
xi = np.transpose([np.tile(xi_x, len(xi_y)), np.repeat(xi_y, len(xi_x))])
xi = torch.tensor(xi).unsqueeze(0).unsqueeze(0)
ys = true_model(xi)


Now we do a bit of magic by calling on a helper function that abstracts away the
SVI loop. What's important to understand here is the interface expose by
`elbo_learn`. As arguments, we pass:

&nbsp; &nbsp; 1. The prior

&nbsp; &nbsp; 2. The training inputs $\xi$

&nbsp; &nbsp; 3. The observation site name

&nbsp; &nbsp; 4. the parameter site name

&nbsp; &nbsp; 5-6. The number of samples and steps for gradient descent

&nbsp; &nbsp; 7. The guide

&nbsp; &nbsp; 8. A dictionary of training targets

&nbsp; &nbsp; 9. An SGD optimiser

In [None]:
prior = make_source_model(
    torch.zeros(1, 1, n_srcs, n_dims),
    torch.ones(1, 1, n_srcs, n_dims),
    obs_sd
)

elbo_n_samples, elbo_n_steps, elbo_lr = 100, 1000, 0.04
loss = elbo_learn(
    prior, xi, ["y"], ["theta"], elbo_n_samples, elbo_n_steps,
    elboguide, {"y": ys}, optim.Adam({"lr": elbo_lr})
)

`elbo_learn` minimises the ELBO loss w.r.t. the parameters of the guide,
`theta_mu` and `theta_sig`. We can now extract them from the pyro param store
and plot the posterior.

In [None]:
theta_mu = pyro.param("theta_mu").detach().data.clone().squeeze()
theta_sig = pyro.param("theta_sig").detach().data.clone().squeeze()
posterior0 = dist.Normal(theta_mu[0], theta_sig[0])
posterior1 = dist.Normal(theta_mu[1], theta_sig[1])
n_samples = 10000
samples0 = posterior0.sample((n_samples,)).cpu().numpy()
samples1 = posterior1.sample((n_samples,)).cpu().numpy()
samples = np.concatenate([samples0, samples1])
hue = np.concatenate([np.zeros(n_samples), np.ones(n_samples)])

plt.figure(figsize=(16,12))
seaborn.kdeplot(x=samples[...,0], y=samples[...,1], fill=True, hue=hue,
                legend=False, thresh=0.01)
np_theta = true_theta.squeeze().cpu().numpy()
plt.scatter(np_theta[:,0], np_theta[:,1], color="green", marker="x", s=100)
plt.grid(True)
plt.show()