Skip to content

Commit

Permalink
Add jax.tree_util.tree_leaves_with_path(tree).
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 539609052
  • Loading branch information
tomhennigan authored and jax authors committed Jun 12, 2023
1 parent 9b10384 commit ed073aa
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/jax.tree_util.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ List of Functions
tree_flatten
tree_flatten_with_path
tree_leaves
tree_leaves_with_path
tree_map
tree_map_with_path
tree_reduce
Expand Down
14 changes: 14 additions & 0 deletions jax/_src/tree_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,20 @@ def tree_flatten_with_path(
return _generate_key_paths(tree, is_leaf), tree_def


def tree_leaves_with_path(
tree: Any, is_leaf: Optional[Callable[[Any], bool]] = None
) -> List[Tuple[KeyPath, Any]]:
"""Flattens a pytree like ``tree_leaves``, but also returns each leaf's key path.
Args:
tree: a pytree to flatten. If it contains a custom type, it must be
registered with ``register_pytree_with_keys``.
Returns:
A list of key-leaf pairs, each of which contains a leaf and its key path.
"""
return _generate_key_paths(tree, is_leaf)


def generate_key_paths(
tree: Any, is_leaf: Optional[Callable[[Any], bool]] = None
) -> List[Tuple[KeyPath, Any]]:
Expand Down
1 change: 1 addition & 0 deletions jax/tree_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
register_pytree_with_keys_class as register_pytree_with_keys_class,
tree_map_with_path as tree_map_with_path,
tree_flatten_with_path as tree_flatten_with_path,
tree_leaves_with_path as tree_leaves_with_path,
keystr as keystr,
SequenceKey as SequenceKey,
DictKey as DictKey,
Expand Down
7 changes: 7 additions & 0 deletions tests/tree_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,13 @@ def testTreeMapWithPathMultipleTrees(self):
from_one_tree = tree_util.tree_map(lambda a: a + 2, tree1)
self.assertEqual(from_two_trees, from_one_tree)

def testTreeLeavesWithPath(self):
tree = [{i: i for i in range(10)}]
actual = tree_util.tree_leaves_with_path(tree)
expected = [((tree_util.SequenceKey(0), tree_util.DictKey(i)), i)
for i in range(10)]
self.assertEqual(actual, expected)

def testKeyStr(self):
tree1 = [ATuple(12, {'cin': [1, 4, 10], 'bar': None}), jnp.arange(5)]
flattened, _ = tree_util.tree_flatten_with_path(tree1)
Expand Down

0 comments on commit ed073aa

Please sign in to comment.