# Gaussian Process Kernel Assessment

We always tell people that they could use something like cross validation for kernel selection, but how do you actually do it?

Let's work through it here.

In [None]:
import matplotlib.pyplot as plt

plt.plot()  # Required to reset the rcParams for some reason
plt.style.use(["default", "./araa-gps.mplstyle"])
plt.close()

In [None]:
import warnings
from functools import partial

import arviz as az
import corner
import jax
import jax.numpy as jnp
import numpyro
from numpyro import distributions as dist
from numpyro import infer
from numpyro_ext.optim import optimize
from tinygp import GaussianProcess, kernels

from paths import figures

warnings.filterwarnings("ignore", category=FutureWarning)
jax.config.update("jax_enable_x64", True)

Then we'll simulate some data using a GP with a known squared exponential kernel.
When simulating, we'll choose to hold out some data as our "test set".
For cross validation more generally we'd want to do this for many realizations of the left out data, but in the name of computational efficiency, we'll just do a single realization.

In [None]:
key_t, key_y, key_split = jax.random.split(jax.random.PRNGKey(100), 3)
N = 120
yerr = 0.25
t = jnp.sort(jax.random.uniform(key_t, (N,), minval=0.0, maxval=10.0))

kernel = 2.5**2 * kernels.ExpSquared(0.5)
gp = GaussianProcess(kernel, t, diag=yerr**2)
y = gp.sample(key_y)

mask_train = jax.random.uniform(key_split, (N,)) < 0.7
mask_train = mask_train.at[:10].set(False)
mask_train = mask_train.at[-10:].set(False)

plt.plot(t[mask_train], y[mask_train], ".k", label="training data")
plt.plot(t[~mask_train], y[~mask_train], "+C0", label="held out data")
plt.legend(loc="upper right")
plt.xlabel("x")
plt.ylabel("y");

Set up the probabilistic model to do MCMC for a range of different kernel models, optionally with different numbers of parameters.

The likelihood for this model is the usual:

$$
y_\mathrm{train} \sim p(y_\mathrm{train} | \theta, \phi) = \mathcal{N}(m_\theta, K_\phi)
$$

Then, the validation likelihood is:

$$
y_\mathrm{test} \sim p(y_\mathrm{test} | y_\mathrm{train}, \theta, \phi) = \mathcal{N}(m_\star + K_\star^\mathrm{T} K^{-1}(y_\mathrm{train} - m_\theta), K_{\star\star} - K_\star^\mathrm{T} K^{-1} K_\star)
$$

where the mean and covariance are the usual predictive distributions.
Importantly $K_{\star\star}$ should include the observational uncertainty for the test data on the diagonal.

In [None]:
def model(kernel_builder, t_train, y_train, t_test=None, y_test=None, t_pred=None):
    gp = GaussianProcess(kernel_builder(), t_train, diag=yerr**2)
    if t_test is None:
        numpyro.sample("y_train", gp.numpyro_dist(), obs=y_train)
    else:
        log_prob_train, gp_cond = gp.condition(y_train, t_test, diag=yerr**2)
        log_prob_test = gp_cond.log_probability(y_test)
        numpyro.factor("log_prob_train", log_prob_train)
        numpyro.deterministic("log_prob_test", log_prob_test)

    if t_pred is not None:
        gp_pred = gp.condition(y_train, t_pred, diag=yerr**2).gp
        numpyro.deterministic("mean_pred", gp_pred.loc)
        numpyro.deterministic("std_pred", jnp.sqrt(gp_pred.variance))


def build_exp_sq():
    sigma = numpyro.sample("sigma", dist.HalfNormal(10.0))
    rho = numpyro.sample("rho", dist.HalfNormal(5.0))
    return sigma**2 * kernels.ExpSquared(rho)


def build_matern():
    sigma = numpyro.sample("sigma", dist.HalfNormal(10.0))
    rho = numpyro.sample("rho", dist.HalfNormal(5.0))
    return sigma**2 * kernels.Matern32(rho)


def build_raquad():
    sigma = numpyro.sample("sigma", dist.HalfNormal(10.0))
    rho = numpyro.sample("rho", dist.HalfNormal(5.0))
    alpha = numpyro.sample("alpha", dist.HalfNormal(10.0))
    return sigma**2 * kernels.RationalQuadratic(rho, alpha=alpha)


sampler_exp_sq = infer.MCMC(
    infer.NUTS(partial(model, build_exp_sq), dense_mass=True, target_accept_prob=0.9),
    num_warmup=1000,
    num_samples=2000,
    num_chains=2,
    progress_bar=True,
    chain_method="sequential",
)
sampler_matern = infer.MCMC(
    infer.NUTS(partial(model, build_matern), dense_mass=True, target_accept_prob=0.9),
    num_warmup=1000,
    num_samples=2000,
    num_chains=2,
    progress_bar=True,
    chain_method="sequential",
)
sampler_raquad = infer.MCMC(
    infer.NUTS(partial(model, build_raquad), dense_mass=True, target_accept_prob=0.9),
    num_warmup=1000,
    num_samples=2000,
    num_chains=2,
    progress_bar=True,
    chain_method="sequential",
)

In [None]:
t_pred = jnp.linspace(-1, 11, 500)
mod = partial(model, build_matern)
soln = optimize(mod)(jax.random.PRNGKey(0), t[mask_train], y[mask_train], t_pred=t_pred)
samp = infer.Predictive(mod, soln)(
    jax.random.PRNGKey(0), t[mask_train], y[mask_train], t_pred=t_pred
)

In [None]:
sampler_exp_sq.run(
    jax.random.PRNGKey(0), t[mask_train], y[mask_train], t[~mask_train], y[~mask_train]
)
inf_exp_sq = az.from_numpyro(sampler_exp_sq)
corner.corner(inf_exp_sq)
az.summary(inf_exp_sq)

In [None]:
sampler_matern.run(
    jax.random.PRNGKey(0), t[mask_train], y[mask_train], t[~mask_train], y[~mask_train]
)
inf_matern = az.from_numpyro(sampler_matern)
corner.corner(inf_matern)
az.summary(inf_matern)

In [None]:
sampler_raquad.run(
    jax.random.PRNGKey(0), t[mask_train], y[mask_train], t[~mask_train], y[~mask_train]
)
inf_raquad = az.from_numpyro(sampler_raquad)
corner.corner(inf_raquad)
az.summary(inf_raquad)

In [None]:
plt.figure(figsize=(5, 5))
plt.hist(
    inf_exp_sq.posterior["log_prob_test"].values.flatten(),
    50,
    histtype="step",
    label="squared exp.",
    density=True,
    linewidth=2,
)
plt.hist(
    inf_raquad.posterior["log_prob_test"].values.flatten(),
    50,
    histtype="step",
    label="rational quad.",
    density=True,
    linestyle="dashed",
    linewidth=2,
)
plt.hist(
    inf_matern.posterior["log_prob_test"].values.flatten(),
    50,
    histtype="step",
    label="Matérn-3/2",
    density=True,
    linestyle="dotted",
    linewidth=2,
)
plt.legend(loc="upper left")
plt.yticks([])
plt.ylabel("posterior density")
plt.xlabel("log probability of held out data")

In [None]:
from numpyro.contrib.nested_sampling import NestedSampler

ns_exp_sq = NestedSampler(partial(model, build_exp_sq))
ns_exp_sq.run(jax.random.PRNGKey(0), t, y)
ns_exp_sq._results.log_Z_mean, ns_exp_sq._results.log_Z_uncert

In [None]:
ns_matern = NestedSampler(partial(model, build_matern))
ns_matern.run(jax.random.PRNGKey(0), t, y)
ns_matern._results.log_Z_mean, ns_matern._results.log_Z_uncert

In [None]:
ns_raquad = NestedSampler(partial(model, build_raquad))
ns_raquad.run(jax.random.PRNGKey(0), t, y)
ns_raquad._results.log_Z_mean, ns_raquad._results.log_Z_uncert

In [None]:
log_Z = [
    ns_exp_sq._results.log_Z_mean,
    ns_raquad._results.log_Z_mean,
    ns_matern._results.log_Z_mean,
]
log_Z_uncert = [
    ns_exp_sq._results.log_Z_uncert,
    ns_raquad._results.log_Z_uncert,
    ns_matern._results.log_Z_uncert,
]
names = ["squared exp.", "rational quad.", "Matérn-3/2"]

plt.figure(figsize=(2, 5))
plt.errorbar(names, log_Z, yerr=log_Z_uncert, fmt="o")
plt.xlim(-1, 3)
[lbl.set_rotation(45) for lbl in plt.gca().get_xticklabels()]
plt.ylabel("log(evidence)")

In [None]:
fig, axes = plt.subplot_mosaic(
    [["A", "B", "C"]],
    gridspec_kw={"width_ratios": [2.5, 1, 2], "wspace": 0.1},
    constrained_layout=True,
    figsize=(12, 6),
)

# Data
ax = axes["A"]
ax.plot(t[mask_train], y[mask_train], ".k", label="training")
ax.plot(t[~mask_train], y[~mask_train], "+C0", label="held out")
ax.plot(t_pred, samp["mean_pred"], "C1", label="max. like.", lw=0.5)
ax.fill_between(
    t_pred,
    samp["mean_pred"] - samp["std_pred"],
    samp["mean_pred"] + samp["std_pred"],
    color="C1",
    alpha=0.2,
)
ax.legend(loc="upper right")
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_xlim(t_pred.min(), t_pred.max())
ax.set_title("data")

# Evidence
ax = axes["B"]
log_Z = [
    ns_exp_sq._results.log_Z_mean,
    ns_raquad._results.log_Z_mean,
    ns_matern._results.log_Z_mean,
]
log_Z_uncert = [
    ns_exp_sq._results.log_Z_uncert,
    ns_raquad._results.log_Z_uncert,
    ns_matern._results.log_Z_uncert,
]
names = ["squared exp.", "rational quad.", "Matérn-3/2"]
ax.errorbar(names, log_Z, yerr=log_Z_uncert, fmt="ok")
ax.set_xlim(-1, 3)
[lbl.set_rotation(45) for lbl in ax.get_xticklabels()]
ax.set_ylabel("log(evidence)")
ax.set_title("Bayesian evidence")

# Cross validation
ax = axes["C"]
ax.hist(
    inf_exp_sq.posterior["log_prob_test"].values.flatten(),
    50,
    histtype="step",
    label="squared exp.",
    density=True,
    linewidth=2,
)
ax.hist(
    inf_raquad.posterior["log_prob_test"].values.flatten(),
    50,
    histtype="step",
    label="rational quad.",
    density=True,
    linestyle="dashed",
    linewidth=2,
)
ax.hist(
    inf_matern.posterior["log_prob_test"].values.flatten(),
    50,
    histtype="step",
    label="Matérn-3/2",
    density=True,
    linestyle="dotted",
    linewidth=2,
)
ax.legend(loc="upper left", fontsize=12)
ax.set_yticks([])
ax.set_ylabel("posterior density")
ax.set_xlabel("log probability of held out data")
ax.set_title("cross validation")

plt.savefig(figures / "assessment.pdf", bbox_inches="tight")