From 14f1a345a147da244db46f2165709736eed27ad0 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 8 Sep 2022 03:59:30 -0700 Subject: [PATCH] roll back breakage PiperOrigin-RevId: 472949225 --- jax/_src/ad_checkpoint.py | 2 - jax/_src/custom_derivatives.py | 5 -- jax/_src/dispatch.py | 36 ++++---- jax/_src/lax/control_flow/common.py | 3 - jax/_src/lax/lax.py | 16 ++-- jax/_src/prng.py | 4 +- jax/core.py | 2 - jax/experimental/array.py | 6 +- jax/interpreters/partial_eval.py | 22 ----- jax/interpreters/pxla.py | 134 ++++------------------------ tests/api_test.py | 5 +- tests/array_test.py | 5 +- tests/debug_nans_test.py | 4 +- tests/multi_device_test.py | 7 +- 14 files changed, 53 insertions(+), 198 deletions(-) diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index 54cf7246b59c..9ce4b8a2249b 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -439,8 +439,6 @@ def remat_jvp(primals, tangents, jaxpr, prevent_cse, differentiated, policy): ad.primitive_jvps[remat_p] = remat_jvp remat_allowed_effects: Set[core.Effect] = set() -remat_allowed_effects.add(lax.lax.InOutFeedEffect.Infeed) -remat_allowed_effects.add(lax.lax.InOutFeedEffect.Outfeed) def remat_partial_eval(trace, *tracers, jaxpr, **params): assert not jaxpr.constvars diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 59d84c23698e..1bebc4675fbe 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -26,7 +26,6 @@ register_pytree_node_class, tree_leaves) from jax._src import custom_api_util from jax._src import dtypes -from jax._src.lax import lax from jax._src.util import cache, safe_zip, safe_map, split_list, Unhashable from jax._src.api_util import flatten_fun_nokwargs, argnums_partial from jax.core import raise_to_shaped @@ -340,10 +339,6 @@ def _apply_todos(todos, outs): allowed_effects: Set[core.Effect] = set() -allowed_effects.add(lax.InOutFeedEffect.Infeed) -allowed_effects.add(lax.InOutFeedEffect.Outfeed) - - custom_jvp_call_p = CustomJVPCallPrimitive('custom_jvp_call') def _custom_jvp_call_typecheck(*in_avals, call_jaxpr, jvp_jaxpr_thunk, num_consts): diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 972fcbf3069b..1ceb03dd4027 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -316,11 +316,6 @@ def sharded_lowering(fun, device, backend, name, donated_invars, keep_unused, "programming model, please read " "https://jax.readthedocs.io/en/latest/multi_process.html.") - if not in_shardings: - inp_device_assignment = da - else: - inp_device_assignment = None - # Pass in a singleton `_UNSPECIFIED` for out_shardings because we don't know # the number of output avals at this stage. lower_sharding_computation will # apply it to all out_avals. @@ -328,13 +323,15 @@ def sharded_lowering(fun, device, backend, name, donated_invars, keep_unused, fun, 'jit', name, in_shardings, pjit._UNSPECIFIED, donated_invars, in_avals, in_is_global=(True,) * len(arg_specs), keep_unused=keep_unused, - committed=committed, inp_device_assignment=inp_device_assignment).compile( + committed=committed).compile( _allow_propagation_to_outputs=True).unsafe_call def _xla_callable_uncached(fun: lu.WrappedFun, device, backend, name, donated_invars, keep_unused, *arg_specs): - if config.jax_array: + # TODO(yashkatariya): Remove the `and arg_specs` from here once + # lower_sharding_computation supports no avals as input. + if config.jax_array and arg_specs: return sharded_lowering(fun, device, backend, name, donated_invars, keep_unused, *arg_specs) else: @@ -344,13 +341,6 @@ def _xla_callable_uncached(fun: lu.WrappedFun, device, backend, name, xla_callable = lu.cache(_xla_callable_uncached) -def is_single_device_sharding(sharding) -> bool: - from jax.experimental.sharding import PmapSharding - # Special case PmapSharding here because PmapSharding maps away an axis - # and needs to be handled separately. - return len(sharding.device_set) == 1 and not isinstance(sharding, PmapSharding) - - @contextlib.contextmanager def log_elapsed_time(fmt: str): if _on_exit: @@ -541,11 +531,19 @@ def jaxpr_has_bints(jaxpr: core.Jaxpr) -> bool: def _prune_unused_inputs( jaxpr: core.Jaxpr) -> Tuple[core.Jaxpr, Set[int], Set[int]]: - used_outputs = [True] * len(jaxpr.outvars) - new_jaxpr, used_consts, used_inputs = pe.dce_jaxpr_consts(jaxpr, used_outputs) - kept_const_idx = {i for i, b in enumerate(used_consts) if b} - kept_var_idx = {i for i, b in enumerate(used_inputs) if b} - return new_jaxpr, kept_const_idx, kept_var_idx + used = {v for v in jaxpr.outvars if isinstance(v, core.Var)} + # TODO(zhangqiaorjc): Improve the DCE algorithm by also pruning primitive + # applications that do not produce used outputs. Must handle side-effecting + # primitives and nested jaxpr. + used.update( + v for eqn in jaxpr.eqns for v in eqn.invars if isinstance(v, core.Var)) + kept_const_idx, new_constvars = util.unzip2( + (i, v) for i, v in enumerate(jaxpr.constvars) if v in used) + kept_var_idx, new_invars = util.unzip2( + (i, v) for i, v in enumerate(jaxpr.invars) if v in used) + new_jaxpr = core.Jaxpr(new_constvars, new_invars, jaxpr.outvars, jaxpr.eqns, + jaxpr.effects) + return new_jaxpr, set(kept_const_idx), set(kept_var_idx) # We can optionally set a Jaxpr rewriter that can be applied just before diff --git a/jax/_src/lax/control_flow/common.py b/jax/_src/lax/control_flow/common.py index 05b6d705b9d9..2a482bca94b7 100644 --- a/jax/_src/lax/control_flow/common.py +++ b/jax/_src/lax/control_flow/common.py @@ -21,7 +21,6 @@ from jax import linear_util as lu from jax.api_util import flatten_fun_nokwargs from jax.interpreters import partial_eval as pe -from jax._src.lax import lax from jax._src import ad_util from jax._src import util from jax._src.util import cache, weakref_lru_cache, safe_map, unzip3 @@ -30,8 +29,6 @@ map, unsafe_map = safe_map, map allowed_effects: Set[core.Effect] = set() -allowed_effects.add(lax.InOutFeedEffect.Infeed) -allowed_effects.add(lax.InOutFeedEffect.Outfeed) def _abstractify(x): diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 6e9def67758a..c965a3c37214 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -30,7 +30,6 @@ from jax._src import api from jax._src import api_util from jax._src import device_array -from jax._src import dispatch from jax import linear_util as lu from jax._src import dtypes from jax import tree_util @@ -1327,7 +1326,7 @@ def full_like(x: Array, fill_value: Array, dtype: Optional[DType] = None, # (so it works in staged-out code as well as 'eager' code). Related to # equi-sharding. if (config.jax_array and hasattr(x, 'sharding') and - not dispatch.is_single_device_sharding(x.sharding)): + not isinstance(x.sharding, sharding.SingleDeviceSharding)): return array.make_array_from_callback( fill_shape, x.sharding, lambda idx: val[idx]) # type: ignore[arg-type] return val @@ -4109,9 +4108,6 @@ def _after_all_lowering(ctx, *operands): mlir.register_lowering(after_all_p, _after_all_lowering) -InOutFeedEffect = enum.Enum('InOutFeedEffect', ['Infeed', 'Outfeed']) - - def infeed(token, shape=None, partitions=None): """Consumes an infeed value of `shape` from the host. Experimental. @@ -4137,14 +4133,13 @@ def infeed(token, shape=None, partitions=None): def _infeed_abstract_eval(token, *, shapes, partitions): if token is not abstract_token: raise TypeError("First argument to infeed must be a token") - return (*shapes, abstract_token), {InOutFeedEffect.Infeed} + return shapes + (abstract_token,) infeed_p = Primitive("infeed") infeed_p.multiple_results = True infeed_p.def_impl(partial(xla.apply_primitive, infeed_p)) -infeed_p.def_effectful_abstract_eval(_infeed_abstract_eval) -mlir.lowerable_effects.add(InOutFeedEffect.Infeed) +infeed_p.def_abstract_eval(_infeed_abstract_eval) def _infeed_lowering(ctx, token, *, shapes, partitions): @@ -4190,12 +4185,11 @@ def outfeed(token, xs, partitions = None): def _outfeed_abstract_eval(token, *xs, partitions): if token is not abstract_token: raise TypeError("First argument to outfeed must be a token") - return abstract_token, {InOutFeedEffect.Outfeed} + return abstract_token outfeed_p = Primitive("outfeed") outfeed_p.def_impl(partial(xla.apply_primitive, outfeed_p)) -outfeed_p.def_effectful_abstract_eval(_outfeed_abstract_eval) -mlir.lowerable_effects.add(InOutFeedEffect.Outfeed) +outfeed_p.def_abstract_eval(_outfeed_abstract_eval) def _outfeed_lowering(ctx, token, *xs, partitions): diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 72f69425c981..99a543292a45 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -31,7 +31,7 @@ from jax.interpreters import pxla from jax.interpreters import xla from jax.experimental.sharding import ( - MeshPspecSharding, PmapSharding, OpShardingSharding) + MeshPspecSharding, SingleDeviceSharding, PmapSharding, OpShardingSharding) from jax._src import dispatch from jax._src import dtypes @@ -364,7 +364,7 @@ def global_sharded_result_handler(aval, out_sharding, committed, phys_handler_maker = pxla.global_result_handlers[ (core.ShapedArray, output_type)] - if dispatch.is_single_device_sharding(out_sharding): + if isinstance(out_sharding, SingleDeviceSharding): phys_sharding = out_sharding elif isinstance(out_sharding, MeshPspecSharding): trailing_spec = [None] * len(key_shape) diff --git a/jax/core.py b/jax/core.py index 2741fbb9dd4f..15a2567c0044 100644 --- a/jax/core.py +++ b/jax/core.py @@ -208,7 +208,6 @@ def __repr__(self): def replace(self, *args, **kwargs): return self._replace(*args, **kwargs) -# TODO(mattjj): call typecheck rules here, so we don't form bad eqns def new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info=None): source_info = source_info or source_info_util.new_source_info() if config.jax_enable_checks: @@ -2163,7 +2162,6 @@ def _unmap_dshaped_array( DShapedArray: (_map_dshaped_array, _unmap_dshaped_array), ShapedArray: (_map_shaped_array, _unmap_shaped_array), ConcreteArray: (_map_shaped_array, _unmap_shaped_array), - AbstractToken: (lambda _, __, a: a, lambda _, __, ___, a: a) } @contextmanager diff --git a/jax/experimental/array.py b/jax/experimental/array.py index bf3ca7f266d4..f421f9f71d4c 100644 --- a/jax/experimental/array.py +++ b/jax/experimental/array.py @@ -451,7 +451,7 @@ def _device_put_array(x, device: Optional[Device]): # TODO(yashkatariya): Remove this restriction and the round trip via host # once lowering to XLA goes through `lower_mesh_computation`. assert x.is_fully_addressable() - if dispatch.is_single_device_sharding(x.sharding): + if isinstance(x.sharding, SingleDeviceSharding): x = dispatch._copy_device_array_to_device(pxla._set_aval(x._arrays[0]), device) return (x,) else: @@ -462,7 +462,7 @@ def _device_put_array(x, device: Optional[Device]): def _array_pmap_shard_arg(x, devices, indices, mode): - if dispatch.is_single_device_sharding(x.sharding): + if isinstance(x.sharding, SingleDeviceSharding): return pxla._shard_device_array(x, devices, indices, mode) if x._fast_path_args is None: @@ -484,7 +484,7 @@ def _array_shard_arg(x, devices, indices, mode): if mode == pxla.InputsHandlerMode.pmap: return _array_pmap_shard_arg(x, devices, indices, mode) else: - if dispatch.is_single_device_sharding(x.sharding): + if isinstance(x.sharding, SingleDeviceSharding): return [buf if buf.device() == d else buf.copy_to_device(d) for buf, d in safe_zip(x._arrays, devices)] # If PmapSharding exists, then do a round trip via host. This will happen diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index cf4cd8a8bbed..63d75b0f7c83 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -961,17 +961,6 @@ def convert_constvars_jaxpr(jaxpr: Jaxpr) -> Jaxpr: config.jax_enable_checks and core.check_jaxpr(lifted_jaxpr) return lifted_jaxpr -@weakref_lru_cache -def convert_invars_to_constvars(jaxpr: Jaxpr, n: int) -> Jaxpr: - """Move n invars to constvars. Like an inverse of convert_constvars_Jaxpr.""" - config.jax_enable_checks and core.check_jaxpr(jaxpr) - constvars, invars = split_list(jaxpr.invars, [n]) - lifted_jaxpr = Jaxpr(constvars=tuple(constvars), invars=invars, - outvars=jaxpr.outvars, eqns=jaxpr.eqns, - effects=jaxpr.effects) - config.jax_enable_checks and core.check_jaxpr(lifted_jaxpr) - return lifted_jaxpr - def convert_envvars_to_constvars(jaxpr: Jaxpr, num_env_vars: int) -> Jaxpr: config.jax_enable_checks and core.check_jaxpr(jaxpr) env_vars, invars = split_list(jaxpr.invars, [num_env_vars]) @@ -1317,17 +1306,6 @@ def dce_jaxpr(jaxpr: Jaxpr, used_outputs: Sequence[bool], instantiate = (instantiate,) * len(jaxpr.invars) return _dce_jaxpr(jaxpr, tuple(used_outputs), tuple(instantiate)) - -def dce_jaxpr_consts(jaxpr: Jaxpr, used_outputs: Sequence[bool], - instantiate: Union[bool, Sequence[bool]] = False, - ) -> Tuple[Jaxpr, List[bool], List[bool]]: - jaxpr_ = convert_constvars_jaxpr(jaxpr) - new_jaxpr_, used_inputs_ = dce_jaxpr(jaxpr_, used_outputs) - used_consts, used_inputs = split_list(used_inputs_, [len(jaxpr.constvars)]) - new_jaxpr = convert_invars_to_constvars(new_jaxpr_, sum(used_consts)) - return new_jaxpr, used_consts, used_inputs - - @weakref_lru_cache def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: Tuple[bool, ...], instantiate: Tuple[bool, ...] diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 1ee11b19fdf8..177fae377137 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -2081,13 +2081,10 @@ def _pmap_partial_eval_custom_res_maker(params_known, aval): def _pmap_dce_rule(used_outputs, eqn): # just like pe.dce_jaxpr_call_rule, except handles in_axes / out_axes new_jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['call_jaxpr'], used_outputs) - _, donated_invars = partition_list(used_inputs, eqn.params['donated_invars']) - # TODO(yashkatariya,mattjj): Handle global_arg_shapes here too. _, in_axes = partition_list(used_inputs, eqn.params['in_axes']) _, out_axes = partition_list(used_outputs, eqn.params['out_axes']) - new_params = dict(eqn.params, call_jaxpr=new_jaxpr, - donated_invars=tuple(donated_invars), - in_axes=tuple(in_axes), out_axes=tuple(out_axes)) + new_params = dict(eqn.params, call_jaxpr=new_jaxpr, in_axes=tuple(in_axes), + out_axes=tuple(out_axes)) if not any(used_inputs) and not any(used_outputs) and not new_jaxpr.effects: return used_inputs, None else: @@ -2675,8 +2672,7 @@ def lower_sharding_computation( global_in_avals: Sequence[core.ShapedArray], in_is_global: Sequence[bool], keep_unused: bool, - committed: bool, - inp_device_assignment: Optional[Sequence[xc.Device]] = None): + committed: bool): """Lowers a computation to XLA. It can take arbitrary shardings as input. The caller of this code can pass in a singleton _UNSPECIFIED because the @@ -2686,23 +2682,14 @@ def lower_sharding_computation( """ # Device assignment across all inputs and outputs should be the same. This # is checked in pjit. - if inp_device_assignment is not None: - from jax.experimental.sharding import SingleDeviceSharding - assert not in_shardings, "if device_assignment given, no in_shardings" - # TODO(yashkatariya): Look into allowing more than 1 device here. - assert len(inp_device_assignment) == 1 - device_assignment = inp_device_assignment - backend = xb.get_device_backend(device_assignment[0]) - first_sharding = SingleDeviceSharding(device_assignment[0]) + if _is_unspecified(out_shardings): + backend, first_sharding = _get_backend_from_shardings(in_shardings) else: - if _is_unspecified(out_shardings): - backend, first_sharding = _get_backend_from_shardings(in_shardings) # type: ignore - else: - # type ignore because mypy can't understand that out_shardings that are - # UNSPECIFIED singleton are filtered above. - backend, first_sharding = _get_backend_from_shardings( # type: ignore - it.chain(in_shardings, out_shardings)) # type: ignore - device_assignment = first_sharding._device_assignment + # type ignore because mypy can't understand that out_shardings that are + # UNSPECIFIED singleton are filtered above. + backend, first_sharding = _get_backend_from_shardings( + it.chain(in_shardings, out_shardings)) # type: ignore + device_assignment = first_sharding._device_assignment name_stack = new_name_stack(wrap_name(fun_name, api_name)) @@ -2711,7 +2698,6 @@ def lower_sharding_computation( "in {elapsed_time} sec"): jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final( fun, global_in_avals, debug_info=pe.debug_info_final(fun, api_name)) - kept_outputs = [True] * len(global_out_avals) log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG logging.log(log_priority, @@ -2737,29 +2723,10 @@ def lower_sharding_computation( donated_invars = tuple(x for i, x in enumerate(donated_invars) if i in kept_var_idx) del kept_const_idx - process_index = xb.process_index() - local_device_assignment = [d for d in device_assignment - if d.process_index == process_index] - if len(device_assignment) != len(local_device_assignment): + if not first_sharding.is_fully_addressable(): check_multihost_collective_allowlist(jaxpr) - - has_outfeed = core.jaxpr_uses_outfeed(jaxpr) jaxpr = dispatch.apply_outfeed_rewriter(jaxpr) - # Computations that only produce constants and/or only rearrange their inputs, - # which are often produced from partial evaluation, don't need compilation, - # and don't need to evaluate their arguments. - if (not (jaxpr.effects or has_outfeed) and - (not jaxpr.eqns and all(kept_outputs) or not jaxpr.outvars) and - all(_is_unspecified(o) for o in out_shardings) and # type: ignore - not hasattr(backend, "compile_replicated")): # this means 'not pathways' - return MeshComputation( - str(name_stack), None, True, donated_invars, jaxpr=jaxpr, consts=consts, - global_in_avals=global_in_avals, global_out_avals=global_out_avals, - in_shardings=in_shardings, - device_assignment=device_assignment, committed=committed, - kept_var_idx=kept_var_idx, keepalive=None) - # Look at the number of replcas present in the jaxpr. In # lower_sharding_computation, nreps > 1 during `jit(pmap)` cases. This is # handled here so as to deprecate the lower_xla_callable codepath when @@ -2842,7 +2809,6 @@ def lower_sharding_computation( return MeshComputation( str(name_stack), module, - False, donated_invars, mesh=None, global_in_avals=global_in_avals, @@ -3005,7 +2971,6 @@ def lower_mesh_computation( return MeshComputation( str(name_stack), module, - False, donated_invars, mesh=mesh, global_in_avals=global_in_avals, @@ -3031,10 +2996,9 @@ class MeshComputation(stages.XlaLowering): _executable: Optional[MeshExecutable] def __init__(self, name: str, hlo: Union[ir.Module, xc.XlaComputation], - is_trivial: bool, donated_invars: Sequence[bool], **compile_args): + donated_invars: Sequence[bool], **compile_args): self._name = name self._hlo = hlo - self.is_trivial = is_trivial self._donated_invars = donated_invars self.compile_args = compile_args self._executable = None @@ -3042,8 +3006,6 @@ def __init__(self, name: str, hlo: Union[ir.Module, xc.XlaComputation], # -- stages.XlaLowering overrides def hlo(self) -> xc.XlaComputation: - if self.is_trivial: - raise ValueError("A trivial computation has no HLO") # this is a method for api consistency with dispatch.XlaComputation if isinstance(self._hlo, xc.XlaComputation): return self._hlo @@ -3052,8 +3014,6 @@ def hlo(self) -> xc.XlaComputation: use_tuple_args=self.compile_args["tuple_args"]) def mhlo(self) -> ir.Module: - if self.is_trivial: - raise ValueError("A trivial computation has no MHLO") if isinstance(self._hlo, xc.XlaComputation): module_str = xe.mlir.xla_computation_to_mlir_module(self._hlo) with mlir.make_ir_context(): @@ -3064,13 +3024,10 @@ def compile(self, _allow_propagation_to_outputs : bool = False, _allow_compile_replicated : bool = True) -> MeshExecutable: if self._executable is None: - if self.is_trivial: - self._executable = MeshExecutable.from_trivial_jaxpr(**self.compile_args) - else: - self._executable = MeshExecutable.from_hlo( - self._name, self._hlo, **self.compile_args, - _allow_propagation_to_outputs=_allow_propagation_to_outputs, - _allow_compile_replicated=_allow_compile_replicated) # type: ignore + self._executable = MeshExecutable.from_hlo( + self._name, self._hlo, **self.compile_args, + _allow_propagation_to_outputs=_allow_propagation_to_outputs, + _allow_compile_replicated=_allow_compile_replicated) # type: ignore return self._executable @@ -3290,32 +3247,6 @@ def from_hlo(name: str, return MeshExecutable(xla_executable, unsafe_call, input_avals, in_shardings, out_shardings, auto_spmd_lowering) - @staticmethod - def from_trivial_jaxpr(jaxpr, consts, global_in_avals, global_out_avals, - in_shardings, device_assignment, - committed, kept_var_idx, keepalive) -> MeshExecutable: - assert keepalive is None - out_shardings = _out_shardings_for_trivial( - jaxpr, consts, in_shardings, device_assignment) - if config.jax_array or config.jax_parallel_functions_output_gda: - are_global = [True] * len(global_out_avals) - else: - are_global = [False] * len(global_out_avals) - _, indices, _ = _get_input_metadata(global_out_avals, out_shardings, - are_global) - process_index = xb.process_index() - local_device_assignment = [d for d in device_assignment - if d.process_index == process_index] - handle_ins = InputsHandler(local_device_assignment, out_shardings, indices, - InputsHandlerMode.pjit_or_xmap) - handle_outs = global_avals_to_results_handler( - global_out_avals, out_shardings, committed, - [False] * len(global_out_avals)) - unsafe_call = partial(_execute_trivial, jaxpr, consts, handle_ins, - handle_outs, kept_var_idx) - return MeshExecutable(None, unsafe_call, global_in_avals, in_shardings, - out_shardings, False) - # -- stages.XlaExecutable overrides def xla_extension_executable(self): @@ -3330,39 +3261,6 @@ def call(self, *args): return self.unsafe_call(*args) -def _out_shardings_for_trivial( - jaxpr: core.Jaxpr, consts: Sequence[Any], - in_shardings: Sequence[XLACompatibleSharding], - device_assignment: Sequence[xc.Device], - ) -> List[XLACompatibleSharding]: - # For each jaxpr output, compute a Sharding by: - # * if the output is a forwarded input, get the corresponding in_sharding; - # * if the output is a constant Array, get its .sharding attribute; - # * otherwise, the output is a literal or numpy.ndarray constant, so give it - # a replicated sharding - from jax.experimental import array - from jax.experimental import sharding - rep = sharding.OpShardingSharding( - device_assignment, sharding._get_replicated_op_sharding()) - shardings: Dict[core.Var, sharding.XLACompatibleSharding] = {} - for constvar, constval in zip(jaxpr.constvars, consts): - if isinstance(constval, array.Array): - shardings[constvar] = constval.sharding - map(shardings.setdefault, jaxpr.invars, in_shardings) - return [rep if isinstance(x, core.Literal) else shardings.get(x, rep) - for x in jaxpr.outvars] - - -def _execute_trivial(jaxpr, consts, in_handler, out_handler, kept_var_idx, *args): - env: Dict[core.Var, Any] = {} - pruned_args = (x for i, x in enumerate(args) if i in kept_var_idx) - map(env.setdefault, jaxpr.invars, pruned_args) - map(env.setdefault, jaxpr.constvars, consts) - outs = [xla.canonicalize_dtype(v.val) if type(v) is core.Literal else env[v] - for v in jaxpr.outvars] - return out_handler(in_handler(outs)) - - @lru_cache() def _create_mesh_pspec_sharding(mesh, pspec, parsed_pspec=None): from jax.experimental.sharding import MeshPspecSharding diff --git a/tests/api_test.py b/tests/api_test.py index 062a8c8ba05e..2a8680f62fe6 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -1070,7 +1070,10 @@ def test_jit_lower_no_prunning(self): jitted_f = self.jit(lambda x, y: x, keep_unused=True) with jtu.count_device_put() as count: _ = jitted_f(1, 2) - self.assertEqual(count[0], 1) + if config.jax_array: + self.assertEqual(count[0], 2) + else: + self.assertEqual(count[0], 1) @jtu.ignore_warning(category=DeprecationWarning) def test_jit_lower_compile_compiler_ir(self): diff --git a/tests/array_test.py b/tests/array_test.py index a4b99bc51e83..8936ddf8b02d 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -20,7 +20,6 @@ import jax import jax.numpy as jnp -from jax._src import dispatch from jax._src import config as jax_config from jax._src import test_util as jtu from jax._src.lib import xla_client as xc @@ -190,7 +189,7 @@ def test_jnp_array(self): with jax_config.jax_array(True): arr = jnp.array([1, 2, 3]) self.assertIsInstance(arr, array.Array) - self.assertTrue(dispatch.is_single_device_sharding(arr.sharding)) + self.assertIsInstance(arr.sharding, sharding.SingleDeviceSharding) self.assertEqual(arr._committed, False) def test_jnp_array_jit_add(self): @@ -272,7 +271,7 @@ def test_zeros_like(self): out = jnp.zeros_like(a) expected = np.zeros(a.shape, dtype=np.int32) self.assertArraysEqual(out, expected) - self.assertTrue(dispatch.is_single_device_sharding(out.sharding)) + self.assertIsInstance(out.sharding, sharding.SingleDeviceSharding) @jax_config.jax_array(True) def test_wrong_num_arrays(self): diff --git a/tests/debug_nans_test.py b/tests/debug_nans_test.py index 70d04c4e52a1..2082f3db35af 100644 --- a/tests/debug_nans_test.py +++ b/tests/debug_nans_test.py @@ -225,8 +225,8 @@ def f(x): def testDebugNansDoesntReturnDeoptimizedResult(self): @jax.jit def f(x): - y = x + 2 # avoid trivial dispatch path by adding some eqn - return jnp.nan, y + x + 2 # avoid trivial dispatch path by adding some eqn + return jnp.nan with self.assertRaisesRegex(FloatingPointError, "de-optimized"): with jax.debug_nans(True): diff --git a/tests/multi_device_test.py b/tests/multi_device_test.py index 98b6a340165e..c254e7b34389 100644 --- a/tests/multi_device_test.py +++ b/tests/multi_device_test.py @@ -202,11 +202,8 @@ def f(): return lax.add(3., 4.) self.assertIsInstance(f(), jnp.DeviceArray) self.assert_uncommitted_to_device(f(), devices[0]) self.assert_uncommitted_to_device(jax.jit(f)(), devices[0]) - # Skip for jax.Array because it doesn't work with the device argument of - # jit as it is deprecated. - if not config.jax_array: - self.assert_committed_to_device(jax.jit(f, device=devices[1])(), - devices[1]) + self.assert_committed_to_device(jax.jit(f, device=devices[1])(), + devices[1]) def test_reshape(self): devices = self.get_devices()