diff --git a/jax/_src/api.py b/jax/_src/api.py index 8a5558c77e64..363641946a60 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -669,10 +669,19 @@ def _jit_lower(fun, static_argnums, static_argnames, device, backend, # all the other arguments stored as attributes. def arg_spec(x): + from jax.experimental.sharding import PmapSharding # like xla.arg_spec but duck-types on x.shape and x.dtype aval = None if jax.config.jax_dynamic_shapes else shaped_abstractify(x) - device = getattr(x, '_device', None) - return aval, device + if jax.config.jax_array: + if hasattr(x, 'sharding'): + if isinstance(x.sharding, PmapSharding): + return aval, None + return aval, (x.sharding if x._committed else None) + else: + return aval, None + else: + device = getattr(x, '_device', None) + return aval, device @api_boundary def lower(*args, **kwargs) -> stages.Lowered: @@ -699,11 +708,18 @@ def lower(*args, **kwargs) -> stages.Lowered: if abstracted_axes: raise ValueError("abstracted_axes must be used with --jax_dynamic_shapes") in_avals, _ = unzip2(arg_specs_and_devices) - computation = dispatch.lower_xla_callable( - flat_fun, device, backend, flat_fun.__name__, donated_invars, True, - keep_unused, *arg_specs_and_devices) - return stages.Lowered.from_flat_info( - computation, in_tree, in_avals, donate_argnums, out_tree()) + if jax.config.jax_array: + computation = dispatch.sharded_lowering( + flat_fun, device, backend, flat_fun.__name__, donated_invars, True, + keep_unused, *arg_specs_and_devices) + return stages.Lowered.from_flat_info( + computation, in_tree, in_avals, donate_argnums, out_tree()) + else: + computation = dispatch.lower_xla_callable( + flat_fun, device, backend, flat_fun.__name__, donated_invars, True, + keep_unused, *arg_specs_and_devices) + return stages.Lowered.from_flat_info( + computation, in_tree, in_avals, donate_argnums, out_tree()) return lower diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index e6244a3b2311..4bcbebd66539 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -309,8 +309,8 @@ def not_none_device_or_backend_on_jit(backend, device, num_ins): return committed, da, in_shardings -def sharded_lowering(fun, device, backend, name, donated_invars, keep_unused, - *arg_specs): +def sharded_lowering(fun, device, backend, name, donated_invars, always_lower, + keep_unused, *arg_specs): # TODO(yashkatariya): Remove the local imports from here when the functions # in pxla.py move to dispatch.py or a utils file. from jax.interpreters import pxla @@ -353,15 +353,16 @@ def sharded_lowering(fun, device, backend, name, donated_invars, keep_unused, fun, 'jit', name, in_shardings, pjit._UNSPECIFIED, donated_invars, in_avals, in_is_global=(True,) * len(arg_specs), keep_unused=keep_unused, - committed=committed, inp_device_assignment=inp_device_assignment).compile( - _allow_propagation_to_outputs=True).unsafe_call + committed=committed, always_lower=always_lower, + inp_device_assignment=inp_device_assignment) def _xla_callable_uncached(fun: lu.WrappedFun, device, backend, name, donated_invars, keep_unused, *arg_specs): if config.jax_array: - return sharded_lowering(fun, device, backend, name, - donated_invars, keep_unused, *arg_specs) + computation = sharded_lowering(fun, device, backend, name, donated_invars, + False, keep_unused, *arg_specs) + return computation.compile(_allow_propagation_to_outputs=True).unsafe_call else: return lower_xla_callable(fun, device, backend, name, donated_invars, False, keep_unused, *arg_specs).compile().unsafe_call diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index c29ff5c71743..4021a5492e43 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -632,7 +632,10 @@ def _lower_native_and_run(fun_jax: Callable, logging.vlog(2, f"XlaCallModule {mhlo_module_text}") # Figure out the result types and shapes - out_avals = lowered.compile_args["out_avals"] + if config.jax_array: + out_avals = lowered.compile_args["global_out_avals"] + else: + out_avals = lowered.compile_args["out_avals"] # TODO(necula): handle d being InDBIdx out_shapes = tuple( tuple(d if type(d) is int else None diff --git a/jax/experimental/pjit.py b/jax/experimental/pjit.py index e180ca98ddc5..34a7a6a3605a 100644 --- a/jax/experimental/pjit.py +++ b/jax/experimental/pjit.py @@ -1031,7 +1031,7 @@ def _pjit_lower_cached( return pxla.lower_sharding_computation( fun, 'pjit', name, in_shardings, out_shardings, donated_invars, jaxpr.in_avals, in_is_global=in_is_global, keep_unused=True, - committed=True) + committed=True, always_lower=False) def _pjit_abstract_eval(*args, jaxpr, out_shardings, resource_env, diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 9939d6c15070..6176d860f856 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -2672,6 +2672,7 @@ def lower_sharding_computation( in_is_global: Sequence[bool], keep_unused: bool, committed: bool, + always_lower: bool, inp_device_assignment: Optional[Sequence[xc.Device]] = None): """Lowers a computation to XLA. It can take arbitrary shardings as input. @@ -2742,7 +2743,7 @@ def lower_sharding_computation( # Computations that only produce constants and/or only rearrange their inputs, # which are often produced from partial evaluation, don't need compilation, # and don't need to evaluate their arguments. - if (not (jaxpr.effects or has_outfeed) and + if (not always_lower and not (jaxpr.effects or has_outfeed) and (not jaxpr.eqns and all(kept_outputs) or not jaxpr.outvars) and all(_is_unspecified(o) for o in out_shardings) and # type: ignore not hasattr(backend, "compile_replicated")): # this means 'not pathways' @@ -3135,19 +3136,21 @@ def _get_mesh_pspec_shardings_from_executable(xla_executable, mesh): class MeshExecutable(stages.XlaExecutable): - __slots__ = ['xla_executable', 'unsafe_call', '_input_avals', - '_in_shardings', '_out_shardings', '_auto_spmd_lowering'] + __slots__ = ['xla_executable', 'unsafe_call', 'in_avals', + '_in_shardings', '_out_shardings', '_auto_spmd_lowering', + '_kept_var_idx'] - def __init__(self, xla_executable, unsafe_call, input_avals, - in_shardings, out_shardings, auto_spmd_lowering): + def __init__(self, xla_executable, unsafe_call, in_avals, + in_shardings, out_shardings, auto_spmd_lowering, kept_var_idx): self.xla_executable = xla_executable self.unsafe_call = unsafe_call - # input_avals is a list of global and local avals. Aval is global if input - # is a GDA else local. - self._input_avals = input_avals + # in_avals is a list of global and local avals. Aval is global if input + # is a GDA or jax.Array else local. + self.in_avals = in_avals self._in_shardings = in_shardings self._out_shardings = out_shardings self._auto_spmd_lowering = auto_spmd_lowering + self._kept_var_idx = kept_var_idx @staticmethod def from_hlo(name: str, @@ -3281,7 +3284,8 @@ def from_hlo(name: str, bool(host_callbacks), kept_var_idx) return MeshExecutable(xla_executable, unsafe_call, input_avals, - in_shardings, out_shardings, auto_spmd_lowering) + in_shardings, out_shardings, auto_spmd_lowering, + kept_var_idx) @staticmethod def from_trivial_jaxpr(jaxpr, consts, global_in_avals, global_out_avals, @@ -3307,7 +3311,7 @@ def from_trivial_jaxpr(jaxpr, consts, global_in_avals, global_out_avals, unsafe_call = partial(_execute_trivial, jaxpr, consts, handle_ins, handle_outs, kept_var_idx) return MeshExecutable(None, unsafe_call, global_in_avals, in_shardings, - out_shardings, False) + out_shardings, False, kept_var_idx) # -- stages.XlaExecutable overrides @@ -3315,11 +3319,12 @@ def xla_extension_executable(self): return self.xla_executable def call(self, *args): - arg_avals = map(xla.abstractify, args) - ref_avals = self._input_avals + kept_args = [a for i, a in enumerate(args) if i in self._kept_var_idx] + arg_avals = map(xla.abstractify, kept_args) + ref_avals = self.in_avals dispatch.check_arg_avals_for_call(ref_avals, arg_avals) # Check the GDA sharding and the input sharding. - _check_gda_or_array_xla_sharding_match(args, self._in_shardings) + _check_gda_or_array_xla_sharding_match(kept_args, self._in_shardings) return self.unsafe_call(*args) diff --git a/tests/api_test.py b/tests/api_test.py index 9413e01a1411..74e3402670d9 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -55,6 +55,7 @@ from jax.interpreters import partial_eval as pe from jax.interpreters.pxla import PartitionSpec as P from jax.experimental import array, sharding +from jax.experimental import pjit from jax._src import config as jax_config from jax._src import custom_derivatives from jax._src import device_array @@ -973,6 +974,15 @@ def f(x): return x out = self.jit(f).lower(1.).compile()(4.) self.assertAllClose(out, 4.) + def test_jit_lower_compile_sharding_computation(self): + if not config.jax_array: + self.skipTest('with_sharding_constraint only works with the Array path ' + 'in jit.') + s = sharding.SingleDeviceSharding(jax.devices()[0]) + def f(x): return pjit.with_sharding_constraint(x, s) + out = self.jit(f).lower(1.).compile()(4.) + self.assertAllClose(out, 4.) + def test_jit_lower_compile_trivial_in_tree_mismatch(self): def f(x): return x f_exe = self.jit(f).lower(1.).compile()