Skip to content

Commit

Permalink
Check if the buffer shape matches the excepted shard shape by Array.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 470732792
  • Loading branch information
yashk2810 authored and jax authors committed Aug 29, 2022
1 parent 6674b14 commit 70a7ee2
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 17 deletions.
24 changes: 17 additions & 7 deletions jax/experimental/array.py
Expand Up @@ -118,22 +118,32 @@ def __init__(self, aval: core.ShapedArray, sharding: Sharding,
self._committed = committed
self._npy_value = None

# TODO(yashkatariya): Add a check here which checks if the expected shard
# shape matches the shape of _arrays. A similar check exists for GDA.
if not _skip_checks or config.jax_enable_checks:
ss = self.sharding.shard_shape(self.shape)
for db in self._arrays:
if db.shape != ss:
raise ValueError(
f"Expected shard shape {ss} doesn't match the buffer "
f"shape {db.shape} for buffer: {db}")

if not _skip_checks or config.jax_enable_checks:
assert all(db.dtype == self.dtype for db in self._arrays), (
"Input arrays to `Array` must have matching dtypes, "
f"got: {[db.dtype for db in self._arrays]}, aval type: {self.dtype}")
for db in self._arrays:
if db.dtype != self.dtype:
raise ValueError(
"Input buffers to `Array` must have matching dtypes. "
f"Got {db.dtype}, expected {self.dtype} for buffer: {db}")

# Don't rearrange if skip_checks is enabled because this assumes that the
# input buffers are already arranged properly. This usually happens when
# Array's are created as output of a JAX transformation
# (like pjit, xmap, etc).
if not _skip_checks:
addressable_device_assignment = self.sharding._addressable_device_assignment
# Rearrange arrays based on the device assignment.
# TODO(yashkatariya): Add a similar check for shardings that are not
# XLACompatibleSharding. But leave the rearragement to XLACompatibleSharding
# only.
if isinstance(sharding, XLACompatibleSharding):
addressable_device_assignment = self.sharding._addressable_device_assignment
if len(self._arrays) != len(addressable_device_assignment):
raise ValueError(
f"Expected {len(addressable_device_assignment)} per-device arrays "
Expand Down Expand Up @@ -271,7 +281,7 @@ def __dlpack__(self):
return to_dlpack(self)

def __reduce__(self):
fun, args, arr_state = self._value.__reduce__()
fun, args, arr_state = self._value.__reduce__() # type: ignore
aval_state = {'weak_type': self.aval.weak_type,
'named_shape': self.aval.named_shape}
return (_reconstruct_array, (fun, args, arr_state, aval_state))
Expand Down
63 changes: 54 additions & 9 deletions jax/experimental/sharding.py
Expand Up @@ -15,7 +15,7 @@
import abc
import functools
from collections import Counter
from typing import Sequence, Tuple, Optional, Mapping, Dict, Set, Union
from typing import Sequence, Tuple, Optional, Mapping, Dict, Set, Union, cast

from jax._src.util import safe_zip
from jax._src.lib import xla_bridge as xb
Expand All @@ -33,6 +33,8 @@

class Sharding(metaclass=abc.ABCMeta):

# Abstract methods below that subclasses should implement.

@abc.abstractproperty
def device_set(self) -> Set[Device]:
"""A unique set of devices that this sharding represents.
Expand All @@ -41,6 +43,18 @@ def device_set(self) -> Set[Device]:
"""
raise NotImplementedError('Subclasses should implement this method.')

@abc.abstractmethod
def devices_indices_map(
self, global_shape: Shape) -> Mapping[Device, Optional[Index]]:
raise NotImplementedError('Subclasses should implement this method.')

@abc.abstractmethod
def shard_shape(self, global_shape: Shape) -> Shape:
raise NotImplementedError('Subclasses should implement this method.')

#############################################################################
# Default implementations below that all subclasses will inherit.

@pxla.maybe_cached_property
def addressable_devices(self) -> Set[Device]:
"""A set of addressable devices by the current process"""
Expand All @@ -55,14 +69,11 @@ def device_indices(self, device: Device,
global_shape: Shape) -> Optional[Index]:
return self.devices_indices_map(global_shape)[device]

@abc.abstractmethod
def devices_indices_map(
self, global_shape: Shape) -> Mapping[Device, Optional[Index]]:
raise NotImplementedError('Subclasses should implement this method.')


class XLACompatibleSharding(Sharding):

# Abstract methods below that subclasses should implement.

@abc.abstractproperty
def _device_assignment(self) -> XLADeviceAssignment:
raise NotImplementedError('Subclasses should implement this method.')
Expand All @@ -71,14 +82,37 @@ def _device_assignment(self) -> XLADeviceAssignment:
def device_replica_id_map(self, global_shape: Shape) -> Mapping[Device, int]:
raise NotImplementedError('Subclasses should implement this method.')

@abc.abstractmethod
def _to_xla_op_sharding(self, num_dimensions: int) -> Optional[xc.OpSharding]:
raise NotImplementedError('Subclasses should implement this method.')

#############################################################################
# Default implementations below that all subclasses will inherit.

@pxla.maybe_cached_property
def _addressable_device_assignment(self) -> XLADeviceAssignment:
process_index = xb.process_index()
return [d for d in self._device_assignment if d.process_index == process_index]

@abc.abstractmethod
def _to_xla_op_sharding(self, num_dimensions: int) -> Optional[xc.OpSharding]:
raise NotImplementedError('Subclasses should implement this method.')
@functools.lru_cache(maxsize=4096)
def shard_shape(self, global_shape: Shape) -> Shape:
op_sharding = cast(xc.OpSharding, self._to_xla_op_sharding(len(global_shape)))
if pxla.is_op_sharding_replicated(op_sharding):
return global_shape
partitions, _ = pxla._get_num_ways_dim_sharded(op_sharding)
assert len(partitions) == len(global_shape), (len(partitions), len(global_shape))
out = []
for dim, (s, p) in enumerate(safe_zip(global_shape, partitions)):
quotient, remainder = divmod(s, p)
if remainder != 0:
raise ValueError(
f"Sharding {self} implies that array axis {dim} is partitioned "
f"{p} times, but the dimension size is {s} "
f"(full shape: {global_shape}, "
f"per-dimension tiling factors: {partitions} should evenly divide "
"the shape)")
out.append(quotient)
return tuple(out)


@functools.lru_cache()
Expand Down Expand Up @@ -274,6 +308,17 @@ def _device_assignment(self) -> XLADeviceAssignment:
def _to_xla_op_sharding(self, num_dimensions: int) -> xc.OpSharding:
raise NotImplementedError("pmap doesn't use OpSharding.")

@functools.lru_cache(maxsize=4096)
def shard_shape(self, global_shape: Shape) -> Shape:
sharded_dim = None
for i, s in enumerate(self.sharding_spec.sharding):
if isinstance(s, pxla.Unstacked):
sharded_dim = i
break
if sharded_dim is None:
return global_shape
return global_shape[:sharded_dim] + global_shape[sharded_dim+1:]


# TODO(yashkatariya): Remove this when minimum_jaxlib version is 0.3.17
def _hash_op_sharding(op: xc.OpSharding):
Expand Down
62 changes: 61 additions & 1 deletion tests/array_test.py
Expand Up @@ -224,7 +224,7 @@ def test_arrays_not_in_device_assignment(self):
self.skipTest('Requires more than 4 devices')
shape = (8, 2)
mesh = jtu.create_global_mesh((1, 2), ('x', 'y'))
s = sharding.MeshPspecSharding(mesh, P('x', 'y'))
s = sharding.MeshPspecSharding(mesh, P('x'))
inp_data = np.arange(prod(shape), dtype=np.float32).reshape(shape)
bufs = [jax.device_put(inp_data, d) for d in jax.devices()[2:4]]
with self.assertRaisesRegex(
Expand All @@ -233,6 +233,42 @@ def test_arrays_not_in_device_assignment(self):
"not used in the specified sharding"):
array.Array(jax.ShapedArray(shape, np.float32), s, bufs, committed=True)

@parameterized.named_parameters(
("mesh_x_y", P("x", "y"), (2, 2)),
("mesh_x", P("x"), (2, 4)),
("mesh_y", P("y"), (4, 4)),
("mesh_none_y", P(None, "y"), (8, 2)),
("mesh_none_x", P(None, "x"), (8, 1)),
("mesh_xy", P(("x", "y")), (1, 4)),
)
def test_shard_shape_mismatch_with_buffer_shape(self, pspec, expected_shard_shape):
shape = (8, 4)
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
mps = sharding.MeshPspecSharding(mesh, pspec)
inp_data = np.arange(prod(shape)).reshape(shape)

str_expected_shard_shape = str(expected_shard_shape).replace(
r"(", r"\(").replace(r")", r"\)")
with self.assertRaisesRegex(
ValueError,
f"Expected shard shape {str_expected_shard_shape} doesn't match the "
"buffer shape"):
array.make_array_from_callback(shape, mps, lambda idx: inp_data)

@jax_config.jax_array(True)
def test_mismatch_dtype(self):
shape = (8, 2)
mesh = jtu.create_global_mesh((1, 2), ('x', 'y'))
s = sharding.MeshPspecSharding(mesh, P('x', 'y'))
inp_data = np.arange(prod(shape), dtype=np.int32).reshape(shape)
indices = s.devices_indices_map(shape)
bufs = [jax.device_put(inp_data[indices[d]], d) for d in mesh.local_devices]
with self.assertRaisesRegex(
ValueError,
"Input buffers to `Array` must have matching dtypes. "
"Got int32, expected float32"):
array.Array(jax.ShapedArray(shape, np.float32), s, bufs, committed=True)


class ShardingTest(jtu.JaxTestCase):

Expand Down Expand Up @@ -270,6 +306,30 @@ def test_op_sharding_indices(self, pspec):
self.assertDictEqual(
ops.devices_indices_map(shape), mps.devices_indices_map(shape))

@parameterized.named_parameters(
("mesh_x_y", P("x", "y"), (2, 2)),
("mesh_x", P("x"), (2, 4)),
("mesh_y", P("y"), (4, 4)),
("mesh_none_y", P(None, "y"), (8, 2)),
("mesh_none_x", P(None, "x"), (8, 1)),
("mesh_xy", P(("x", "y")), (1, 4)),
("mesh_fully_replicated", P(), (8, 4)),
)
def test_shard_shape(self, pspec, expected_shard_shape):
shape = (8, 4)
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
mps = sharding.MeshPspecSharding(mesh, pspec)
self.assertEqual(mps.shard_shape(shape), expected_shard_shape)

def test_uneven_shard_error(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
mps = sharding.MeshPspecSharding(mesh, P('x', 'y'))
with self.assertRaisesRegex(
ValueError,
r"Sharding.*implies that array axis 1 is partitioned 2 times, but the "
r"dimension size is 3 \(full shape: \(8, 3\), per-dimension tiling "
r"factors: \[4, 2\] should evenly divide the shape\)"):
mps.shard_shape((8, 3))

if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 70a7ee2

Please sign in to comment.