Skip to content

Commit

Permalink
Fix the test_sharding_on_output_with_vmap failure in Pathways which…
Browse files Browse the repository at this point in the history
… was getting a cache miss in pjit_call_impl.

There was an inconsistency between how the global cache was used at the top level and in pjit_call_impl so standardize it via a helper function.

In the test, check for re-compilation which is what that test was doing before cl/535630905

PiperOrigin-RevId: 536575987
  • Loading branch information
yashk2810 authored and jax authors committed May 31, 2023
1 parent 3ad756f commit f884b4d
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
17 changes: 11 additions & 6 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,13 @@ def _cpp_pjit_evict_fn(self):
_cpp_pjit_cache = xc._xla.PjitFunctionCache(capacity=8192)


def _get_cpp_global_cache(pjit_has_explicit_sharding):
if pjit_has_explicit_sharding:
return xc._xla.PjitFunctionCache()
else:
return _cpp_pjit_cache


def _cpp_pjit(fun: Callable, infer_params_fn, static_argnums, static_argnames,
donate_argnums, pjit_has_explicit_sharding):

Expand All @@ -245,14 +252,10 @@ def cache_miss(*args, **kwargs):
fastpath_data = _get_fastpath_data(executable, out_tree, args_flat, out_flat)
return outs, fastpath_data

if pjit_has_explicit_sharding:
global_cache = xc._xla.PjitFunctionCache()
else:
global_cache = _cpp_pjit_cache
cpp_pjit_f = xc._xla.pjit( # type: ignore
getattr(fun, "__name__", "<unnamed function>"), # type: ignore
fun, cache_miss, static_argnums, static_argnames, # type: ignore
donate_argnums, global_cache) # type: ignore
donate_argnums, _get_cpp_global_cache(pjit_has_explicit_sharding)) # type: ignore

cpp_pjitted_f = wraps(fun)(cpp_pjit_f)
cpp_pjitted_f._fun = fun
Expand Down Expand Up @@ -1185,8 +1188,10 @@ def call_impl_cache_miss(*args_, **kwargs_):
tuple(getattr(o, '_original_sharding', o) for o in out_shardings),
resource_env, donated_invars, name, keep_unused, inline)
donated_argnums = [i for i, d in enumerate(donated_invars) if d]
has_explicit_sharding = _pjit_explicit_sharding(
in_shardings, out_shardings, None, None)
return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
_cpp_pjit_cache)(*args)
_get_cpp_global_cache(has_explicit_sharding))(*args)

pjit_p.def_impl(_pjit_call_impl)

Expand Down
2 changes: 1 addition & 1 deletion tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3234,7 +3234,7 @@ def test_sharding_on_output_with_vmap(self):
arr = jax.device_put(
np.arange(16).reshape(8, 2), NamedSharding(mesh, P(None, 'x')))

with jtu.count_pjit_cpp_cache_miss() as count:
with jtu.count_jit_and_pmap_compiles() as count:
vf = jax.vmap(pjit(lambda x: x * 2, in_shardings=ns))
out = vf(arr)
self.assertIsInstance(out.sharding, NamedSharding)
Expand Down

0 comments on commit f884b4d

Please sign in to comment.