# Toy example

Code for the toy example in "Importance nested sampling with normalising flows". Produces all of the relevant plots and results


In [None]:
import os
import sys

import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
from scipy.special import logsumexp
from scipy import optimize
import seaborn as sns

basedir = "../"
sys.path.append(basedir)
from utils import configure_plotting

configure_plotting(basedir)
np.random.seed(1234)

plt.rcParams["text.usetex"] = True

figsize = plt.rcParams['figure.figsize']
double_figsize = (2.0 * figsize[0], figsize[1])

os.makedirs("figures", exist_ok=True)

## Define the problem

In [None]:
# Gaussian prior standard deviation
prior_std = 2.0
# Gaussian likelihood standard deviation
likelihood_std = 1.0
# Number of dimensions
dims = 2

Compute the posterior standard deviation

In [None]:
post_std = np.sqrt(1 / ((1 / prior_std ** 2) + (1 / likelihood_std ** 2)))
print(f"Posterior standard deviation: {post_std:.8f}")

Define the distributions using scipy

In [None]:
prior_dist = stats.multivariate_normal(
    mean=np.zeros(dims),
    cov=prior_std ** 2 * np.eye(dims)
)
likelihood_dist = stats.multivariate_normal(
    mean=np.zeros(dims),
    cov=likelihood_std ** 2 * np.eye(dims)
)
post_dist = stats.multivariate_normal(
    mean=np.zeros(dims),
    cov=post_std ** 2 * np.eye(dims)
)

The product of two Gaussian PDFs is proportional to another Gaussian where the scaling factor is another Gaussian evaluated at $\mu_1$

$$
\mathcal{N}(x=\mu_1; \mu_2, \sqrt(\sigma_1^2 + \sigma_2^2)).
$$

Thus the evidence is this scaling factor.


In [None]:
true_evidence = stats.multivariate_normal(np.zeros(dims), prior_std ** 2 + likelihood_std ** 2).pdf(0)
print(f"True evidence: {true_evidence}")

## Posterior distribution $p(ln \mathcal{L})$

Determine the posterior define in terms of $\lambda = \ln \mathcal{L}$.

Start by checking the distributions of the radius and radius squared.

In [None]:
r2_dist = stats.chi2(df=dims, scale=post_std ** 2)
r_dist = stats.chi(df=dims, scale=post_std)

In [None]:
post_samples = post_dist.rvs(size=100_000)
post_r2 = np.sum(post_samples ** 2, axis=1)

In [None]:
r2_vec = np.linspace(0, 16, 1000)
r_vec = np.linspace(0, 8, 1000)

In [None]:
fig, axs = plt.subplots(1, 2, sharey=True)
axs[0].hist(np.sqrt(post_r2), 100, density=True, histtype="step")
axs[0].plot(r_vec, r_dist.pdf(r_vec))
axs[0].set_xlabel(r"$r$")

axs[1].hist(post_r2, 100, density=True, histtype="step")
axs[1].plot(r2_vec, r2_dist.pdf(r2_vec))
axs[1].set_xlabel(r"$r^2$")

plt.show()

Functions for converting from log-likelihood ($\lambda$) to radius $r^2$

The equations are:

$$
r^2 = -2 \sigma^2 \left[\frac{1}{2} \ln(2 \pi \sigma) + \lambda \right]
$$

then 
$$
p(\lambda) = p(r^2) \left| \frac{\partial r^2}{\partial \lambda}\right|
$$

where

$$
p(r^2) = \frac{1}{2^{k/2}\Gamma(k/2)} x^{k/2 - 1} e^{-x/2} 
$$

In [None]:
def lambda_to_radius2(l, var, n):
    return -2 * var * (0.5 * n * np.log(2 * np.pi * var) + l)

def radius2_to_lambda(r2, var, n):
    return -0.5 * (n * np.log(2 * np.pi * var) - r2 / var)

In [None]:
def max_lambda(n, var):
    return (-n / 2) * np.log(2 * np.pi * var)

In [None]:
prior_samples = prior_dist.rvs(10_000)
prior_ll = likelihood_dist.logpdf(prior_samples)

In [None]:
post_ll = likelihood_dist.logpdf(post_samples)

In [None]:
r2_recon = lambda_to_radius2(post_ll, likelihood_std ** 2, dims)
assert np.allclose(r2_recon, post_r2)

In [None]:
lambda_vec = np.linspace(min(prior_ll.min(), post_ll.min()), max(prior_ll.max(), post_ll.max()), 1000)

**Converting from radius squared**

Need the Jacobian

$$
\frac{\partial \lambda}{\partial r^2} = \left|\frac{1}{2\sigma^2}\right|
$$

In [None]:
def lambda_pdf(l):
    return r2_dist.pdf(lambda_to_radius2(l, likelihood_std ** 2, dims)) * (2 * likelihood_std ** 2)

In [None]:
lambda_pdf_values = lambda_pdf(lambda_vec)

In [None]:
plt.figure()
plt.hist(post_ll, 100, density=True, histtype="stepfilled")
plt.plot(lambda_vec, lambda_pdf_values)
plt.xlabel(r"$\lambda$")
plt.show()

**Converting from radius**

$$
\frac{\partial \lambda}{\partial r} = \left|\frac{r}{\sigma^2}\right|
$$

In [None]:
plt.figure()
plt.hist(post_ll, 100, density=True, histtype="stepfilled")
r_vec = lambda_to_radius2(lambda_vec, likelihood_std ** 2, dims) ** 0.5
plt.plot(lambda_vec, r_dist.pdf(r_vec) * likelihood_std ** 2 / r_vec)
plt.xlabel(r"$\lambda$")
plt.show()

We can see that both methods agree.

## Algorithm

In [None]:
nlive = 500
n_levels = 3 # (+1 for prior)
rho = 0.5

In [None]:
def log_likelihood(x):
    return likelihood_dist.logpdf(x)

In [None]:
def log_prior(x):
    return prior_dist.logpdf(x)

In [None]:
def construct_level(x, rho=0.5):
    n_remove = int(rho * x.shape[0])
    scale = np.std(x[n_remove:])
    level = stats.multivariate_normal(cov=(scale ** 2) * np.eye(dims))
    return level

In [None]:
def log_meta_proposal(levels, samples, weights=None):
    if isinstance(samples, dict):
        all_samples = np.concatenate([s for s in samples.values()])
    else:
        all_samples = samples.copy()
    # All levels have the same number of samples, so weights are equal
    if weights is None:
        weights = 1 / len(levels)
    # Initial samples are drawn from the log_prior
    log_q_i = np.nan * np.zeros((len(all_samples), len(levels)))
    for i, level in enumerate(levels.values()):
        log_q_i[:, i] = level.logpdf(all_samples)
    log_Q = logsumexp(log_q_i, b=weights, axis=1)
    return log_Q

In [None]:
levels = dict()
level_samples = dict()
sample_log_likelihoods = dict()
sample_log_priors = dict()
sample_meta_proposal = dict()

In [None]:
def update_samples(
    level_samples,
    sample_log_likelihoods,
    sample_log_priors,
    label,
):
    sample_log_likelihoods[label] = log_likelihood(level_samples[label])
    sample_log_priors[label] = log_prior(level_samples[label])
    sorted_idx = np.argsort(sample_log_likelihoods[label])
    level_samples[label] = level_samples[label][sorted_idx]
    sample_log_priors[label] = sample_log_priors[label][sorted_idx]
    sample_log_likelihoods[label] = sample_log_likelihoods[label][sorted_idx]
    return (
        level_samples,
        sample_log_likelihoods,
        sample_log_priors,
    )

In [None]:
label = '-1'
levels[label] = prior_dist
level_samples[label] = prior_dist.rvs(size=nlive)
level_samples, sample_log_likelihoods, sample_log_priors = update_samples(
    level_samples, sample_log_likelihoods, sample_log_priors, label
)


for n in range(n_levels):
    previous = str(n - 1)
    label = str(n)
    levels[label] = construct_level(level_samples[previous], rho=rho)
    level_samples[label] = levels[label].rvs(size=nlive)
    level_samples, sample_log_likelihoods, sample_log_priors = update_samples(
        level_samples, sample_log_likelihoods, sample_log_priors, label
    )

In [None]:
final_samples = np.concatenate([s for s in level_samples.values()], axis=0)
final_log_l = log_likelihood(final_samples)
final_log_p = log_prior(final_samples)
final_log_q = log_meta_proposal(levels, final_samples)

In [None]:
log_posterior_weights = final_log_l + final_log_p - final_log_q
post_weights = np.exp(log_posterior_weights)
post_weights /= np.sum(post_weights)

## Plots

In [None]:
def plot_level_log_likelihood(d):
    fig = plt.figure()
    colours = plt.cm.viridis(np.linspace(0, 1, len(d)))
    for logL, c in zip(d.values(), colours):
        plt.hist(logL, color=c, histtype='step', lw=2.0, density=True)
    return fig

In [None]:
fig, axd = plt.subplot_mosaic(
    [["upper", "upper"], ["lower_left", "lower_right"]],
    figsize=(figsize[0], 1.5 * figsize[1]),
#     sharey=,
)
colours = plt.cm.viridis(np.linspace(0, 1, len(levels)))

print(list(axd.keys()))

for i, ls in enumerate(level_samples.values()):
    axd["upper"].scatter(ls[:, 0], ls[:, 1], s=1, color="silver")
theta = np.linspace(0, 2 * np.pi, 1000)
for i, level in enumerate(levels.values()):
    r = np.sqrt(np.diag(level.cov))[0]
    axd["upper"].plot(r * np.cos(theta), r * np.sin(theta), c=colours[i], ls='-')
r_post = np.sqrt(np.diag(post_dist.cov))[0]
axd["upper"].plot(r_post * np.cos(theta), r_post * np.sin(theta), c="C1", ls='--')
axd["upper"].set_xlabel(r"$\theta_0$")
axd["upper"].set_ylabel(r"$\theta_1$")
axd["upper"].set_xlim([-4, 4])
axd["upper"].set_ylim([-4, 4])
axd["upper"].set_aspect('equal', adjustable='box')

post_range = [-15, final_log_l.max()]
post_range = [-15, 0]

bins = np.linspace(post_range[0], post_range[1], 32)
for i, logL in enumerate(sample_log_likelihoods.values()):
    axd["lower_left"].hist(logL, bins=bins, color=colours[i], histtype='step', density=True)
# axs[1].plot(lambda_vec, lambda_pdf, c='C1', lw=2.0, ls='--')
axd["lower_left"].set_xlabel(r"$\ln \mathcal{L}$")
axd["lower_left"].set_ylabel(r"$p(\ln \mathcal{L})$")
axd["lower_left"].set_xlim(post_range)
# axs[1].set_yscale("log")

# post_range = [-20, final_log_l.max()]
axd["lower_right"].hist(
    final_log_l, 50, density=True, weights=post_weights, histtype="stepfilled",
    range=post_range, color=colours[1]
)
axd["lower_right"].plot(lambda_vec, lambda_pdf_values, c='C1', ls='--')
axd["lower_right"].set_xlim(post_range)
axd["lower_right"].set_xlabel(r"$\ln \mathcal{L}$")
# axd["lower_right"].set_ylabel(r"$p(\ln \mathcal{L})$")

# axd["lower_right"].set_yticklabels([])
axd["lower_left"].sharey(axd["lower_right"])
plt.setp(axd["lower_right"].get_yticklabels(), visible=False)

# axd["lower_right"].set_yticks(axd["lower_right"].get_yticks())
# axs[2].set_yscale("log")

plt.tight_layout()
fig.savefig("figures/toy_example.pdf")
plt.show()

## Evidence

In [None]:
def evidence(log_likelihood, log_prior, log_meta_proposal):
    return np.exp(logsumexp(log_likelihood + log_prior - log_meta_proposal)) / len(log_likelihood)


def evidence_error(log_likelihood, log_prior, log_meta_proposal):
    n = len(log_likelihood)
    z = evidence(log_likelihood, log_prior, log_meta_proposal)
    return (1 / (n * (n - 1))) * np.sum(
        (np.exp(log_likelihood + log_prior - log_meta_proposal) - z) ** 2.0,
        axis=-1
    )

In [None]:
Z_hat = evidence(final_log_l, final_log_p, final_log_q)
Z_hat_sigma = np.sqrt(evidence_error(final_log_l, final_log_p, final_log_q))
print(f"Final estimate: {Z_hat} +/- {Z_hat_sigma}")