In [None]:
try:
    import tinygp
except ImportError:
    %pip install -q tinygp

(multivariate)=

# Multivariate



In [None]:
import numpy as np
import matplotlib.pyplot as plt

random = np.random.default_rng(48392)
X = random.uniform(-5, 5, (100, 2))
yerr = 0.1
y = np.sin(X[:, 0]) * np.cos(X[:, 1] + X[:, 0]) + yerr * random.normal(size=len(X))

# For plotting predictions on a grid
x_grid, y_grid = np.linspace(-5, 5, 100), np.linspace(-5, 5, 50)
x_, y_ = np.meshgrid(x_grid, y_grid)
y_true = np.sin(x_) * np.cos(x_ + y_)
X_pred = np.vstack((x_.flatten(), y_.flatten())).T

# For plotting covariance ellipses
theta = np.linspace(0, 2*np.pi, 500)[None, :]
elipse = 0.5 * np.concatenate((np.cos(theta), np.sin(theta)), axis=0)

plt.figure(figsize=(6, 6))
plt.pcolor(x_grid, y_grid, y_true, vmin=y_true.min(), vmax=y_true.max())
plt.scatter(X[:, 0], X[:, 1], c=y, ec="black", vmin=y_true.min(), vmax=y_true.max())
plt.xlabel("x")
plt.ylabel("y")
_ = plt.title("data")

In [None]:
from functools import partial

import jax
import jax.numpy as jnp
from scipy.optimize import minimize
from tinygp import GaussianProcess, kernels, transforms

from jax.config import config
config.update("jax_enable_x64", True)

def train_gp(nparams, build_gp_func):
    @jax.jit
    @jax.value_and_grad
    def loss(params):
        return -build_gp_func(params).condition(y)
    
    params = np.zeros(nparams)
    soln = minimize(loss, params, jac=True)
    return build_gp_func(soln.x)

def build_gp_uncorr(params):
    kernel = jnp.exp(params[0]) * kernels.ExpSquared(jnp.exp(params[1:3]))
    return GaussianProcess(kernel, X, diag=yerr ** 2)

uncorr_gp = train_gp(3, build_gp_uncorr)

In [None]:
y_pred = uncorr_gp.predict(y, X_pred).reshape(y_true.shape)
xy = uncorr_gp.kernel.kernel2.scale[:, None] * elipse

fig, axes = plt.subplots(1, 2, figsize=(12, 6), sharex=True, sharey=True)
axes[0].plot(xy[0], xy[1], "--k", lw=0.5)
axes[0].pcolor(x_, y_, y_pred, vmin=y_true.min(), vmax=y_true.max())
axes[0].scatter(X[:, 0], X[:, 1], c=y, ec="black", vmin=y_true.min(), vmax=y_true.max())
axes[1].pcolor(x_, y_, y_pred - y_true, vmin=y_true.min(), vmax=y_true.max())
axes[0].set_xlabel("x")
axes[0].set_ylabel("y")
axes[0].set_title("uncorrelated kernel")
axes[1].set_xlabel("x")
_ = axes[1].set_title("residuals")

In [None]:
def build_gp_corr(params):    
    # Build the lower triangular matrix giving the input parameter covariance
    factor = jnp.zeros((2, 2))
    factor = factor.at[jnp.diag_indices(2)].add(jnp.exp(params[1:3]))
    factor = factor.at[jnp.tril_indices(2, -1)].add(params[3:])
    
    kernel = jnp.exp(params[0]) * transforms.Affine(factor, kernels.ExpSquared())
    return GaussianProcess(kernel, X, diag=yerr ** 2)

corr_gp = train_gp(4, build_gp_corr)

In [None]:
y_pred = corr_gp.predict(y, X_pred).reshape(y_true.shape)
xy = corr_gp.kernel.kernel2.scale @ elipse

fig, axes = plt.subplots(1, 2, figsize=(12, 6), sharex=True, sharey=True)
axes[0].plot(xy[0], xy[1], "--k", lw=0.5)
axes[0].pcolor(x_, y_, y_pred, vmin=y_true.min(), vmax=y_true.max())
axes[0].scatter(X[:, 0], X[:, 1], c=y, ec="black", vmin=y_true.min(), vmax=y_true.max())
axes[1].pcolor(x_, y_, y_pred - y_true, vmin=y_true.min(), vmax=y_true.max())
axes[0].set_xlabel("x")
axes[0].set_ylabel("y")
axes[0].set_title("correlated kernel")
axes[1].set_xlabel("x")
_ = axes[1].set_title("residuals")