In [None]:
try:
    import tinygp
except ImportError:
    %pip install -q tinygp

try:
    import jaxopt
except ImportError:
    %pip install -q jaxopt

(classification)=

## GP Classification

In [None]:
import jax.numpy as jnp
import jax
import matplotlib.pyplot as plt

key = jax.random.PRNGKey(0)

In [None]:
import numpy as np

In [None]:
np.random

In [None]:
X = jax.random.normal(key, (200, 2))
y = jnp.logical_xor(X[:, 0] > 0, X[:, 1] > 0)

plt.scatter(
    X[:, 0], X[:, 1], s=30, c=y, cmap=plt.cm.Paired, edgecolors=(0, 0, 0)
)
plt.gca().set_aspect("equal")
_ = plt.axhline(0, color="k")
_ = plt.axvline(0, color="k")

In [None]:
xs = jnp.linspace(-2, 2, num=100)
ys = jnp.linspace(-2, 2, num=100)

xx, yy = jnp.meshgrid(xs, ys)
xx = xx.T
yy = yy.T
T = jnp.vstack((xx.ravel(), yy.ravel())).T

In [None]:
import jax
import optax
import jax.numpy as jnp
import flax.linen as nn
from flax.linen.initializers import zeros
from tinygp import kernels, transforms, GaussianProcess


class RBFLoss(nn.Module):
    @nn.compact
    def __call__(self, X, y, T):
        # Set up a typical Matern-3/2 kernel
        log_sigma = self.param("log_sigma", zeros, ())
        log_ell = self.param("log_ell", zeros, ())
        log_jitter = self.param("log_jitter", zeros, ())
        base_kernel = jnp.exp(2 * log_sigma) * kernels.ExpSquared(
            jnp.exp(log_ell)
        )

        # Evaluate and return the GP negative log likelihood as usual
        gp = GaussianProcess(base_kernel, X, diag=jnp.exp(2 * log_jitter))
        log_prob, gp_cond = gp.condition(y, T)
        return -log_prob, (gp_cond.loc, gp_cond.variance)


def loss(model, params):
    return model.apply(params, X, y, T)[0]

In [None]:
model = RBFLoss()

In [None]:
params = model.init(jax.random.PRNGKey(1234), X, y, T)
tx = optax.sgd(learning_rate=1e-4)
opt_state = tx.init(params)

loss_grad_fn = jax.jit(jax.value_and_grad(loss))
for i in range(1000):
    loss_val, grads = loss_grad_fn(params)
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)