__Pytree__ is a flexible data structure in JAX that represents nested collections like list, tuple, dictionary or custom objects.

It can hold other pytree or leaves (non-container objects like array or scalars)

``tree.leaves`` from ``jax.tree_util`` is used to extract the flattened leave from the trees in JAX

In [1]:
import jax
import jax.numpy as jnp

In [3]:
example_trees = [
    [1, 'a', object()],
    (1, (2,3),()),
    [1, {'k1':2, 'k2':(3, 4)},5],
    {'a':2, 'b':(2,3)},
    jnp.array([1,2,3]),
]

#printing how many leaves the pytree have
for pytree in example_trees:
    leaves = jax.tree_util.tree_leaves(pytree)
    print(f"{repr(pytree):<45} has {len(leaves)} leaves: {leaves}")

[1, 'a', <object object at 0x00000199837C3040>] has 3 leaves: [1, 'a', <object object at 0x00000199837C3040>]
(1, (2, 3), ())                               has 3 leaves: [1, 2, 3]
[1, {'k1': 2, 'k2': (3, 4)}, 5]               has 5 leaves: [1, 2, 3, 4, 5]
{'a': 2, 'b': (2, 3)}                         has 3 leaves: [2, 2, 3]
Array([1, 2, 3], dtype=int32)                 has 1 leaves: [Array([1, 2, 3], dtype=int32)]


Any object whose type is not in the pytree container registry will be treated as a leaf node in the tree

``jax.tree.map`` works analogously to python ``map`` byt transparently operates over entire pytree

In [4]:
list_of_lists = [
    [1,2,3],
    [1,2],
    [1,2,3,4],
]

jax.tree.map(lambda x:x*2, list_of_lists)

[[2, 4, 6], [2, 4], [2, 4, 6, 8]]

``jax.tree.map()`` also allows mapping __N-ary__ function over multiple arguments.

Structure of the inputs must match. e.g list mst have same number of elements. Dictionary must have the same keys, etc.

In [None]:
another_list = list_of_lists
jax.tree.map(lambda x, y: x+y, list_of_lists, another_list)

[[2, 4, 6], [2, 4], [2, 4, 6, 8]]

##### Example of __jax.tree.map__ with basic ML model parameters of MLP

In [8]:
import numpy as np 
def init_mlp_params(layer_widths):
    params = []
    for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]):
        params.append(
            dict(weights=np.random.normal(size=(n_in, n_out)) * np.sqrt(2/n_in),
                 bias=np.ones(shape=(n_out,)))
        )
    return params
params = init_mlp_params([1, 128, 128, 1])

In [10]:
#using jax.tree.map to check the shape of the initial parameters
jax.tree.map(lambda x: x.shape, params)

[{'bias': (128,), 'weights': (1, 128)},
 {'bias': (128,), 'weights': (128, 128)},
 {'bias': (1,), 'weights': (128, 1)}]

In [11]:
#NEXT, define the function for training the MLP Model

#defining the forward pass

def forward(params, x):

    *hidden, last = params
    for layer in hidden:
        x = jax.nn.relu(x @ layer['weights'] + layer['bias'])
    return x @ last['weights'] + last['bias']

#define the loss function
def loss_fn(params, x, y):
    return jnp.mean((forward(params, x) - y) ** 2)

#set the learning rate
lr = 0.0001

#using the stochastic gradient descent, define the parameter
#udpate function. Apply @jax.jit for JIT Compilation

@jax.jit
def update(params, x, y):
    #calculate the gradient with jax.grad
    grads = jax.grad(loss_fn)(params, x, y)
    #Note that grads is a pytree with the same structure as params

    return jax.tree_map(lambda p, g: p - lr * g, params, grads)


#### Custom Pytree NODES

In [12]:
class Special(object):
    def __init__(self, x, y):
        self.x = x
        self.y = y 

jax.tree.leaves([
    Special(0,1),
    Special(2,4),
])

[<__main__.Special at 0x19984f6a310>, <__main__.Special at 0x19984f5f8b0>]

In [14]:
#Using jax.tree.map expecting the leaves to be elements inside tree will be INCORRECT
# and throw ERROR 

# jax.tree.map(lambda x: x+1, [
#     Special(0,1),
#     Special(2,4),
# ])

With ``jax.tree_util.register_pytree_node()`` you can extend set of types to be considered __internal nodes__ through a global registry of types

In [17]:
from jax.tree_util import register_pytree_node

class RegisteredSpecial(Special):
  def __repr__(self):
    return "RegisteredSpecial(x={}, y={})".format(self.x, self.y)

def special_flatten(v):
  """Specifies a flattening recipe.

  Params:
    v: The value of the registered type to flatten.
  Returns:
    A pair of an iterable with the children to be flattened recursively,
    and some opaque auxiliary data to pass back to the unflattening recipe.
    The auxiliary data is stored in the treedef for use during unflattening.
    The auxiliary data could be used, for example, for dictionary keys.
  """
  children = (v.x, v.y)
  aux_data = None
  return (children, aux_data)

def special_unflatten(aux_data, children):
  """Specifies an unflattening recipe.

  Params:
    aux_data: The opaque data that was specified during flattening of the
      current tree definition.
    children: The unflattened children

  Returns:
    A reconstructed object of the registered type, using the specified
    children and auxiliary data.
  """
  return RegisteredSpecial(*children)

# Global registration
register_pytree_node(
    RegisteredSpecial,
    special_flatten,    # Instruct JAX what are the children nodes.
    special_unflatten   # Instruct JAX how to pack back into a `RegisteredSpecial`.
)

In [18]:
# Now we can use `jax.tree_map` on `RegisteredSpecial` instances.
jax.tree.map(lambda x: x+1, [
    RegisteredSpecial(0,1),
    RegisteredSpecial(2,4),
])

[RegisteredSpecial(x=1, y=2), RegisteredSpecial(x=3, y=5)]

``NamedTuple`` python subclass doesn't need to be registered to be considered pytree node

In [19]:
from typing import NamedTuple, Any 

class MyOtherContainer(NamedTuple):
    name: str
    a: Any
    b: Any
    c: Any
jax.tree.leaves([
    MyOtherContainer("foo", 1, 2, 3),
    MyOtherContainer("bar", 4, 5, 6),
])

['foo', 1, 2, 3, 'bar', 4, 5, 6]

class decorated with ``@dataclass`` are not automatically pytrees. However, they can be registered as pytrees using the 
``jax.tree_util.register_dataclass`` decorator

In [20]:
from dataclasses import dataclass
import functools 

@functools.partial(jax.tree_util.register_dataclass,
                   data_fields=['a', 'b', 'c'],
                   meta_fields=['name'])
@dataclass
class MyDataclassContainer(object):
    name: str
    a: Any
    b: Any
    c: Any

jax.tree.leaves([
    MyDataclassContainer("apple", 5.2, 2.4, jnp.zeros([4])),
    MyDataclassContainer("banana", 3.1, 1.2, jnp.ones([4])),
])

[5.2,
 2.4,
 Array([0., 0., 0., 0.], dtype=float32),
 3.1,
 1.2,
 Array([1., 1., 1., 1.], dtype=float32)]

#### Pytrees and JAX Transformation

Many JAX functions like ``jax.lax.scan()`` operates over pytree of arrays. JAX functions transformations can be applied to function that accept and outputs pytree of arrays. 

If optional paramters are required, these can also be pytrees, and their structure must correspond to the pytree structure of the corresponding arguments

In [25]:
from jax import vmap
#vmap(f, in_axes = (a1, {"k1":a2, "k2":a3}))

#Then

#vmap(f, in_axes=(None,0)) equivalent to (None, {"k1":0, "k2":0})
#vmap(f, in_axes= 0) #equivalent to (0, {"k1":0, "k2":0})

#### Explicit Key Paths

In a pytree, each leaf has a _key path_. A key path for a leaf is a list of keys, and the type of the key depends on the pytree node type; e.g, key of ``dict`` is different from key for ``tuple``

JAX has the following ``jax.tree.util.*`` method for working with key paths:

- ``jax.tree.util.tree_flatten_with_path()``. Works similarly to ``jax.tree.flatten()`` but returns key paths
- ``jax.tree.util.tree_map_with_path()``: works similarly to ``jax.tree.map()`` but the function also takes key paths as arguments

In [28]:
import collections 

ATuple = collections.namedtuple("Atuple", ('name'))
tree = [1, {'k1':2, 'k2':(3,4)}, ATuple("foo")]
flattened, _ = jax.tree_util.tree_flatten_with_path(tree)

for key_path, value in flattened:
    print(f"Value of tree {jax.tree_util.keystr(key_path)}: {value}")

Value of tree [0]: 1
Value of tree [1]['k1']: 2
Value of tree [1]['k2'][0]: 3
Value of tree [1]['k2'][1]: 4
Value of tree [2].name: foo


- ``jax.tree.utils.keystr()`` given a general key path, returns a reader-friendly string expression.

In [29]:
for key_path, _ in flattened:
    print(f"key path of tree {jax.tree_util.keystr(key_path)}:{repr(key_path)}")

key path of tree [0]:(SequenceKey(idx=0),)
key path of tree [1]['k1']:(SequenceKey(idx=1), DictKey(key='k1'))
key path of tree [1]['k2'][0]:(SequenceKey(idx=1), DictKey(key='k2'), SequenceKey(idx=0))
key path of tree [1]['k2'][1]:(SequenceKey(idx=1), DictKey(key='k2'), SequenceKey(idx=1))
key path of tree [2].name:(SequenceKey(idx=2), GetAttrKey(name='name'))


Default Key Types for the built-in pytree nodes are:
 - ``SequenceKey(idx: int)`` for lists and tuples
 - ``DictKey(Key:hashable)`` for dictionaries
 - ``GetAttrKey(name:str)`` for ``namedtuples`` 

To Transpose a pytree (turn a list of trees into a tree of lists), JAX has two functions: ``jax.tree.map()`` and ``jax.tree.transpose`` (more flexible, complex, and verbose)

In [30]:
def tree_transpose(list_of_trees):

    return jax.tree.map(lambda *xs: list(xs), *list_of_trees)

#convert a dataset from row-major to column-major
episode_steps = [dict(t=1, obs=3), dict(t=2, obs=4)]
tree_transpose(episode_steps)

{'obs': [3, 4], 't': [1, 2]}

In [32]:
jax.tree.transpose(
    outer_treedef=jax.tree.structure([0 for e in episode_steps]),
    inner_treedef=jax.tree.structure(episode_steps[0]),
    pytree_to_transpose=episode_steps
)

{'obs': [3, 4], 't': [1, 2]}