diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index d4e5495e6359..167c19651c8d 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -3597,15 +3597,12 @@ def aot_cache_miss(*args, **kwargs): return outs, fastpath_data if xc._version < 108: - def dummy(): pass - dummy.__name__ = self.unsafe_call.name - return xc._xla.pjit(dummy, aot_cache_miss, []) # type: ignore + return xc._xla.pjit(dummy, aot_cache_miss, []) # type: ignore else: - return xc._xla.pjit( # type: ignore - self.unsafe_call.name, aot_cache_miss, []) + return xc._xla.pjit(self.unsafe_call.name, aot_cache_miss, []) # type: ignore def _out_shardings_for_trivial( @@ -3723,24 +3720,27 @@ def _check_gda_or_array_xla_sharding_match(args, in_xla_shardings): from jax.experimental.global_device_array import GlobalDeviceArray from jax._src.array import ArrayImpl - @lru_cache(maxsize=4096) - def _cached_check(arg_sharding, in_xla_sharding, arg_type, ndim, committed): - if committed and not are_op_shardings_equal( - arg_sharding._to_xla_op_sharding(ndim), - in_xla_sharding._to_xla_op_sharding(ndim)): - raise ValueError( - f"{arg_type} sharding does not match the input sharding. " - f"Got {arg_type} sharding: {arg_sharding} and " - f"xla sharding: {in_xla_sharding}") - for arg, xs in safe_zip(args, in_xla_shardings): if not isinstance(arg, (GlobalDeviceArray, ArrayImpl)): continue if isinstance(arg, GlobalDeviceArray): - _cached_check(_create_mesh_pspec_sharding(arg.mesh, arg.mesh_axes), xs, - 'GDA', arg.ndim, True) + arg_sharding = _create_mesh_pspec_sharding(arg.mesh, arg.mesh_axes) + arg_type = 'GDA' + committed = True else: - _cached_check(arg.sharding, xs, 'Array', arg.ndim, arg._committed) + arg_sharding = arg.sharding + arg_type = 'Array' + committed = arg._committed + + # No need to cache this check since MeshExecutable has a C++ fast path + # for AOT compiled call. + if committed and not are_op_shardings_equal( + arg_sharding._to_xla_op_sharding(arg.ndim), + xs._to_xla_op_sharding(arg.ndim)): + raise ValueError( + f"{arg_type} sharding does not match the input sharding. " + f"Got {arg_type} sharding: {arg_sharding} and xla sharding: {xs} for " + f"arg shape: {arg.shape}, arg value: {arg}") def _get_array_mapping(pspec: PartitionSpec) -> ArrayMappingOrAutoOrUnspecified: