From 84ee045f5506059ca060f6d549fb9c97f9b500a1 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 29 Jan 2024 13:59:44 -0800 Subject: [PATCH] [key reuse] handle polymorphic shapes in slice --- jax/experimental/key_reuse/_forwarding.py | 13 +++++++++---- jax/experimental/key_reuse/_simple.py | 13 +++++++++---- tests/shape_poly_test.py | 20 +++++++++++--------- 3 files changed, 29 insertions(+), 17 deletions(-) diff --git a/jax/experimental/key_reuse/_forwarding.py b/jax/experimental/key_reuse/_forwarding.py index 21f147165a2c..df3ff634f5ea 100644 --- a/jax/experimental/key_reuse/_forwarding.py +++ b/jax/experimental/key_reuse/_forwarding.py @@ -19,10 +19,10 @@ from typing import Any, Callable, NamedTuple import jax -from jax import core from jax import lax from jax import tree_util from jax._src import api_util +from jax._src import core from jax._src import linear_util as lu from jax._src import pjit from jax._src import prng @@ -195,9 +195,14 @@ def _slice_signature(eqn, args_consumed): limit_indices = eqn.params['limit_indices'] strides = eqn.params['strides'] or (1,) * len(start_indices) idx = tuple(slice(*tup) for tup in util.safe_zip(start_indices, limit_indices, strides)) - mask = np.zeros(in_aval.shape, dtype=bool) - mask[idx] = True - return KeyReuseSignatureWithForwards([Sink(0, mask)], [Source(0)]) + if any(core.is_symbolic_dim(s) for s in in_aval.shape): + sink = True + else: + # TODO(jakevdp): should we avoid constructing the mask array if the input + # does not have a key dtype? + sink = np.zeros(in_aval.shape, dtype=bool) + sink[idx] = True + return KeyReuseSignatureWithForwards([Sink(0, sink)], [Source(0)]) key_reuse_signatures_dynamic[lax.slice_p] = _slice_signature diff --git a/jax/experimental/key_reuse/_simple.py b/jax/experimental/key_reuse/_simple.py index 5f27cae3dd71..d2f1f087d39e 100644 --- a/jax/experimental/key_reuse/_simple.py +++ b/jax/experimental/key_reuse/_simple.py @@ -19,10 +19,10 @@ from typing import Any, Callable import jax -from jax import core from jax import lax from jax import tree_util from jax._src import api_util +from jax._src import core from jax._src import linear_util as lu from jax._src import pjit from jax._src import prng @@ -166,9 +166,14 @@ def _slice_signature(eqn, args_consumed): limit_indices = eqn.params['limit_indices'] strides = eqn.params['strides'] or (1,) * len(start_indices) idx = tuple(slice(*tup) for tup in util.safe_zip(start_indices, limit_indices, strides)) - mask = np.zeros(in_aval.shape, dtype=bool) - mask[idx] = True - return KeyReuseSignature([Sink(0, mask)], [Source(0)]) + if any(core.is_symbolic_dim(s) for s in in_aval.shape): + sink = True + else: + # TODO(jakevdp): should we avoid constructing the mask array if the input + # does not have a key dtype? + sink = np.zeros(in_aval.shape, dtype=bool) + sink[idx] = True + return KeyReuseSignature([Sink(0, sink)], [Source(0)]) key_reuse_signatures_dynamic[lax.slice_p] = _slice_signature diff --git a/tests/shape_poly_test.py b/tests/shape_poly_test.py index 84d8b5ca2624..806a422055a2 100644 --- a/tests/shape_poly_test.py +++ b/tests/shape_poly_test.py @@ -106,7 +106,7 @@ def _stop_profile(tst: jtu.JaxTestCase): p.sort_stats("cumtime").print_stats(.2) p.print_callers(.2) -@jtu.with_config(jax_enable_key_reuse_checks=False) + class DimExprTest(jtu.JaxTestCase): def setUp(self): @@ -1074,7 +1074,6 @@ def check_shape_poly(tst, f_jax: Callable, *, return h.run_test(tst) -@jtu.with_config(jax_enable_key_reuse_checks=False) class ShapePolyTest(jtu.JaxTestCase): def setUp(self): @@ -1475,6 +1474,7 @@ def f(x): polymorphic_shapes=["(b,)"]) self.assertAllClose(f(x), res_tf) + @jax.enable_key_reuse_checks(False) def test_prng(self): # The PRNG implementation uses opaque types, test shape polymorphism with config.enable_custom_prng(True): @@ -2908,7 +2908,6 @@ def _flatten_harnesses(harnesses): return res -@jtu.with_config(jax_enable_key_reuse_checks=False) class ShapePolyHarnessesTest(jtu.JaxTestCase): """This test runs for all _POLY_SHAPE_PRIMITIVE_HARNESSES.""" @@ -2987,16 +2986,19 @@ def test_harness(self, harness: PolyHarness): if harness.group_name == "eig" and not jtu.test_device_matches(["cpu"]): raise unittest.SkipTest("JAX implements eig only on CPU.") - prev_jax_config_flags = { - fname: getattr(jax.config, fname) - for fname, fvalue in harness.override_jax_config_flags.items() - } + config_flags = harness.override_jax_config_flags + # Update this here rather than in harness object because vmap_random_gamma is derived + # from test_harnesses.all_harnesses, which strips override_jax_config_flags. + if "random_gamma" in harness.group_name: + config_flags = {**config_flags, "jax_enable_key_reuse_checks": False} + + prev_jax_config_flags = {fname: getattr(jax.config, fname) for fname in config_flags} try: - for fname, fvalue in harness.override_jax_config_flags.items(): + for fname, fvalue in config_flags.items(): jax.config.update(fname, fvalue) harness.run_test(self) finally: - for fname, _ in harness.override_jax_config_flags.items(): + for fname, _ in config_flags.items(): jax.config.update(fname, prev_jax_config_flags[fname]) if __name__ == "__main__":