Skip to content

Commit

Permalink
[key reuse] handle polymorphic shapes in slice
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jan 29, 2024
1 parent a043325 commit 84ee045
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 17 deletions.
13 changes: 9 additions & 4 deletions jax/experimental/key_reuse/_forwarding.py
Expand Up @@ -19,10 +19,10 @@
from typing import Any, Callable, NamedTuple

import jax
from jax import core
from jax import lax
from jax import tree_util
from jax._src import api_util
from jax._src import core
from jax._src import linear_util as lu
from jax._src import pjit
from jax._src import prng
Expand Down Expand Up @@ -195,9 +195,14 @@ def _slice_signature(eqn, args_consumed):
limit_indices = eqn.params['limit_indices']
strides = eqn.params['strides'] or (1,) * len(start_indices)
idx = tuple(slice(*tup) for tup in util.safe_zip(start_indices, limit_indices, strides))
mask = np.zeros(in_aval.shape, dtype=bool)
mask[idx] = True
return KeyReuseSignatureWithForwards([Sink(0, mask)], [Source(0)])
if any(core.is_symbolic_dim(s) for s in in_aval.shape):
sink = True
else:
# TODO(jakevdp): should we avoid constructing the mask array if the input
# does not have a key dtype?
sink = np.zeros(in_aval.shape, dtype=bool)
sink[idx] = True
return KeyReuseSignatureWithForwards([Sink(0, sink)], [Source(0)])

key_reuse_signatures_dynamic[lax.slice_p] = _slice_signature

Expand Down
13 changes: 9 additions & 4 deletions jax/experimental/key_reuse/_simple.py
Expand Up @@ -19,10 +19,10 @@
from typing import Any, Callable

import jax
from jax import core
from jax import lax
from jax import tree_util
from jax._src import api_util
from jax._src import core
from jax._src import linear_util as lu
from jax._src import pjit
from jax._src import prng
Expand Down Expand Up @@ -166,9 +166,14 @@ def _slice_signature(eqn, args_consumed):
limit_indices = eqn.params['limit_indices']
strides = eqn.params['strides'] or (1,) * len(start_indices)
idx = tuple(slice(*tup) for tup in util.safe_zip(start_indices, limit_indices, strides))
mask = np.zeros(in_aval.shape, dtype=bool)
mask[idx] = True
return KeyReuseSignature([Sink(0, mask)], [Source(0)])
if any(core.is_symbolic_dim(s) for s in in_aval.shape):
sink = True
else:
# TODO(jakevdp): should we avoid constructing the mask array if the input
# does not have a key dtype?
sink = np.zeros(in_aval.shape, dtype=bool)
sink[idx] = True
return KeyReuseSignature([Sink(0, sink)], [Source(0)])

key_reuse_signatures_dynamic[lax.slice_p] = _slice_signature

Expand Down
20 changes: 11 additions & 9 deletions tests/shape_poly_test.py
Expand Up @@ -106,7 +106,7 @@ def _stop_profile(tst: jtu.JaxTestCase):
p.sort_stats("cumtime").print_stats(.2)
p.print_callers(.2)

@jtu.with_config(jax_enable_key_reuse_checks=False)

class DimExprTest(jtu.JaxTestCase):

def setUp(self):
Expand Down Expand Up @@ -1074,7 +1074,6 @@ def check_shape_poly(tst, f_jax: Callable, *,
return h.run_test(tst)


@jtu.with_config(jax_enable_key_reuse_checks=False)
class ShapePolyTest(jtu.JaxTestCase):

def setUp(self):
Expand Down Expand Up @@ -1475,6 +1474,7 @@ def f(x):
polymorphic_shapes=["(b,)"])
self.assertAllClose(f(x), res_tf)

@jax.enable_key_reuse_checks(False)
def test_prng(self):
# The PRNG implementation uses opaque types, test shape polymorphism
with config.enable_custom_prng(True):
Expand Down Expand Up @@ -2908,7 +2908,6 @@ def _flatten_harnesses(harnesses):
return res


@jtu.with_config(jax_enable_key_reuse_checks=False)
class ShapePolyHarnessesTest(jtu.JaxTestCase):
"""This test runs for all _POLY_SHAPE_PRIMITIVE_HARNESSES."""

Expand Down Expand Up @@ -2987,16 +2986,19 @@ def test_harness(self, harness: PolyHarness):
if harness.group_name == "eig" and not jtu.test_device_matches(["cpu"]):
raise unittest.SkipTest("JAX implements eig only on CPU.")

prev_jax_config_flags = {
fname: getattr(jax.config, fname)
for fname, fvalue in harness.override_jax_config_flags.items()
}
config_flags = harness.override_jax_config_flags
# Update this here rather than in harness object because vmap_random_gamma is derived
# from test_harnesses.all_harnesses, which strips override_jax_config_flags.
if "random_gamma" in harness.group_name:
config_flags = {**config_flags, "jax_enable_key_reuse_checks": False}

prev_jax_config_flags = {fname: getattr(jax.config, fname) for fname in config_flags}
try:
for fname, fvalue in harness.override_jax_config_flags.items():
for fname, fvalue in config_flags.items():
jax.config.update(fname, fvalue)
harness.run_test(self)
finally:
for fname, _ in harness.override_jax_config_flags.items():
for fname, _ in config_flags.items():
jax.config.update(fname, prev_jax_config_flags[fname])

if __name__ == "__main__":
Expand Down

0 comments on commit 84ee045

Please sign in to comment.