diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 96624feb1e61..c35ad24bc7cf 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1875,6 +1875,14 @@ def are_all_shardings_default_mem_kind(da_object, shardings): MaybeLayout = Sequence[Optional[Union[XLACompatibleLayout, LayoutRequest]]] + +class AllArgsInfo(NamedTuple): + """Avals, shardings, layouts and debug_info for all arguments prior to DCE.""" + in_avals: Sequence[core.ShapedArray] + in_shardings: Any + debug_info: core.JaxprDebugInfo | None + + @profiler.annotate_function def lower_sharding_computation( closed_jaxpr: core.ClosedJaxpr, @@ -1904,6 +1912,9 @@ def lower_sharding_computation( check_if_any_auto(in_shardings) if is_unspecified(out_shardings) else check_if_any_auto(it.chain.from_iterable([in_shardings, out_shardings]))) # type: ignore + all_args_info = AllArgsInfo(global_in_avals, in_shardings, + closed_jaxpr.jaxpr.debug_info) + (closed_jaxpr, global_in_avals, global_out_avals, donated_invars, kept_var_idx, name_stack) = _dce_jaxpr( closed_jaxpr, global_in_avals, api_name, fun_name, keep_unused, @@ -2004,7 +2015,8 @@ def lower_sharding_computation( pmap_nreps=nreps, jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info, shape_poly_state=shape_poly_state, - all_default_mem_kind=all_default_mem_kind) + all_default_mem_kind=all_default_mem_kind, + all_args_info=all_args_info) def _to_logical_sharding( @@ -2090,6 +2102,8 @@ def lower_mesh_computation( out_jaxpr_avals = fun_or_jaxpr.out_avals consts = fun_or_jaxpr.consts + all_args_info = AllArgsInfo(global_in_avals, in_shardings, jaxpr.debug_info) + assert len(out_shardings) == len(out_jaxpr_avals) if spmd_lowering: global_out_avals = out_jaxpr_avals @@ -2179,7 +2193,8 @@ def lower_mesh_computation( in_layouts=(None,) * len(global_in_avals), out_layouts=(None,) * len(global_out_avals), jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info, - shape_poly_state=lowering_result.shape_poly_state) + shape_poly_state=lowering_result.shape_poly_state, + all_args_info=all_args_info) class MeshComputation(stages.XlaLowering): _hlo: ir.Module | None @@ -2568,6 +2583,7 @@ class UnloadedMeshExecutable: jaxpr_debug_info: core.JaxprDebugInfo | None in_layouts: Sequence[SpecifiedLayout | None] out_layouts: Sequence[SpecifiedLayout | None] + all_args_info: AllArgsInfo | None def build_unsafe_call(self): input_indices = _get_input_indices(self.input_avals, self.input_shardings, @@ -2590,7 +2606,7 @@ def load(self) -> MeshExecutable: self.input_shardings, self.output_shardings, self.auto_spmd_lowering, self.kept_var_idx, self.in_layouts, self.out_layouts, - self.jaxpr_debug_info, self) + self.jaxpr_debug_info, self.all_args_info, self) # May return a MeshExecutable in the compile_replicated case. @staticmethod @@ -2618,6 +2634,7 @@ def from_hlo(name: str, jaxpr_debug_info: core.JaxprDebugInfo | None = None, shape_poly_state: mlir.ShapePolyLoweringState | None = None, all_default_mem_kind: bool = True, + all_args_info: AllArgsInfo | None = None, compiler_options=None, ) -> MeshExecutable: if shape_poly_state is not None and shape_poly_state.uses_dim_vars: @@ -2710,7 +2727,8 @@ def from_hlo(name: str, auto_spmd_lowering=auto_spmd_lowering, jaxpr_debug_info=jaxpr_debug_info, in_layouts=in_layouts, # type: ignore - out_layouts=out_layouts).load() # type: ignore + out_layouts=out_layouts, # type: ignore + all_args_info=all_args_info).load() # type: ignore class MeshExecutableFastpathData(NamedTuple): @@ -2735,12 +2753,14 @@ class MeshExecutable(stages.XlaExecutable): __slots__ = [ "xla_executable", "_unsafe_call", "build_unsafe_call", "in_avals", "_in_shardings", "_out_shardings", "_auto_spmd_lowering", "_kept_var_idx", - "_in_layouts", "_out_layouts", "_jaxpr_debug_info", "_unloaded_executable", + "_in_layouts", "_out_layouts", "_jaxpr_debug_info", + "_all_args_info", "_unloaded_executable", ] def __init__(self, xla_executable, build_unsafe_call, in_avals, in_shardings, out_shardings, auto_spmd_lowering, kept_var_idx, in_layouts, out_layouts, jaxpr_debug_info=None, + all_args_info: AllArgsInfo | None = None, unloaded_executable=None): self.xla_executable = xla_executable self.build_unsafe_call = build_unsafe_call @@ -2755,13 +2775,14 @@ def __init__(self, xla_executable, build_unsafe_call, in_avals, in_shardings, self._in_layouts = in_layouts self._out_layouts = out_layouts self._jaxpr_debug_info = jaxpr_debug_info + self._all_args_info = all_args_info self._unloaded_executable = unloaded_executable @property def unsafe_call(self) -> Callable[..., Any]: if self._unsafe_call is None: self._unsafe_call = self.build_unsafe_call() - return self._unsafe_call + return self._unsafe_call # type: ignore # -- stages.XlaExecutable overrides @@ -2769,13 +2790,23 @@ def xla_extension_executable(self): return self.xla_executable def call(self, *args): - kept_args = [a for i, a in enumerate(args) if i in self._kept_var_idx] + if self._all_args_info is None: + kept_args = [a for i, a in enumerate(args) if i in self._kept_var_idx] + ref_avals = self.in_avals + in_shardings = self._in_shardings + debug_info = self._jaxpr_debug_info + else: + kept_args = args + ref_avals = self._all_args_info.in_avals + iter_in_shardings = iter(self._in_shardings) + in_shardings = [next(iter_in_shardings) if i in self._kept_var_idx else s + for i, s in enumerate(self._all_args_info.in_shardings)] + debug_info = self._all_args_info.debug_info + arg_avals = map(xla.abstractify, kept_args) - ref_avals = self.in_avals - check_arg_avals_for_call(ref_avals, arg_avals, self._jaxpr_debug_info) + check_arg_avals_for_call(ref_avals, arg_avals, debug_info) # Check the GDA sharding and the input sharding. - check_gda_or_array_xla_sharding_match(kept_args, self._in_shardings, - self._jaxpr_debug_info) + check_gda_or_array_xla_sharding_match(kept_args, in_shardings, debug_info) return self.unsafe_call(*args) # pylint: disable=not-callable def input_shardings(self) -> Sequence[sharding_impls.XLACompatibleSharding]: @@ -2922,7 +2953,8 @@ def _compile_replicated_mesh_executable_from_hlo( return MeshExecutable(xla_executable, lambda: unsafe_call, global_in_avals, in_shardings, out_shardings, auto_spmd_lowering, kept_var_idx, (None,) * len(global_in_avals), - (None,) * len(global_out_avals), jaxpr_debug_info, None) + (None,) * len(global_out_avals), jaxpr_debug_info, + None, None) @lru_cache @@ -2956,6 +2988,8 @@ def check_gda_or_array_xla_sharding_match( for arg, xs, name in safe_zip(args, in_xla_shardings, arg_names): if not isinstance(arg, ArrayImpl): continue + if is_unspecified_or_auto(xs): + continue db_xs = check_device_backend_on_shardings([xs]) if not db_xs: diff --git a/jax/experimental/export/export.py b/jax/experimental/export/export.py index 1c7ee4ab5479..a500a9a43ea7 100644 --- a/jax/experimental/export/export.py +++ b/jax/experimental/export/export.py @@ -764,7 +764,7 @@ def _check_lowering(lowering) -> None: "tuple_args", "ordered_effects", "unordered_effects", "keepalive", "host_callbacks", "pmap_nreps", "committed", "device_assignment", "jaxpr_debug_info", "shape_poly_state", - "all_default_mem_kind", "in_layouts", "out_layouts"] + "all_default_mem_kind", "in_layouts", "out_layouts", "all_args_info"] for compile_arg in lowering.compile_args.keys(): if compile_arg not in allowed_compile_args: raise NotImplementedError(f"Unrecognized lowered.compile_args[{compile_arg}]") diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 40d0bbd2a8be..602eb6a5f784 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4018,6 +4018,52 @@ def test_pjit_with_deleted_input_at_subsequent_call(self, committed): x.delete() _ = f(x) + def test_aot_error_on_dced_avals_mismatch(self): + x, y1, y2 = jnp.ones(4), jnp.ones(4), jnp.ones(1) + + @jax.jit + def f(x, y): + return x + 1 if y.shape[0] > 2 else x + 2 + + f_out1 = f(x, y1) + f(x, y2) + + g = f.lower(x, y1).compile() + g_out1 = g(x, y1) + self.assertArraysEqual(f_out1, g_out1) + + with self.assertRaisesRegex( + TypeError, + 'Argument types differ from the types for which this computation was' + ' compiled'): + g(x, y2) + + def test_aot_error_on_dced_shardings_mismatch(self): + mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + shape = (8, 2) + np_inp = np.arange(math.prod(shape)).reshape(shape) + + x = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) + y1 = jax.device_put(np_inp, NamedSharding(mesh, P('x'))) + y2 = jax.device_put(np_inp, NamedSharding(mesh, P('y'))) + + @jax.jit + def f(x, y): + return x + 1 + + f_out1 = f(x, y1) + f(x, y2) + + g = f.lower(x, y1).compile() + g_out1 = g(x, y1) + self.assertArraysEqual(f_out1, g_out1) + + with self.assertRaisesRegex( + ValueError, + r"Compiled object called with input sharding.*does not match the " + r"sharding.*the computation was compiled with"): + g(x, y2) + @jtu.pytest_mark_if_available('multiaccelerator') class UtilTest(jtu.JaxTestCase):