In [24]:
import jax
import collections
import numpy as np
import jax.numpy as jnp

from jax.tree_util import tree_structure
from jax.tree_util import tree_flatten, tree_unflatten
from jax.tree_util import tree_leaves

In [5]:
list_pytree = ['a', 'b', 'c']

leaves = tree_leaves(list_pytree)
leaves

['a', 'b', 'c']

In [6]:
list_pytree = ['a', 'b', ('Alice', 'Bob')]

leaves = tree_leaves(list_pytree)
leaves

['a', 'b', 'Alice', 'Bob']

In [7]:
tree_structure(list_pytree)

PyTreeDef([*, *, (*, *)])

In [9]:
dict_pytree = {'x': 1, 'y': 33, 'z': 3343.33}

leaves = tree_leaves(dict_pytree) # dict-keys are metadata, dict-values are leaves
leaves

[1, 33, 3343.33]

In [11]:
tree_structure(dict_pytree) # *'s are dict-values and dict-keys (metadata) are visible here

PyTreeDef({'x': *, 'y': *, 'z': *})

In [13]:
tuple_pytree = ('a', 'b', {'x': 1, 'y': 33})

leaves = tree_leaves(tuple_pytree) # dict-keys are metadata, dict-values are leaves
leaves

['a', 'b', 1, 33]

In [15]:
tree_structure(tuple_pytree)

PyTreeDef((*, *, {'x': *, 'y': *}))

In [16]:
complex_pytree = ['a', 'b', 'c', [1, 2], (3., 4.), {'x':2, 'y': (3,4)}]
leaves = tree_leaves(complex_pytree)
leaves

['a', 'b', 'c', 1, 2, 3.0, 4.0, 2, 3, 4]

In [17]:
print('Number of leaves:', len(leaves))
print('Leaves:', leaves)
print('Tree structure:', tree_structure(complex_pytree))

Number of leaves: 10
Leaves: ['a', 'b', 'c', 1, 2, 3.0, 4.0, 2, 3, 4]
Tree structure: PyTreeDef([*, *, *, [*, *], (*, *), {'x': *, 'y': (*, *)}])


In [18]:
complex_pytree = ['a', 'b', 'c', [1, 2], (3., 4.), ()]
leaves = tree_leaves(complex_pytree)

print('Number of leaves:', len(leaves))
print('Leaves:', leaves)
print('Tree structure:', tree_structure(complex_pytree))

Number of leaves: 7
Leaves: ['a', 'b', 'c', 1, 2, 3.0, 4.0]
Tree structure: PyTreeDef([*, *, *, [*, *], (*, *), ()])


In [20]:
complex_pytree = {'x': 1., 'y': (2., 3.), 'z': [4., 5., 6.]}
leaves = tree_leaves(complex_pytree)

print('Number of leaves:', len(leaves))
print('Leaves:', leaves)
print('Tree structure:', tree_structure(complex_pytree))

Number of leaves: 6
Leaves: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
Tree structure: PyTreeDef({'x': *, 'y': (*, *), 'z': [*, *, *]})


In [21]:
complex_pytree = {'x': 1., 'y': jnp.array((2., 3.)), 'z': jnp.array([4., 5., 6.])} # jax arrays are primitives (not pytree objects) like int, float, etc.
leaves = tree_leaves(complex_pytree)

print('Number of leaves:', len(leaves))
print('Leaves:', leaves)
print('Tree structure:', tree_structure(complex_pytree))

Number of leaves: 3
Leaves: [1.0, Array([2., 3.], dtype=float32), Array([4., 5., 6.], dtype=float32)]
Tree structure: PyTreeDef({'x': *, 'y': *, 'z': *})


# Operations on pytrees

In [25]:
complex_pytree = {'x': 1., 'y': (2., 3.), 'z': [4., 5., 6.]}

t_leaves, t_structure = tree_flatten(complex_pytree)
print('Leaves:', t_leaves)
print('Tree structure:', tree_structure(t_structure))

Leaves: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
Tree structure: PyTreeDef(*)


In [26]:
transformed_leaves = list(map(lambda x: x**2, t_leaves))
transformed_leaves

[1.0, 4.0, 9.0, 16.0, 25.0, 36.0]

In [27]:
reconstructed_complex_tree = tree_unflatten(treedef= t_structure, leaves= transformed_leaves)
print('Original Pytree: ', complex_pytree)
print('Transformed Pytree: ', reconstructed_complex_tree)

Original Pytree:  {'x': 1.0, 'y': (2.0, 3.0), 'z': [4.0, 5.0, 6.0]}
Transformed Pytree:  {'x': 1.0, 'y': (4.0, 9.0), 'z': [16.0, 25.0, 36.0]}


In [29]:
transformed_leaves = jax.tree.map(lambda x: x**2, complex_pytree)
transformed_leaves

{'x': 1.0, 'y': (4.0, 9.0), 'z': [16.0, 25.0, 36.0]}

In [30]:
copy_complex_pytree = complex_pytree.copy()
print(complex_pytree)
print(copy_complex_pytree)
print('*' * 50)
print(jax.tree.map(lambda x, y: x + y, complex_pytree, copy_complex_pytree))

{'x': 1.0, 'y': (2.0, 3.0), 'z': [4.0, 5.0, 6.0]}
{'x': 1.0, 'y': (2.0, 3.0), 'z': [4.0, 5.0, 6.0]}
**************************************************
{'x': 2.0, 'y': (4.0, 6.0), 'z': [8.0, 10.0, 12.0]}


# Pytree containers

In [31]:
containers_or_not = [
    None,
    1.,
    object(),
    jnp.ones(3)
]

def show_example(container):
    t_leaves, t_structure = tree_flatten(container)
    unflattened = tree_unflatten(t_structure, t_leaves)

    print('Original={}\n flat={}\n tree={}\n unflattened={}'.format(container, t_leaves, t_structure, unflattened))

for not_container in containers_or_not:
    show_example(not_container)

Original=None
 flat=[]
 tree=PyTreeDef(None)
 unflattened=None
Original=1.0
 flat=[1.0]
 tree=PyTreeDef(*)
 unflattened=1.0
Original=<object object at 0x7ef9aabeae80>
 flat=[<object object at 0x7ef9aabeae80>]
 tree=PyTreeDef(*)
 unflattened=<object object at 0x7ef9aabeae80>
Original=[1. 1. 1.]
 flat=[Array([1., 1., 1.], dtype=float32)]
 tree=PyTreeDef(*)
 unflattened=[1. 1. 1.]
