# Why JAX?

NumPy.

Functional. 
Stateless. State is something *returned* by a model, not something a model "possesses".

Random number generation.

## Pytrees

One thing that makes JAX nice and powerful is the consistent way that its tools interact with data that is organized in arbitrarily nested Python containers.

Feedbax relies on these capabilities, so if you are unfamiliar with *pytrees*, please read this short tutorial and check out the JAX [documentation](https://jax.readthedocs.io/en/latest/pytrees.html) as well.

Say we have a list and a dict that contain similar values.

In [1]:
some_list = [1, 2, 3]

some_dict = {'a': 1, 'b': 2, 'c': 3}

{'a': 1, 'b': 2, 'c': 3}


In standard Python, a *comprehension* is a typical way of applying some computation to every value in a list or dict.

In [6]:
[x ** 2 for x in some_list]

[1, 4, 9]

In [7]:
{k: x ** 2 for k, x in some_dict.items()}

{'a': 1, 'b': 4, 'c': 9}

While these solutions are similar, they are not interchangeable. If our data is a list we can use the first method and get a list in return. But as soon as we start using some data that's stored in a dict, we need to change our code.

Conveniently, JAX provides a function `tree_map` that behaves exactly the same way for both lists and dicts.

In [16]:
from jax.tree_util import tree_map

tree_map(lambda x: x ** 2, some_list)

[1, 4, 9]

In [17]:
tree_map(lambda x: x ** 2, some_dict)

{'a': 1, 'b': 4, 'c': 9}

!!! Note    
    Python includes a built-in function `map` which is similar in principle to `tree_map`. For example, we can do `list(map(lambda x: x**2, some_list))` to get the same result as `tree_map(lambda x: x**2, some_list)`. However, `map` doesn't return the same data structure it is given: to square all the values in a dict, and return a dict like we did with `tree_map`, we'd have to do something like `dict(zip(some_dict.keys(), map(lambda x: x**2, some_dict.values())))`. Obviously, this is ugly and harder to read.

Even better, `tree_map` works on nested containers. 

In [39]:
import jax.numpy as jnp

some_data = [{'p': [1, 2], 'x': 1.0}, 
             [5, 6, 7, 8, {'y': jnp.array([2, 2, 2])}]]

tree_map(lambda x: x ** 2, some_data)

[{'p': [1, 4], 'x': 1.0},
 [25, 36, 49, 64, {'y': Array([4, 4, 4], dtype=int32)}]]

How does this work? JAX considers both lists and dicts—and any nested structures of lists and dicts—to be *pytrees*.

A pytree's *leaves* are the data it ultimately contains.

In [40]:
from jax.tree_util import tree_leaves, tree_structure

tree_leaves(some_data)

[1, 2, 1.0, 5, 6, 7, 8, Array([2, 2, 2], dtype=int32)]

Those leaves are arranged in a tree with a certain structure. 

In [24]:
tree_structure(some_data)

PyTreeDef([{'p': [*, *], 'x': *}, [*, *, *, *, {'y': *}]])

What does JAX treat as a leaf, and what does it treat as the structure in which all the leaves are contained? 

By default:

- *leaves* include arrays (NumPy or JAX) and basic data types (like `int`, `float`, `str`, and `bool`)
- *nodes* are lists, dicts, and tuples, which are pytrees themselves. When JAX encounters these, its default stance is that "nesting continues here"—so it looks inside the node for leaves, or even deeper layers of nodes.

Importantly, we can change what counts as a leaf. Many functions that operate on PyTrees can take an argument `is_leaf`.

In [41]:
tree_leaves(some_data, is_leaf=lambda x: isinstance(x, dict))

[{'p': [1, 2], 'x': 1.0}, 5, 6, 7, 8, {'y': Array([2, 2, 2], dtype=int32)}]

In [42]:
tree_structure(some_data, is_leaf=lambda x: isinstance(x, dict))

PyTreeDef([*, [*, *, *, *, *]])

Here, we've told JAX to treat dicts as leaves, rather than as containers. Now, the dicts in `some_data` appear whole and unflattened in the list of its leaves, and the pytree structure reflects this. 

Because JAX understands the structure of pytrees, we can apply operations to pytree whose structures match.

In [36]:
some_arrays = [
    (jnp.array([1, 2]), jnp.array([3, 4])),
    jnp.array([5, 6])
]

some_other_arrays = [
    (jnp.array([7, 8]), jnp.array([3, 4])),
    jnp.array([1, 1])
]

tree_structure(some_arrays) == tree_structure(some_other_arrays)

True

In [37]:
tree_map(
    lambda x, y: x + y, 
    some_arrays, 
    some_other_arrays
)

[(Array([ 8, 10], dtype=int32), Array([6, 8], dtype=int32)),
 Array([6, 7], dtype=int32)]

We've relied on the fact that `tree_map` can work "leafwise" to pick out the arguments to a function: the `x` values are the leaves from `some_arrays`, and the `y` values are the matching leaves from `some_other_arrays`.

However, if we tell JAX to treat tuples as leaves, then the result is different: the first `x`, `y` passed to the function will be a pair of tuples, and we'll end up concatenating them instead of adding the arrays inside them. 

In [38]:
tree_map(
    lambda x, y: x + y, 
    some_arrays, 
    some_other_arrays,
    is_leaf=lambda x: isinstance(x, tuple)
)

[(Array([1, 2], dtype=int32),
  Array([3, 4], dtype=int32),
  Array([7, 8], dtype=int32),
  Array([3, 4], dtype=int32)),
 Array([6, 7], dtype=int32)]

The array that's not inside a tuple gets added the same way it did before, because it still counts as a leaf—it's just that now, tuples *also* count as leaves.

### A pytree of your own

If lists, dicts, and tuples aren't enough for us, we can define our own types of containers, tell JAX how to treat them like pytrees, and it will do so! This isn't very hard to do.

Equinox.

### Feedbax and pytrees

Most objects in Feedbax are derived from `eqx.Module`. That means they are automatically treated as pytrees. 

This is why we can print the structure of a model so easily.

## Vectorisation and `vmap`

The power of pytrees goes much deeper than we've seen here. The core JAX transformations, jax.vmap and jax.grad

!!! NOTE    
    If you run into problems with `jax.vmap`, try using Equinox's `filter_vmap` as we've done above. It does the same thing, but a little more intelligently.

## Equinox

Allows us to keep model parameters together with model code, similar to PyTorch `nn.Module`.

However, immutable.

!!! Note
    Equinox's [`Module`](https://docs.kidger.site/equinox/api/module/module/#equinox.Module) is kind of like PyTorch's [`nn.Module`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html). However, in PyTorch we define the model's computation in the `forward` method, rather than `__call__`. 
    
    Note that PyTorch must still define `__call__` in the background to have its module objects behave like functions. 

### Performing surgery

We have a model, and want to change just one part of it. How can we do that?

In Feedbax, and across the JAX ecosystem, objects are often *immutable* once they are constructed. That means we can't simply alter them once they've been created. This won't slow us down in the end, and it has advantages. It does require a slightly different programming style than you might be used to, though.

!!! Note     
    Feedbax objects are immutable because they are Equinox modules. It's Equinox that enforces the rules we'll discuss in this example.

Let's start with a pre-built model.

In [None]:
import jax

from feedbax.xabdeef import point_mass_nn_simple_reaches


context = point_mass_nn_simple_reaches(key=jax.random.PRNGKey(0))
model = context.model  # Shorthand

CUDA backend failed to initialize: Unable to load CUDA. Is it installed? (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
- the fwd and bwd functions take an extra `perturbed` argument, which     indicates which primals actually need a gradient. You can use this     to skip computing the gradient for any unperturbed value. (You can     also safely just ignore this if you wish.)
- `None` was previously passed to indicate a symbolic zero gradient for     all objects that weren't inexact arrays, but all inexact arrays     always had an array-valued gradient. Now, `None` may also be passed     to indicate that an inexact array has a symbolic zero gradient.
  _loop_backsolve.defvjp(_loop_backsolve_fwd, _loop_backsolve_bwd)


This model has a point mass of mass $1.0$ as its skeleton.

In [None]:
model.step.mechanics.plant.skeleton

PointMass(mass=1.0)

If we try to directly alter the model to use a point mass of mass $5.0$, an error is raised.

In [None]:
from feedbax.mechanics.skeleton import PointMass 

# Try to replace the entire point mass
model.step.mechanics.plant.skeleton = PointMass(5.0)

FrozenInstanceError: cannot assign to field 'skeleton'

In [None]:
# Or just try to change the mass
model.step.mechanics.plant.skeleton.mass = 5.0

FrozenInstanceError: cannot assign to field 'mass'

This kind of direct re-assignment is common in Python. It might seem inconvenient to have it outlawed! 

Well, it is still possible to alter our model. But if we want to switch out just the point mass, we have to do something slightly more complex.

In [None]:
import equinox as eqx

model_heavy = eqx.tree_at(
    lambda m: m.step.mechanics.plant.skeleton, 
    model, 
    PointMass(5.0)
)

To replace a part of our model tree, we use the `tree_at` function provided by Equinox. 

The use of `lambda` in the first agument to `tree_at` is similar to the `lambda` we used in Example 1. There, we defined a function `where_train` that picked out which parts of the model should be trainable. Here, our function picks out which part of our model will be replaced.

The second argument is just `model`, which is the model we want to alter. 

The third argument is the part we want to replace it with.

Why the added complexity? *It forces us never to alter our models in-place*. The function `eqx.tree_at` does not modify `model` directly, but *returns a copy* of `model` which possesses the alterations. Here, we assign the new object to `model_heavy`, since maybe we want to be able to refer to both the original model (still called `model`!) and the altered one. But we could just as easily have written

In [None]:
model = eqx.tree_at(
    lambda m: m.step.mechanics.plant.skeleton, 
    model, 
    PointMass(5.0)
)

This means "create an altered model, and make it so that `model` now refers to the new model---I don't need to refer to the original one by that name anymore".

There are downstream advantages to our ban. The downside to in-place changes is that it can be hard to keep track of their hidden consequences. Some parts of my code may depend on a given object, and if other parts of my code can reach into that object and make implicit alterations, the relationships within my code may be altered without consent of all the stakeholders, so to speak.

When we are forced to *return* an altered object rather than *mutate* an existing object, the consequences are always out in the open. We have to be explicit about what-refers-to-what, with respect to what-has-changed. 

- We also can't alter JAX arrays in-place. Use of `at` and `set`.