## What is a pytree?
A pytree is a container-like structure built out of container-like Python objects — “leaf” pytrees and/or more pytrees. A pytree can include lists, tuples, and dicts. A leaf is anything that’s not a pytree, such as an array, but a single leaf is also a pytree.

In the context of machine learning (ML), a pytree can contain:

- Model parameters

- Dataset entries

- Reinforcement learning agent observations

When working with datasets, you can often come across pytrees (such as lists of lists of dicts).

Below is an example of a simple pytree. In JAX, you can use jax.tree.leaves(), to extract the flattened leaves from the trees, as demonstrated here:

In [1]:
import jax
import jax.numpy as jnp

example_trees = [
    [1, "a", object()],
    (1, (2, 3), ()),
    [1, {"k1": 2, "k2": (3, 4)}, 5],
    {"a": 2, "b": (2, 3)},
    jnp.array([1, 2, 3]),
]

# Print how many leaves the pytrees have.
for pytree in example_trees:
    # This `jax.tree.leaves()` method extracts the flattened leaves from the pytrees.
    leaves = jax.tree.leaves(pytree)
    print(f"{repr(pytree):<45} has {len(leaves)} leaves: {leaves}")

[1, 'a', <object object at 0x7f8250bc3ef0>]   has 3 leaves: [1, 'a', <object object at 0x7f8250bc3ef0>]
(1, (2, 3), ())                               has 3 leaves: [1, 2, 3]
[1, {'k1': 2, 'k2': (3, 4)}, 5]               has 5 leaves: [1, 2, 3, 4, 5]
{'a': 2, 'b': (2, 3)}                         has 3 leaves: [2, 2, 3]
Array([1, 2, 3], dtype=int32)                 has 1 leaves: [Array([1, 2, 3], dtype=int32)]


In [2]:
list_of_lists = [[1, 2, 3], [1, 2], [1, 2, 3, 4]]

jax.tree.map(lambda x: x * 2, list_of_lists)

[[2, 4, 6], [2, 4], [2, 4, 6, 8]]

When using multiple arguments with `jax.tree.map()`, the structure of the inputs must exactly match. That is, lists must have the same number of elements, dicts must have the same keys, etc.

In [3]:
another_list_of_lists = list_of_lists
jax.tree.map(lambda x, y: x + y, list_of_lists, another_list_of_lists)

[[2, 4, 6], [2, 4], [2, 4, 6, 8]]

In [8]:
import numpy as np


def init_mlp_params(layer_width):
    params = []
    for n_in, n_out in zip(layer_width[:-1], layer_width[1:]):
        params.append(
            dict(
                weights=np.random.normal(size=(n_in, n_out)) * np.sqrt(2 / n_in),
                biases=np.ones(shape=(n_out,)),
            )
        )
    return params


params = init_mlp_params([1, 128, 128, 1])

In [9]:
jax.tree.map(lambda x: x.shape, params)

[{'biases': (128,), 'weights': (1, 128)},
 {'biases': (128,), 'weights': (128, 128)},
 {'biases': (1,), 'weights': (128, 1)}]

In [None]:
# Define the forward pass
def forward(params, x):
    *hidden, last = params
    for layer in hidden:
        x = jax.nn.relu(x @ layer["weights"] + layer["biases"])
    return x @ last["weights"] + last["biases"]


# Define the loss function
def loss_fn(params, x, y):
    return jnp.mean((forward(params, x) - y) ** 2)

# Set the learning rate
LEARNING_RATE = 1e-4

# Using the stochastic gradient descent, define the parameter update function.
# Apply `@jax.jit` for JIT compilation (speed).
@jax.jit
def update(params, x, y):
    # Calculate the gradients with `jax.grad`
    grads = jax.grad(loss_fn)(params, x, y)
    # Note that `grads` is a pytree with the same structure as `params`.
    # `jax.grad` is one of many JAX functions that has
    # built-in support for pytrees.
    # This is useful - you can apply the SGD update using JAX pytree utilities.
    return jax.tree.map(lambda p, g: p - LEARNING_RATE * g, params, grads)

(5,)

## Custom pytree nodes
Why would you need this? In the previous examples, pytrees were shown as lists, tuples, and dicts, with everything else as pytree leaves. This is because if you define your own container class, it will be considered to be a pytree leaf unless you register it with JAX. This is also the case even if your container class has trees inside it. For example:

In [10]:
class Special(object):
    def __init__(self, x, y):
        self.x = x
        self.y = y


jax.tree.leaves([Special(0, 1), Special(2, 4)])

[<__main__.Special at 0x7f81a0117bb0>, <__main__.Special at 0x7f81a0115990>]

In [11]:
jax.tree.map(lambda x: x + 1, [Special(0, 1), Special(2, 4)])

TypeError: unsupported operand type(s) for +: 'Special' and 'int'

In [12]:
from jax.tree_util import register_pytree_node


class RegisteredSpecial(Special):
    def __repr__(self):
        return f"RegisteredSpecial(x={self.x}, y={self.y})"


def special_flatten(v):
    """Specifies a flattening recipe.

    Params:
      v: The value of the registered type to flatten.
    Returns:
      A pair of an iterable with the children to be flattened recursively,
      and some opaque auxiliary data to pass back to the unflattening recipe.
      The auxiliary data is stored in the treedef for use during unflattening.
      The auxiliary data could be used, for example, for dictionary keys.
    """
    children = (v.x, v.y)
    aux_data = None
    return (children, aux_data)


def special_unflatten(aux_data, children):
    """Specifies an unflattening recipe.

    Params:
      aux_data: The opaque data that was specified during flattening of the
        current tree definition.
      children: The unflattened children

    Returns:
      A reconstructed object of the registered type, using the specified
      children and auxiliary data.
    """
    return RegisteredSpecial(*children)


# Global registration
register_pytree_node(
    RegisteredSpecial,
    special_flatten,  # Instruct JAX what are the children nodes.
    special_unflatten,  # Instruct JAX how to pack back into a `RegisteredSpecial`.
)

In [None]:
jax.tree.map(
    lambda x: x + 1,
    [
        RegisteredSpecial(0, 1),
        RegisteredSpecial(2, 4),
    ],
)