Skip to content

Commit

Permalink
Add test_assert_tree_shape_prefix assertion.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 358408201
  • Loading branch information
hbq1 authored and ChexDev chex-dev@google.com committed Feb 19, 2021
1 parent 82f19d9 commit 97bcf3a
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 2 deletions.
2 changes: 2 additions & 0 deletions chex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from chex._src.asserts import assert_tree_all_equal_shapes
from chex._src.asserts import assert_tree_all_equal_structs
from chex._src.asserts import assert_tree_all_finite
from chex._src.asserts import assert_tree_shape_prefix
from chex._src.asserts import assert_type
from chex._src.asserts import clear_trace_counter
from chex._src.asserts import if_args_not_none
Expand Down Expand Up @@ -105,6 +106,7 @@
"assert_tree_all_equal_shapes",
"assert_tree_all_equal_structs",
"assert_tree_all_finite",
"assert_tree_shape_prefix",
"assert_type",
"ChexVariantType",
"clear_trace_counter",
Expand Down
24 changes: 22 additions & 2 deletions chex/_src/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import jax.numpy as jnp
import jax.test_util as jax_test
import numpy as np
import tree
import tree as dm_tree

Scalar = pytypes.Scalar
Array = pytypes.Array
Expand Down Expand Up @@ -587,6 +587,26 @@ def assert_numerical_grads(
jax_test.check_grads(f, f_args, order=order, atol=atol, **check_kwargs)


def assert_tree_shape_prefix(tree: ArrayTree, shape_prefix: Sequence[int]):
"""Assert all tree leaves shapes' have the same prefix.
Args:
tree: tree to assert.
shape_prefix: expected shapes' prefix.
Raise:
AssertionError: if some leaf's shape doesn't start with the expected prefix.
"""

def _assert_fn(path, leaf):
prefix = leaf.shape[:len(shape_prefix)]
if prefix != shape_prefix:
raise AssertionError(
f"Tree leaf '{'/'.join(path)}' has a shape prefix "
f"diffent from expected: {prefix} != {shape_prefix}.")

dm_tree.map_structure_with_path(_assert_fn, tree)


def assert_tree_all_equal_structs(*trees: Sequence[ArrayTree]):
"""Assert trees have the same structure.
Expand Down Expand Up @@ -620,7 +640,7 @@ def _tree_error_msg_fn(l_1: TLeaf, l_2: TLeaf, path: str, i_1: int, i_2: int):

cmp = functools.partial(_assert_leaves_all_eq_comparator,
equality_comparator, _tree_error_msg_fn)
tree.map_structure_with_path(cmp, *trees)
dm_tree.map_structure_with_path(cmp, *trees)


def assert_tree_all_close(*trees: Sequence[ArrayTree],
Expand Down
12 changes: 12 additions & 0 deletions chex/_src/asserts_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,18 @@ def test_assert_tree_all_equal_structs(self):
asserts.assert_tree_all_equal_structs(tree3, tree3)
self._assert_tree_structs_validation(asserts.assert_tree_all_equal_structs)

def test_assert_tree_shape_prefix(self):
tree1 = {'x': {'y': np.zeros([3, 2])}, 'z': np.zeros([3, 2, 1])}
asserts.assert_tree_shape_prefix(tree1, ())
asserts.assert_tree_shape_prefix(tree1, (3,))
asserts.assert_tree_shape_prefix(tree1, (3, 2))

with self.assertRaisesRegex(
AssertionError,
r'Tree leaf \'x/y\' .* diffent from expected: \(3, 2\) != \(3, 2, 1\)'
):
asserts.assert_tree_shape_prefix(tree1, (3, 2, 1))


class NumDevicesAssertTest(parameterized.TestCase):

Expand Down

0 comments on commit 97bcf3a

Please sign in to comment.