In [17]:
import jax
from jax.tree_util import register_pytree_node

In [28]:
import collections

ATuple = collections.namedtuple("ATuple", ('name'))

tree = [1, {'k1': 2, 'k2': (3, 4)}, ATuple('foo')]
flattened, pytreedef = jax.tree_util.tree_flatten_with_path(tree)

print(pytreedef)
for key_path, value in flattened:
  print('-'*30)
  print(f"key_path | {key_path}")
  print(f'key_str  | {jax.tree_util.keystr(key_path)}')
  print(f"value    | {value}")

PyTreeDef([*, {'k1': *, 'k2': (*, *)}, CustomNode(namedtuple[ATuple], [*])])
------------------------------
key_path | (SequenceKey(idx=0),)
key_str  | [0]
value    | 1
------------------------------
key_path | (SequenceKey(idx=1), DictKey(key='k1'))
key_str  | [1]['k1']
value    | 2
------------------------------
key_path | (SequenceKey(idx=1), DictKey(key='k2'), SequenceKey(idx=0))
key_str  | [1]['k2'][0]
value    | 3
------------------------------
key_path | (SequenceKey(idx=1), DictKey(key='k2'), SequenceKey(idx=1))
key_str  | [1]['k2'][1]
value    | 4
------------------------------
key_path | (SequenceKey(idx=2), GetAttrKey(name='name'))
key_str  | [2].name
value    | foo


* `SequenceKey(idx: int)`: For lists and tuples.
* `DictKey(key: Hashable)`: For dictionaries.
* `GetAttrKey(name: str)`: For `namedtuple`s and preferably custom pytree nodes (more in the next section)

## Custom pytree

In [23]:
class Special1(object):
  def __repr__(self):
    return "Special1(x={}, y={})".format(self.x, self.y)

  def __init__(self, x, y):
    self.x = x
    self.y = y

def special_flatten1(v):
  children = (v.x, v.y)
  aux_data = None
  return (children, aux_data)

def special_unflatten1(aux_data, children):
  return Special1(*children)

# Global registration
register_pytree_node(
    Special1,
    special_flatten1,
    special_unflatten1
)

In [24]:
special1 = Special1(1, 2)

flattened, pytreedef = jax.tree_util.tree_flatten_with_path(special1)

print(pytreedef)
for key_path, value in flattened:
  print('-'*30)
  print(f"key_path | {key_path}")
  print(f'key_str  | {jax.tree_util.keystr(key_path)}')
  print(f"value    | {value}")

PyTreeDef(CustomNode(Special1[None], [*, *]))
------------------------------
key_path | (FlattenedIndexKey(key=0),)
key_str  | [<flat index 0>]
value    | 1
------------------------------
key_path | (FlattenedIndexKey(key=1),)
key_str  | [<flat index 1>]
value    | 2


In [26]:
class Special2(object):
  def __repr__(self):
    return "Special(x={}, y={})".format(self.x, self.y)

  def __init__(self, x, y):
    self.x = x
    self.y = y

In [27]:
special1 = Special2(1, 2)

flattened, pytreedef = jax.tree_util.tree_flatten_with_path(special1)

print(pytreedef)
for key_path, value in flattened:
  print('-'*30)
  print(f"key_path | {key_path}")
  print(f'key_str  | {jax.tree_util.keystr(key_path)}')
  print(f"value    | {value}")

PyTreeDef(*)
------------------------------
key_path | ()
key_str  | 
value    | Special(x=1, y=2)
