In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from tinygp import GaussianProcess, kernels, transforms

from paths import figures

In [None]:
jax.config.update("jax_enable_x64", True)
sns.set_context("notebook")
sns.set_style("ticks")

In [None]:
class DerivativeKernel(kernels.Kernel):
    def __init__(self, kernel):
        self.kernel = kernel

    def evaluate(self, X1, X2):
        t1, d1 = X1
        t2, d2 = X2

        # Differentiate the kernel function: the first derivative wrt x1
        Kp = jax.grad(self.kernel.evaluate, argnums=0)

        # ... and the second derivative
        Kpp = jax.grad(Kp, argnums=1)

        # Evaluate the kernel matrix and all of its relevant derivatives
        K = self.kernel.evaluate(t1, t2)
        d2K_dx1dx2 = Kpp(t1, t2)

        # For stationary kernels, these are related just by a minus sign, but we'll
        # evaluate them both separately for generality's sake
        dK_dx2 = jax.grad(self.kernel.evaluate, argnums=1)(t1, t2)
        dK_dx1 = Kp(t1, t2)

        return jnp.where(
            d1, jnp.where(d2, d2K_dx1dx2, dK_dx1), jnp.where(d2, dK_dx2, K)
        )

In [None]:
def plot_kernel(base_kernel):
    kernel = DerivativeKernel(base_kernel)

    N = 500
    dt = np.linspace(-7.5, 7.5, N)

    k00 = kernel(
        (jnp.zeros((1)), jnp.zeros((1), dtype=bool)),
        (dt, np.zeros(N, dtype=bool)),
    )[0]
    k11 = kernel(
        (jnp.zeros((1)), jnp.ones((1), dtype=bool)),
        (dt, np.ones(N, dtype=bool)),
    )[0]
    k01 = kernel(
        (jnp.zeros((1)), jnp.zeros((1), dtype=bool)),
        (dt, np.ones(N, dtype=bool)),
    )[0]
    k10 = kernel(
        (jnp.zeros((1)), jnp.ones((1), dtype=bool)),
        (dt, np.zeros(N, dtype=bool)),
    )[0]

    X = (
        np.concatenate((dt, dt)),
        np.concatenate((np.zeros(N, dtype=bool), np.ones(N, dtype=bool))),
    )
    gp = GaussianProcess(kernel, X)
    y_ = gp.sample(jax.random.PRNGKey(0))
    y = y_[:N]
    ydot = y_[N:]

    fig, axes = plt.subplot_mosaic(
        [["A", "B"], ["A", "C"]],
        constrained_layout=True,
        figsize=(10, 4),
    )
    ax = axes["A"]
    ax.plot(dt, k00, "-", label=r"$\mathrm{cov}(f,\,f)$", lw=1)
    ax.plot(dt, k01, "--", label=r"$\mathrm{cov}(f,\,\mathrm{d}f)$", lw=1)
    ax.plot(dt, k10, "-.", label=r"$\mathrm{cov}(\mathrm{d}f,\,f)$", lw=1)
    ax.plot(dt, k11, ":", label=r"$\mathrm{cov}(\mathrm{d}f,\,\mathrm{d}f)$", lw=1)
    ax.legend()
    ax.set_xlabel(r"$\Delta t$")
    # ax.set_ylabel("function value")
    ax.set_xlim(dt.min(), dt.max())

    ax = axes["B"]
    ax.plot(dt - dt[0], y, "k")
    ax.set_xlim(0, dt.max() - dt[0])
    ax.set_ylabel("$f(t)$")

    ax = axes["C"]
    ax.plot(dt - dt[0], ydot, "k")
    ax.set_xlim(0, dt.max() - dt[0])
    ax.set_ylabel(r"$\mathrm{d}f(t) / \mathrm{d}t$")

    axes["B"].set_xticklabels([])
    axes["C"].set_xlabel("t")

    return fig


# fig = plot_kernel(kernels.Matern52(scale=1.5))
# fig.suptitle("Mat√©rn-5/2", fontsize=14)
# plt.savefig(figures / "kernel-ops1.pdf", bbox_inches="tight")

fig = plot_kernel(kernels.ExpSquared(scale=1.5))
fig.suptitle("Squared Exponential", fontsize=14)
plt.savefig(figures / "kernel-ops1.pdf", bbox_inches="tight")

fig = plot_kernel(
    kernels.ExpSquared(scale=2.5) * kernels.ExpSineSquared(gamma=1.0, scale=3.5)
)
fig.suptitle("Quasi-periodic", fontsize=14)
plt.savefig(figures / "kernel-ops2.pdf", bbox_inches="tight")