Skip to content

Commit

Permalink
Remove the cached check in aot compiled call in MeshExecutable becaus…
Browse files Browse the repository at this point in the history
…e a fast C++ dispatch path exists. This leads to a better error message which contains the shape and arg value.

PiperOrigin-RevId: 494815311
  • Loading branch information
yashk2810 authored and jax authors committed Dec 12, 2022
1 parent 23001ae commit d491d9f
Showing 1 changed file with 18 additions and 18 deletions.
36 changes: 18 additions & 18 deletions jax/interpreters/pxla.py
Expand Up @@ -3597,15 +3597,12 @@ def aot_cache_miss(*args, **kwargs):
return outs, fastpath_data

if xc._version < 108:

def dummy():
pass

dummy.__name__ = self.unsafe_call.name
return xc._xla.pjit(dummy, aot_cache_miss, []) # type: ignore
return xc._xla.pjit(dummy, aot_cache_miss, []) # type: ignore
else:
return xc._xla.pjit( # type: ignore
self.unsafe_call.name, aot_cache_miss, [])
return xc._xla.pjit(self.unsafe_call.name, aot_cache_miss, []) # type: ignore


def _out_shardings_for_trivial(
Expand Down Expand Up @@ -3723,24 +3720,27 @@ def _check_gda_or_array_xla_sharding_match(args, in_xla_shardings):
from jax.experimental.global_device_array import GlobalDeviceArray
from jax._src.array import ArrayImpl

@lru_cache(maxsize=4096)
def _cached_check(arg_sharding, in_xla_sharding, arg_type, ndim, committed):
if committed and not are_op_shardings_equal(
arg_sharding._to_xla_op_sharding(ndim),
in_xla_sharding._to_xla_op_sharding(ndim)):
raise ValueError(
f"{arg_type} sharding does not match the input sharding. "
f"Got {arg_type} sharding: {arg_sharding} and "
f"xla sharding: {in_xla_sharding}")

for arg, xs in safe_zip(args, in_xla_shardings):
if not isinstance(arg, (GlobalDeviceArray, ArrayImpl)):
continue
if isinstance(arg, GlobalDeviceArray):
_cached_check(_create_mesh_pspec_sharding(arg.mesh, arg.mesh_axes), xs,
'GDA', arg.ndim, True)
arg_sharding = _create_mesh_pspec_sharding(arg.mesh, arg.mesh_axes)
arg_type = 'GDA'
committed = True
else:
_cached_check(arg.sharding, xs, 'Array', arg.ndim, arg._committed)
arg_sharding = arg.sharding
arg_type = 'Array'
committed = arg._committed

# No need to cache this check since MeshExecutable has a C++ fast path
# for AOT compiled call.
if committed and not are_op_shardings_equal(
arg_sharding._to_xla_op_sharding(arg.ndim),
xs._to_xla_op_sharding(arg.ndim)):
raise ValueError(
f"{arg_type} sharding does not match the input sharding. "
f"Got {arg_type} sharding: {arg_sharding} and xla sharding: {xs} for "
f"arg shape: {arg.shape}, arg value: {arg}")


def _get_array_mapping(pspec: PartitionSpec) -> ArrayMappingOrAutoOrUnspecified:
Expand Down

0 comments on commit d491d9f

Please sign in to comment.