In [1]:
# Imports
from jax import config

config.update("jax_enable_x64", True)

import jax.numpy as jnp
import jax.random as jr
import gpjax as gpx
from gpjax.typing import Array, Float

from dataclasses import dataclass

from jax import jacfwd, jacrev

import tensorflow_probability as tfp

  from .autonotebook import tqdm as notebook_tqdm


Problem overview, reference to Linearly Constrained GPs paper

# 2D div-free

3rd dimension of 0 (x), 1 (y), 2 (0 div)

## Data processing

Example div-free function: insert math here

In [2]:
def div_free_2d_example(x, y, a=0.01):
    exp_term = jnp.exp(-a * x * y)
    trig_term_x = a * x * jnp.sin(x * y) - x * jnp.cos(x * y)
    trig_term_y = -a * y * jnp.sin(x * y) + y * jnp.cos(x * y)
    return exp_term * trig_term_x, exp_term * trig_term_y

plot of the above field

## In-depth problem specification

In [3]:
n_divisions = 20

In [4]:
def label_position_2d(data):
    # introduce alternating z label
    n_points = len(data[0])
    label = jnp.tile(jnp.array([0.0, 1.0]), n_points)
    return jnp.vstack((jnp.repeat(data, repeats=2, axis=1), label)).T

def stack_vector(data):
    return data.T.flatten().reshape(-1, 1)

def dataset_3d(pos, obs):
    return gpx.Dataset(label_position_2d(pos), stack_vector(obs))

In [5]:
positions = jnp.mgrid[0:4:n_divisions*1j, 0:4:n_divisions*1j].reshape(2, -1)

observations = jnp.stack(
    div_free_2d_example(positions[0], positions[1]), axis=0
).reshape(2, -1)

dataset = dataset_3d(positions, observations)

In [21]:
default_key = jr.PRNGKey(0)

def train_test_split(positions, observations, n_train, n_split, key=default_key):
    permutation = jr.permutation(key, jnp.arange(positions.shape[1]))
    train_indices = permutation[:n_train]
    test_indices = permutation[n_train:n_train+n_split]  # fmt: skip

    n_dims = positions.shape[0]

    match n_dims:
        case 2:
            dataset_train = dataset_3d(
                positions[:, train_indices], observations[:, train_indices]
            )
            dataset_test = dataset_3d(
                positions[:, test_indices], observations[:, test_indices]
            )
        case 3:
            # CR TODO: implement me
            assert False
        case _:
            raise ValueError(f"Invalid number of dimensions: {n_dims}")

    return dataset_train, dataset_test, positions[:, test_indices]


### Diagonal kernel

In [17]:
@dataclass
class VectorKernel_2d(gpx.kernels.AbstractKernel):

    # CR TODO: should this allow specification of lengthscale and variance?
    kernel: gpx.kernels.AbstractKernel = gpx.kernels.RBF(active_dims=[0, 1])

    def __call__(
        self, X: Float[Array, "1 D"], Xp: Float[Array, "1 D"]
    ) -> Float[Array, "1"]:
        # standard RBF-SE kernel if x and x' are on the same output, otherwise returns 0

        w = jnp.array(X[2], dtype=int)
        wp = jnp.array(Xp[2], dtype=int)

        # drop output label to reduce resource usage
        X = X[:2]
        Xp = Xp[:2]

        K = (w == wp) * self.kernel(X, Xp)

        return K


## GPJax implementation

In [18]:
def initialise_gp(kernel, mean, dataset):
    prior = gpx.gps.Prior(mean_function=mean, kernel=kernel)
    likelihood = gpx.likelihoods.Gaussian(
        num_datapoints=dataset.n, obs_stddev=jnp.array([1.0e-3], dtype=jnp.float64)
    )
    posterior = prior * likelihood
    return posterior

In [22]:
simulation_key = jr.PRNGKey(0)

training_data, _, _ = train_test_split(
    positions, observations, 50, 1, key=simulation_key
)


mean = gpx.mean_functions.Zero()
kernel = VectorKernel_2d()
diagonal_posterior = initialise_gp(kernel, mean, training_data)


In [25]:
def optimise_mll(posterior, dataset, NIters=1000):
    # define the MLL using dataset_train
    objective = gpx.objectives.ConjugateMLL(negative=True)
    # Optimise to minimise the MLL
    opt_posterior, _ = gpx.fit_scipy(
        model=posterior,
        objective=objective,
        train_data=dataset,
        max_iters=NIters,
    )
    return opt_posterior

In [26]:
opt_diagonal_posterior = optimise_mll(diagonal_posterior, training_data)

  torch.utils._pytree._register_pytree_node(cls, tree_flatten, tree_unflatten)


Initial loss is 42021.82397017794
Optimization was successful
Final loss is 138.90764227379722 after 17 iterations


## Comparison

Field + residuals plot

In [27]:
def rmse(predictions, truth):
    # in the paper they compute RMS per vectror component
    return jnp.sqrt(jnp.sum((predictions - truth) ** 2) / truth.shape[0])

In [28]:
def latent_distribution(opt_posterior, prediction_locations, dataset_train):
    latent = opt_posterior.predict(prediction_locations, train_data=dataset_train)
    latent_mean = latent.mean()
    latent_std = latent.stddev()
    return latent_mean, latent_std

In [31]:
diagonal_mean, diagonal_std = latent_distribution(
    opt_diagonal_posterior, dataset.X, training_data
)

In [32]:
dataset_latent_diagonal = dataset_3d(positions, diagonal_mean)

In [35]:
rmse(dataset_latent_diagonal.y, dataset.y)

Array(0.89089591, dtype=float64)

## Divergence free kernel

Short derivation of the div-free kernel

In [37]:
def small_hessian(
    kernel, X: Float[Array, "1 D"], Xp: Float[Array, "1 D"]
) -> Float[Array, "1"]:
    # compute all relevant second derivatives at once
    # eg small_hessian(k)[0][1] is d2k/dx1dy2
    return jnp.array(
        jacfwd(jacrev(kernel, argnums=0), argnums=1)(X, Xp), dtype=jnp.float64
    )

@dataclass
class DivFreeKernel(gpx.kernels.AbstractKernel):
    kernel: gpx.kernels.AbstractKernel = gpx.kernels.RBF(active_dims=[0, 1])

    def __call__(
        self, X: Float[Array, "1 D"], Xp: Float[Array, "1 D"]
    ) -> Float[Array, "1"]:
        # the third dimension switches between 00, 01 and 11 kernels

        z = jnp.array(X[2], dtype=int)
        zp = jnp.array(Xp[2], dtype=int)

        # achieve the correct value via 'switches' that are either 1 or 0
        k00_switch = ((z + 1) % 2) * ((zp + 1) % 2)
        k01_switch = ((z + 1) % 2) * zp
        k10_switch = z * ((zp + 1) % 2)
        k11_switch = z * zp

        # drop output label to reduce resource usage
        X = jnp.array(X[0:2])
        Xp = jnp.array(Xp[0:2])

        hess = small_hessian(self.kernel, X, Xp)

        K = (
            k00_switch * hess[1][1]
            - k01_switch * hess[1][0]
            - k10_switch * hess[0][1]
            + k11_switch * hess[0][0]
        )

        return K


In [38]:
kernel = DivFreeKernel()
div_free_posterior = initialise_gp(kernel, mean, training_data)

In [39]:
opt_div_free_posterior = optimise_mll(div_free_posterior, training_data)

Initial loss is 12187.049155670646
Optimization was successful
Final loss is 65.0245418085733 after 31 iterations


In [40]:
div_free_mean, div_free_std = latent_distribution(
    opt_div_free_posterior, dataset.X, training_data
)

In [41]:
dataset_latent_div_free = dataset_3d(positions, div_free_mean)

In [42]:
rmse(dataset_latent_div_free.y, dataset.y)

Array(0.47258673, dtype=float64)

In [45]:
# ensure testing data alternates between x0 and x1 components
def nlpd(mean, std, test_positions):
    test_grid = jnp.column_stack((test_positions[0], test_positions[1])).flatten()
    normal = tfp.substrates.jax.distributions.Normal(loc=mean, scale=std)
    return -jnp.sum(normal.log_prob(test_grid))


# compute nlpd for velocity and helmholtz
nlpd_diagonal = nlpd(diagonal_mean, diagonal_std, positions)
nlpd_div_free = nlpd(div_free_mean, div_free_std, positions)

print(f"NLPD for diagonal: {nlpd_diagonal:.2E} \nNLPD for div free: {nlpd_div_free:.2E}")


NLPD for diagonal: 2.17E+08 
NLPD for div free: 1.10E+07
