# PyTrees

One thing that makes JAX especially nice and powerful: it allows us to treat different kinds of tree-structured data in similar ways. This is central to Feedbax as well, so if you are unfamiliar with PyTrees, please read this short tutorial.

Say we have a list and a dict that contain similar values.

In [1]:
some_list = [1, 2, 3]
some_dict = {'a': 1, 'b': 2, 'c': 3}

{'a': 1, 'b': 2, 'c': 3}


We want to calculate something from these values---their squares, for example. 

A typical way of doing this in standard Python is a list (or dict) comprehension.

In [6]:
[x ** 2 for x in some_list]

[1, 4, 9]

In [7]:
{k: x ** 2 for k, x in some_dict.items()}

{'a': 1, 'b': 4, 'c': 9}

On the other hand, JAX provides a function `tree_map` that applies a function to a container. It works exactly the same way for both lists and dicts.

In [16]:
from jax.tree_util import tree_map

tree_map(lambda x: x ** 2, some_list)

[1, 4, 9]

In [17]:
tree_map(lambda x: x ** 2, some_dict)

{'a': 1, 'b': 4, 'c': 9}

Even better, `tree_map` works on nested containers. 

In [39]:
import jax.numpy as jnp

some_data = [{'p': [1, 2], 'x': 1.0}, [5, 6, 7, 8, {'y': jnp.array([2, 2, 2])}]]

tree_map(lambda x: x ** 2, some_data)

[{'p': [1, 4], 'x': 1.0},
 [25, 36, 49, 64, {'y': Array([4, 4, 4], dtype=int32)}]]

!!! Note    
    Python includes a built-in function `map` which is similar in principle to `tree_map`. For example, we can do `list(map(lambda x: x**2, some_list))` to get the same result as `tree_map(lambda x: x**2, some_list)`. However, `map` doesn't return the same data structure it is given. But to apply the same transformation to a dict, and return a dict with the same keys, we'd have to do something like `dict(zip(some_dict.keys(), map(lambda x: x**2, some_dict.values())))`. Obviously, this is ugly and harder to read.

How does this work? JAX considers both lists and dicts to be *PyTrees*. A PyTree is just a type of containers that can be nested, so that it forms arbitrarily complex tree structures. 

A PyTree's *leaves* are the data it ultimately contains.

In [40]:
from jax.tree_util import tree_leaves, tree_structure

tree_leaves(some_data)

[1, 2, 1.0, 5, 6, 7, 8, Array([2, 2, 2], dtype=int32)]

Those leaves are arranged in a tree with a certain structure. 

In [24]:
tree_structure(some_data)

PyTreeDef([{'p': [*, *], 'x': *}, [*, *, *, *, {'y': *}]])

What does JAX treat as a leaf, and what does it treat as the structure in which the leaves are contained? 

By default,

- *leaves* include arrays (NumPy or JAX) and basic data types (like `int`, `float`, `str`, and `bool`)
- *nodes* are things like lists, dicts, and tuples, which are PyTrees themselves. When JAX encounters these, its default stance is that "nesting continues here"---so it looks inside the node for leaves (or even deeper layers of nodes).

Importantly, we can override this behaviour. Most functions that operate on PyTrees can take an argument `is_leaf`.

In [41]:
tree_leaves(some_data, is_leaf=lambda x: isinstance(x, dict))

[{'p': [1, 2], 'x': 1.0}, 5, 6, 7, 8, {'y': Array([2, 2, 2], dtype=int32)}]

In [42]:
tree_structure(some_data, is_leaf=lambda x: isinstance(x, dict))

PyTreeDef([*, [*, *, *, *, *]])

Here, we've told JAX to treat dicts as leaves, rather than as containers. Now, the dicts in `some_data` appear whole and unflattened in the list of its leaves, and the PyTree structure reflects this. 

Because JAX understands the structure of PyTrees, we can apply operations to PyTrees whose structures match.

In [36]:
some_arrays = [
    (jnp.array([1, 2]), jnp.array([3, 4])),
    jnp.array([5, 6])
]

some_other_arrays = [
    (jnp.array([7, 8]), jnp.array([3, 4])),
    jnp.array([1, 1])
]

tree_structure(some_arrays) == tree_structure(some_other_arrays)

True

In [37]:
tree_map(
    lambda x, y: x + y, 
    some_arrays, 
    some_other_arrays
)

[(Array([ 8, 10], dtype=int32), Array([6, 8], dtype=int32)),
 Array([6, 7], dtype=int32)]

We've relied on the fact that `tree_map` can work "leafwise" to pick out the arguments to a function: the `x` values are the leaves from `some_arrays`, and the `y` values are the matching leaves from `some_other_arrays`.

However, if we tell JAX to treat tuples as leaves, then the result is different: the first `x`, `y` passed to the function will be a pair of tuples, and we'll end up concatenating them instead of adding the arrays inside them. 

In [38]:
tree_map(
    lambda x, y: x + y, 
    some_arrays, 
    some_other_arrays,
    is_leaf=lambda x: isinstance(x, tuple)
)

[(Array([1, 2], dtype=int32),
  Array([3, 4], dtype=int32),
  Array([7, 8], dtype=int32),
  Array([3, 4], dtype=int32)),
 Array([6, 7], dtype=int32)]

The array that's not inside a tuple gets added the same way it did before, because it still counts as a leaf---it's just that now, tuples *also* count as leaves.

## The wider world of PyTrees

If lists, dicts, and tuples aren't enough for us, we can define our own types of containers, and tell JAX how to treat them like PyTrees, and it will do so! This takes only a little work

Equinox.

## Feedbax and PyTrees

Most objects in Feedbax are derived from `eqx.Module`. That means they are automatically treated as PyTrees. 

This is why we can print the structure of a model so easily.

## But wait, there's more

The power of PyTrees goes much deeper than we've seen here. The core JAX transformations, jax.vmap and jax.grad