Skip to content

Commit

Permalink
[jax_export] Add backwards compatibility tests for shape_assertion.
Browse files Browse the repository at this point in the history
  • Loading branch information
gnecula committed Aug 12, 2023
1 parent cf4e1d4 commit deefdbe
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 37 deletions.
35 changes: 21 additions & 14 deletions jax/experimental/jax2tf/tests/back_compat_test.py
Expand Up @@ -90,7 +90,7 @@ def test_detect_different_custom_calls(self):
dummy_data,
platform=self.default_jax_backend(),
custom_call_targets=["missing"])
with self.assertRaisesRegex(AssertionError, "Lists differ"):
with self.assertRaisesRegex(AssertionError, "Element counts were not equal"):
self.run_one_test(jnp.sin, platform_dummy_data)

def test_custom_call_coverage(self):
Expand All @@ -117,6 +117,7 @@ def test_custom_call_coverage(self):
tpu_stablehlo_dynamic_reduce_window.data_variadic_2023_06_17,
stablehlo_dynamic_rng_bit_generator.data_2023_06_17,
stablehlo_dynamic_top_k.data_2023_07_16,
stablehlo_dynamic_top_k.data_2023_08_11, # with shape_assertion
]
# Some of the above are nested structures.
covering_testdatas = itertools.chain(
Expand All @@ -128,8 +129,6 @@ def test_custom_call_coverage(self):

covered_targets = covered_targets.union({
"tpu_custom_call", # tested separately
# TODO(necula): add tests for shape_assertion
"shape_assertion",
})
not_covered = targets_to_cover.difference(covered_targets)
self.assertEmpty(not_covered,
Expand All @@ -143,8 +142,9 @@ def func(x):

# An old lowering, with ducc_fft. We keep it for 6 months.
data = self.load_testdata(cpu_ducc_fft.data_2023_03_17)
# We have changed the lowering for fft, do not compare with current.
self.run_one_test(func, data, compare_with_current=False)
# We have changed the lowering for fft since we saved this data.
self.run_one_test(func, data,
expect_current_custom_calls=["dynamic_ducc_fft"])

# A newer lowering, with dynamic_ducc_fft.
data = self.load_testdata(cpu_ducc_fft.data_2023_06_14)
Expand Down Expand Up @@ -617,8 +617,8 @@ def func(x):
self.run_one_test(
func, data,
polymorphic_shapes=("b, ...",),
# TODO(necula): now also includes shape_assertion
compare_with_current=False)
# Recent serializations also include shape_assertion, tested with dynamic_top_k
expect_current_custom_calls=["stablehlo.dynamic_reduce_window", "shape_assertion"])

def test_tpu_stablehlo_dynamic_reduce_window_variadic(self):
# stablehlo.dynamic_reduce_window is used temporarily on TPU for a
Expand All @@ -640,8 +640,8 @@ def func(x, y): # x: f32[b, 2] y: i32[b, 2]
self.run_one_test(
func, data,
polymorphic_shapes=("b, ...", "b, ..."),
# TODO(necula): now also includes shape_assertion
compare_with_current=False)
# Recent serializations also include shape_assertion, tested with dynamic_top_k
expect_current_custom_calls=["stablehlo.dynamic_reduce_window", "shape_assertion"])

def test_stablehlo_dynamic_rbg_bit_generator(self):
# stablehlo.dynamic_rbg_bit_generator is used temporarily for a
Expand Down Expand Up @@ -672,9 +672,10 @@ def func(key, a): # a is only used for its shape
try:
jax.config.update("jax_default_prng_impl", "unsafe_rbg")

self.run_one_test(func, data, polymorphic_shapes=(None, "b0, b1"),
# TODO(necula): now also includes shape_assertion
compare_with_current=False)
self.run_one_test(
func, data, polymorphic_shapes=(None, "b0, b1"),
# Recent serializations also include shape_assertion, tested with dynamic_top_k
expect_current_custom_calls=["stablehlo.dynamic_rng_bit_generator", "shape_assertion"])
finally:
jax.config.update("jax_default_prng_impl", prev_default_prng_impl)

Expand All @@ -700,8 +701,14 @@ def check_top_k_results(res_run, res_expected, *, rtol, atol):
self.run_one_test(func, data,
polymorphic_shapes=("_, b",),
check_results=check_top_k_results,
# TODO(necula): now also includes shape_assertion
compare_with_current=False)
# Recent serializations also include shape_assertion
expect_current_custom_calls=["stablehlo.dynamic_top_k", "shape_assertion"])

# Now a test with serialization version 7, including shape_assertion
data_2 = self.load_testdata(stablehlo_dynamic_top_k.data_2023_08_11)
self.run_one_test(func, data_2,
polymorphic_shapes=("_, b",),
check_results=check_top_k_results)


if __name__ == "__main__":
Expand Down
43 changes: 21 additions & 22 deletions jax/experimental/jax2tf/tests/back_compat_test_util.py
Expand Up @@ -163,9 +163,9 @@ def run_one_test(self, func: Callable[..., jax.Array],
polymorphic_shapes: Optional[Sequence[str]] = None,
rtol: Optional[float] = None,
atol: Optional[float] = None,
allow_additional_custom_call_targets: Sequence[str] = (),
allow_unstable_custom_call_targets: Sequence[str] = (),
check_results: Optional[Callable[..., None]] = None,
compare_with_current: bool = True):
expect_current_custom_calls: Optional[Sequence[str]] = None):
"""Run one compatibility test.
Args:
Expand All @@ -178,14 +178,12 @@ def run_one_test(self, func: Callable[..., jax.Array],
check_results: invoked with the results obtained from running the
serialized code, and those stored in the test data, and the kwargs rtol
and atol.
allow_additional_custom_call_targets: additional custom call targets to allow.
compare_with_current: whether to compare the current behavior for
`func` with the one stored in `data`. If `True` (default) uses the
current version of JAX and XLA to lower and serialize `func` and check
its results compared to the stored ones; it also dumps the current
test data. If `False`, no current serialization are comparisons are
done, tests only the saved serialization. Use this option for a test
data for which we have changed the serialization.
allow_unstable_custom_call_targets: additional custom call targets to allow.
expect_current_custom_calls: if `None` checks that the current serialization
has the same custom calls as the saved one. This is the default, and
will fail when the serialization changes. Otherwise, when checking old
serializations you can specify what custom calls are expected in the
current serialization.
"""
if not isinstance(data, CompatTestData):
raise ValueError(f"Expecting data: CompatTestData but got {data}. "
Expand All @@ -204,10 +202,10 @@ def run_one_test(self, func: Callable[..., jax.Array],
serialized, module_str, module_version = self.serialize(
func, data,
polymorphic_shapes=polymorphic_shapes,
allow_additional_custom_call_targets=allow_additional_custom_call_targets)
allow_unstable_custom_call_targets=allow_unstable_custom_call_targets)

custom_call_re = r"stablehlo.custom_call\s*@([^\(]+)\("
custom_call_targets = sorted(
current_custom_call_targets = sorted(
list(set(re.findall(custom_call_re, module_str))))

np.set_printoptions(threshold=sys.maxsize, floatmode="unique")
Expand All @@ -217,7 +215,7 @@ def run_one_test(self, func: Callable[..., jax.Array],
data_{datetime.date.today().strftime('%Y_%m_%d')} = dict(
testdata_version={CURRENT_TESTDATA_VERSION},
platform={repr(self.default_jax_backend())},
custom_call_targets={repr(custom_call_targets)},
custom_call_targets={repr(current_custom_call_targets)},
serialized_date={repr(datetime.date.today())},
inputs={repr(data.inputs)},
expected_outputs={repr(res_run_current)},
Expand Down Expand Up @@ -258,25 +256,26 @@ def run_one_test(self, func: Callable[..., jax.Array],
else:
self.assertAllClose(res_run_serialized, data.expected_outputs,
rtol=rtol, atol=atol)
if compare_with_current:
self.assertListEqual(custom_call_targets, data.custom_call_targets)
if expect_current_custom_calls is None:
expect_current_custom_calls = data.custom_call_targets
self.assertItemsEqual(expect_current_custom_calls, current_custom_call_targets)

def run_current(self, func: Callable, data: CompatTestData):
"""Lowers and runs the test function at the current JAX version."""
return jax.jit(func)(*data.inputs)

def serialize(self,
func: Callable, data: CompatTestData, *,
polymorphic_shapes: Optional[Sequence[str]] = None,
allow_additional_custom_call_targets: Sequence[str] = ()
) -> tuple[bytes, str, int]:
func: Callable, data: CompatTestData, *,
polymorphic_shapes: Optional[Sequence[str]] = None,
allow_unstable_custom_call_targets: Sequence[str] = ()
) -> tuple[bytes, str, int]:
"""Serializes the test function.
Args:
func: the function to serialize
polymorphic_shapes: the polymorphic_shapes to use for serialization
allow_additional_custom_call_targets: whether to allow additional
custom call targets besides the standard ones.
allow_unstable_custom_call_targets: whether to allow additional
custom call targets besides those known as stable.
Returns: a tuple with the (a) serialization, (b) the module contents as
a string (for debugging), and (c) the module serialization version.
Expand All @@ -288,7 +287,7 @@ def serialize(self,
lowering_platform=self.default_jax_backend(),
disabled_checks=tuple(
jax_export.DisabledSafetyCheck.custom_call(target)
for target in allow_additional_custom_call_targets)
for target in allow_unstable_custom_call_targets)
)(*args_specs)

module_str = str(exported.mlir_module())
Expand Down

0 comments on commit deefdbe

Please sign in to comment.