In [16]:
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jax.numpy.linalg import slogdet
from jax.scipy.stats import multivariate_normal as mvtn
import matplotlib.pyplot as plt
import numpy as np
from t import t_logpdf
from tqdm import tqdm

In [17]:
def rbf_kernel(X1: jnp.ndarray,
                X2: jnp.ndarray,
                lengthscale: float) -> jnp.ndarray:
    sqdist = jnp.sum((X1[:, None, :] - X2[None, :, :]) ** 2, axis=-1)
    return jnp.exp(-0.5 * sqdist / (lengthscale ** 2))

In [18]:
# GP model

def log_pdf(y, X, lengthscale, amplitude, beta):
    return mvtn.logpdf(y, X @ beta, amplitude**2 * rbf_kernel(X, X, lengthscale))

def sample(key, X, lengthscale, amplitude, beta):
    return jax.random.multivariate_normal(key, X @ beta, amplitude**2 * rbf_kernel(X, X, lengthscale))

In [19]:
# we are working with a fixed lengthscale
lengthscale = 1.0

In [20]:
from jax.numpy.linalg import inv

# dof=n-p for right-invariant prior and dof=n for Jeffreys prior
def log_pred(yp, Xp, yo, Xo, dof):
    n_obs = len(yo)
    Xo = jnp.vstack((Xo, Xp))
    K = rbf_kernel(Xo, Xo, lengthscale)
    K_inv = inv(K)
    A = K_inv - K_inv @ Xo @ inv(Xo.T @ K_inv @ Xo) @ Xo.T @ K_inv
    Aoo, Aop, Apo, App = A[:n_obs, :n_obs], A[:n_obs, n_obs:], A[n_obs:, :n_obs], A[n_obs:, n_obs:]
    App_inv = inv(App)
    Sigma = (yo.T @ (Aoo - Aop @ App_inv @ Apo) @ yo / dof) * App_inv
    mu = - App_inv @ Apo @ yo
    return t_logpdf(yp, mu, Sigma, dof)

# dof=n-p for unbiased and dof=n for MLE
def log_pred_plug_in(yp, Xp, yo, Xo, dof):
    K = rbf_kernel(Xo, Xo, lengthscale)
    K_inv = inv(K)
    beta_hat = inv(Xo.T @ K_inv @ Xo) @ Xo.T @ K_inv @ yo
    amp_hat = jnp.sqrt((yo - Xo @ beta_hat).T @ K_inv @ (yo - Xo @ beta_hat) / dof)
    return log_pdf(yp, Xp, lengthscale, amp_hat, beta_hat)

In [None]:
# numerically evaluate predictive procedures against knowing true parameters
n, p = 3, 2
Xot = jax.random.uniform(jax.random.PRNGKey(0), (n, p))
Xpt = jax.random.uniform(jax.random.PRNGKey(0), (1, p))
# a, b = jnp.array(2, float), jax.random.uniform(jax.random.PRNGKey(42), (p,))
a, b = jnp.array(2, float), jnp.array([7, -15], float)

@jax.jit
def mc_estimate_risk(key, Xo, Xp, amplitude, beta):
    X = jnp.vstack((Xo, Xp))
    samples = sample(key, X, lengthscale, amplitude, beta)
    yo, yp = samples[:n], samples[n:]
    # true conditional likelihood
    true_pdf = log_pdf(samples, X, lengthscale, amplitude, beta) - log_pdf(yo, Xo, lengthscale, amplitude, beta)
    return jnp.array([
        true_pdf - log_pred(yp, Xp, yo, Xo, dof=n-p),
        true_pdf - log_pred(yp, Xp, yo, Xo, dof=n),
        true_pdf - log_pred_plug_in(yp, Xp, yo, Xo, dof=n-p),
        true_pdf - log_pred_plug_in(yp, Xp, yo, Xo, dof=n)])

samples = 2 ** 16
iters = 2 ** 12
key = jax.random.PRNGKey(0)
scores, nans = jnp.zeros((4,)), jnp.zeros((4,))
for i in tqdm(range(iters)):
    key_now, key = jax.random.split(key, 2)
    keys = jax.random.split(key_now, samples)
    results = jax.vmap(mc_estimate_risk, (0, None, None, None, None))(keys, Xot, Xpt, a, b)
    scores += jnp.nansum(results, axis=0)
    nans += jnp.sum(jnp.isnan(results), axis=0)
print(scores / (iters * samples - nans), nans)

100%|██████████| 4096/4096 [00:25<00:00, 163.43it/s]

[5.90440357e-01 1.28367098e+00 1.20501732e+05 3.61501651e+05] [36. 36.  0.  0.]



