In [1]:
# Imports
from jax import config

config.update("jax_enable_x64", True)

import jax.numpy as jnp
import jax.random as jr
import gpjax as gpx
from gpjax.typing import Array, Float

from dataclasses import dataclass

from jax import jacfwd, jacrev

import tensorflow_probability as tfp
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


# 3D Curl-free

## Data processing

We will use the following latent function, which is the field of an electric dipole.

$$\mathbf f:\mathbb{R}^D\to\mathbb{R}^K$$

where $D=K=3$, defined by:

$$\mathbf f(\mathbf x) = \dfrac 1 {4\pi\varepsilon_0}\left(\dfrac{3\mathbf p\cdot \mathbf x}{r^5}\mathbf x - \dfrac{\mathbf p}{r^3}\right)$$

with $\varepsilon_0=1$ and $\mathbf p = (0,0,1)^\top$. The derivation of this may be found online easily, e.g. on [Wikipedia](https://en.wikipedia.org/wiki/Electric_dipole_moment#Potential_and_field_of_an_electric_dipole) or in any elementary electromagnetism textbook, for example the end of section 4.1 in _Classical Electrodynamics_ by Jackson.

Importantly, $\boldsymbol \nabla\times \mathbf f = 0$, and in this notebook, we will exploit this linear constraint to improve Gaussian Process Regression.

In [None]:
# electric dipole field
def curl_free_3d_example(x, y, z):
    # assumes epsilon is 1
    # dipole moment
    p_x, p_y, p_z = 0, 0, 1

    r = jnp.sqrt(x**2 + y**2 + z**2)

    e_term = 3 * (p_x * x + p_y * y + p_z * z) / r**5
    p_term = 1 / r**3

    f_x = (x * e_term - p_x * p_term) / (4 * jnp.pi)
    f_y = (y * e_term - p_y * p_term) / (4 * jnp.pi)
    f_z = (z * e_term - p_z * p_term) / (4 * jnp.pi)

    return f_x, f_y, f_z

## In-depth specification

We restrict our focus to the following region: $[-2,2]\times [-2,2]\times [0.5,4.5]$
### Train points
Choose 50 train points uniformly randomly from the region
### Test points
A grid of $n_{\text{divisions}}^3$ points evenly spaced (in each dimension)

### Processing data
Initially the each datum is of the form $(\mathbf x,\mathbf y)$, where $\mathbf x,\mathbf y\in\mathbb{R}^3$.

This is modified to three measurements: $\{((\mathbf x, i), y_i)\}_{i=1}^3$. This ensures that the output is one dimensional, and so the kernel function is scalar valued. Explicitly, denoting the matrix-valued kernel function as $\mathbf K$ and the scalar-valued kernel function as $\tilde K$:

$$K_{ij}(\mathbf x,\mathbf x') = \tilde K((\mathbf x,i),(\mathbf x',j))$$

In [None]:
n_divisions = 7

In [None]:
def label_position_3d(data, inducing=0.0):
    # introduce alternating axis label
    n_points = len(data[0])
    axis_label = jnp.tile(jnp.array([0.0, 1.0, 2.0]), n_points)
    axis_labeled_position = jnp.vstack(
        (jnp.repeat(data, repeats=3, axis=1), axis_label)
    ).T
    # introduce label distinguishing between observations and inducing points
    return jnp.concatenate(
        (axis_labeled_position, inducing * jnp.ones((n_points * 3, 1))), axis=1
    )

def stack_vector(data):
    return data.T.flatten().reshape(-1, 1)

def dataset_5d(pos, obs, inducing=0.0):
    return gpx.Dataset(label_position_3d(pos, inducing), stack_vector(obs))

In [None]:
positions = jnp.mgrid[-2:2:n_divisions*1j, -2:2:n_divisions*1j, 0.5:4.5:n_divisions*1j].reshape(3, -1)

observations = jnp.stack(
    curl_free_3d_example(positions[0], positions[1], positions[2]), axis=0
).reshape(3, -1)

dataset = dataset_5d(positions, observations)

In [None]:
simulation_key = jr.PRNGKey(0)

# Proposition for training data, unrestricted by a grid
train_positions = jr.uniform(key=simulation_key, minval=0.0, maxval=4.0, shape=(50, 3)).reshape(3,-1) + jnp.array([-2., -2., 0.5]).reshape(3,1)
train_observations = jnp.stack(
    curl_free_3d_example(train_positions[0], train_positions[1], train_positions[2]), axis=0
).reshape(3, -1)
training_data = dataset_5d(train_positions, train_observations)

## Diagonal kernel
The diagonal kernel is where 
$$K_{ij}(\mathbf x,\mathbf x') = \tilde K((\mathbf x,i),(\mathbf x',j)) = \delta_{ij}k(\mathbf x,\mathbf x')$$

for some $k(\mathbf x,\mathbf x')$.

Therefore, the outputs of the predicted function are independent. This is equivalent to totally separating the dataset into three different datasets: $\mathcal D_i:=\{(\mathbf x_n, y^{(i)}_n)\}_{n=1}^{N}$ for $i=1,2,3$ and performing a Gaussian Process Regression on each dataset separately.

This does not use any prior information about the divergence-free nature of the underlying latent function, so is expected to perform worse than other methods.

In [None]:
@dataclass
class VectorKernel_3d(gpx.kernels.AbstractKernel):
    kernel: gpx.kernels.AbstractKernel = gpx.kernels.RBF(active_dims=[0, 1, 2])

    def __call__(
        self, X: Float[Array, "1 D"], Xp: Float[Array, "1 D"]
    ) -> Float[Array, "1"]:
        # standard RBF-SE kernel if x and x' are on the same output, otherwise returns 0

        w = jnp.array(X[3], dtype=int)
        wp = jnp.array(Xp[3], dtype=int)

        # drop output label to reduce resource usage
        X = jnp.array(X[0:3])
        Xp = jnp.array(Xp[0:3])

        K = (w == wp) * self.kernel(X, Xp)

        return K

## GPJax implementation

In [None]:
def initialise_gp(kernel, mean, dataset):
    prior = gpx.gps.Prior(mean_function=mean, kernel=kernel)
    likelihood = gpx.likelihoods.Gaussian(
        num_datapoints=dataset.n, obs_stddev=jnp.array([1.0e-3], dtype=jnp.float64)
    )
    posterior = prior * likelihood
    return posterior

In [None]:
mean = gpx.mean_functions.Zero()
kernel = VectorKernel_3d()
diagonal_posterior = initialise_gp(kernel, mean, training_data)

In [None]:
def optimise_mll(posterior, dataset, NIters=1000):
    # define the MLL using dataset_train
    objective = gpx.objectives.ConjugateMLL(negative=True)
    # Optimise to minimise the MLL
    opt_posterior, _ = gpx.fit_scipy(
        model=posterior,
        objective=objective,
        train_data=dataset,
        max_iters=NIters,
    )
    return opt_posterior

In [None]:
opt_diagonal_posterior = optimise_mll(diagonal_posterior, training_data)

## Comparison

The models are evaluated by calculating the RMSE between the predicted and true outputs. In particular,

$$RMSE = \sqrt{\frac{1}{N_P D}\sum_{i=1}^{N_P} \|\mathbf y_p-\mathbf y_t\|_2^2}$$

where $\mathbf y_p$ are the predicted values and $\mathbf y_t$ are the true values.

In [None]:
def rmse(predictions, truth):
    # in the paper they compute RMS per vectror component
    return jnp.sqrt(jnp.sum((predictions - truth) ** 2) / truth.shape[0])

In [None]:
def latent_distribution(opt_posterior, prediction_locations, dataset_train):
    latent = opt_posterior.predict(prediction_locations, train_data=dataset_train)
    latent_mean = latent.mean()
    latent_std = latent.stddev()
    return latent_mean, latent_std

In [None]:
diagonal_mean, diagonal_std = latent_distribution(
    opt_diagonal_posterior, dataset.X, training_data
)

In [None]:
dataset_latent_diagonal = dataset_5d(positions, diagonal_mean)

In [None]:
rmse(dataset_latent_diagonal.y, dataset.y)

## Curl free kernel

### Derivation of curl free kernel
The latent function $\mathbf f$ was chosen such that it is curl-free, in other words:
$$\boldsymbol \nabla\times \mathbf f := \begin{pmatrix}\dfrac{\partial f_3}{\partial x_2} - \dfrac{\partial f_2}{\partial x_3}\\
\dfrac{\partial  f_1}{\partial x_3} - \dfrac{\partial  f_3}{\partial x_1}\\
\dfrac{\partial  f_2}{\partial x_1}-\dfrac{\partial  f_1}{\partial x_2}\end{pmatrix}=\mathbf 0$$

Note that*, given any differentiable function $g$, the function:
$$\mathbf f(\mathbf x):= \mathscr G_\mathbf x g := \nabla _{\mathbf x} g$$
automatically satisfies the required constraint.

As is the case with multivariate Gaussians, linear transformations of GPs are GPs (and they transform in much the same way).
In particular:
$$g \sim \mathcal{GP}(0, k_g) \implies \mathscr G_\mathbf x g \sim \mathcal{GP}(\mathbf 0, \mathscr G_\mathbf x k_g \mathscr G_{\mathbf x'}^\top)$$

In our case, we choose $k_g$ to be the squared exponential kernel. Therefore, any $\mathbf f$ picked from the distribution $\mathbf f \sim \mathcal {GP}(\mathbf 0, \mathscr G_\mathbf x  k_g(\mathbf x, \mathbf x') \mathscr G_{\mathbf x'}^\top)$, where
$$\mathscr G_\mathbf x  k_g(\mathbf x, \mathbf x') \mathscr G_{\mathbf x'}^\top= \begin{pmatrix}\dfrac{\partial^2}{\partial x_1 x_1'} & \dfrac{\partial^2}{\partial x_1 x_2'} & \dfrac{\partial^2}{\partial x_1 x_3'}\\
\dfrac{\partial^2}{\partial x_2 x_1'} & \dfrac{\partial^2}{\partial x_2 x_2'} & \dfrac{\partial^2}{\partial x_2 x_3'}\\
\dfrac{\partial^2}{\partial x_3 x_1'} & \dfrac{\partial^2}{\partial x_3 x_2'} & \dfrac{\partial^2}{\partial x_3 x_3'}
\end{pmatrix}k_g(\mathbf x, \mathbf x')$$
will satisfy the required constraint. This is the divergence-free kernel.

*Much of the paper is dedicated to devising a systematic way to construct $\mathscr{G}_\mathbf x$ for arbitrary linear constraints.

In [None]:
def small_hessian(
    kernel, X: Float[Array, "1 D"], Xp: Float[Array, "1 D"]
) -> Float[Array, "1"]:
    # compute all relevant second derivatives at once
    # eg small_hessian(k)[0][1] is d2k/dx1dy2
    return jnp.array(
        jacfwd(jacrev(kernel, argnums=0), argnums=1)(X, Xp), dtype=jnp.float64
    )

@dataclass
class CurlFreeKernel(gpx.kernels.AbstractKernel):
    kernel: gpx.kernels.AbstractKernel = gpx.kernels.RBF(active_dims=[0, 1, 2])

    def __call__(
        self, X: Float[Array, "1 D"], Xp: Float[Array, "1 D"]
    ) -> Float[Array, "1"]:
        axis_1 = jnp.array(X[3], dtype=int)
        axis_2 = jnp.array(Xp[3], dtype=int)

        # drop output label to reduce resource usage
        X = jnp.array(X[0:3])
        Xp = jnp.array(Xp[0:3])

        hess = small_hessian(self.kernel, X, Xp)

        K = hess[axis_1][axis_2]

        return K

In [None]:
kernel = CurlFreeKernel()
curl_free_posterior = initialise_gp(kernel, mean, training_data)

In [None]:
opt_curl_free_posterior = optimise_mll(curl_free_posterior, training_data)

In [None]:
curl_free_mean, curl_free_std = latent_distribution(
    opt_curl_free_posterior, dataset.X, training_data
)

In [None]:
dataset_latent_curl_free = dataset_5d(positions, curl_free_mean)

In [None]:
rmse(dataset_latent_curl_free.y, dataset.y)

## NLPD (Negative Log Predictive Density)
An alternative to RMSE for measuring how well the predicted model matches the true values. It is formally the log-likelihood of predicting the true values using the model. It is calculated using the following formula:

$$\text{NLPD} = -\sum_{i=1}^{N_PD} \log p(\mathbf f(\mathbf x_i)|\mathbf x_i)$$

Here $p(\mathbf y|\mathbf x)$ is a Gaussian distribution with mean given by the posterior mean and standard deviation given by the posterior standard deviation.

In [None]:
# ensure testing data alternates between x0 and x1 components
def nlpd(mean, std, true_observations):
    test_grid = jnp.column_stack((true_observations[0], true_observations[1], true_observations[2])).flatten()
    normal = tfp.substrates.jax.distributions.Normal(loc=mean, scale=std)
    return -jnp.sum(normal.log_prob(test_grid))


# compute nlpd for velocity and helmholtz
nlpd_diagonal = nlpd(diagonal_mean, diagonal_std, observations)
nlpd_curl_free = nlpd(curl_free_mean, curl_free_std, observations)

print(f"NLPD for diagonal: {nlpd_diagonal:.2E} \nNLPD for curl free: {nlpd_curl_free:.2E}")
