In [None]:
import warnings

import pandas as pd

warnings.simplefilter("ignore", FutureWarning)
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import numpyro
import numpyro.distributions as dist
from jax import jit, random, vmap
from numpyro.diagnostics import hpdi
from numpyro.infer import (
    MCMC,
    NUTS,
    Predictive,
    init_to_feasible,
    init_to_median,
    init_to_sample,
    init_to_uniform,
    init_to_value,
)



In [None]:
def predict_stp(X, Y, X_test, var, length, noise, nu):
    # naive implementation

    n1, _ = X.shape
    n2, _ = X_test.shape

    psi_1 = jnp.zeros(n1)  # assumption of zero mean function
    psi_2 = jnp.zeros(n2)  # assumption of zero mean function

    K_11 = kernel(X=X, Z=X.T, include_noise=True, var=var, length=length, noise=noise)
    assert K_11.shape == (n1, n1)
    K_22 = kernel(
        X=X_test, Z=X_test.T, include_noise=True, var=var, length=length, noise=noise
    )
    assert K_22.shape == (n2, n2)
    K_21 = kernel(
        X=X_test, Z=X.T, include_noise=False, var=var, length=length, noise=noise
    )
    assert K_21.shape == (n2, n1)
    K_12 = K_21.T
    K_11_inv = jnp.linalg.inv(K_11)

    psi_2_tilde = K_21 @ K_11_inv @ (Y - psi_1) + psi_2
    beta_1 = (Y - psi_1).T @ K_11_inv @ (Y - psi_1)
    K_22_tilde = K_22 - K_21 @ K_11_inv @ K_12
    df = nu + n1
    mu = psi_2_tilde
    K = K_22_tilde * (nu + beta_1 - 2) / (nu + n1 - 2)
    # return df, mu, K
    return df, mu, K




def predict_gaussian(X, Y, X_test, var, length, noise):
    # compute kernels between train and test data, etc.

    k_pp = kernel(
        X=X_test, Z=X_test.T, include_noise=True, var=var, length=length, noise=noise
    )
    k_pX = kernel(
        X=X_test, Z=X.T, include_noise=False, var=var, length=length, noise=noise
    )
    k_XX = kernel(X=X, Z=X.T, include_noise=True, var=var, length=length, noise=noise)
    # K_xx_inv = jnp.linalg.inv(k_XX)
    K_xx_inv = jnp.linalg.solve(k_XX, jnp.eye(k_XX.shape[0]))
    K = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX)))

    mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, Y))
    return mean, K


In [None]:
def rbf_kernel(
    X, Z, length, var, noise, jitter=1.0e-6, include_noise=True, *args, **kwargs
):

    deltaXsq = jnp.power((X - Z) / length, 2.0)
    k = var * jnp.exp(-0.5 * deltaXsq)
    if include_noise:
        k += (noise + jitter) * np.eye(X.shape[0])

    return k



In [None]:
np.random.seed(12)
x = np.random.uniform(-3, 3, size=4)
y = np.sin(2 * x)

plt.plot(x, y, ".")
plt.show()

X = x[:, np.newaxis]
X_test = np.linspace(-3, 3, 100)[:, np.newaxis]


def sample_(x, n, seed):
    idx = random.randint(random.PRNGKey(seed), shape=(n,), minval=0, maxval=x.shape[0])
    return x[idx]



In [None]:
var = 1
length_scale = .5
noise_gp = 1e-4
noise_stp = 1e-5
nu = 2
mean_gp, K_gp = predict_gaussian(
    X, y, X_test, var=var, length=length_scale, noise=noise_gp
)
gp = dist.MultivariateNormal(loc=mean_gp, covariance_matrix=K_gp)
samples_gp = gp.sample(jax.random.PRNGKey(132), (1_000,))
y_hpdi = hpdi(samples_gp, prob=0.9)
y_05_gp = y_hpdi[0, :]
y_95_gp = y_hpdi[1, :]

nu_stp, mean_stp, K_stp = predict_stp(
    X, y, X_test, var=var, length=length_scale, noise=noise_stp, nu=nu
)
stp = dist.MultivariateStudentT(
    df=nu_stp, loc=mean_stp, scale_tril=np.linalg.cholesky(K_stp)
)
samples_stp = stp.sample(jax.random.PRNGKey(12), (1_000,))
y_hpdi = hpdi(samples_stp, prob=0.9)
y_05_stp = y_hpdi[0, :]
y_95_stp = y_hpdi[1, :]


subsample_gp = sample_(samples_gp, 10, 123)
subsample_stp = sample_(samples_stp, 10, 413)


fig, axes = plt.subplots(1, 2, figsize=(12, 8), sharey=True)

ax1 = axes[0]
ax2 = axes[1]
ax1.plot(x, y, ".")
ax1.plot(X_test.ravel(), mean_gp, "tab:orange")
for i in range(subsample_gp.shape[0]):
    ax1.plot(X_test.ravel(), subsample_gp[i, :], "green", alpha=0.4)
ax1.fill_between(X_test.ravel(), y_05_gp, y_95_gp, color="green", alpha=0.2)
ax1.set_title("GP")

ax2.plot(x, y, ".")
ax2.plot(X_test.ravel(), mean_stp, "tab:orange")
for i in range(subsample_stp.shape[0]):
    ax2.plot(X_test.ravel(), subsample_stp[i, :], "green", alpha=0.4)
ax2.fill_between(X_test.ravel(), y_05_stp, y_95_stp, color="green", alpha=0.2)
ax2.set_title("STP")
plt.show()
