# GPJax Review

## Gaussian Process Refresher

### Maths recap

Consider data $(X, f, y)=\{x_i, f_i, y_i\}^N_{i=1}$ where $x_i\in\mathbb{R}^d$, and $y_i \in \mathbb{R}$ is a stochastic observation depending on $f_i=f(x_i)$ for some latent function $f$. Let $k_{\theta}(x, x')$ be a positive definite kernel function parameterised by a set of hyperparameters $\theta$ with resultant Gram matrix $K_{xx}=k_{\theta}(x, x')$. Following standard practice in the literature and assuming a zero mean function, we can posit the hierarchical GP framework as 
$$p(y \lvert f , \theta) = \prod_{i=1}^N p(y_i \lvert f_i, \theta)  f\lvert X, \qquad \theta \sim \mathcal{N}(0, K_{xx}), \qquad \theta \sim p_0.$$

From the above generative model, we can see that the posterior distribution of a GP is
$$p(f, \theta \lvert y) = \frac{1}{C}p(y \lvert f,\theta) p(f \lvert \theta)p_0(\theta)$$
where $C$ denotes the unknown normalisation constant of the posterior. Often we are interested in using the posterior to make new function predictions $f^{\star}$ for test data $X^{\star}$,
$$p\left(f^{\star} | y\right)=\iint p\left(f^{\star} | f, \theta\right) p(f, \theta | y) \,\mathrm{d} \theta \,\mathrm{d} f$$

When the likelihood $p(y_i\lvert f_i,\theta)$ is Gaussian the posterior predictive distribution conditional on $\theta$ is analytically available as we can marginalise $f$ out from the predcitive distribution, and inference methods focus on $\theta$.

### [Demo](intro_to_gps.jl) (note: this is a [Pluto](https://github.com/fonsp/Pluto.jl) notebook)

## GPJax overview

### Aims of GPJax

* Provide a GP API that represents the underlying maths
* Provide a low-level API for research purposes
* Give me a good excuse to play around in Jax

### Principles

* _Functional_ in design
    * Multiple dispatch helps with this
    * There appears to be some synergy between Jax and multiple dispatching
        * It's not perfect though...
* Try not to hide any GP trickery
* Fancy way to make Gaussian random variables


### [Documentation source](https://gpjax.readthedocs.io/en/latest/index.html)

## The good

### Abstraction choice

A Jax API with multiple dispatch makes choices around how much of the API to expose really simple.

Almost everything is a function, so a more heavily abstracted API can be made by simply piping one function through another. This is achieved without having to remember multiple function names.

For example, in GPJax, the posterior GP is represented by a Gaussian random variable. For people just wanting to fit a 

### Automatic differentiation

Working with Jax's AutoDiff module means and optimisation steps can be made verbose. 

```python
mll = jit(marginal_ll(posterior, negative=True))

opt_init, opt_update, get_params = optimizers.adam(step_size=0.01)
opt_state = opt_init(params)

def step(i, opt_state):
    p = get_params(opt_state)
    g = jax.grad(mll)(p, x, y)
    return opt_update(i, g, opt_state)

[opt_state, mll_estimate = step(i, opt_state) for i in range(100)]
```

Given the aim of GPJax is to simply provide the building blocks for building GPs, this means we don't have to wrap the optimisation up in a mysterious `posterior.fit()` style method.

### Wider ecosystem

* GPJax heavily uses Chex and TensorFlow Probability
* Both packages are almost seamless to compose with Jax

#### Chex

Provides a _struct-like_ backbone to the code e.g., [GPs](https://github.com/thomaspinder/GPJax/blob/master/gpjax/gps.py) and [Kernels](https://github.com/thomaspinder/GPJax/blob/master/gpjax/kernels/base.py)

#### TensorFlow Probability

Provides a clean way to state distributions e.g., [Priors](https://gpjax.readthedocs.io/en/latest/nbs/tfp_interface.html#State-priors) and also to return them e.g., [GP random variables](https://gpjax.readthedocs.io/en/latest/nbs/regression.html#Realising-the-random-variable).

### Testing

Jax, Chex and PyTest together make writing unit tests really easy. 

Chex has some handy Jax-specific unit test functions. Straight out of the Chex README:
```python
assert_tree_all_close(tree_x, tree_y)  # values and structure of trees match
assert_tree_all_finite(tree_x)         # all tree_x leaves are finite
```

Similar to PyTest parameterisations, Chex has some decorator functions:
```python
@chex.variants(with_jit=True, without_jit=True)
```

## The bad

Disclaimer: None of these packages/approaches are explicitly bad, they just were not great for GPJax's usage!

### ObJax

* V0.1-0.2 of GPJax used ObJax to enable a more modular code base.
* Unlike in Jax, quantities are stateful in ObJax.

* Unfortunately, anything new in Jax/Jax's ecosystem takes a while to propogate into ObJax e.g., [Optax](https://github.com/deepmind/optax) and [Elegy](https://github.com/poets-ai/elegy).
* Gradients, particularly second-order, are tricky/messy to compute in ObJax as they're w.r.t the object's `.vars()`.
    * This was particularly problematic for implementing HMC and Laplace posterior approximations

### Modularity

Currently the only modularity in GPJax is provided through Chex's `dataclass` objects. The entire conjugate posterior is
```python
@dataclass
class ConjugatePosterior:
    prior: Prior
    likelihood: Gaussian
    name: Optional[str] = "ConjugatePosterior"

    def __repr__(self):
        meanf_string = self.prior.mean_function.__repr__()
        kernel_string = self.prior.kernel.__repr__()
        likelihood_string = self.likelihood.__repr__()
        return f"Conjugate Posterior\n{'-'*80}\n- {meanf_string}\n- {kernel_string}\n- {likelihood_string}"
```

I'm unsure how annoying/undesirable this is to all users(?).

### Parameter transforms

Different parameters have different constraints, if any at all.

The way this is currently handled is to for the user to define some parameters and a corresponding transform e.g.,
```python
params = {"lengthscale": jnp.array([1.0]), "variance": jnp.array([1.0]), "obs_noise": jnp.array([1.0])}
params = transform(params, SoftplusTransformation)
# Do you optimisation...
final_params = untransform(params, SoftplusTransformation)
```

It's fragile and inelgant though e.g., [ignoring unconstrained parameter](https://github.com/thomaspinder/GPJax/blob/master/gpjax/parameters/transforms.py#L8).

It'd be nice to have a parameter-transformation lookup, but it's hard to see how to make this work in Jax without defining custom gradients.

## The ugly

### Parameter handling

#### Trade off 

1. __Dictionaries__ give a clean and safe interface. 
2. Less value corcion is need with __arrays__.

#### Current approach

All parameters, regardless of whether they're fixed or trainable, are dictionaries.  

This does make the underlying code more readable e.g. for scaling an input by a lengthscale parameter

```python
def scale(x: Array, params: dict):
    return x/params['lengthscale']
```
vs.
```python
def scale(x: Array, params: Array):
    return x/params[0]
```

### Parameter handling

Coercion is a pain though... 

For example, integration with TFP's HMC sampler parameters must be an array:

```python
def array_to_dict(varray: jnp.DeviceArray, keys: List):
    pdict = {}
    for val, key in zip(varray, keys):
        pdict[key] = val
    return pdict


def build_log_pi(params, target_fn):
    param_keys = list(params.keys())

    def target(params: jnp.DeviceArray):
        coerced_params = array_to_dict(params, param_keys)
        return target_fn(coerced_params)

    return target


mll = marginal_ll(posterior, negative=False)
target_dist = partial(mll, x=x, y=y, priors=priors)
log_pi = build_log_pi(params, target_dist)
```

### Parameter handling

Jax also sorts dictionaries whenever it operates on them...

In [1]:
import jax.numpy as jnp
from jax import grad, jit

parameters = {"b": jnp.array(1.0), "a": jnp.array(2.0)}

def f(params: dict) -> jnp.DeviceArray:
    return parameters["b"] * jnp.square(params["a"])

print(grad(f)(parameters))



{'a': DeviceArray(4., dtype=float32), 'b': array(0., dtype=float32)}


### Parameter handling

Although safer in a dictionary, parameters stored in this way are also more fragile. For example, this would break GPJax
```
params = {
    "lenghtscale": jnp.array([1.0]),
    "variance": jnp.array([1.0]),
    "obs_noise": jnp.array([1.0]),
}
```
and it's not immediate obvious why...

GPJax does provide some initialisers to try and mitigate this imperfection - [example](https://gpjax.readthedocs.io/en/latest/nbs/regression.html#Stating-parameters)

### Multiple dispatch isn't perfect (yet)

Dispatch types must be explicitly stated e.g., 
```python
@dispatch(jnp.DeviceArray, SpectralKernel, int, int)
def sample_frequencies(
    key, kernel: SpectralKernel, n_frequencies: int, input_dimension: int
) -> jnp.DeviceArray:
    density = spectral_density(kernel)
    return density.sample(sample_shape=(n_frequencies, input_dimension), seed=key)
```

It would be cleaner if the dispatch decorator was just
```python
@dispatch
def sample_frequencies(
    key, kernel: SpectralKernel, n_frequencies: int, input_dimension: int
) -> jnp.DeviceArray:
```
and types inferred from the function's typing.

### Multiple dispatch isn't perfect (yet)

Arrays are not the easiest type to dispatch on in Jax.

Example:

In [5]:
def f(x: jnp.DeviceArray):
    print(type(x))
    return jnp.square(x)

x = jnp.array(1.)
y = f(x)

<class 'jax.interpreters.xla._DeviceArray'>


In [6]:
dydx = grad(f)(x)

<class 'jax.interpreters.ad.JVPTracer'>


In [7]:
jf = jit(f)
jit_dydx = jf(x)

<class 'jax.interpreters.partial_eval.DynamicJaxprTracer'>


## Miscellaneous

### Jitted functions

* Currently, no function in GPJax is jitted. 

* Given everything is returned as function, I've left the decision about what needs jitting in the user's.
    * Some guidance is given in notebooks.
    
* Thought-process here is that having lots of jitted code will make it tricky for users to debug their own code

### Other packages/frameworks used

* [PyTest](https://docs.pytest.org/en/stable/) and [CodeCov](https://app.codecov.io/gh/thomaspinder/gpjax): Both fantastic. CodeCov can be a little finnicky, but it's really helpful so worth the odd frustration
* [Sphinx](https://www.sphinx-doc.org/en/master/): Pain to setup, but just works (I hope!) afterwards
* [ReadTheDocs](https://readthedocs.org/projects/gpjax/): Builds and hosts documentation in one click
* [Black](https://black.readthedocs.io/en/stable/) and [iSort](https://github.com/PyCQA/isort): Formats code. Provided in a PR by someone else - seems to work fine though and salvages my messy code
* [BumpVersion](https://pypi.org/project/bumpversion/) and [Twine](https://pypi.org/project/twine/): For semantic versioning and PyPi submission

### Next steps

* Fully integrate the spectral kernel approximation i.e., [Sparse spectrum Gaussian processes](https://quinonero.net/Publications/lazaro-gredilla10a.pdf)
* Now ObJax has been abstracted out, provide scope for Laplace approximations to the GP's latent variables when the likelihood is non-Gaussian
* Provide an interface to NumPyro (work in progress)
* Integrate work on graph kernels and GP-based dimensionality reduction

![](graphs.png)

## Summary

* Writing APIs in Jax is a net-positive experience
    * Some mental gymnastics are involved with code structuring
    * A more modular framework might (?) be more intuitive
    * It makes decisions about how much abstraction to make easy
* Jax lends itself quite nicely to multiple dispatch
* The wider Jax ecosystem is already quite mature for a new-ish package