# Why JAX?

!!! Info ""
    If you are interested in Feedbax but unfamiliar with JAX—or new to Python—then read on, for an overview of some of the tools on which Feedbax is based.

JAX isn't a machine learning framework like [PyTorch](https://pytorch.org/). It's a more general-purpose tool. 

??? Info "What does JAX provide?"

    - A [NumPy](https://numpy.org/)-like API: Many of the things you can write in NumPy, you [can also write](https://jax.readthedocs.io/en/latest/jax.numpy.html) in JAX—you just have to `import jax.numpy as jnp` instead of `import numpy as np`. 
    - [Just-in-time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html): In many cases, this makes JAX much faster than NumPy. 
    - [Automatic differentiation](https://jax.readthedocs.io/en/latest/jax-101/01-jax-basics.html#jax-first-transformation-grad): We use this to get derivatives of functions—usually, to train models through gradient descent.
    - [Automatic vectorization](https://jax.readthedocs.io/en/latest/jax-101/03-vectorization.html): We can easily transform a function that works on single examples, to a function that processes entire batches of data.
    - [Parallelism](https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html): This makes it easy to split up a large model across multiple devices (e.g. GPUs).

    Automatic differention and JIT compilation are features normally found working in the background in ML frameworks, but JAX lets you use them in explicit, arbitrary, powerful ways.


That's why Feedbax is not just built on JAX, but also:

- [Equinox](https://github.com/patrick-kidger/equinox), which allows us to define PyTorch-like modules, making it easiers to organize our models;
- [Optax](https://github.com/google-deepmind/optax), which provides optimizers (like Adam) which you'd normally find in ML frameworks;
- [Diffrax](https://github.com/patrick-kidger/diffrax), which provides numerical solvers for differential equations.

My favourite part about working with JAX is how nicely it plays with nested containers of data, or [*PyTrees*](https://jax.readthedocs.io/en/latest/pytrees.html). 

## Pytrees



Let's start with a list and a dict that contain similar values.

In [1]:
some_list = [1, 2, 3]

some_dict = {'a': 1, 'b': 2, 'c': 3}

In standard Python, a [*comprehension*](https://docs.python.org/3/tutorial/datastructures.html#list-comprehensions) is a typical way of applying some computation to every value in a list or dict.

In [2]:
[x ** 2 for x in some_list]

[1, 4, 9]

In [3]:
{k: x ** 2 for k, x in some_dict.items()}

{'a': 1, 'b': 4, 'c': 9}

While these solutions are similar, they're not interchangeable. If our data is a list we can use the first method and get a list in return. But as soon as we start using some data that's stored in a dict, we need to change our code.

Conveniently, JAX provides a function [`tree_map`](https://jax.readthedocs.io/en/latest/_autosummary/jax.tree_util.tree_map.html) that behaves the same way for both lists and dicts.

In [4]:
from jax.tree_util import tree_map

tree_map(lambda x: x ** 2, some_list)

[1, 4, 9]

In [5]:
tree_map(lambda x: x ** 2, some_dict)

{'a': 1, 'b': 4, 'c': 9}

??? Note "Python's built-in `map`"
    Python includes a built-in function [`map`](https://docs.python.org/3/library/functions.html#map) which is similar in principle to `tree_map`. For example, we can do `list(map(lambda x: x**2, some_list))` to get the same result as `tree_map(lambda x: x**2, some_list)`. 
    
    However, `map` doesn't return the same data structure it is given: to square all the values in a dict, and return a dict like we did with `tree_map`, we'd have to do something like `dict(zip(some_dict.keys(), map(lambda x: x**2, some_dict.values())))`. This is harder to read, and write.


Even better, `tree_map` works on nested containers. 

In [7]:
import jax.numpy as jnp

some_data = [{'p': [1, 2], 'x': 1.0},
             [5, 6, 7, 8, {'y': jnp.array([2, 2, 2])}]]

tree_map(lambda x: x ** 2, some_data)

CUDA backend failed to initialize: Unable to load CUDA. Is it installed? (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


[{'p': [1, 4], 'x': 1.0},
 [25, 36, 49, 64, {'y': Array([4, 4, 4], dtype=int32)}]]

How does this work? JAX treats both lists and dicts—*and any nested structures of lists and dicts*—as PyTrees. 

A PyTree's *leaves* are the data it ultimately contains.

In [8]:
from jax.tree_util import tree_leaves, tree_structure

tree_leaves(some_data)

[1, 2, 1.0, 5, 6, 7, 8, Array([2, 2, 2], dtype=int32)]

Those leaves are arranged in a tree with a certain structure. 

In [9]:
tree_structure(some_data)

PyTreeDef([{'p': [*, *], 'x': *}, [*, *, *, *, {'y': *}]])

Note that the JAX array counts as a leaf, not as part of the tree structure. What does JAX treat as a leaf, and what does it treat as the structure in which all the leaves are contained? 

By default:

- *leaves*—AKA *leaf nodes*—include NumPy and JAX arrays, as well as basic data types like `int`, `float`, `str`, and `bool`;
- *internal nodes* are lists, dicts, and tuples, which JAX recognizes as PyTrees themselves. When JAX encounters these, its default stance is that "nesting continues here"—so it looks inside the node for leaves, or even deeper layers of nodes.

Importantly, we can change what counts as a leaf. Many functions that operate on PyTrees can take an argument `is_leaf`. 

In [10]:
tree_leaves(some_data, is_leaf=lambda x: isinstance(x, dict))

[{'p': [1, 2], 'x': 1.0}, 5, 6, 7, 8, {'y': Array([2, 2, 2], dtype=int32)}]

In [11]:
tree_structure(some_data, is_leaf=lambda x: isinstance(x, dict))

PyTreeDef([*, [*, *, *, *, *]])

Here, we've told JAX to treat dicts as leaves, rather than as containers. Now, the dicts appear whole and unflattened in the list of leaves, and the PyTree structure reflects this. 

To get the leaves and the tree structure in one call, use [`tree_flatten`](https://jax.readthedocs.io/en/latest/_autosummary/jax.tree_util.tree_flatten.html#jax.tree_util.tree_flatten):

In [None]:
from jax.tree_util import tree_flatten

leaves, structure = tree_flatten(some_data)

Given both leaves and the structure, [`tree_unflatten`](https://jax.readthedocs.io/en/latest/_autosummary/jax.tree_util.tree_unflatten.html#jax.tree_util.tree_unflatten) builds a PyTree. Let's reconstruct the original `some_data`:

In [None]:
from jax.tree_util import tree_unflatten

tree_unflatten(structure, leaves)

Because JAX understands the structure of PyTrees, we can apply operations to multiple PyTrees when their structures match.

In [14]:
some_arrays = [
    (jnp.array([1, 2]), jnp.array([3, 4])),
    jnp.array([5, 6])
]

some_other_arrays = [
    (jnp.array([7, 8]), jnp.array([3, 4])),
    jnp.array([1, 1])
]

tree_structure(some_arrays) == tree_structure(some_other_arrays)

True

In [15]:
tree_map(
    lambda x, y: x + y,
    some_arrays,
    some_other_arrays
)

[(Array([ 8, 10], dtype=int32), Array([6, 8], dtype=int32)),
 Array([6, 7], dtype=int32)]

Here, `tree_map` works "leafwise" to pick out the arguments to a function: the `x` values are the leaves from `some_arrays`, and the `y` values are the matching leaves from `some_other_arrays`. 

In this example, the result will be different if we tell JAX to treat tuples as leaves. The first two JAX arrays in each PyTree are contained in a tuple, so the first `x` passed to the function will be a pair of tuples, as will the first `y`. When we apply `+` to two tuples, we concatenate them.

In [38]:
tree_map(
    lambda x, y: x + y,
    some_arrays,
    some_other_arrays,
    is_leaf=lambda x: isinstance(x, tuple)
)

[(Array([1, 2], dtype=int32),
  Array([3, 4], dtype=int32),
  Array([7, 8], dtype=int32),
  Array([3, 4], dtype=int32)),
 Array([6, 7], dtype=int32)]

The array that's not inside a tuple gets added the same way it did before, because it still counts as a leaf—it's just that now, tuples *also* count as leaves.

### A pytree of your own

A PyTree is any kind of container that JAX knows how to flatten and unflatten. By default, this includes lists, dicts, and tuples. 

When the default containers aren't enough for us, we can define our own types of containers, and tell JAX how to flatten and unflatten them. After that, JAX will treat them as PyTrees!

In [39]:
from jax.tree_util import register_pytree_node_class

@register_pytree_node_class
class TwoValues:
    def __init__(self, a, b):
        self.a = a
        self.b = b

    def tree_flatten(self):
        return (self.a, self.b), None  # leaves, aux_data

    @classmethod
    def tree_unflatten(cls, aux_data, leaves):
        print(aux_data)
        return cls(*leaves)



Here, the method `tree_flatten` tells JAX how to flatten a `TwoValues` object into its leaves, and `tree_unflatten` tells how to construct `TwoValues` given the leaves.

As its name suggests, the decorator [`register_pytree_node_class`](https://jax.readthedocs.io/en/latest/_autosummary/jax.tree_util.register_pytree_node_class.html#jax.tree_util.register_pytree_node_class) registers our new PyTree type with JAX.

Now we can use `TwoValues` as part of any PyTree:

In [40]:
tree = (TwoValues(12, 45), 3, {'a': TwoValues(4, 5)})

tree_leaves(tree)

[12, 45, 3, 4, 5]

In [43]:
tree_structure(tree)

PyTreeDef((CustomNode(TwoValues[None], [*, *]), *, {'a': CustomNode(TwoValues[None], [*, *])}))

## Equinox

!!! Note inline end ""
    Most objects in Feedbax are derived from `equinox.Module`.

[Equinox](https://docs.kidger.site/equinox/) adds some useful tools to JAX. 

In particular, [`equinox.Module`](https://docs.kidger.site/equinox/api/module/module/) allows us to easily define classes that are PyTrees, and that combine model parameters with model computations.

In [45]:
import equinox as eqx
import jax


class SomeModel(eqx.Module):
    param1: int
    param2: jax.Array

    def __call__(self, x: float):
        return self.param1 + x * self.param2


# Construct an example model.
model = SomeModel(3, jnp.array([1, 2, 3]))

In our class definition, the method `__call__` tells Python how a `SomeModel` object should behave, when we call it like a function:

In [None]:
model(2.5)

This is a nice way to define and execute our model computation. 


Another convenient thing about Equinox `Module` is that it's a [`dataclass`](https://docs.python.org/3/library/dataclasses.html). In a normal Python class, to assign `param1` and `param2` as instance attributes we'd have to do this:

In [None]:
class SomeModel:
    def __init__(self, param1: int, param2: jax.Array):
        self.param1 = param1
        self.param2 = param2

    def __call__(self, x: float):
        return self.param1 + x * self.param2

When our class is a dataclass, it automatically defines a default `__init__` method like the one above. We just have to define the list of parameters (that is, dataclass *fields*):

In [None]:
from dataclasses import dataclass

@dataclass
class SomeModel:
    param1: int
    param2: jax.Array

    def __call__(self, x: float):
        return self.param1 + x * self.param2

Any class or subclass we define from `eqx.Module` will automatically work this way, without needing to add the `@dataclass` decorator. 

!!! Note 
    We can still add our own `__init__` method to a dataclass if we need to do something fancier than just assigning values to fields. 
    
    In case only small modifications to `__init__` are needed, it may be convenient to define [`__post_init__`](https://docs.python.org/3/library/dataclasses.html#dataclasses.__post_init__) instead.

The best thing about Equinox modules is that they are PyTrees:

In [50]:
# Get a flattened list of model parameters.
tree_leaves(model)

[3, Array([1, 2, 3], dtype=int32)]

It turns out this is very useful for structuring models, but that's beyond the scope of this example.


??? Note "Similarity of Equinox and PyTorch modules"
    Equinox's [`Module`](https://docs.kidger.site/equinox/api/module/module/#equinox.Module) is kind of like PyTorch's [`nn.Module`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html). However, PyTorch modules:
    
    - are not PyTrees, because PyTorch has no general, built-in concept of PyTrees;
    - are not automatically dataclasses, and it can be kind of [problematic](https://discuss.pytorch.org/t/how-to-use-dataclass-with-pytorch/53444/9) to convert them;
    - define the model computation in the `forward` method, rather than `__call__`. Technically though, PyTorch still has to define `__call__` in the background to have its module objects behave like functions.

## Vectorisation and `vmap`

The power of pytrees goes much deeper than we've seen here. The core JAX transformations, jax.vmap and jax.grad

!!! NOTE    
    If you run into problems with `jax.vmap`, try using Equinox's `filter_vmap` as we've done above. It does the same thing, but a little more intelligently.

## Functions and states

JAX [plays best](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions) with [pure functions](https://en.wikipedia.org/wiki/Purely_functional_programming). Let's see what that means.

Perhaps you are familiar with object-oriented programming, where *classes* define how objects possess and manipulate their *internal states*. For example, let's define a type of object that 1) possesses two attributes, and 2) when it's called, returns a result, but also internally updates one of its attributes.

In [83]:
class StatefulFoo:
    smee: int
    a: int

    def __init__(self, a: int):
        self.smee = 0
        self.a = a

    def __call__(self, x: int):

        if x > 3:
            self.smee = 2

        return self.a * x


a = 2
foo = StatefulFoo(a)
x = 1

print("\t\tx\tsmee")

for i in range(7):
    x = foo(x)

    print(f"Step {i}:\t\t{x}\t{foo.smee}")

		x	smee
Step 0:		2	0
Step 1:		4	0
Step 2:		8	2
Step 3:		16	2
Step 4:		32	2
Step 5:		64	2
Step 6:		128	2


Importantly, the internal state—the value of `foo.smee`—changes once a certain value is passed to `foo`. This is obvious in this case, since we're printing `foo.smee` on every step. But under different circumstances, we might not even know it had changed.

Seen as a function, the main thing that `foo` does is to return `result`. But it also has the *side effect* of altering `foo.smee`.

On the other hand, *a pure function does not have side effects*. Everything that the function does, is how its input gets turned into its return value. 

We can still do what we did with `foo.smee`, except that `smee` can no longer be hidden. It just needs to be part of the input and output of the function.

In [86]:
class PureFoo:
    a: int

    def __init__(self, a: int):
        self.a = a

    def __call__(self, x: int, smee: int):

        if x > 3:
            smee = 2

        return self.a * x, smee

a = 2
foo = PureFoo(a)
smee = 0
x = 1

print("\t\tx\tsmee")

for i in range(7):
    x, smee = foo(x, smee)

    print(f"Step {i}:\t\t{x}\t{smee}")

		x	smee
Step 0:		2	0
Step 1:		4	0
Step 2:		8	2
Step 3:		16	2
Step 4:		32	2
Step 5:		64	2
Step 6:		128	2


Maybe this doesn't seem as nice as `StatefulFoo`, but it is totally transparent. And if we keep building up our programs in this way, it forces us to start adding more structure to the inputs and outputs of our functions. 

In [89]:
@dataclass
class Data:
    x: int
    smee: int


class PureFoo:
    a: int

    def __init__(self, a: int):
        self.a = a

    def __call__(self, data: Data):

        if data.x > 3:
            smee = 2
        else:
            smee = data.smee

        return Data(2 * data.x, smee)


a = 2
foo = PureFoo(a)
data = Data(x=1, smee=0)

print("\t\tx\tsmee")

for i in range(7):
    data = foo(data)

    print(f"Step {i}:\t\t{data.x}\t{data.smee}")

TypeError: PureFoo() takes no arguments

It turns out that as our programs grow complex, this style will work at least as well as the stateful style ever did—and without hiding anything.

!!! Note ""
    In Feedbax, the relationship between a model and its state is like the relationship between `PureFoo` and `Data`, in this example. A model does not *possess* state, it *operates* on it. 
    
    Similarly, we never change a state object by directly reassigning its values. For example, in the above example we would never do this:
    
    ```python
    data = Data(x=1, smee=0)
    data.smee = 2
    ```
    
    As we'll see shortly, this won't be a problem. We'll just need to define the alteration to `data` as some function that takes `data` as its input, and constructs the altered version as its output.

### Equinox and pure functions

It might seem a little odd that we contrasted object oriented programming with purely functional programming, and then we kept defining our "pure function" as a  `class`!

It's not really odd, though. What matters is that our classes *behave* like pure function because of the way we define `__call__`. And classes do one very convenient thing for us: they let us keep fixed model parameters (like `a`) in the same place as the function that defines the model's computation.

This is essentially what `eqx.Module` is for. What's more, it forces us to code in a functional style. Watch what happens if we try to change one of the attributes of an Equinox module:

In [90]:
class Bar(eqx.Module):
    a: int

my_bar = Bar(a=3)

my_bar.a = 4

FrozenInstanceError: cannot assign to field 'a'

Things are no different if the object tries to mutate itself:

In [93]:
class Baz(eqx.Module):
    a: int

    def __call__(self, x: int):
        self.a = x


my_baz = Baz(a=3)

my_baz(4)

FrozenInstanceError: cannot assign to field 'a'

In other words, Equinox modules are *immutable*. Immutability goes hand in hand with pure functions, because it ensures that the internal state of our objects cannot be altered in the background.

### Random number generation

One way that differs from NumPy API; PyTorch

### Performing surgery

We have a model. It's an immutable PyTree. We want to change just one part of it. How can we do that?

Let's start with a [pre-built model](/feedbax/examples/0_train_simple).

In [None]:
import jax

from feedbax.xabdeef import point_mass_nn_simple_reaches


context = point_mass_nn_simple_reaches(key=jax.random.PRNGKey(0))
model = context.model  # Shorthand

CUDA backend failed to initialize: Unable to load CUDA. Is it installed? (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
- the fwd and bwd functions take an extra `perturbed` argument, which     indicates which primals actually need a gradient. You can use this     to skip computing the gradient for any unperturbed value. (You can     also safely just ignore this if you wish.)
- `None` was previously passed to indicate a symbolic zero gradient for     all objects that weren't inexact arrays, but all inexact arrays     always had an array-valued gradient. Now, `None` may also be passed     to indicate that an inexact array has a symbolic zero gradient.
  _loop_backsolve.defvjp(_loop_backsolve_fwd, _loop_backsolve_bwd)


This model has a point mass of mass $1.0$ as its skeleton.

In [None]:
model.step.mechanics.plant.skeleton

PointMass(mass=1.0)

If we try to directly alter the model to use a point mass of mass $5.0$, an error is raised.

In [None]:
from feedbax.mechanics.skeleton import PointMass

# Try to replace the entire point mass
model.step.mechanics.plant.skeleton = PointMass(5.0)

FrozenInstanceError: cannot assign to field 'skeleton'

In [None]:
# Or just try to change the mass
model.step.mechanics.plant.skeleton.mass = 5.0

FrozenInstanceError: cannot assign to field 'mass'

This kind of direct re-assignment is common in Python. It might seem inconvenient to have it outlawed! 

Well, it is still possible to alter our model. But if we want to switch out just the point mass, we have to do something slightly more complex.

In [None]:
import equinox as eqx

model_heavy = eqx.tree_at(
    lambda m: m.step.mechanics.plant.skeleton,
    model,
    PointMass(5.0)
)

To replace a part of our model tree, we use the `tree_at` function provided by Equinox. 

The use of `lambda` in the first agument to `tree_at` is similar to the `lambda` we used in Example 1. There, we defined a function `where_train` that picked out which parts of the model should be trainable. Here, our function picks out which part of our model will be replaced.

The second argument is just `model`, which is the model we want to alter. 

The third argument is the part we want to replace it with.

Why the added complexity? *It forces us never to alter our models in-place*. The function `eqx.tree_at` does not modify `model` directly, but *returns a copy* of `model` which possesses the alterations. Here, we assign the new object to `model_heavy`, since maybe we want to be able to refer to both the original model (still called `model`!) and the altered one. But we could just as easily have written

In [None]:
model = eqx.tree_at(
    lambda m: m.step.mechanics.plant.skeleton,
    model,
    PointMass(5.0)
)

This means "create an altered model, and make it so that `model` now refers to the new model---I don't need to refer to the original one by that name anymore".

There are downstream advantages to our ban. The downside to in-place changes is that it can be hard to keep track of their hidden consequences. Some parts of my code may depend on a given object, and if other parts of my code can reach into that object and make implicit alterations, the relationships within my code may be altered without consent of all the stakeholders, so to speak.

When we are forced to *return* an altered object rather than *mutate* an existing object, the consequences are always out in the open. We have to be explicit about what-refers-to-what, with respect to what-has-changed. 

- We also can't alter JAX arrays in-place. Use of `at` and `set`.