Importing required libraries and modules

In [3]:
import jax
import collections
import numpy as np
import jax.numpy as jnp

from jax.tree_util import tree_structure
from jax.tree_util import tree_flatten, tree_unflatten

# Pytrees

[docs](https://jax.readthedocs.io/en/latest/pytrees.html)

A *pytree* refers to a tree-like structure built out of container-like Python objects. 

Classes are considered container-like if they are in the pytree registry, which by default includes **list**, **tuples**, **dictionaries**, **namedtuple**, **OrderedDict**, and **None**

Pytrees have a nested recursive structure. Every node in the Pytree is one of the following:

- a leaf
- a pytree

In [4]:
list_pytree = ['a', 'b', 'c']

leaves = jax.tree_leaves(list_pytree)
leaves

['a', 'b', 'c']

In [5]:
list_pytree = ['a', 'b', ('Alice', 'Bob')]

leaves = jax.tree_leaves(list_pytree)
leaves

['a', 'b', 'Alice', 'Bob']

Looking at the tree structure. The `*` correspond to the position of the leaves in the Pytree

In [6]:
tree_structure(list_pytree)

PyTreeDef([*, *, (*, *)])

In [7]:
dict_pytree = {'x': 1, 'y': 33, 'z': 3343.33}
leaves = jax.tree_leaves(dict_pytree)

leaves

[1, 33, 3343.33]

In [8]:
tree_structure(dict_pytree)

PyTreeDef({'x': *, 'y': *, 'z': *})

In [9]:
tuple_pytree = ('a', 'b', {'x': 1, 'y': 33})

leaves = jax.tree_leaves(tuple_pytree)
leaves

['a', 'b', 1, 33]

In [10]:
tree_structure(tuple_pytree)

PyTreeDef((*, *, {'x': *, 'y': *}))

a list that represents pytree
can contain objects of any type.Note that dictionary keys are shown as is and values are leaves

In [11]:
complex_pytree = ['a', 'b', 'c', [1, 2], (3., 4.), {'x': 2, 'y': (3, 4)}]
leaves = jax.tree_leaves(complex_pytree)

print('Number of Leaves:', len(leaves))
print('Leaves:', leaves)
print('Tree structure:', tree_structure(complex_pytree))

Number of Leaves: 10
Leaves: ['a', 'b', 'c', 1, 2, 3.0, 4.0, 2, 3, 4]
Tree structure: PyTreeDef([*, *, *, [*, *], (*, *), {'x': *, 'y': (*, *)}])


Similarly we can define pytree using a tuple as well.'()' is not considered as leaf.

In [12]:
complex_pytree = ('a', 'b', 'c', [1., 2.], (3., 4.), ())
leaves = jax.tree_leaves(complex_pytree)

print('Number of Leaves:', len(leaves))
print('Leaves:', leaves)
print('Tree structure:', tree_structure(complex_pytree))

Number of Leaves: 7
Leaves: ['a', 'b', 'c', 1.0, 2.0, 3.0, 4.0]
Tree structure: PyTreeDef((*, *, *, [*, *], (*, *), ()))


In [13]:
complex_pytree = {'x': 1., 'y': (2., 3.), 'z' : [4., 5., 6.]}
leaves = jax.tree_leaves(complex_pytree)

print('Number of Leaves:', len(leaves))
print('Leaves:', leaves)
print('Tree structure:', tree_structure(complex_pytree))

Number of Leaves: 6
Leaves: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
Tree structure: PyTreeDef({'x': *, 'y': (*, *), 'z': [*, *, *]})


In [14]:
complex_pytree = {'x': 1., 'y': jnp.array((2., 3.)), 'z' : jnp.array([4., 5., 6.])}
leaves = jax.tree_leaves(complex_pytree)

print('Number of Leaves:', len(leaves))
print('Leaves:', leaves)
print('Tree structure:', tree_structure(complex_pytree))



Number of Leaves: 3
Leaves: [1.0, DeviceArray([2., 3.], dtype=float32), DeviceArray([4., 5., 6.], dtype=float32)]
Tree structure: PyTreeDef({'x': *, 'y': *, 'z': *})


So, in short, a pytree is just a composition of **nodes**(container-like Python objects) and **leaves**(all other Python objects). JAX also lets you register custom types as pytrees (we will take a few examples of this later on)
We can *flatten* the tree at each level, get the leaves, and the original tree structure as well. Let's see it in action

In [15]:
complex_pytree = {'x': 1., 'y': (2., 3.), 'z' : [4., 5., 6.]}

tree_leaves, tree_structure = tree_flatten(complex_pytree)
print('Leaves:', tree_leaves)
print('Tree structure:', tree_structure)

Leaves: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
Tree structure: PyTreeDef({'x': *, 'y': (*, *), 'z': [*, *, *]})


Transforming the flat value of tree leaves using an element-wise numeric transformer

In [16]:
transformed_leaves = list(map(lambda v: v ** 2, tree_leaves))

transformed_leaves 

[1.0, 4.0, 9.0, 16.0, 25.0, 36.0]

In [17]:
reconstructed_complex_tree = tree_unflatten(treedef = tree_structure, leaves = transformed_leaves)

print('Original Pytree:   ',  complex_pytree)
print('Transformed Pytree:', reconstructed_complex_tree)

Original Pytree:    {'x': 1.0, 'y': (2.0, 3.0), 'z': [4.0, 5.0, 6.0]}
Transformed Pytree: {'x': 1.0, 'y': (4.0, 9.0), 'z': [16.0, 25.0, 36.0]}


We can(should) use `tree_map(...)` for doing operations on leaves as it is a much better way to achieve it. The above example is to showcase that you can do it in a way where you want more control over the operations applied to different leaves of the tree

In [18]:
transformed_leaves = jax.tree_map(lambda x: x**2, complex_pytree)

transformed_leaves 

{'x': 1.0, 'y': (4.0, 9.0), 'z': [16.0, 25.0, 36.0]}

In [19]:
copy_of_complex_pytree = complex_pytree

print(complex_pytree)
print(copy_of_complex_pytree)

print('*' * 50)

print(jax.tree_map(lambda x, y: x + y, complex_pytree, copy_of_complex_pytree))

{'x': 1.0, 'y': (2.0, 3.0), 'z': [4.0, 5.0, 6.0]}
{'x': 1.0, 'y': (2.0, 3.0), 'z': [4.0, 5.0, 6.0]}
**************************************************
{'x': 2.0, 'y': (4.0, 6.0), 'z': [8.0, 10.0, 12.0]}


Reconstructing the transformed tree , using the original tree's structure

By default, pytree containers can be lists, tuples, dicts, namedtuple, None, OrderedDict. 

Other types of values, including numeric and ndarray values, are treated as leaves. An object is also treated as a leaf as is a DeviceArray

jax.tree_utils treats None as a node without children, not as a leaf.

In [21]:
containers_or_not = [
    None,
    1.,
    object(),
    jnp.ones(3),
]

def show_example(container):

  leaves, structure = tree_flatten(container)
  unflattened = tree_unflatten(structure , leaves)

  print('Original={}\n  flat={}\n  tree={}\n  unflattened={}'.format(
      container, leaves, structure, unflattened))

for not_container in containers_or_not:
  show_example(not_container)

Original=None
  flat=[]
  tree=PyTreeDef(None)
  unflattened=None
Original=1.0
  flat=[1.0]
  tree=PyTreeDef(*)
  unflattened=1.0
Original=<object object at 0x7f4653268a10>
  flat=[<object object at 0x7f4653268a10>]
  tree=PyTreeDef(*)
  unflattened=<object object at 0x7f4653268a10>
Original=[1. 1. 1.]
  flat=[DeviceArray([1., 1., 1.], dtype=float32)]
  tree=PyTreeDef(*)
  unflattened=[1. 1. 1.]


We saw that Pytree are container-like Python objects like lists, tuples, dicts, etc. But what if you want to extend this set of Python objects treated as pytree nodes? For example, what if you want to treat your class as a Pytree node? 

Well, if you think about it, to treat a class as a pytree node, we need to:
1. Tell JAX that you want to treat it as a node and not a leaf by registering it in the internal registry.
2. Because this is a custom object, JAX doesn't know how to `flatten` and `unflatten` it, we need to tell JAX this as well
3. There are cases when we need to compare two `treedef` structures for equality. Hence we need to make sure that adding a custom object doesn't break the equality check.

Let's look at an example.
We are creating Special Container for storing Customer specific attributes as x,y,z.
By default, any part of a structured value that is not recognized as an internal pytree node (i.e. container-like) is treated as a leaf:



In [22]:
class Triplet():

  def __init__(self, name, x, y, z):
    self.name = name
    self.x = x
    self.y = y
    self.z = z

  def __repr__(self):
    return 'Triplet(name={},x={}, y={}, z={})'.format(self.name, self.x, self.y, self.z)


show_example(Triplet('John', 10., 20., 30.))

Original=Triplet(name=John,x=10.0, y=20.0, z=30.0)
  flat=[Triplet(name=John,x=10.0, y=20.0, z=30.0)]
  tree=PyTreeDef(*)
  unflattened=Triplet(name=John,x=10.0, y=20.0, z=30.0)


The set of Python types that are considered internal pytree nodes is extensible, through a global registry of types, and values of registered types are traversed recursively. To register a new type, you can use register_pytree_node(). Details can be referred from this link-
https://jax.readthedocs.io/en/latest/pytrees.html#extending-pytrees

In [23]:
from jax.tree_util import register_pytree_node

In [26]:
def triplet_flatten(v):
  children = (v.x, v.y, v.z)
  aux_data = v.name

  return (children, aux_data)

def triplet_unflatten(aux_data, children):
  return Triplet(aux_data, *children)

register_pytree_node(
    Triplet,
    triplet_flatten,    
    triplet_unflatten   
)

In [28]:
show_example(Triplet('John', 10., 20., 30.))

Original=Triplet(name=John,x=10.0, y=20.0, z=30.0)
  flat=[10.0, 20.0, 30.0]
  tree=PyTreeDef(CustomNode(<class '__main__.Triplet'>[John], [*, *, *]))
  unflattened=Triplet(name=John,x=10.0, y=20.0, z=30.0)


Another customer container can be instantiated and details are displayed.

In [29]:
show_example(Triplet('Jane', 25., 40., 70.))

Original=Triplet(name=Jane,x=25.0, y=40.0, z=70.0)
  flat=[25.0, 40.0, 70.0]
  tree=PyTreeDef(CustomNode(<class '__main__.Triplet'>[Jane], [*, *, *]))
  unflattened=Triplet(name=Jane,x=25.0, y=40.0, z=70.0)


Adding a specific no. to attributes for both customers

In [32]:
jax.tree_map(lambda x: x + 10, [
    Triplet('John', 10., 20., 30.),
    Triplet('Jane', 25., 40., 70.)
])

[Triplet(name=John,x=20.0, y=30.0, z=40.0),
 Triplet(name=Jane,x=35.0, y=50.0, z=80.0)]


# References

1. https://jax.readthedocs.io/en/latest/pytrees.html
2. https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html