diff --git a/jax/_src/config.py b/jax/_src/config.py index 6e1789cd288c..460f79a3e3c4 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -21,7 +21,7 @@ import os import sys import threading -from typing import Any, List, Callable, NamedTuple, Iterator, Optional +from typing import Any, List, Callable, Hashable, NamedTuple, Iterator, Optional import warnings from absl import logging @@ -374,7 +374,12 @@ def _trace_context(self): Values included in this set should also most likely be included in the C++ JIT state, which is handled separately.""" - return (self.x64_enabled, self.jax_numpy_rank_promotion, + tls = jax_jit.thread_local_state() + axis_env_state = () + context = tls.extra_jit_context + if context and context.axis_env_state is not None: + axis_env_state = context.axis_env_state + return (axis_env_state, self.x64_enabled, self.jax_numpy_rank_promotion, self.jax_default_matmul_precision, self.jax_dynamic_shapes, self.jax_numpy_dtype_promotion, self.jax_default_device) @@ -483,6 +488,7 @@ class _ThreadLocalExtraJitContext(NamedTuple): `_update_thread_local_jit_state` in core.py to prevent circular imports. """ dynamic_trace_state: Optional[Any] = None + axis_env_state: Optional[Hashable] = None numpy_rank_promotion: Optional[str] = None numpy_dtype_promotion: Optional[str] = None default_matmul_precision: Optional[Any] = None diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index ac27fe3f41c4..313ae2340d00 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -4695,6 +4695,7 @@ def _compress_method(a, condition, axis=None, out=None): return compress(condition, a, axis, out) +@core.stash_axis_env() @partial(jit, static_argnums=(1,2,3)) def _multi_slice(arr, start_indices: Tuple[Tuple[int, ...]], diff --git a/jax/core.py b/jax/core.py index 4744147e2b29..e077a7545b7b 100644 --- a/jax/core.py +++ b/jax/core.py @@ -2051,21 +2051,51 @@ def _unmap_dshaped_array( @contextmanager def extend_axis_env(axis_name: AxisName, size: int, tag: Any): frame = AxisEnvFrame(axis_name, size, tag) - thread_local_state.trace_state.axis_env.append(frame) + ts = thread_local_state.trace_state + ts.axis_env.append(frame) + jax_config.update_thread_local_jit_state( + axis_env_state=tuple(f for f in ts.axis_env + if f.name is not no_axis_name)) try: yield finally: - thread_local_state.trace_state.axis_env.pop() + ts.axis_env.pop() + jax_config.update_thread_local_jit_state( + axis_env_state=tuple(f for f in ts.axis_env + if f.name is not no_axis_name)) @contextmanager def extend_axis_env_nd(axes: Iterable[Tuple[AxisName, int]]): frames = [AxisEnvFrame(axis_name, size, None) for axis_name, size in axes] - thread_local_state.trace_state.axis_env.extend(frames) + ts = thread_local_state.trace_state + ts.axis_env.extend(frames) + jax_config.update_thread_local_jit_state( + axis_env_state=tuple(f for f in ts.axis_env + if f.name is not no_axis_name)) try: yield finally: - for _ in frames: - thread_local_state.trace_state.axis_env.pop() + for _ in frames: ts.axis_env.pop() + jax_config.update_thread_local_jit_state( + axis_env_state=tuple(f for f in ts.axis_env + if f.name is not no_axis_name)) + + +@contextmanager +def stash_axis_env(): + "Promise that a function or with-suite does not depend implicitly on axis env" + # If the promise is broken, then a NameError about an unbound axis name will + # be raised. + ts = thread_local_state.trace_state + prev_axis_env, ts.axis_env = ts.axis_env, [] + jax_config.update_thread_local_jit_state(axis_env_state=()) + try: + yield + finally: + ts.axis_env = prev_axis_env + jax_config.update_thread_local_jit_state( + axis_env_state=tuple(f for f in ts.axis_env + if f.name is not no_axis_name)) # When a mapped function is given no axis name, we generate a name object based @@ -2601,7 +2631,8 @@ def _compact_eqn_should_include(k: str, v: Any) -> bool: if k == 'branches': return False if isinstance(v, (Jaxpr, ClosedJaxpr)): return False if (isinstance(v, tuple) and - any(isinstance(e, (Jaxpr, ClosedJaxpr)) for e in v)): return False + any(isinstance(e, (Jaxpr, ClosedJaxpr)) for e in v)): + return False return True def str_eqn_compact(primitive_name: str, params: Dict) -> str: diff --git a/tests/api_test.py b/tests/api_test.py index c88907121b31..71c352f6b21f 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -1106,6 +1106,28 @@ def jit_impl_and_count(*args, **kwargs): finally: xla.xla_call_p.def_impl(jit_impl) + def test_caches_depend_on_axis_env(self): + # https://github.com/google/jax/issues/9187 + f = lambda: lax.psum(1, "i") + g = jax.jit(f) + expected = jax.vmap(f, axis_name="i", axis_size=2, out_axes=None)() + ans = jax.vmap(g, axis_name="i", axis_size=2, out_axes=None)() + self.assertEqual(ans, expected) + + # This second call to g could erroneously get a cache hit. + expected = jax.vmap(f, axis_name="i", axis_size=3, out_axes=None)() + ans = jax.vmap(g, axis_name="i", axis_size=3, out_axes=None)() + self.assertEqual(ans, expected) + + def test_caches_dont_depend_on_unnamed_axis_env(self): + # https://github.com/google/jax/issues/9187 + f = jax.jit(lambda: jnp.sin(1)) + expected = f() + with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841 + ans = jax.vmap(f, axis_size=2, out_axes=None)() + self.assertEqual(count[0], 0) # no compiles + self.assertArraysAllClose(ans, expected, check_dtypes=True) + class PythonJitTest(CPPJitTest): diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index e2adebcde632..13d0810d28ec 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -2974,6 +2974,15 @@ def test_for_jvp(self, jit_for, f, ref, body_shapes, n): self.assertAllClose(ans, expected, check_dtypes=True, rtol=tol, atol=tol) jtu.check_grads(partial(for_, n, f), (args,), order=3, modes=["fwd"]) + def test_caches_depend_on_axis_env(self): + # https://github.com/google/jax/issues/9187 + scanned_f = lambda _, __: (lax.psum(1, 'i'), None) + f = lambda: lax.scan(scanned_f, 0, None, length=1)[0] + ans = jax.vmap(f, axis_name='i', axis_size=2, out_axes=None)() + self.assertEqual(ans, 2) + ans = jax.vmap(f, axis_name='i', axis_size=3, out_axes=None)() + self.assertEqual(ans, 3) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pmap_test.py b/tests/pmap_test.py index 894e24a2c1f2..ffa2ebae3d73 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -1949,8 +1949,9 @@ def f(x): self.assertEqual(count[0], 2) # one for fwd, one for bwd with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841 - _ = jax.vjp(f, x) + _, f_bwd2 = jax.vjp(f, x) _ = f_bwd(x) + _ = f_bwd2(x) self.assertEqual(count[0], 0) # cache hits on fwd and bwd def testSizeOverflow(self):