From 2178f339bd2c224b3486b461a5099ce0d1f03d3f Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 29 Jul 2022 11:37:08 -0700 Subject: [PATCH] Add `OpShardingSharding` and add a function that can calculate indices from an opsharding proto PiperOrigin-RevId: 464123009 --- jax/experimental/pjit.py | 20 ++----- jax/experimental/sharding.py | 110 ++++++++++++++++++++++++----------- jax/interpreters/pxla.py | 51 ++++++++++++++++ tests/array_test.py | 18 ++++++ 4 files changed, 148 insertions(+), 51 deletions(-) diff --git a/jax/experimental/pjit.py b/jax/experimental/pjit.py index 69fdc6137cb7..98a9808e022e 100644 --- a/jax/experimental/pjit.py +++ b/jax/experimental/pjit.py @@ -584,21 +584,6 @@ def _pjit_jaxpr(fun, out_shardings_thunk, global_in_avals, out_tree): return _ListWithW([jaxpr, normalized_out_shardings_flat]) -def _get_num_ways_dim_sharded(s, aval_shape) -> List[int]: - op_sharding = s._to_xla_op_sharding(len(aval_shape)) - tile_assignment_dimensions = op_sharding.tile_assignment_dimensions - - if op_sharding.last_tile_dims == [xc.OpSharding.Type.REPLICATED]: - replicate_on_last_tile_dim = True - else: - replicate_on_last_tile_dim = op_sharding.replicate_on_last_tile_dim - if op_sharding.last_tile_dims: - raise NotImplementedError("Unhandled OpSharding type. Please open a bug report!") - if replicate_on_last_tile_dim: - tile_assignment_dimensions = tile_assignment_dimensions[:-1] - return tile_assignment_dimensions - - def pjit_check_aval_sharding( shardings: Sequence[XLACompatibleSharding], flat_avals, what_aval: str, allow_uneven_sharding: bool): @@ -615,7 +600,10 @@ def pjit_check_aval_sharding( # Use the `OpSharding` proto to find out how many ways each dimension of # the aval is sharded. This approach will work across all # XLACompatibleSharding. - num_ways_dim_sharded = _get_num_ways_dim_sharded(s, shape) + op_sharding = s._to_xla_op_sharding(len(shape)) + assert op_sharding is not None + num_ways_dim_sharded, _ = pxla._get_num_ways_dim_sharded( + cast(xc.OpSharding, op_sharding)) for i, size in enumerate(num_ways_dim_sharded): if not allow_uneven_sharding and shape[i] % size != 0: raise ValueError(f"One of {what_aval} was given the sharding " diff --git a/jax/experimental/sharding.py b/jax/experimental/sharding.py index ede05517c337..68c8d0898fdf 100644 --- a/jax/experimental/sharding.py +++ b/jax/experimental/sharding.py @@ -17,7 +17,6 @@ from collections import Counter from typing import Sequence, Tuple, Optional, Mapping, Dict, Set, Union -from jax._src.config import config from jax._src.util import safe_zip from jax._src.lib import xla_bridge as xb from jax._src.lib import xla_client as xc @@ -51,10 +50,9 @@ def is_fully_addressable(self) -> bool: # The pytype disable is because pytype can't recognize a cached property. return len(self.device_set) == len(self.addressable_devices) # type: ignore - @abc.abstractmethod def device_indices(self, device: Device, global_shape: Shape) -> Optional[Index]: - raise NotImplementedError('Subclasses should implement this method.') + return self.devices_indices_map(global_shape)[device] @abc.abstractmethod def devices_indices_map( @@ -100,6 +98,24 @@ def _check_mesh_resource_axis(mesh, parsed_pspec): "undefined.") from None +def _hashed_index(x) -> int: + # This works for both `pjit`/`xmap` indices and `pmap` indices (which might + # have an integer instead of a slice). + assert all(v.step is None for v in x if isinstance(v, slice)) + return hash(tuple((v.start, v.stop) if isinstance(v, slice) else v for v in x)) + + +def _device_replica_id_map(sharding, global_shape: Shape) -> Mapping[Device, int]: + index_to_replica: Dict[int, int] = Counter() + out = {} + for device, index in sharding.devices_indices_map(global_shape).items(): + h_index = _hashed_index(index) + replica_id = index_to_replica[h_index] + index_to_replica[h_index] += 1 + out[device] = replica_id + return out + + class MeshPspecSharding(XLACompatibleSharding): def __init__( @@ -161,9 +177,6 @@ def _from_parsed_pspec(cls, mesh, parsed_pspec): def device_set(self) -> Set[Device]: return set(self.mesh.devices.flat) - def device_indices(self, device: Device, global_shape: Shape) -> Optional[Index]: - return self.devices_indices_map(global_shape)[device] - def devices_indices_map( self, global_shape: Shape) -> Mapping[Device, Optional[Index]]: # TODO(yashkatariya): Remove this when utilities are moved to pxla.py. @@ -172,19 +185,9 @@ def devices_indices_map( # `get_shard_indices` is cached. return global_device_array.get_shard_indices(global_shape, self.mesh, self.spec) - def _hashed_index(self, x) -> int: - return hash(tuple((v.start, v.stop) for v in x)) - @functools.lru_cache(maxsize=4096) def device_replica_id_map(self, global_shape: Shape) -> Mapping[Device, int]: - index_to_replica: Dict[int, int] = Counter() - out = {} - for device, index in self.devices_indices_map(global_shape).items(): - h_index = self._hashed_index(index) - replica_id = index_to_replica[h_index] - index_to_replica[h_index] += 1 - out[device] = replica_id - return out + return _device_replica_id_map(self, global_shape) @functools.lru_cache(maxsize=4096) def _device_assignment(self) -> XLADeviceAssignment: @@ -240,9 +243,6 @@ def normalize(self): def device_set(self) -> Set[Device]: return {self._device} - def device_indices(self, device: Device, global_shape: Shape) -> Optional[Index]: - return self.devices_indices_map(global_shape)[device] - @functools.lru_cache(maxsize=4096) def devices_indices_map( self, global_shape: Shape) -> Mapping[Device, Optional[Index]]: @@ -279,9 +279,6 @@ def normalize(self): def device_set(self) -> Set[Device]: return set(self.devices.flat) - def device_indices(self, device: Device, global_shape: Shape) -> Optional[Index]: - return self.devices_indices_map(global_shape)[device] - @pxla.maybe_cached_property def sharded_dim(self): for i, s in enumerate(self.sharding_spec.sharding): @@ -295,20 +292,9 @@ def devices_indices_map( indices = pxla.spec_to_indices(global_shape, self.sharding_spec) return {d: i for d, i in safe_zip(self.devices.flat, indices)} # type: ignore - def _hashed_index(self, x) -> int: - return hash( - tuple((v.start, v.stop) if isinstance(v, slice) else v for v in x)) - @functools.lru_cache(maxsize=4096) def device_replica_id_map(self, global_shape: Shape) -> Mapping[Device, int]: - index_to_replica: Dict[int, int] = Counter() - out = {} - for device, index in self.devices_indices_map(global_shape).items(): - h_index = self._hashed_index(index) - replica_id = index_to_replica[h_index] - index_to_replica[h_index] += 1 - out[device] = replica_id - return out + return _device_replica_id_map(self, global_shape) @functools.lru_cache(maxsize=4096) def _device_assignment(self) -> XLADeviceAssignment: @@ -316,3 +302,57 @@ def _device_assignment(self) -> XLADeviceAssignment: def _to_xla_op_sharding(self, num_dimensions: int) -> xc.OpSharding: raise NotImplementedError("pmap doesn't use OpSharding.") + + +class OpShardingSharding(XLACompatibleSharding): + + def __init__(self, devices: Sequence[Device], op_sharding: xc.OpSharding): + self._devices = devices + self._op_sharding = op_sharding + + def __eq__(self, other): + if not isinstance(other, OpShardingSharding): + return False + return pxla.are_op_shardings_equal(self, other) + + def __hash__(self): + if not hasattr(self, '_hash'): + # TODO(yashkatariya): Write a hash function that's backwards compatible + # for `xla_extension_version` < 81. + self._hash = hash(xc.HloSharding.from_proto(self._op_sharding)) + return self._hash + + def __repr__(self): + return self._op_sharding + + def normalize(self, *_): + return self + + def is_compatible_aval(self, aval_shape: Shape): + num_ways_dim_sharded, _ = pxla._get_num_ways_dim_sharded(self._op_sharding) + if len(aval_shape) < len(num_ways_dim_sharded): + raise ValueError( + f"Sharding {self} is only valid for values of rank at least " + f"{len(num_ways_dim_sharded)}, but was applied to a value of rank " + f"{len(aval_shape)}") + + @pxla.maybe_cached_property + def device_set(self) -> Set[Device]: + return set(self._devices) + + @functools.lru_cache(maxsize=4096) + def devices_indices_map( + self, global_shape: Shape) -> Mapping[Device, Optional[Index]]: + indices = pxla.op_sharding_to_indices(self._op_sharding, global_shape, + len(self._devices)) + return dict(safe_zip(self._devices, indices)) + + @functools.lru_cache(maxsize=4096) + def device_replica_id_map(self, global_shape: Shape) -> Mapping[Device, int]: + return _device_replica_id_map(self, global_shape) + + def _device_assignment(self) -> XLADeviceAssignment: + return list(self._devices) + + def _to_xla_op_sharding(self, num_dimensions: int) -> xc.OpSharding: + return self._op_sharding diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 4a60b15ff526..2d9d3cd8a4c0 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -205,6 +205,57 @@ def sharding_spec_sharding_proto(self, special_axes: Mapping[int, OpShardingType proto.tile_assignment_devices = list(proto_mesh.flat) return proto + +def _get_num_ways_dim_sharded(op_sharding: xc.OpSharding) -> Tuple[Sequence[int], int]: + partitions = op_sharding.tile_assignment_dimensions + if op_sharding.last_tile_dims == [xc.OpSharding.Type.REPLICATED]: + replicate_on_last_tile_dim = True + else: + replicate_on_last_tile_dim = op_sharding.replicate_on_last_tile_dim + if op_sharding.last_tile_dims: + raise NotImplementedError("Unhandled OpSharding type. Please open a bug report!") + num_replicas = 1 + if replicate_on_last_tile_dim: + num_replicas = partitions[-1] + partitions = partitions[:-1] + return partitions, num_replicas + + +def op_sharding_to_indices(op_sharding: xc.OpSharding, shape: Tuple[int, ...], + num_devices: int): + # num_devices is required as an argument when op_sharding is of type + # REPLICATED. `jax.device_count()` cannot be used because you can create + # an opsharding with less number of devices than `jax.device_count()`. + if op_sharding.type == xc.OpSharding.Type.REPLICATED: + # xb.device_count maybe not be always right as you can use less devices than + # what's available. + return tuple((slice(None),) * len(shape) for _ in range(num_devices)) + + assert num_devices == len(op_sharding.tile_assignment_devices) + + partitions, num_replicas = _get_num_ways_dim_sharded(op_sharding) + assert len(partitions) == len(shape), (len(partitions), len(shape)) + + axis_indices: List[Sequence[Index]] = [] + for dim, n_shards in zip(shape, partitions): + if n_shards == 1: + axis_indices.append([slice(None)]) + elif n_shards > 1: + shard_size, ragged = divmod(dim, n_shards) + assert not ragged, (dim, n_shards, dim) + axis_indices.append([slice(i * shard_size, (i + 1) * shard_size) + for i in range(n_shards)]) + else: + raise AssertionError('Unrecognized number of shards. Please file a bug!') + + indices = np.empty(num_devices, dtype=np.object_) + device_it = iter(op_sharding.tile_assignment_devices) + for i, idxs in enumerate(it.product(*axis_indices)): + for _ in range(num_replicas): + indices[next(device_it)] = idxs + return tuple(indices.flat) + + def sharding_spec_indices(self, shape: Tuple[int, ...]) -> np.ndarray: """Returns NumPy-style indices corresponding to a sharding spec. diff --git a/tests/array_test.py b/tests/array_test.py index 00f97915f277..60200ac0725c 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -177,6 +177,24 @@ def test_mesh_pspec_sharding_interface(self): self.assertListEqual(op_sharding.tile_assignment_devices, [0, 2, 4, 6, 1, 3, 5, 7]) + @parameterized.named_parameters( + ("mesh_x_y", P("x", "y")), + ("mesh_x", P("x")), + ("mesh_y", P("y")), + ("mesh_none_y", P(None, "y")), + ("mesh_none_x", P(None, "x")), + ("mesh_xy", P(("x", "y"))), + ("mesh_fully_replicated", P()), + ) + def test_op_sharding_indices(self, pspec): + shape = (8, 4) + mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mps = sharding.MeshPspecSharding(mesh, pspec) + ops = sharding.OpShardingSharding( + list(mesh.devices.flat), mps._to_xla_op_sharding(len(shape))) + self.assertDictEqual( + ops.devices_indices_map(shape), mps.devices_indices_map(shape)) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())