In [None]:
try:
    import tinygp
except ImportError:
    !pip install -q tinygp

try:
    import flax
except ImportError:
    !pip install -q flax
    
try:
    import optax
except ImportError:
    !pip install -q optax

# Deep kernel learning with flax

In [None]:
import numpy as np
import matplotlib.pyplot as plt

random = np.random.default_rng(567)

noise = 0.1

x = np.sort(random.uniform(-1, 1, 200))
y = 2 * (x > 0) - 1 + random.normal(0.0, noise, len(x))
t = np.linspace(-1.5, 1.5, 500)

plt.plot(x, y, ".k")
plt.plot(t, 2 * (t > 0) - 1)
plt.xlim(-1.5, 1.5)
plt.ylim(-1.3, 1.3);

In [None]:
import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.linen.initializers import zeros

from tinygp import kernels, GaussianProcess


class FeatureExtractor(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=100)(x)
        x = nn.relu(x)
        x = nn.Dense(features=20)(x)
        x = nn.relu(x)
        x = nn.Dense(features=1)(x)
        return x


class GPLoss(nn.Module):
    @nn.compact
    def __call__(self, x, y, t):
        extr = FeatureExtractor()
        x = extr(x[:, None])
        t = extr(t[:, None])

        xmin = jnp.min(x, axis=0, keepdims=True)
        xmax = jnp.max(x, axis=0, keepdims=True)
        x = (x - xmin) / (xmax - xmin)
        t = (t - xmin) / (xmax - xmin)

        mean = self.param("mean", zeros, ())
        log_sigma = self.param("log_sigma", zeros, ())
        log_rho = self.param("log_rho", zeros, (x.shape[1],))
        log_jitter = self.param("log_jitter", zeros, ())
        kernel = jnp.exp(2 * log_sigma) * kernels.Matern32(
            jnp.exp(2 * log_rho)
        )

        gp = GaussianProcess(
            kernel, x, diag=noise ** 2 + jnp.exp(2 * log_jitter), mean=mean
        )
        return -gp.condition(y), gp.predict(y, t, return_var=True), (x, t)

In [None]:
import optax

model = GPLoss()


def loss(params):
    return model.apply(params, x, y, t)[0]


params = model.init(jax.random.PRNGKey(0), x, y, t)
tx = optax.sgd(learning_rate=1e-4)
opt_state = tx.init(params)
loss_grad_fn = jax.jit(jax.value_and_grad(loss))

for i in range(1001):
    loss_val, grads = loss_grad_fn(params)
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    if i % 100 == 0:
        print("Loss step {}: ".format(i), loss_val)

In [None]:
mu, var = model.apply(params, x, y, t)[1]
plt.plot(x, y, ".k")
plt.plot(t, mu)
plt.fill_between(t, mu + np.sqrt(var), mu - np.sqrt(var), alpha=0.5)
plt.xlim(-1.5, 1.5)
plt.ylim(-1.3, 1.3);

In [None]:
xp, tp = model.apply(params, x, y, t)[2]

plt.plot(t, tp)
plt.xlabel("x")
plt.ylabel("warped x")
plt.xlim(-1.5, 1.5);