
# Example: Gaussian Process

In this example we show how to use NUTS to sample from the posterior
over the hyperparameters of a gaussian process.

<img src="file://../_static/img/examples/gp.png" align="center">


In [1]:
import argparse
import os
import time

import matplotlib
import matplotlib.pyplot as plt
import numpy as np

import jax
from jax import vmap
import jax.numpy as jnp
import jax.random as random

import numpyro
import numpyro.distributions as dist
from numpyro.infer import (
    MCMC,
    NUTS,
    init_to_feasible,
    init_to_median,
    init_to_sample,
    init_to_uniform,
    init_to_value,
)
from scipy.spatial.distance import squareform, cdist


# matplotlib.use("Agg")  # noqa: E402



In [110]:

# squared exponential kernel with diagonal noise term
def kernel(X,Xs, var, length, noise, jitter=1.0e-6, include_noise=True):
    delta_sq =  jnp.power(cdist(X, Xs, metric="euclidean") / length, 2.0)
    k = var * jnp.exp(-0.5 * delta_sq)
    if include_noise:
        k += (noise + jitter) * jnp.eye(delta_sq.shape[0])
    return k


def model(Xgp,Xlin, Y):
    # set uninformative log-normal priors on our three kernel hyperparameters
    var = numpyro.sample("kernel_var", dist.LogNormal(0.0, 10.0))
    noise = numpyro.sample("kernel_noise", dist.LogNormal(0.0, 10.0))
    length = numpyro.sample("kernel_length", dist.LogNormal(0.0, 10.0))

    # compute kernel
    k = kernel(Xgp,Xgp, var, length, noise)
    
    # Linear mean
    beta = numpyro.sample("beta",dist.Normal(0,5))
    mu = Xlin*beta
    
    # sample Y according to the standard gaussian process formula
    numpyro.sample(
        "Y",
        # dist.MultivariateNormal(loc=jnp.zeros(Y.shape[0]), covariance_matrix=k),
        dist.MultivariateNormal(loc=mu, covariance_matrix=k),
        obs=Y,
    )


# helper function for doing hmc inference
def run_inference(model, rng_key, Xgp,Xlin, Y):
    start = time.time()
    # demonstrate how to use different HMC initialization strategies
    # if args.init_strategy == "value":
    #     init_strategy = init_to_value(
    #         values={"kernel_var": 1.0, "kernel_noise": 0.05, "kernel_length": 0.5}
    #     )
    # elif args.init_strategy == "median":
    #     init_strategy = init_to_median(num_samples=10)
    # elif args.init_strategy == "feasible":
    #     init_strategy = init_to_feasible()
    # elif args.init_strategy == "sample":
    #     init_strategy = init_to_sample()
    # elif args.init_strategy == "uniform":
    #     init_strategy = init_to_uniform(radius=1)
    # kernel = NUTS(model, init_strategy=init_strategy)
    kernel = NUTS(model, init_strategy=init_to_median(num_samples=10))
    mcmc = MCMC(
        kernel,
        num_warmup=1000,
        num_samples=2000,
        num_chains=2,
        # thinning=2,
        progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
    )
    mcmc.run(rng_key, Xgp, Xlin, Y)
    mcmc.print_summary()
    print("\nMCMC elapsed time:", time.time() - start)
    return mcmc.get_samples()


# do GP prediction for a given set of hyperparameters. this makes use of the well-known
# formula for Gaussian process predictions
def predict(rng_key, Xgp, Y, Xgp_test, var, length, noise, use_cholesky=True):
    # compute kernels between train and test data, etc.
    k_pp = kernel(Xgp_test, Xgp_test, var, length, noise, include_noise=True) # mXm
    k_pX = kernel(Xgp_test, Xgp, var, length, noise, include_noise=False) # mXn
    k_XX = kernel(Xgp, Xgp, var, length, noise, include_noise=True) #nXn

    # since K_xx is symmetric positive-definite, we can use the more efficient and
    # stable Cholesky decomposition instead of matrix inversion
    if use_cholesky:
        K_xx_cho = jax.scipy.linalg.cho_factor(k_XX)
        K = k_pp - jnp.matmul(k_pX, jax.scipy.linalg.cho_solve(K_xx_cho, k_pX.T))
        mean = jnp.matmul(k_pX, jax.scipy.linalg.cho_solve(K_xx_cho, Y))
    else:
        K_xx_inv = jnp.linalg.inv(k_XX) #nXn
        K = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX))) #mXm 
        mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, Y)) 
        # mean += np.dot(beta[:,np.newaxis],np.transpose(Xlin_test))
        
    # mean += np.dot(beta[:,np.newaxis], np.transpose(Xlin_test[:,np.newaxis]))
    sigma_noise = jnp.sqrt(jnp.clip(jnp.diag(K), a_min=0.0)) * jax.random.normal(
        rng_key, X_test.shape[:1]
    )

    # we return both the mean function and a sample from the posterior predictive for the
    # given set of hyperparameters
    return mean, mean + sigma_noise


# create artificial regression dataset
def get_data(N=150, sigma_obs=0.25):
    np.random.seed(0)
    # X = jnp.linspace(-1, 1, N)
    X = np.random.uniform(size=N, low=0, high=1)
    # X = np.sort(X,axis=0)
    X2 = np.random.uniform(size=N,low=-1,high=0)
    X3 = np.random.binomial(n=1,p=0.5,size=N)
    # X2 = np.random.normal(size=N,loc=1,scale=1)
    # Y = X + 0.2 * jnp.power(X, 3.0) + 0.5 * jnp.power(0.5 + X, 2.0) * jnp.sin(4.0 * X)
    Y = 2*X3 + 1.5*X + X2 + 2*X*X2 + 4*np.exp(X2*X)
    # Y = X + 2*X2 + 0.2 * jnp.power(X, 3.0) + 0.5 * jnp.power(0.5 + X, 2.0)
    # Y += sigma_obs * np.random.randn(N)
    Y += sigma_obs * np.random.standard_normal(N)
    # Y -= jnp.mean(Y)
    # Y /= jnp.std(Y)

    # assert X.shape == (N,)
    # assert Y.shape == (N,)

    # X_test = np.random.uniform(size=N, low=-0.2, high=1.2)
    # X_test = np.sort(X_test, axis=0)
    # Y_test = X_test + 0.2 * jnp.power(X_test, 3.0) + 0.5 * jnp.power(0.5 + X_test, 2.0) * jnp.sin(4.0 * X_test)
    # Y_test += sigma_obs * np.random.randn(N)
    X_t = np.random.uniform(size=N, low=-0.25, high=1.25)
    X2_t = np.random.uniform(size=N,low=-1.5,high=0)
    X3_t = np.random.binomial(n=1,p=0.7,size=N)
    # X2_t = np.random.normal(size=N,loc=1,scale=1)
    # Y = X + 0.2 * jnp.power(X, 3.0) + 0.5 * jnp.power(0.5 + X, 2.0) * jnp.sin(4.0 * X)
    # Y_t = X_t + 2*X2_t + 0.2 * np.power(X_t, 3.0) + 0.5 * np.power(0.5 + X_t, 2.0)
    # Y_t = X2_t + np.power(X2,2) + 0.5*np.power(0.5 + X_t, 2.0)*jnp.sin(4*X_t)
    Y_t = 2*X3_t + 1.5*X_t + X2_t + 2*X_t*X2_t + 0.5*np.exp(2*X_t)
    # Y_t = X_t + 2*X2_t + 0.2 * jnp.power(X_t, 3.0) + 0.5 * jnp.power(0.5 + X_t, 2.0)
    # Y_t += sigma_obs * np.random.randn(N)
    Y_t += sigma_obs * np.random.standard_normal(N)
    # Y -= jnp.mean(Y)
    # Y /= jnp.std(Y)

    # return np.transpose(np.array([X,X2])), Y, np.transpose(np.array([X_t,X2_t])), Y_t
    return X,X2,X3,Y, X_t,X2_t,X3_t,Y_t



In [111]:
# DATA

# X, Y, X_test = get_data(N=args.num_data)
X, X2, X3, Y, X_test, X2_test,X3_test, Y_test = get_data()

Xgp = np.transpose(np.array([X,X2]))
Xgp_test = np.transpose(np.array([X_test,X2_test]))

print(Xgp.shape,X3.shape, Y.shape)
print(Xgp_test.shape,X3_test.shape, Y_test.shape)

# Xarr = np.transpose(np.array([X,X2]))
# Xarr_tst = np.transpose(np.array([X_test,X2_test]))




(150, 2) (150,) (150,)
(150, 2) (150,) (150,)


In [112]:
# do inference
rng_key, rng_key_predict = random.split(random.PRNGKey(0))
samples = run_inference(model, rng_key, Xgp, X3, Y)



  mcmc = MCMC(
sample: 100%|██████████| 3000/3000 [00:26<00:00, 111.37it/s, 15 steps of size 3.87e-01. acc. prob=0.94]
sample: 100%|██████████| 3000/3000 [00:24<00:00, 121.67it/s, 7 steps of size 3.49e-01. acc. prob=0.93] 


                     mean       std    median      5.0%     95.0%     n_eff     r_hat
           beta      2.04      0.04      2.04      1.97      2.11   2694.75      1.00
  kernel_length      1.36      0.37      1.30      0.81      1.88   1440.96      1.00
   kernel_noise      0.06      0.01      0.06      0.05      0.07   2562.47      1.00
     kernel_var     61.08    189.48     25.19      2.64    114.88   1262.07      1.01

Number of divergences: 0

MCMC elapsed time: 51.98292517662048





In [79]:
# print(samples['beta'][:,np.newaxis].shape)
# print(np.transpose(X2_test[:,np.newaxis]).shape)
# print(np.dot(samples['beta'][:,np.newaxis],np.transpose(X2_test[:,np.newaxis])).shape)

KeyError: 'beta'

In [113]:
# do prediction
vmap_args = (
    random.split(rng_key_predict, samples["kernel_var"].shape[0]),
    samples["kernel_var"],
    samples["kernel_length"],
    samples["kernel_noise"],
    # samples["beta"],
)
means, predictions = vmap(
    lambda rng_key, var, length, noise: predict(
        rng_key, Xgp, Y, Xgp_test, var, length, noise, use_cholesky=True
    )
)(*vmap_args)



In [114]:
lin_pred = np.dot(samples['beta'][:,np.newaxis],np.transpose(X3_test[:,np.newaxis]))

means += lin_pred
predictions += lin_pred


In [115]:
mean_prediction = np.mean(means, axis=0)
percentiles = np.percentile(predictions, [2.5, 97.5], axis=0)
mean_percentiles = np.percentile(means, [2.5, 97.5], axis=0)

print(mean_prediction.shape, '\n',
      means.shape, 
    'mean:', np.mean(mean_prediction), '\n',
      'mean ytest', np.mean(Y_test), '\n',
      'mean lin_pred', np.mean(lin_pred), '\n',
      'perc mean:', np.percentile(means, [2.5, 97.5]), '\n',
      'perc:', np.percentile(predictions, [2.5, 97.5]), '\n',
      'rmse:', np.mean(np.power(mean_prediction - Y_test,2))) 


(150,) 
 (4000, 150) mean: 4.50788 
 mean ytest 2.3724102412687276 
 mean lin_pred 1.3584474298556646 
 perc mean: [-0.68139954  7.92929771] 
 perc: [-0.71212576  7.96811142] 
 rmse: 12.220756


In [118]:
matplotlib.use('Qt5Agg')

# make plots
fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True)

# plot training data
# ax.scatter(X_test.flatten(), Y_test, c=X2_test, cmap = "prism")
ax.scatter(X_test.flatten(), Y_test-np.mean(lin_pred,axis=0))
# plot 90% confidence level of predictions
# ax.fill_between(X_test.flatten(), percentiles[0,:], percentiles[1,:], color="lightblue")
# ax.fill_between(X_test.flatten(), mean_percentiles[0, :], mean_percentiles[1, :], color="lightblue")
# plot mean prediction
# ax.scatter(X_test.flatten(), mean_prediction, c="red")
ax.scatter(X_test.flatten(), mean_prediction-np.mean(lin_pred,axis=0), c="red")
ax.set(xlabel="X", ylabel="Y", title="Mean predictions with 95% CI")
plt.show()

# plt.savefig("gp_plot.pdf")
# plt.savefig("gp_plot2.pdf")

In [119]:
# make plots
fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True)

# plot training data
# ax.scatter(X_test.flatten(), Y_test, c=X2_test, cmap = "prism")
ax.scatter(X2_test.flatten(), Y_test)
# plot 90% confidence level of predictions
# ax.fill_between(X_test.flatten(), percentiles[0,:], percentiles[1,:], color="lightblue")
# ax.fill_between(X_test.flatten(), mean_percentiles[0, :], mean_percentiles[1, :], color="lightblue")
# plot mean prediction
# ax.scatter(X_test.flatten(), mean_prediction, c="red")
ax.scatter(X2_test.flatten(), mean_prediction, c="red")
ax.set(xlabel="X", ylabel="Y", title="Mean predictions with 95% CI")
plt.show()


In [109]:
fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True)

# plot training data
# ax.scatter(X_test.flatten(), Y_test, c=X2_test, cmap = "prism")
ax.scatter(X2_test.flatten(), Y_test)
# plot 90% confidence level of predictions
# ax.fill_between(X_test.flatten(), percentiles[0,:], percentiles[1,:], color="lightblue")
# ax.fill_between(X_test.flatten(), mean_percentiles[0, :], mean_percentiles[1, :], color="lightblue")
# plot mean prediction
# ax.scatter(X_test.flatten(), mean_prediction, c="red")
ax.scatter(X2_test.flatten(), mean_prediction, c="red")
ax.set(xlabel="X", ylabel="Y", title="Mean predictions with 95% CI")
plt.show()