# 05 Pytrees

Original Documentation: https://docs.jax.dev/en/latest/working-with-pytrees.html


In [22]:
import jax
import jax.numpy as jnp

## What is a pytree

Container-like structure made of container-like Python objects. A leaf is anything not a pytree (e.g., arrays).

In ML, a pytree can represent model weights, dateset entries, or RL agent observations.


In [23]:
example_trees = [
    [1, "a", object()],
    (1, (2, 3), ()),
    [1, {"k1": 2, "k2": (3, 4)}, 5],
    {"a": 2, "b": (2, 3)},
    jnp.array([1, 2, 3]),
]

for pytree in example_trees:
    leaves = jax.tree.leaves(pytree)
    print(f"{repr(pytree):<45} has {len(leaves)} leaves: {leaves}")

[1, 'a', <object object at 0x10d4ab950>]      has 3 leaves: [1, 'a', <object object at 0x10d4ab950>]
(1, (2, 3), ())                               has 3 leaves: [1, 2, 3]
[1, {'k1': 2, 'k2': (3, 4)}, 5]               has 5 leaves: [1, 2, 3, 4, 5]
{'a': 2, 'b': (2, 3)}                         has 3 leaves: [2, 2, 3]
Array([1, 2, 3], dtype=int32)                 has 1 leaves: [Array([1, 2, 3], dtype=int32)]


Effectively, the “leaves” are simply child nodes in the tree.

Any tree-like structure built out of container-like Python objects can be treated as a pytree.

Container-like classes are lists, tuples, and dicts. For example, in the example, the `jnp.array` was treated as a single leaf.

## Common pytree functions

### Pytree map

Can use `jax.tree.map()` to perform a transform-style operation similar to Python’s `map()` over entire pytrees:


In [24]:
list_of_lists = [
    [1, 2, 3],
    [1, 2],
    [1, 2, 3, 4],
]

print(jax.tree.map(lambda x: x * 2, list_of_lists))

[[2, 4, 6], [2, 4], [2, 4, 6, 8]]


You can also do `zip()`-style operations by mapping an N-ary function over multiple pytrees:


In [25]:
lol = [[1, 2, 3], [1, 2], [1, 2, 3, 4]]
dlol = jax.tree.map(lambda x: x * 2, lol)

print(jax.tree.map(lambda x, y: x + y, lol, dlol))

[[3, 6, 9], [3, 6], [3, 6, 9, 12]]


### Example with model parameters

Consider this example code to randomly initialize the model weights at each MLP layer:


In [26]:
def init_mlp_params(layer_widths, key):
    params = []

    # Each layer has 2 random draws (weights + biases)
    num_keys_needed = (len(layer_widths) - 1) * 2
    keys = jax.random.split(key, num_keys_needed)
    key_iter = iter(keys)

    for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]):
        w_key = next(key_iter)
        b_key = next(key_iter)

        # Generate (n_in, n_out) weights for each layer. Multiplying an input vector of
        # shape (n_in,) by this matrix produces an output of shape (n_out,). We scale
        # the weights using He initialization to help keep activation variance stable
        # across layers.
        weights = jax.random.normal(w_key, shape=(n_in, n_out)) * jnp.sqrt(2 / n_in)

        # Biases have shape (n_out,) so they can be broadcast across the outputs
        # from the weight multiplication. The shape is (batch_size, n_out).
        biases = jax.random.normal(b_key, shape=(n_out,))

        params.append({"weights": weights, "biases": biases})

    return params


key = jax.random.key(1701)
params = init_mlp_params([1, 128, 128, 1], key)

We can use `jax.tree.map()` to check the shapes of the initial parameters:


In [27]:
from pprint import pprint

pprint(jax.tree.map(lambda x: x.shape, params))

[{'biases': (128,), 'weights': (1, 128)},
 {'biases': (128,), 'weights': (128, 128)},
 {'biases': (1,), 'weights': (128, 1)}]


Next, we can use `jax.tree.map()` to write the forward pass + training loop:


In [28]:
def forward(params, x):
    # The last layer is the output layer. The other layers before it are considered
    # "hidden" layers.
    *hidden, last = params

    # Perform a forward pass of x @ weights + biases for each layer.
    for layer in hidden:
        # ReLU(z) = max(0, z)
        # Each layer is performing a linear transformation. Stacking all of these
        # linear layers is equivalent to a single large linear transformation.
        # A ReLU activation layer adds some nonlinearity to the model, allowing
        # modeling of curved boundaries and not allowing collapsing all layers
        # into one.

        x = jax.nn.relu(x @ layer["weights"] + layer["biases"])

    return x @ last["weights"] + last["biases"]


def loss_fn(params, x, y):
    # MSE error loss function
    return jnp.mean((forward(params, x) - y) ** 2)


LEARNING_RATE = 0.0001


@jax.jit
def update(params, x, y):
    grads = jax.grad(loss_fn)(params, x, y)

    # Note: grads is a pytree with the same structure as params.
    # We can apply the SGD optimizer update using JAX pytree utilities.
    params = jax.tree.map(
        lambda layer, grad: layer - grad * LEARNING_RATE, params, grads
    )

## Custom pytree nodes

We can extend the set of Python types that will be considered internal nodes in pytrees by using `jax.tree_util.register_pytree_node()`.

This is useful because, unless a custom type is registered as an internal node, JAX will consider it as a leaf (even if the container actually has leaves inside it):


In [29]:
class Special:
    def __init__(self, x, y):
        self.x = x
        self.y = y


print(jax.tree.leaves([Special(0, 1), Special(2, 4)]))

[<__main__.Special object at 0x10fee84a0>, <__main__.Special object at 0x10fee8bc0>]


This fails because we have not registered `System` as a pytree node.

We can register `Special` as a pytree node by providing flattening and unflattening functions:


In [30]:
class Special:
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __repr__(self) -> str:
        return f"Special(x={self.x}, y={self.y})"


def special_flatten(v):
    # The flatten function must return the children to flatten recursively and
    # opaque auxiliary to pass back during unflattening.
    return (v.x, v.y), None


def special_unflatten(aux_data, children):
    # The unflatten function must return a reconstructed object of the registered
    # type using the children and auxiliary data.
    return Special(*children)


jax.tree_util.register_pytree_node(
    Special,
    special_flatten,  # Function that tells JAX where the children are
    special_unflatten,  # Function that tells JAX how to recover the parent
)

Then, we can perform pytree operations on pytrees that contain `Special` nodes:


In [31]:
print(jax.tree.map(lambda x: x + 1, [Special(1, 2), Special(3, 4)]))

[Special(x=2, y=3), Special(x=4, y=5)]


JAX will recurse down the pytree and apply the transformation to the primitive leaves.

Subclassing a `NamedTuple` does not need to be explicitly registered to be considered a pytree node as JAX will handle generating its flatten/unflatten functions:


In [32]:
from typing import NamedTuple


class Special(NamedTuple):
    x: int
    y: int


print(jax.tree.map(lambda x: x + 1, [Special(1, 2), Special(3, 4)]))

[Special(x=2, y=3), Special(x=4, y=5)]


## Pytrees and JAX transformations

Many JAX functions (like `jax.lax.scan()`) operate on pytrees of array. In fact, all JAX transformations can be applied to functions that accept as input and produce as output pytrees of arrays.

Some JAX transformation take optional parameters that specify how certain input/output parameters should be treated. These parameters can also be pytrees, and their structure must correspond to the pytree structure of the corresponding arguments.

For example, we may have some arguments like this:

```py
args = (a1, {"k1": a2, "k2": a3})  # 3 leaf nodes (a1, a2, a3 - each are arrays)

# If we would like to apply vmap to a function that takes in a pytree
# of this shape, we must structure our in_axes the same way.
#
# In this example, we do not vmap over a1 or a2, but vmap over the columns of a3.
in_axes = (None, {"k1": None, "k2": 0})

jax.vmap(f, in_axes=in_axes)(args)
```


However, if we wanted to simply map over axis 0 for all leaf nodes, we can simply leave it as:

```py
jax.vmap(f, in_axes=0)(args)
```


## Explicit key paths

Each pytree leaf has a key path. A key path is a list of keys, where the length of the list is the depth of that leaf.

A key represents an index into the corresponding pytree node type. The type of the key depends on the pytree node type (e.g., key for a dict is a different type than the key for a list).

We can view the path of each leaf with `jax.tree_util.tree_flatten_with_path()`:


In [33]:
class Special(NamedTuple):
    name: str


tree = [1, {"k1": 2, "k2": (3, 4)}, Special("foo")]
flattened, _ = jax.tree_util.tree_flatten_with_path(tree)

for key_path, value in flattened:
    print(f"Value of tree{jax.tree_util.keystr(key_path)}: {value}")

Value of tree[0]: 1
Value of tree[1]['k1']: 2
Value of tree[1]['k2'][0]: 3
Value of tree[1]['k2'][1]: 4
Value of tree[2].name: foo


## Common pytree gotchas

### Mistaking pytree nodes as leaves

A common mistake is accidentally introducing tree nodes instead of leaves:


In [34]:
tree = [jnp.zeros((2, 3)), jnp.zeros((3, 4))]

# Try to make another pytree with ones instead of zeros
shapes = jax.tree.map(lambda x: x.shape, tree)
print(jax.tree.map(jnp.ones, shapes))

[(Array([1., 1.], dtype=float32), Array([1., 1., 1.], dtype=float32)), (Array([1., 1., 1.], dtype=float32), Array([1., 1., 1., 1.], dtype=float32))]


Here, since each tuple is not a leaf (the leaf nodes are 2, 3, 3, 4), the new tree will instead have 4 1D arrays.

The correct way to handle this is to convert the tuples to `jnp.array()` or a custom type, which is considered as a leaf:


In [35]:
shapes = jax.tree.map(lambda x: jnp.array(x.shape), tree)
tree_ones = jax.tree.map(jnp.ones, shapes)

assert len(tree_ones) == 2

### Handling of None

All `jax.tree_util` functions treat `None` as the absence of a node, not as a leaf:


In [36]:
print(jax.tree.leaves([None, None, None]))

[]


To treat `None` as a leaf, we can use the `is_leaf` argument:


In [37]:
print(jax.tree.leaves([None, None, None], is_leaf=lambda x: x is None))

[None, None, None]


## Transposing pytrees

To transpose a pytree from a list of trees into a tree of lists, you can use `jax.tree.map()` or `jax.tree.transpose()`:


In [38]:
def transpose(row_major_dataset):
    return jax.tree.map(lambda *xs: list(xs), *row_major_dataset)


dataset = [
    {"price": 100, "owner": "Ayush", "creation_year": 2020},
    {"price": 150, "owner": "Ishaan", "creation_year": 2025},
]
print(transpose(dataset))

{'creation_year': [2020, 2025], 'owner': ['Ayush', 'Ishaan'], 'price': [100, 150]}
