Skip to content

Commit

Permalink
Jax caches should depend on axis env.
Browse files Browse the repository at this point in the history
  • Loading branch information
pschuh committed Jun 29, 2022
1 parent 90af8e8 commit 6c5d204
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 9 deletions.
10 changes: 8 additions & 2 deletions jax/_src/config.py
Expand Up @@ -21,7 +21,7 @@
import os
import sys
import threading
from typing import Any, List, Callable, NamedTuple, Iterator, Optional
from typing import Any, List, Callable, Hashable, NamedTuple, Iterator, Optional
import warnings

from absl import logging
Expand Down Expand Up @@ -374,7 +374,12 @@ def _trace_context(self):
Values included in this set should also most likely be included in
the C++ JIT state, which is handled separately."""
return (self.x64_enabled, self.jax_numpy_rank_promotion,
tls = jax_jit.thread_local_state()
axis_env_state = ()
context = tls.extra_jit_context
if context and context.axis_env_state is not None:
axis_env_state = context.axis_env_state
return (axis_env_state, self.x64_enabled, self.jax_numpy_rank_promotion,
self.jax_default_matmul_precision, self.jax_dynamic_shapes,
self.jax_numpy_dtype_promotion, self.jax_default_device)

Expand Down Expand Up @@ -483,6 +488,7 @@ class _ThreadLocalExtraJitContext(NamedTuple):
`_update_thread_local_jit_state` in core.py to prevent circular imports.
"""
dynamic_trace_state: Optional[Any] = None
axis_env_state: Optional[Hashable] = None
numpy_rank_promotion: Optional[str] = None
numpy_dtype_promotion: Optional[str] = None
default_matmul_precision: Optional[Any] = None
Expand Down
1 change: 1 addition & 0 deletions jax/_src/numpy/lax_numpy.py
Expand Up @@ -4695,6 +4695,7 @@ def _compress_method(a, condition, axis=None, out=None):
return compress(condition, a, axis, out)


@core.stash_axis_env()
@partial(jit, static_argnums=(1,2,3))
def _multi_slice(arr,
start_indices: Tuple[Tuple[int, ...]],
Expand Down
43 changes: 37 additions & 6 deletions jax/core.py
Expand Up @@ -2051,21 +2051,51 @@ def _unmap_dshaped_array(
@contextmanager
def extend_axis_env(axis_name: AxisName, size: int, tag: Any):
frame = AxisEnvFrame(axis_name, size, tag)
thread_local_state.trace_state.axis_env.append(frame)
ts = thread_local_state.trace_state
ts.axis_env.append(frame)
jax_config.update_thread_local_jit_state(
axis_env_state=tuple(f for f in ts.axis_env
if f.name is not no_axis_name))
try:
yield
finally:
thread_local_state.trace_state.axis_env.pop()
ts.axis_env.pop()
jax_config.update_thread_local_jit_state(
axis_env_state=tuple(f for f in ts.axis_env
if f.name is not no_axis_name))

@contextmanager
def extend_axis_env_nd(axes: Iterable[Tuple[AxisName, int]]):
frames = [AxisEnvFrame(axis_name, size, None) for axis_name, size in axes]
thread_local_state.trace_state.axis_env.extend(frames)
ts = thread_local_state.trace_state
ts.axis_env.extend(frames)
jax_config.update_thread_local_jit_state(
axis_env_state=tuple(f for f in ts.axis_env
if f.name is not no_axis_name))
try:
yield
finally:
for _ in frames:
thread_local_state.trace_state.axis_env.pop()
for _ in frames: ts.axis_env.pop()
jax_config.update_thread_local_jit_state(
axis_env_state=tuple(f for f in ts.axis_env
if f.name is not no_axis_name))


@contextmanager
def stash_axis_env():
"Promise that a function or with-suite does not depend implicitly on axis env"
# If the promise is broken, then a NameError about an unbound axis name will
# be raised.
ts = thread_local_state.trace_state
prev_axis_env, ts.axis_env = ts.axis_env, []
jax_config.update_thread_local_jit_state(axis_env_state=())
try:
yield
finally:
ts.axis_env = prev_axis_env
jax_config.update_thread_local_jit_state(
axis_env_state=tuple(f for f in ts.axis_env
if f.name is not no_axis_name))


# When a mapped function is given no axis name, we generate a name object based
Expand Down Expand Up @@ -2601,7 +2631,8 @@ def _compact_eqn_should_include(k: str, v: Any) -> bool:
if k == 'branches': return False
if isinstance(v, (Jaxpr, ClosedJaxpr)): return False
if (isinstance(v, tuple) and
any(isinstance(e, (Jaxpr, ClosedJaxpr)) for e in v)): return False
any(isinstance(e, (Jaxpr, ClosedJaxpr)) for e in v)):
return False
return True

def str_eqn_compact(primitive_name: str, params: Dict) -> str:
Expand Down
22 changes: 22 additions & 0 deletions tests/api_test.py
Expand Up @@ -1106,6 +1106,28 @@ def jit_impl_and_count(*args, **kwargs):
finally:
xla.xla_call_p.def_impl(jit_impl)

def test_caches_depend_on_axis_env(self):
# https://github.com/google/jax/issues/9187
f = lambda: lax.psum(1, "i")
g = jax.jit(f)
expected = jax.vmap(f, axis_name="i", axis_size=2, out_axes=None)()
ans = jax.vmap(g, axis_name="i", axis_size=2, out_axes=None)()
self.assertEqual(ans, expected)

# This second call to g could erroneously get a cache hit.
expected = jax.vmap(f, axis_name="i", axis_size=3, out_axes=None)()
ans = jax.vmap(g, axis_name="i", axis_size=3, out_axes=None)()
self.assertEqual(ans, expected)

def test_caches_dont_depend_on_unnamed_axis_env(self):
# https://github.com/google/jax/issues/9187
f = jax.jit(lambda: jnp.sin(1))
expected = f()
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
ans = jax.vmap(f, axis_size=2, out_axes=None)()
self.assertEqual(count[0], 0) # no compiles
self.assertArraysAllClose(ans, expected, check_dtypes=True)


class PythonJitTest(CPPJitTest):

Expand Down
9 changes: 9 additions & 0 deletions tests/lax_control_flow_test.py
Expand Up @@ -2974,6 +2974,15 @@ def test_for_jvp(self, jit_for, f, ref, body_shapes, n):
self.assertAllClose(ans, expected, check_dtypes=True, rtol=tol, atol=tol)
jtu.check_grads(partial(for_, n, f), (args,), order=3, modes=["fwd"])

def test_caches_depend_on_axis_env(self):
# https://github.com/google/jax/issues/9187
scanned_f = lambda _, __: (lax.psum(1, 'i'), None)
f = lambda: lax.scan(scanned_f, 0, None, length=1)[0]
ans = jax.vmap(f, axis_name='i', axis_size=2, out_axes=None)()
self.assertEqual(ans, 2)
ans = jax.vmap(f, axis_name='i', axis_size=3, out_axes=None)()
self.assertEqual(ans, 3)


if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())
3 changes: 2 additions & 1 deletion tests/pmap_test.py
Expand Up @@ -1949,8 +1949,9 @@ def f(x):
self.assertEqual(count[0], 2) # one for fwd, one for bwd

with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
_ = jax.vjp(f, x)
_, f_bwd2 = jax.vjp(f, x)
_ = f_bwd(x)
_ = f_bwd2(x)
self.assertEqual(count[0], 0) # cache hits on fwd and bwd

def testSizeOverflow(self):
Expand Down

0 comments on commit 6c5d204

Please sign in to comment.