Skip to content

Commit

Permalink
Add OpShardingSharding and add a function that can calculate indice…
Browse files Browse the repository at this point in the history
…s from an opsharding proto

PiperOrigin-RevId: 464123009
  • Loading branch information
yashk2810 authored and jax authors committed Jul 29, 2022
1 parent cc19c94 commit 2178f33
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 51 deletions.
20 changes: 4 additions & 16 deletions jax/experimental/pjit.py
Expand Up @@ -584,21 +584,6 @@ def _pjit_jaxpr(fun, out_shardings_thunk, global_in_avals, out_tree):
return _ListWithW([jaxpr, normalized_out_shardings_flat])


def _get_num_ways_dim_sharded(s, aval_shape) -> List[int]:
op_sharding = s._to_xla_op_sharding(len(aval_shape))
tile_assignment_dimensions = op_sharding.tile_assignment_dimensions

if op_sharding.last_tile_dims == [xc.OpSharding.Type.REPLICATED]:
replicate_on_last_tile_dim = True
else:
replicate_on_last_tile_dim = op_sharding.replicate_on_last_tile_dim
if op_sharding.last_tile_dims:
raise NotImplementedError("Unhandled OpSharding type. Please open a bug report!")
if replicate_on_last_tile_dim:
tile_assignment_dimensions = tile_assignment_dimensions[:-1]
return tile_assignment_dimensions


def pjit_check_aval_sharding(
shardings: Sequence[XLACompatibleSharding], flat_avals, what_aval: str,
allow_uneven_sharding: bool):
Expand All @@ -615,7 +600,10 @@ def pjit_check_aval_sharding(
# Use the `OpSharding` proto to find out how many ways each dimension of
# the aval is sharded. This approach will work across all
# XLACompatibleSharding.
num_ways_dim_sharded = _get_num_ways_dim_sharded(s, shape)
op_sharding = s._to_xla_op_sharding(len(shape))
assert op_sharding is not None
num_ways_dim_sharded, _ = pxla._get_num_ways_dim_sharded(
cast(xc.OpSharding, op_sharding))
for i, size in enumerate(num_ways_dim_sharded):
if not allow_uneven_sharding and shape[i] % size != 0:
raise ValueError(f"One of {what_aval} was given the sharding "
Expand Down
110 changes: 75 additions & 35 deletions jax/experimental/sharding.py
Expand Up @@ -17,7 +17,6 @@
from collections import Counter
from typing import Sequence, Tuple, Optional, Mapping, Dict, Set, Union

from jax._src.config import config
from jax._src.util import safe_zip
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
Expand Down Expand Up @@ -51,10 +50,9 @@ def is_fully_addressable(self) -> bool:
# The pytype disable is because pytype can't recognize a cached property.
return len(self.device_set) == len(self.addressable_devices) # type: ignore

@abc.abstractmethod
def device_indices(self, device: Device,
global_shape: Shape) -> Optional[Index]:
raise NotImplementedError('Subclasses should implement this method.')
return self.devices_indices_map(global_shape)[device]

@abc.abstractmethod
def devices_indices_map(
Expand Down Expand Up @@ -100,6 +98,24 @@ def _check_mesh_resource_axis(mesh, parsed_pspec):
"undefined.") from None


def _hashed_index(x) -> int:
# This works for both `pjit`/`xmap` indices and `pmap` indices (which might
# have an integer instead of a slice).
assert all(v.step is None for v in x if isinstance(v, slice))
return hash(tuple((v.start, v.stop) if isinstance(v, slice) else v for v in x))


def _device_replica_id_map(sharding, global_shape: Shape) -> Mapping[Device, int]:
index_to_replica: Dict[int, int] = Counter()
out = {}
for device, index in sharding.devices_indices_map(global_shape).items():
h_index = _hashed_index(index)
replica_id = index_to_replica[h_index]
index_to_replica[h_index] += 1
out[device] = replica_id
return out


class MeshPspecSharding(XLACompatibleSharding):

def __init__(
Expand Down Expand Up @@ -161,9 +177,6 @@ def _from_parsed_pspec(cls, mesh, parsed_pspec):
def device_set(self) -> Set[Device]:
return set(self.mesh.devices.flat)

def device_indices(self, device: Device, global_shape: Shape) -> Optional[Index]:
return self.devices_indices_map(global_shape)[device]

def devices_indices_map(
self, global_shape: Shape) -> Mapping[Device, Optional[Index]]:
# TODO(yashkatariya): Remove this when utilities are moved to pxla.py.
Expand All @@ -172,19 +185,9 @@ def devices_indices_map(
# `get_shard_indices` is cached.
return global_device_array.get_shard_indices(global_shape, self.mesh, self.spec)

def _hashed_index(self, x) -> int:
return hash(tuple((v.start, v.stop) for v in x))

@functools.lru_cache(maxsize=4096)
def device_replica_id_map(self, global_shape: Shape) -> Mapping[Device, int]:
index_to_replica: Dict[int, int] = Counter()
out = {}
for device, index in self.devices_indices_map(global_shape).items():
h_index = self._hashed_index(index)
replica_id = index_to_replica[h_index]
index_to_replica[h_index] += 1
out[device] = replica_id
return out
return _device_replica_id_map(self, global_shape)

@functools.lru_cache(maxsize=4096)
def _device_assignment(self) -> XLADeviceAssignment:
Expand Down Expand Up @@ -240,9 +243,6 @@ def normalize(self):
def device_set(self) -> Set[Device]:
return {self._device}

def device_indices(self, device: Device, global_shape: Shape) -> Optional[Index]:
return self.devices_indices_map(global_shape)[device]

@functools.lru_cache(maxsize=4096)
def devices_indices_map(
self, global_shape: Shape) -> Mapping[Device, Optional[Index]]:
Expand Down Expand Up @@ -279,9 +279,6 @@ def normalize(self):
def device_set(self) -> Set[Device]:
return set(self.devices.flat)

def device_indices(self, device: Device, global_shape: Shape) -> Optional[Index]:
return self.devices_indices_map(global_shape)[device]

@pxla.maybe_cached_property
def sharded_dim(self):
for i, s in enumerate(self.sharding_spec.sharding):
Expand All @@ -295,24 +292,67 @@ def devices_indices_map(
indices = pxla.spec_to_indices(global_shape, self.sharding_spec)
return {d: i for d, i in safe_zip(self.devices.flat, indices)} # type: ignore

def _hashed_index(self, x) -> int:
return hash(
tuple((v.start, v.stop) if isinstance(v, slice) else v for v in x))

@functools.lru_cache(maxsize=4096)
def device_replica_id_map(self, global_shape: Shape) -> Mapping[Device, int]:
index_to_replica: Dict[int, int] = Counter()
out = {}
for device, index in self.devices_indices_map(global_shape).items():
h_index = self._hashed_index(index)
replica_id = index_to_replica[h_index]
index_to_replica[h_index] += 1
out[device] = replica_id
return out
return _device_replica_id_map(self, global_shape)

@functools.lru_cache(maxsize=4096)
def _device_assignment(self) -> XLADeviceAssignment:
return list(self.devices.flat)

def _to_xla_op_sharding(self, num_dimensions: int) -> xc.OpSharding:
raise NotImplementedError("pmap doesn't use OpSharding.")


class OpShardingSharding(XLACompatibleSharding):

def __init__(self, devices: Sequence[Device], op_sharding: xc.OpSharding):
self._devices = devices
self._op_sharding = op_sharding

def __eq__(self, other):
if not isinstance(other, OpShardingSharding):
return False
return pxla.are_op_shardings_equal(self, other)

def __hash__(self):
if not hasattr(self, '_hash'):
# TODO(yashkatariya): Write a hash function that's backwards compatible
# for `xla_extension_version` < 81.
self._hash = hash(xc.HloSharding.from_proto(self._op_sharding))
return self._hash

def __repr__(self):
return self._op_sharding

def normalize(self, *_):
return self

def is_compatible_aval(self, aval_shape: Shape):
num_ways_dim_sharded, _ = pxla._get_num_ways_dim_sharded(self._op_sharding)
if len(aval_shape) < len(num_ways_dim_sharded):
raise ValueError(
f"Sharding {self} is only valid for values of rank at least "
f"{len(num_ways_dim_sharded)}, but was applied to a value of rank "
f"{len(aval_shape)}")

@pxla.maybe_cached_property
def device_set(self) -> Set[Device]:
return set(self._devices)

@functools.lru_cache(maxsize=4096)
def devices_indices_map(
self, global_shape: Shape) -> Mapping[Device, Optional[Index]]:
indices = pxla.op_sharding_to_indices(self._op_sharding, global_shape,
len(self._devices))
return dict(safe_zip(self._devices, indices))

@functools.lru_cache(maxsize=4096)
def device_replica_id_map(self, global_shape: Shape) -> Mapping[Device, int]:
return _device_replica_id_map(self, global_shape)

def _device_assignment(self) -> XLADeviceAssignment:
return list(self._devices)

def _to_xla_op_sharding(self, num_dimensions: int) -> xc.OpSharding:
return self._op_sharding
51 changes: 51 additions & 0 deletions jax/interpreters/pxla.py
Expand Up @@ -205,6 +205,57 @@ def sharding_spec_sharding_proto(self, special_axes: Mapping[int, OpShardingType
proto.tile_assignment_devices = list(proto_mesh.flat)
return proto


def _get_num_ways_dim_sharded(op_sharding: xc.OpSharding) -> Tuple[Sequence[int], int]:
partitions = op_sharding.tile_assignment_dimensions
if op_sharding.last_tile_dims == [xc.OpSharding.Type.REPLICATED]:
replicate_on_last_tile_dim = True
else:
replicate_on_last_tile_dim = op_sharding.replicate_on_last_tile_dim
if op_sharding.last_tile_dims:
raise NotImplementedError("Unhandled OpSharding type. Please open a bug report!")
num_replicas = 1
if replicate_on_last_tile_dim:
num_replicas = partitions[-1]
partitions = partitions[:-1]
return partitions, num_replicas


def op_sharding_to_indices(op_sharding: xc.OpSharding, shape: Tuple[int, ...],
num_devices: int):
# num_devices is required as an argument when op_sharding is of type
# REPLICATED. `jax.device_count()` cannot be used because you can create
# an opsharding with less number of devices than `jax.device_count()`.
if op_sharding.type == xc.OpSharding.Type.REPLICATED:
# xb.device_count maybe not be always right as you can use less devices than
# what's available.
return tuple((slice(None),) * len(shape) for _ in range(num_devices))

assert num_devices == len(op_sharding.tile_assignment_devices)

partitions, num_replicas = _get_num_ways_dim_sharded(op_sharding)
assert len(partitions) == len(shape), (len(partitions), len(shape))

axis_indices: List[Sequence[Index]] = []
for dim, n_shards in zip(shape, partitions):
if n_shards == 1:
axis_indices.append([slice(None)])
elif n_shards > 1:
shard_size, ragged = divmod(dim, n_shards)
assert not ragged, (dim, n_shards, dim)
axis_indices.append([slice(i * shard_size, (i + 1) * shard_size)
for i in range(n_shards)])
else:
raise AssertionError('Unrecognized number of shards. Please file a bug!')

indices = np.empty(num_devices, dtype=np.object_)
device_it = iter(op_sharding.tile_assignment_devices)
for i, idxs in enumerate(it.product(*axis_indices)):
for _ in range(num_replicas):
indices[next(device_it)] = idxs
return tuple(indices.flat)


def sharding_spec_indices(self, shape: Tuple[int, ...]) -> np.ndarray:
"""Returns NumPy-style indices corresponding to a sharding spec.
Expand Down
18 changes: 18 additions & 0 deletions tests/array_test.py
Expand Up @@ -177,6 +177,24 @@ def test_mesh_pspec_sharding_interface(self):
self.assertListEqual(op_sharding.tile_assignment_devices,
[0, 2, 4, 6, 1, 3, 5, 7])

@parameterized.named_parameters(
("mesh_x_y", P("x", "y")),
("mesh_x", P("x")),
("mesh_y", P("y")),
("mesh_none_y", P(None, "y")),
("mesh_none_x", P(None, "x")),
("mesh_xy", P(("x", "y"))),
("mesh_fully_replicated", P()),
)
def test_op_sharding_indices(self, pspec):
shape = (8, 4)
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
mps = sharding.MeshPspecSharding(mesh, pspec)
ops = sharding.OpShardingSharding(
list(mesh.devices.flat), mps._to_xla_op_sharding(len(shape)))
self.assertDictEqual(
ops.devices_indices_map(shape), mps.devices_indices_map(shape))


if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 2178f33

Please sign in to comment.