# PyTree's

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/daniel-dodd/gpjax_workshop/blob/main/pytrees.ipynb)

In [None]:
!pip install gpjax==0.7.0

`GPJax` **represents all objects as JAX
[_PyTrees_](https://jax.readthedocs.io/en/latest/pytrees.html)**, giving

- A simple API with a **TensorFlow / PyTorch feel** ...
- ... whilst **fully compatible** with JAX's functional paradigm ...
- ... And **works out of the box** (no filtering) with JAX's transformations
  such as `grad`.

We achieve this through providing a base `Module` abstraction to cleanly
handle parameter trainability and optimising transformations of JAX models.


### The RBF kernel


Intuitively, the kernel defines the notion of *similarity* between
the value of the function at two points, $f(\mathbf{x})$ and $f(\mathbf{x}')$, and
will be denoted as $k(\mathbf{x}, \mathbf{x}')$:

$$\begin{aligned} k(\mathbf{x}, \mathbf{x}') &= \text{Cov}[f(\mathbf{x}),
f(\mathbf{x}')] \end{aligned}$$

For the RBF kernel, we have

$$
k(x, y) = \sigma^2\exp\left(\frac{\lVert
x-y\rVert_{2}^2}{2\ell^2} \right)
$$

- $\sigma^2\in\mathbb{R}_{>0}$ is a
variance parameter 
-  $\ell^2\in\mathbb{R}_{>0}$ a lengthscale parameter.

Terming the evaluation of $`k(x, y)`$ the _covariance_, we can represent
this object as a Python `dataclass` as follows:

In [2]:
import jax
import jax.numpy as jnp
from dataclasses import dataclass, field


@dataclass
class RBF:
    lengthscale: float = field(default=1.0)
    variance: float = field(default=1.0)

    def __call__(self, x: float, y: float) -> jax.Array:
        return self.variance * jnp.exp(-0.5 * ((x - y) / self.lengthscale) ** 2)

Here, the Python `dataclass` is a class that simplifies the process of
creating classes that primarily store data. It reduces boilerplate code and
provides convenient methods for initialising and representing the data. An
equivalent class could be written as:

In [3]:
class RBF:
    def __init__(self, lengthscale: float = 1.0, variance: float = 1.0) -> None:
        self.lengthscale = lengthscale
        self.variance = variance

    def __call__(self, x: float, y: float) -> jax.Array:
        return self.variance * jnp.exp(-0.5 * ((x-y) / self.lengthscale)**2)

To establish some terminology, within the above RBF `dataclass`, we refer to
the lengthscale and variance as _fields_. Further, the `RBF.__call__` is a
_method_. So far so good. However, if we wanted to take the gradient of
the kernel with respect to its parameters $`\nabla_{\ell, \sigma^2} k(1.0, 2.0;
\ell, \sigma^2)`$ at inputs $`x=1.0`$ and $`y=2.0`$, then we encounter a problem:

In [4]:
kernel = RBF()

try:
    jax.grad(lambda kern: kern(1.0, 2.0))(kernel)
except TypeError as e:
    print(e)

Argument '<__main__.RBF object at 0x293b51b40>' of type <class '__main__.RBF'> is not a valid JAX type.


This issues arises as the object we have defined is not yet
compatible with JAX. To achieve this we must consider [JAX's _PyTree_](https://jax.readthedocs.io/en/latest/pytrees.html)
abstraction.

### PyTrees

JAX PyTrees are a powerful tool in the JAX library that enable users to work
with complex data structures in a way that is efficient, flexible, and easy to
use. A PyTree is a data structure that is composed of other data
structures, and it can be thought of as a tree where each 'node' is either a
leaf (a simple data structure) or another PyTree. By default, the set
of 'node' types that are regarded a PyTree are Python lists, tuples, and
dicts.

In [5]:
import jax.tree_util as jtu

#### Example 1:

In [6]:
tree = [3.14, {"Monte": object(), "Carlo": False}]
print(tree)

[3.14, {'Monte': <object object at 0x293d95270>, 'Carlo': False}]


is a PyTree with structure

In [7]:
print(jtu.tree_structure(tree))

PyTreeDef([*, {'Carlo': *, 'Monte': *}])


with the following leaves

In [8]:
print(jtu.tree_leaves(tree))

[3.14, False, <object object at 0x293d95270>]


#### Example 2:

In [9]:
tree = (
    jnp.array([1.0, 2.0, 3.0]),
    jnp.array([4.0, 5.0, 6.0]),
    jnp.array([7.0, 8.0, 9.0]),
)

You can use this template to perform various operations on the data, such as
applying a function to each leaf of the PyTree.

For example, suppose you want to square each element of the arrays. You can
then apply this using the `tree_map` function from the `jax.tree_util` module:

In [10]:
print(jtu.tree_map(lambda x: x**2, tree))

(Array([1., 4., 9.], dtype=float32), Array([16., 25., 36.], dtype=float32), Array([49., 64., 81.], dtype=float32))


In this example, the PyTree makes it easy to apply a function to each leaf of
a complex data structure, without having to manually traverse the data
structure and handle each leaf individually. JAX PyTrees, therefore, are a
powerful tool that can simplify many tasks in machine learning and scientific
computing. As such, most JAX functions operate over _PyTrees of JAX arrays_.
For instance, `jax.lax.scan`, accepts as input and produces as output a
PyTree of JAX arrays.

Another key advantages of using JAX PyTrees is that they are designed to work
efficiently with JAX's automatic differentiation and compilation features. For
example, suppose you have a function that takes a PyTree as input and returns
a scalar value:

In [11]:
def sum_squares(x):
    return jnp.sum(x[0] ** 2 + x[1] ** 2 + x[2] ** 2)

sum_squares(tree)

Array(285., dtype=float32)

You can use JAX's `grad` function to automatically compute the gradient of
this function with respect to the input PyTree:

In [12]:
gradient = jax.grad(sum_squares)(tree)
print(gradient)

(Array([2., 4., 6.], dtype=float32), Array([ 8., 10., 12.], dtype=float32), Array([14., 16., 18.], dtype=float32))


This computes the gradient of the `sum_squares` function with respect to the
input PyTree, and returns a new PyTree with the same shape and structure.

JAX PyTrees are also designed to be highly extensible, where custom types can be readily registered through a global registry with the
values of such traversed recursively (i.e., as a tree!). This means we can
define our own custom data structures and use them as PyTrees. This is the
functionality that we exploit, whereby we construct all Gaussian process
models via a tree-structure through our `Module` object.

### Module

Core idea is represent all model objects via
an immutable PyTree.
- leaves of the PyTree represent the parameters
that are to be trained
- describe their domain and trainable status as
`dataclass` metadata.

- For our RBF kernel we have two parameters; the lengthscale and the variance.
Both of these have positive domains, and by default we want to train both of
these parameters. 
- To encode this we use a `param_field`, where we can define
the domain of both parameters via a `Softplus` bijector (that restricts them
to the positive domain), and set their trainable status to `True`.

In [13]:
import tensorflow_probability.substrates.jax.bijectors as tfb
from gpjax.base import Module, param_field


@dataclass
class RBF(Module):
    lengthscale: float = param_field(1.0, bijector=tfb.Softplus(), trainable=True)
    variance: float = param_field(1.0, bijector=tfb.Softplus(), trainable=True)

    def __call__(self, x: jax.Array, y: jax.Array) -> jax.Array:
        return self.variance * jnp.exp(-0.5 * ((x - y) / self.lengthscale) ** 2)

Here `param_field` is just a special type of `dataclasses.field`. As such the
following:

In [14]:
param_field(1.0, bijector= tfb.Identity(), trainable=False)

Field(name=None,type=None,default=<dataclasses._MISSING_TYPE object at 0x1037b8580>,default_factory=<function param_field.<locals>.<lambda> at 0x2c8a4d3f0>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'bijector': <tfp.bijectors.Identity 'identity' batch_shape=[] forward_min_event_ndims=0 inverse_min_event_ndims=0 dtype_x=? dtype_y=?>, 'trainable': False, 'pytree_node': True}),kw_only=<dataclasses._MISSING_TYPE object at 0x1037b8580>,_field_type=None)

is equivalent to the following `dataclasses.field`

In [15]:
field(default=1.0, metadata={"trainable": False, "bijector": tfb.Identity()})

Field(name=None,type=None,default=1.0,default_factory=<dataclasses._MISSING_TYPE object at 0x1037b8580>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'trainable': False, 'bijector': <tfp.bijectors.Identity 'identity' batch_shape=[] forward_min_event_ndims=0 inverse_min_event_ndims=0 dtype_x=? dtype_y=?>}),kw_only=<dataclasses._MISSING_TYPE object at 0x1037b8580>,_field_type=None)

By default unmarked leaf attributes default to an `Identity` bijector and
trainablility set to `True`.

Critically we can now take the gradient of the kernel with respect to its parameters!

In [22]:
kernel = RBF()

try:
    jax.grad(lambda kern: kern(1.0, 2.0))(kernel)
except TypeError as e:
    print(e)

# Efficient GPs

The white noise kernel has covariance,
$$
    k(x, y) = \sigma^2 \delta(x-y)
$$

We can code this as follows:

In [16]:
@dataclass
class WhiteNoise(Module):
    variance: float = param_field(1.0, bijector=tfb.Softplus(), trainable=True)

    def __call__(self, x: jax.Array, y: jax.Array) -> jax.Array:
        K = jnp.all(jnp.equal(x, y)) * self.variance
        return K.squeeze()

kernel = WhiteNoise()

Recall the costly covariance:

$$k(\mathbf{x}, \mathbf{x}) = \begin{bmatrix} k(\mathbf{x}_1, \mathbf{x}_1) & \cdots & k(\mathbf{x}_1, \mathbf{x}_n) \\ \vdots & \ddots & \vdots \\ k(\mathbf{x}_n, \mathbf{x}_1) & \cdots & k(\mathbf{x}_n, \mathbf{x}_n) \end{bmatrix}$$

Lets compute this for some data using `vmap`

In [17]:
from jax import vmap

# datapoints:
x = jnp.linspace(-3., 3., 50)


# function to compute the gram matrix:
def gram(kernel: WhiteNoise, x:  jax.Array) -> jax.Array:
    return vmap(lambda xi: vmap(lambda yi: kernel(xi, yi))(x))(x)


# compute the gram matrix:
print(gram(kernel, x))

[[1. 0. 0. ... 0. 0. 0.]
 [0. 1. 0. ... 0. 0. 0.]
 [0. 0. 1. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 1. 0. 0.]
 [0. 0. 0. ... 0. 1. 0.]
 [0. 0. 0. ... 0. 0. 1.]]


Notice that the covariance is zero everywhere except when $x=y$. This means the Gram matrix is sparse!

### Compute engine:

We introduce computation engines for computing the Linear Operators instead of the dense Gram matrix!

In [18]:
import cola
from cola.ops import Diagonal, Dense, LinearOperator

@dataclass
class AbstractComputeEngine:
    def gram(self, kernel, x) -> LinearOperator:
        raise NotImplementedError

Naive computation like before can be represented as a `Dense` linear operator:

In [None]:
# Default:
@dataclass
class DenseComputeEngine(AbstractComputeEngine):
    def gram(self, kernel, x) -> Dense:
        return Dense(vmap(lambda xi: vmap(lambda yi: kernel(xi, yi))(x))(x))


But, given the structure, we can save memory and compute time by using a `Diagonal` linear operator:

In [19]:
# Structured:
@dataclass
class DiagonalComputeEngine(AbstractComputeEngine):
    def gram(self, kernel, x) -> Diagonal:
        return Diagonal(vmap(lambda xi: kernel(xi, xi))(x))

We can set the default compute enegine of the kernel to be `Diagonal`!

In [21]:
@dataclass
class WhiteNoise(Module):
    variance: float = param_field(1.0, bijector=tfb.Softplus(), trainable=True)
    compute_engine: AbstractComputeEngine = DiagonalComputeEngine()

    def __call__(self, x: jax.Array, y: jax.Array) -> jax.Array:
        K = jnp.all(jnp.equal(x, y)) * self.variance
        return K.squeeze()

    def gram(self, x: jax.Array) -> jax.Array:
        return self.compute_engine.gram(self, x)


kernel = WhiteNoise()
Kxx = kernel.gram(x)
print(f"LinearOperator: {Kxx}")
print(f"LinearOperator as dense matrix: {Kxx.to_dense()}")

LinearOperator: diag([1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1.])
LinearOperator as dense matrix: [[1. 0. 0. ... 0. 0. 0.]
 [0. 1. 0. ... 0. 0. 0.]
 [0. 0. 1. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 1. 0. 0.]
 [0. 0. 0. ... 0. 1. 0.]
 [0. 0. 0. ... 0. 0. 1.]]


The compute engine can be changed for approximations e.g., Random Fourier Features!