From b8098b1782a7de30f0daab9cb865dbecaf5e646b Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 5 Jan 2024 14:16:32 -0800 Subject: [PATCH] Remove indices and devices from shard_arg_handlers and shard_args. This only affects python dispatch path. This has no impact on the speed of cpp dispatch (which is why benchmarks are **not** regressing). If your code ends up taking the python dispatch, then something is going wrong anyways. PiperOrigin-RevId: 596081987 --- jax/_src/api.py | 4 +- jax/_src/array.py | 11 ++- jax/_src/dispatch.py | 9 +- jax/_src/interpreters/pxla.py | 156 ++++++++++++++++------------------ jax/_src/pjit.py | 11 ++- jax/_src/prng.py | 15 +--- jax/_src/sharding.py | 8 +- jax/_src/sharding_impls.py | 4 + tests/lax_test.py | 4 +- tests/pjit_test.py | 50 +---------- tests/pmap_test.py | 4 +- 11 files changed, 108 insertions(+), 168 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index ea5bfd8ea86c..a05c0f231900 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -63,6 +63,7 @@ from jax._src.lax import lax as lax_internal from jax._src.lib import jax_jit from jax._src.lib import xla_client as xc +from jax._src.lib import xla_extension_version from jax._src.lib import pmap_lib from jax._src.sharding import Sharding from jax._src.sharding_impls import (PmapSharding, TransferToMemoryKind, @@ -1845,7 +1846,8 @@ def cache_miss(*args, **kwargs): return out, fastpath_data cpp_mapped_f = pmap_lib.pmap( - fun, cache_miss, static_broadcasted_tuple, pxla.shard_arg, + fun, cache_miss, static_broadcasted_tuple, + pxla.shard_arg if xla_extension_version >= 229 else pxla.temp_shard_arg, # type: ignore pytree_registry=tree_util.default_registry) _pmap_cache_clears.add(cpp_mapped_f) diff --git a/jax/_src/array.py b/jax/_src/array.py index dca74b2de9a1..fb82ffe92ad1 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -834,8 +834,7 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding): if not candidates_list: # This array isn't sharded correctly. Reshard it via host roundtrip. # TODO(skye): more efficient reshard? - return pxla.shard_arg(x._value, devices, indices, sharding, - canonicalize=False) + return pxla.shard_arg(x._value, sharding, canonicalize=False) # Try to find a candidate buffer already on the correct device, # otherwise copy one of them. for buf in candidates_list: @@ -848,10 +847,11 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding): return pxla.batched_device_put(x.aval, sharding, bufs, devices) -def _array_shard_arg(x, devices, indices, sharding): +def _array_shard_arg(x, sharding): x._check_if_deleted() x_indices = x.sharding.addressable_devices_indices_map(x.shape).values() + indices = sharding.addressable_devices_indices_map(x.shape).values() if not x.is_fully_addressable: if tuple(x_indices) == tuple(indices): return x @@ -859,16 +859,15 @@ def _array_shard_arg(x, devices, indices, sharding): raise NotImplementedError( "Cannot reshard an input that is not fully addressable") else: + devices = pxla.get_addressable_devices_for_shard_arg(sharding) if tuple(x_indices) == tuple(indices): - return xc.copy_array_to_devices_with_sharding( - x, list(devices), sharding) + return xc.copy_array_to_devices_with_sharding(x, list(devices), sharding) # Resharding starts here: if dispatch.is_single_device_sharding(x.sharding): return shard_device_array(x, devices, indices, sharding) else: return shard_sharded_device_array_slow_path(x, devices, indices, sharding) - pxla.shard_arg_handlers[ArrayImpl] = _array_shard_arg diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index f39664f0e508..af3be13794ae 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -124,12 +124,8 @@ def __init__(self): def get_token_input(self, eff: core.Effect, devices: list[Device]) -> jax.Array: tok = self.current_tokens.get(eff, np.zeros(0, np.bool_)) - s = NamedSharding(pxla.Mesh(devices, axis_names=["dev"]), - PartitionSpec([])) s = jax.sharding.GSPMDSharding.get_replicated(devices) - indices = tuple( - s.addressable_devices_indices_map(tok.shape).values()) - sharded_tok = pxla.shard_args(devices, [indices], [s], [tok])[0] + sharded_tok = pxla.shard_args([s], [tok])[0] self.current_tokens[eff] = sharded_tok return sharded_tok @@ -331,8 +327,7 @@ def _check_special(name: str, dtype: np.dtype, buf: basearray.Array) -> None: def _put_x(x, s: Sharding, aval: core.AbstractValue, committed: bool): result_handler = pxla.global_aval_to_result_handler(aval, s, committed, False) - map_ = s.devices_indices_map(aval.shape) # type: ignore - return result_handler(pxla.shard_arg(x, list(map_), list(map_.values()), s)) + return result_handler(pxla.shard_arg(x, s)) def _override_get_device_assignment(sharding, *args, **kwargs): da = sharding._device_assignment diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index d5dbc23ccabe..6d29dd8c641b 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -106,60 +106,46 @@ class WeakRefList(list): def identity(x): return x -def shard_arg(arg, devices, arg_indices, sharding, canonicalize=True): - """Returns a list of size len(devices) containing per-device buffers. - - For the C++ pmap path, we fallback to Python (this function) to shard - arguments that are not supported by the C++ `ShardArg`. - - Args: - arg: The Python argument. - devices: The list of devices to shard over. - arg_indices: A list of `len(devices)` indices to use to shard the argument. - """ +def shard_arg(arg, sharding, canonicalize=True): if canonicalize: arg = xla.canonicalize_dtype(arg) - return shard_arg_handlers[type(arg)](arg, devices, arg_indices, sharding) + return shard_arg_handlers[type(arg)](arg, sharding) @profiler.annotate_function def shard_args( - devices: Sequence[xb.xla_client.Device], - indices: Sequence[Sequence[Index]], - shardings: Sequence[sharding_impls.XLACompatibleSharding], - args, + shardings: Sequence[sharding_impls.XLACompatibleSharding], args, ) -> Sequence[jax.Array]: - """Shard each argument data array along its leading axis. + return [shard_arg(arg, shardings[i]) for i, arg in enumerate(args)] - Args: - devices: sequence of Devices mapping replica index to a physical device. - indices: sequence of the same length as `args` describing how each arg - should be sharded/replicated across `devices`. Each element in `indices` - is the same length as `devices`. - args: a sequence of JaxTypes representing arguments to be sharded according - to `indices` and placed on `devices`. +shard_arg_handlers: dict[Any, Callable[[Any, Any], Any]] = {} - Returns: - A list of length matching args, containing lists of per-device buffers - for each argument. - """ - return [shard_arg(arg, devices, indices[i], shardings[i]) - for i, arg in enumerate(args)] -shard_arg_handlers: dict[Any, Callable[[Any, Any, Any, Any], Any]] = {} +@lru_cache(maxsize=1024) +def get_addressable_devices_for_shard_arg( + s: sharding_impls.XLACompatibleSharding) -> tuple[xc.Device, ...]: + return s._addressable_device_assignment + +@lru_cache(maxsize=1024) +def _get_replicated_slices(num_addressable_devices: int): + return ((slice(None),),) * num_addressable_devices -def _shard_token(x, devices, indices, sharding): +def _shard_token(x, sharding): + devices = get_addressable_devices_for_shard_arg(sharding) + indices = _get_replicated_slices(len(devices)) zeros = np.zeros((), dtype=np.dtype(np.bool_)) aval = api_util.shaped_abstractify(zeros) - return batched_device_put(aval, sharding, [zeros for i in indices], devices) + return batched_device_put(aval, sharding, [zeros for _ in indices], devices) shard_arg_handlers[core.Token] = _shard_token -def _masked_array_error(x, devices, indices, sharding): +def _masked_array_error(x, sharding): raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. " "Use arr.filled() to convert the value to a standard numpy array.") shard_arg_handlers[np.ma.MaskedArray] = _masked_array_error -def _shard_array(x, devices, indices, sharding): +def _shard_array(x, sharding): + indices = tuple(sharding.addressable_devices_indices_map(x.shape).values()) + devices = get_addressable_devices_for_shard_arg(sharding) if x.dtype == dtypes.float0: x = np.zeros(x.shape, dtype=np.dtype(bool)) aval = api_util.shaped_abstractify(x) @@ -167,8 +153,8 @@ def _shard_array(x, devices, indices, sharding): for _t in array_types: shard_arg_handlers[_t] = _shard_array -def _shard_darray(x, devices, indices, sharding): - return shard_arg(x._data, devices, indices, sharding) +def _shard_darray(x, sharding): + return shard_arg(x._data, sharding) shard_arg_handlers[core.DArray] = _shard_darray def batched_device_put(aval: core.ShapedArray, @@ -183,7 +169,7 @@ def batched_device_put(aval: core.ShapedArray, if len(bufs) == len(xs): return array.ArrayImpl( aval, sharding, bufs, committed=committed, _skip_checks=True) - return xc.batched_device_put(aval, sharding, xs, devices, committed) # type: ignore + return xc.batched_device_put(aval, sharding, xs, list(devices), committed) # type: ignore def shard_aval(size, axis: int, aval): try: @@ -849,8 +835,8 @@ def build_execute_fun(self): if spec.sharding_spec is not None else None) handle_outs = local_avals_to_results_handler(self.local_output_avals, self.output_shardings) - handle_args = InputsHandler(self.compiled.local_devices(), - self.input_shardings, input_indices) + handle_args = InputsHandler(self.input_shardings, + self.compiled.local_devices(), input_indices) execute_fun = ExecuteReplicated(self.compiled, "parallel computation", self.backend, handle_args, handle_outs, self.unordered_effects, @@ -1054,9 +1040,8 @@ def _get_pmap_sharding(devices, specs): class InputsHandler: __slots__ = ("handler", "local_devices", "in_shardings", "input_indices") - def __init__(self, local_devices, in_shardings, input_indices): - self.handler = partial( - shard_args, local_devices, input_indices, in_shardings) + def __init__(self, in_shardings, local_devices=None, input_indices=None): + self.handler = partial(shard_args, in_shardings) self.local_devices = local_devices self.in_shardings = in_shardings self.input_indices = input_indices @@ -2248,37 +2233,36 @@ def cost_analysis(self) -> dict[str, float]: return xe.hlo_module_cost_analysis(backend, self.hlo().as_hlo_module()) -@lru_cache(maxsize=1024) -def _get_replicated_slices(num_addressable_devices: int, ndim: int | None): - if ndim is None: - return ((slice(None),),) * num_addressable_devices - else: - return ((slice(None),) * ndim,) * num_addressable_devices +if xla_extension_version < 229: + def _get_input_indices( + avals: Sequence[ShapedArray], + shardings: Sequence[sharding_impls.XLACompatibleSharding], + da_object: _DeviceAssignment | Sequence[xc.Device], # type: ignore + ) -> Sequence[tuple[Index | None, ...]]: + input_indices = [] + if not isinstance(da_object, _DeviceAssignment): + da_object = _create_da_object(tuple(da_object)) + num_addressable_devices = len(da_object.addressable_device_list) -def _get_input_indices( - avals: Sequence[ShapedArray], - shardings: Sequence[sharding_impls.XLACompatibleSharding], - da_object: _DeviceAssignment | Sequence[xc.Device], # type: ignore -) -> Sequence[tuple[Index | None, ...]]: - - input_indices = [] - if not isinstance(da_object, _DeviceAssignment): - da_object = _create_da_object(tuple(da_object)) - num_addressable_devices = len(da_object.addressable_device_list) + def _get_replicated_slices(num_addressable_devices: int, ndim: int | None): + if ndim is None: + return ((slice(None),),) * num_addressable_devices + else: + return ((slice(None),) * ndim,) * num_addressable_devices - for aval, sharding in zip(avals, shardings): - if aval is core.abstract_token: - index = _get_replicated_slices(num_addressable_devices, None) - else: - if sharding.is_fully_replicated: - index = _get_replicated_slices(num_addressable_devices, aval.ndim) + for aval, sharding in zip(avals, shardings): + if aval is core.abstract_token: + index = _get_replicated_slices(num_addressable_devices, None) else: - index = tuple( - sharding.addressable_devices_indices_map(aval.shape).values()) # type: ignore - input_indices.append(index) + if sharding.is_fully_replicated: + index = _get_replicated_slices(num_addressable_devices, aval.ndim) + else: + index = tuple( + sharding.addressable_devices_indices_map(aval.shape).values()) # type: ignore + input_indices.append(index) - return input_indices + return input_indices def get_gspmd_shardings_from_executable( @@ -2604,10 +2588,13 @@ class UnloadedMeshExecutable: all_args_info: AllArgsInfo | None def build_unsafe_call(self): - input_indices = _get_input_indices(self.input_avals, self.input_shardings, - self.device_assignment) - handle_args = InputsHandler(self.xla_executable.local_devices(), - self.input_shardings, input_indices) + if xla_extension_version >= 229: + handle_args = InputsHandler(self.input_shardings) + else: + input_indices = _get_input_indices(self.input_avals, self.input_shardings, + self.device_assignment) + handle_args = InputsHandler( + self.input_shardings, self.xla_executable.local_devices(), input_indices) handle_outs = global_avals_to_results_handler( self.output_avals, self.output_shardings, self.committed, self.are_out_shardings_from_xla) # type: ignore # arg-type @@ -2755,6 +2742,7 @@ class MeshExecutableFastpathData(NamedTuple): out_avals: Sequence[ShapedArray] out_committed: Sequence[bool] kept_var_bitvec: Iterable[bool] + # TODO(yashkatariya): Remove once minimum jaxlib version is 0.4.24 arg_handler_devices: Sequence[xc.Device] arg_handler_indices: Sequence[tuple[Index | None, ...]] @@ -2865,13 +2853,20 @@ def aot_cache_miss(*args, **kwargs): return outs, fastpath_data if xla_extension_version >= 226: - return xc._xla.pjit(self.unsafe_call.name, None, aot_cache_miss, [], [], [], - tree_util.dispatch_registry, shard_arg) + return xc._xla.pjit( + self.unsafe_call.name, None, aot_cache_miss, [], [], [], + tree_util.dispatch_registry, + shard_arg if xla_extension_version >= 229 else temp_shard_arg) # type: ignore else: return xc._xla.pjit(self.unsafe_call.name, None, aot_cache_miss, [], [], [], # type: ignore tree_util.dispatch_registry) +# TODO(yashkatariya): Remove once minimum jaxlib version is 0.4.24 +def temp_shard_arg(arg, devices, arg_indices, sharding, canonicalize=True): + return shard_arg(arg, sharding) + + def check_arg_avals_for_call(ref_avals, arg_avals, jaxpr_debug_info: core.JaxprDebugInfo | None = None): if len(ref_avals) != len(arg_avals): @@ -2926,20 +2921,15 @@ def _compile_replicated_mesh_executable_from_hlo( in_shardings = semantics_in_shardings.shardings out_shardings = semantics_out_shardings.shardings - input_indices = _get_input_indices(global_in_avals, in_shardings, da) # type: ignore - if pmap_nreps > 1: - # For a jit wrapping a pmap, replicate each input index to match the - # devices of the replicated jit computation. - input_indices = [index * pmap_nreps for index in input_indices] kept_var_idx = set(kept_var_idx) # Will compute out_handler with executable information. unsafe_call = backend.compile_replicated( is_trivial=False, name=name, computation=computation, compile_options=compile_options, host_callbacks=host_callbacks, has_unordered_effects=has_unordered_effects, - ordered_effects=ordered_effects, in_avals=global_in_avals, - in_indices=input_indices, in_shardings=in_shardings, - kept_var_idx=kept_var_idx, + device_assignment=da, ordered_effects=ordered_effects, + in_avals=global_in_avals, + in_shardings=in_shardings, kept_var_idx=kept_var_idx, out_avals=global_out_avals, out_shardings=out_shardings, committed=committed, pmap_nreps=pmap_nreps) xla_executable = None diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 30d9617940ee..fda1d7d26666 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -234,7 +234,8 @@ def cache_miss(*args, **kwargs): getattr(fun, "__name__", ""), fun, cache_miss, static_argnums, static_argnames, donate_argnums, tree_util.dispatch_registry, - pxla.shard_arg, _get_cpp_global_cache(pjit_has_explicit_sharding)) # type: ignore + pxla.shard_arg if xla_extension_version >= 229 else pxla.temp_shard_arg, # type: ignore + _get_cpp_global_cache(pjit_has_explicit_sharding)) # type: ignore else: cpp_pjit_f = xc._xla.pjit( # type: ignore getattr(fun, "__name__", ""), @@ -1348,9 +1349,11 @@ def call_impl_cache_miss(*args_, **kwargs_): has_explicit_sharding = _pjit_explicit_sharding( in_shardings, out_shardings, None, None) if xla_extension_version >= 226: - return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums, - tree_util.dispatch_registry, pxla.shard_arg, - _get_cpp_global_cache(has_explicit_sharding))(*args) + return xc._xla.pjit( + name, f, call_impl_cache_miss, [], [], donated_argnums, + tree_util.dispatch_registry, + pxla.shard_arg if xla_extension_version >= 229 else pxla.temp_shard_arg, # type: ignore + _get_cpp_global_cache(has_explicit_sharding))(*args) else: return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums, # type: ignore tree_util.dispatch_registry, diff --git a/jax/_src/prng.py b/jax/_src/prng.py index df76b9d5f062..8cce8953e4ea 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -636,20 +636,11 @@ def __hash__(self) -> int: xla.canonicalize_dtype_handlers[PRNGKeyArrayImpl] = lambda x: x -def key_array_shard_arg_handler(x: PRNGKeyArrayImpl, devices, indices, sharding): - aval = x.aval - key_shape = aval.dtype._impl.key_shape +def key_array_shard_arg_handler(x: PRNGKeyArrayImpl, sharding): arr = x._base_array - - # TODO(yashkatariya,frostig): This assumes that the last dimensions are not - # sharded. This is only true when enable_custom_prng is True. - trailing_inds = [slice(None)] * len(key_shape) - phys_indices = [(*inds, *trailing_inds) for inds in indices] phys_sharding = make_key_array_phys_sharding( - aval, sharding, is_sharding_from_xla=False) - return pxla.shard_arg_handlers[type(arr)]( - arr, devices, phys_indices, phys_sharding - ) + x.aval, sharding, is_sharding_from_xla=False) + return pxla.shard_arg_handlers[type(arr)](arr, phys_sharding) pxla.shard_arg_handlers[PRNGKeyArrayImpl] = key_array_shard_arg_handler diff --git a/jax/_src/sharding.py b/jax/_src/sharding.py index 7980d5f35284..7c294d66fb6e 100644 --- a/jax/_src/sharding.py +++ b/jax/_src/sharding.py @@ -30,9 +30,13 @@ @functools.lru_cache(maxsize=4096) def _addressable_devices_indices_map( sharding: Sharding, global_shape: Shape) -> Mapping[Device, Index | None]: + global_map = sharding.devices_indices_map(global_shape) if sharding.is_fully_addressable: - return sharding.devices_indices_map(global_shape) - return {d: ind for d, ind in sharding.devices_indices_map(global_shape).items() + return global_map + if hasattr(sharding, '_internal_device_list'): + return {d: global_map[d] + for d in sharding._internal_device_list.addressable_device_list} + return {d: ind for d, ind in global_map.items() if d.process_index == d.client.process_index()} diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 88a72c1a7b83..5e87f4ca3544 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -110,6 +110,10 @@ def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]: @functools.cached_property def _addressable_device_assignment(self) -> XLADeviceAssignment: + if self.is_fully_addressable: + return self._device_assignment + if hasattr(self, '_internal_device_list'): + return tuple(self._internal_device_list.addressable_device_list) return tuple(d for d in self._device_assignment if d.process_index == d.client.process_index()) diff --git a/tests/lax_test.py b/tests/lax_test.py index c3762ffee587..67e58b875e9f 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -3042,8 +3042,8 @@ def __repr__(self) -> str: size = property(lambda self: self.data.size // 2) ndim = property(lambda self: self.data.ndim - 1) -def shard_foo_array_handler(x, devices, indices, sharding): - device, = devices +def shard_foo_array_handler(x, sharding): + device, = sharding._addressable_device_assignment aval = core.raise_to_shaped(core.get_aval(x.data)) return pxla.batched_device_put( aval, jax.sharding.SingleDeviceSharding(device), [x.data], [device]) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index a3f974f69441..bbe35c312055 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -45,7 +45,7 @@ from jax.experimental import multihost_utils from jax.experimental.custom_partitioning import custom_partitioning from jax._src import array -from jax._src.sharding import Sharding, _addressable_devices_indices_map +from jax._src.sharding import Sharding from jax._src import op_shardings from jax._src import sharding_impls from jax._src.sharding_impls import ( @@ -60,7 +60,7 @@ from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension from jax._src.lib import xla_extension_version -from jax._src.util import curry, unzip2, safe_zip +from jax._src.util import curry, unzip2 config.parse_flags_with_absl() @@ -3546,32 +3546,6 @@ def test_sharding_preserved_apply_primitive(self): self.assertIsInstance(out4.sharding, SingleDeviceSharding) self.assertEqual(out4.devices(), {jax.devices()[1]}) - def test_get_indices_cache(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) - ns = NamedSharding(mesh, P('x')) - ns2 = NamedSharding(mesh, P('x', 'y')) - - np_inp = np.arange(16).reshape(8, 2) - arr1 = jax.device_put(np_inp, ns) - arr2 = jax.device_put(np_inp, ns2) - arr3 = jax.device_put(np_inp, ns) - - _addressable_devices_indices_map.cache_clear() - - cache_info1 = _addressable_devices_indices_map.cache_info() - out = pjit(lambda x, y, z: x + y + z)(arr1, arr2, arr3) - cache_info2 = _addressable_devices_indices_map.cache_info() - self.assertArraysEqual(out, np_inp * 3) - - # arr3 and arr1 should have the same GSPMDSharding objects internally. - # So there will be 2 hits in _addressable_devices_indices_map, - # One in `pxla._get_input_indices` and second in `_array_shard_arg`. - self.assertEqual(cache_info2.hits, cache_info1.hits + 2) - # There will double the amount of misses as hits because arr1 and arr2's - # sharding are not the same. So 2 misses in _addressable_devices_indices_map - # and 2 in _array_shard_arg. - self.assertEqual(cache_info2.misses, cache_info1.misses + 4) - def test_same_named_sharding_pspec_on_eager_ops(self): mesh = jtu.create_global_mesh((1, 8, 1), ('x', 'y', 'z')) sharding = jax.sharding.NamedSharding(mesh, P('x', 'y', 'z')) @@ -4261,26 +4235,6 @@ def test_array_mapping_to_axis_resources(self, inp, expected_out): sharding_impls.array_mapping_to_axis_resources(inp), expected_out ) - def test_get_input_indices_fully_replicated(self): - global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) - global_in_aval1 = core.ShapedArray((4, 4), jnp.int32) - global_in_aval2 = core.ShapedArray((4, 4, 4), jnp.int32) - global_in_aval3 = core.ShapedArray((), jnp.int32) - in_avals = [global_in_aval1, global_in_aval2, global_in_aval3] - - mp = NamedSharding(global_mesh, P(None)) - - out_indices = pxla._get_input_indices(in_avals, [mp, mp, mp], - list(global_mesh.devices.flat)) - - self.assertLen(out_indices, len(in_avals)) - self.assertTrue(all(len(out) == len(global_mesh.local_devices) - for out in out_indices)) - self.assertTrue(all(len(i) == aval.ndim - for out, aval in safe_zip(out_indices, in_avals) for i in out)) - self.assertTrue(all(i == (slice(None),) * aval.ndim - for out, aval in safe_zip(out_indices, in_avals) for i in out)) - @parameterized.named_parameters( ("all_unspecified", (UNSPECIFIED, UNSPECIFIED), AssertionError), ("only_unspecified", UNSPECIFIED), diff --git a/tests/pmap_test.py b/tests/pmap_test.py index 81040590f096..9c7f8d6b029b 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -3006,9 +3006,7 @@ def testShardArgs(self, shape, spec, make_arg): x = np.arange(math.prod(shape)).reshape(shape) arg = make_arg(x) sharding = jax.sharding.PmapSharding(jax.devices()[:nshards], spec) - results = pxla.shard_args( - jax.devices()[:nshards], [indices], [sharding], [arg] - ) + results = pxla.shard_args([sharding], [arg]) self.assertEqual(len(results), 1) if isinstance(results[0], array.ArrayImpl): bufs = results[0]._arrays