In [None]:
try:
    import tinygp
except ImportError:
    %pip install -q tinygp

try:
    import jaxopt
except ImportError:
    %pip install -q jaxopt

(quasisep-diff)=

# Differentiating quasisep kernels

In [None]:
import jax
import jax.numpy as jnp

import numpy as np
import matplotlib.pyplot as plt

from tinygp import GaussianProcess
from tinygp.solvers.quasisep import kernels

from jax.config import config

config.update("jax_enable_x64", True)

In [None]:
class MyKernel(kernels.Matern32):
    def p(self, X):
        _, f = X
        return jnp.array(
            [
                self.sigma * (1 - f),
                -3 * self.sigma * f / jnp.square(self.scale),
            ]
        )

    def q(self, X):
        _, f = X
        return jnp.array([self.sigma * (1 - f), -self.sigma * f])

    def A(self, X1, X2):
        return super().A(X1[0], X2[0]).T


t = jnp.linspace(0, 10, 500)

gp1 = GaussianProcess(kernels.Matern32(1.5), t, diag=1e-8)
y1 = gp1.sample(jax.random.PRNGKey(1))

X = (t, (np.random.rand(len(t)) < 0.5).astype(int))
gp2 = GaussianProcess(MyKernel(1.5), X, diag=1e-8)
y2 = gp2.sample(jax.random.PRNGKey(1))

plt.plot(t, y1)
plt.plot(t[X[1] == 0], y2[X[1] == 0])
plt.plot(t[X[1] == 1], y2[X[1] == 1])
plt.plot(t[X[1] == 0][:-1], np.diff(y2[X[1] == 0]) / np.diff(t[X[1] == 0]))

In [None]:
gp2.kernel.to_symm_qsm(X).lower