In [None]:
try:
    import tinygp
except ImportError:
    !pip install -q tinygp

# Custom kernels: Radial velocity example

To demonstrate the flexibility of the `tinygp` kernel building interface, we will build a custom kernel that is commonly used when studying radial velocity observations of exoplanet-hosting stars, as described by [Rajpaul et al. (2015)](https://arxiv.org/abs/1506.07304).
Take a look at that paper for more details about the math, but the tl;dr is that we want to model a set parallel, but qualitatively different, time series, using a latent Gaussian process.
The interesting part of this model is that Rajpaul et al. model the observations as arbitrary linear combinations of the process and its first time derivative, and they work through the derivation of the resulting kernel function.

In this tutorial, we will implement this kernel using `tinygp`—something that is significantly more annoying to do with other Gaussian process frameworks (trust me, I've tried!)—and demonstrate a few key features along the way.
Besides describing the interface for implementing new kernels and kernel transforms, we also show how `tinygp` can support arbitrary [JAX pytrees](https://jax.readthedocs.io/en/latest/pytrees.html) as input.

## The kernel

The kernel matrix described by [Rajpaul et al. (2015)](https://arxiv.org/abs/1506.07304) is a block matrix where each element is a linear combination of the latent kernel and its first and second derivatives, where the relevant coefficients depend on the "class" of each pair of observations.
This means that our input data needs to include, at each observation, the time `t` (our input coordinate in this case) and an integer class label `label`.
As discussed below, we will structure our data in such a way that we can treat each input as being a tuple `(t, label)`.

Now, let's implement this kernel in a way that `tinygp` understands.
When doing this, you will subclass {class}`tinygp.kernels.Kernel` and implement the {func}`tinygp.kernels.Kernel.evaluate` method.
One very important thing to note here is that `evaluate` will always be called via `vmap`, so you should write your `evaluate` method to operate on a **single pair of inputs** and let `vmap` handle the broadcasting sematics for you.
In this case, we will unpack our inputs `t, label = X` and treat `t` and `label` as scalars.

Here's our implementation:

In [None]:
import jax
import jax.numpy as jnp

import tinygp


class DerivativeKernel(tinygp.kernels.Kernel):
    """A custom kernel based on Rajpaul et al. (2015)

    Args:
        kernel: The kernel function describing the latent process. This can be any other
            ``tinygp`` kernel.
        coeff_prim: The primal coefficients for each class. This can be thought of as how
            much the latent process itself projects into the observations for that class.
            This should be an array with an entry for each class of observation.
        coeff_deriv: The derivative coefficients for each class. This should have the same
            shape as ``coeff_prim``.
    """

    def __init__(self, kernel, coeff_prim, coeff_deriv):
        self.kernel = kernel
        self.coeff_prim, self.coeff_deriv = jnp.broadcast_arrays(
            jnp.asarray(coeff_prim), jnp.asarray(coeff_deriv)
        )

    def evaluate(self, X1, X2):
        t1, label1 = X1
        t2, label2 = X2

        # Differentiate the kernel function: the first derivative wrt x1
        Kp = jax.grad(self.kernel.evaluate, argnums=0)

        # ... and the second derivative
        Kpp = jax.grad(Kp, argnums=1)

        # Evaluate the kernel matrix and all of its relevant derivatives
        K = self.kernel.evaluate(t1, t2)
        d2K_dx1dx2 = Kpp(t1, t2)

        # For stationary kernels, these are related just by a minus sign, but we'll
        # evaluate them both separately for generality's sake
        dK_dx2 = jax.grad(self.kernel.evaluate, argnums=1)(t1, t2)
        dK_dx1 = Kp(t1, t2)

        # Extract the coefficients
        a1 = self.coeff_prim[label1]
        a2 = self.coeff_prim[label2]
        b1 = self.coeff_deriv[label1]
        b2 = self.coeff_deriv[label2]

        # Construct the matrix element
        return (
            a1 * a2 * K
            + a1 * b2 * dK_dx2
            + b1 * a2 * dK_dx1
            + b1 * b2 * d2K_dx1dx2
        )

Now that we have this definition, we can plot what the kernel functions look like for different latent processes.
Don't worry too much about the syntax here, but we're plotting two classes of observations where the first class is just a direct observation of the latent process and the second observes the time derivative.

In [None]:
import numpy as np
import matplotlib.pyplot as plt


def plot_kernel(latent_kernel):
    kernel = DerivativeKernel(latent_kernel, [1.0, 0.0], [0.0, 1.0])

    N = 500
    dt = np.linspace(-7.5, 7.5, N)

    k00 = kernel(
        (jnp.zeros((1)), jnp.zeros((1), dtype=int)),
        (dt, np.zeros(N, dtype=int)),
    )[0]
    k11 = kernel(
        (jnp.zeros((1)), jnp.ones((1), dtype=int)), (dt, np.ones(N, dtype=int))
    )[0]
    k01 = kernel(
        (jnp.zeros((1)), jnp.zeros((1), dtype=int)),
        (dt, np.ones(N, dtype=int)),
    )[0]
    k10 = kernel(
        (jnp.zeros((1)), jnp.ones((1), dtype=int)),
        (dt, np.zeros(N, dtype=int)),
    )[0]

    plt.figure()
    plt.plot(dt, k00, label="$k_{00}$", lw=1)
    plt.plot(dt, k01, label="$k_{01}$", lw=1)
    plt.plot(dt, k10, label="$k_{10}$", lw=1)
    plt.plot(dt, k11, label="$k_{11}$", lw=1)
    plt.legend()
    plt.xlabel(r"$\Delta t$")
    plt.xlim(dt.min(), dt.max())


plot_kernel(tinygp.kernels.Matern52(scale=1.5))
plt.title("Matern-5/2")

plot_kernel(
    tinygp.kernels.ExpSquared(scale=2.5)
    * tinygp.kernels.ExpSineSquared(period=2.5, gamma=0.5)
)
_ = plt.title("Quasiperiodic")