Skip to content

Commit

Permalink
Enable the debugging_primitives pjit(xmap) case. Also don't check for…
Browse files Browse the repository at this point in the history
… sharding mismatch when the array is not committed. Check the device assignment only for committed arrays.

PiperOrigin-RevId: 474598597
  • Loading branch information
yashk2810 authored and jax authors committed Sep 15, 2022
1 parent 9791199 commit 60ec541
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 34 deletions.
45 changes: 34 additions & 11 deletions jax/experimental/array.py
Expand Up @@ -524,21 +524,44 @@ def _array_pmap_shard_arg(x, devices, indices, mode):
return pxla._shard_sharded_device_array_slow_path(x, devices, indices, mode)


def _array_rest_shard_arg(x, devices, indices, mode):
if not x._committed:
if dispatch.is_single_device_sharding(x.sharding):
# This condition is to break the recursion that happens when only
# `pxla._shard_device_array` is used since it has `_multi_slice` in the
# implementation which is jitted. Eventually it calls back here and the
# recursion happens.
if len(devices) == 1:
return [buf if buf.device() == d else buf.copy_to_device(d)
for buf, d in safe_zip(x._arrays, devices)]
return pxla._shard_device_array(x, devices, indices, mode)
else:
raise NotImplementedError('Resharding uncommitted arrays sharded over '
'multiple devices is not supported.')
# TODO(yashkatariya): Remove the special case here and don't move to another
# device if its already committed. There is a TODO in dispatch.py already
# for this.
if dispatch.is_single_device_sharding(x.sharding):
return [buf if buf.device() == d else buf.copy_to_device(d)
for buf, d in safe_zip(x._arrays, devices)]
# If PmapSharding exists, then do a round trip via host. This will happen
# if the input Array containing PmapSharding takes the jit path
# i.e. `apply_primitive` or `xla_callable_uncached`. `jit(pmap)` is the most
# common case where this will happen.
# TODO(yashkatariya): Remove the special case here and don't move to another
# device if its already committed. There is a TODO in dispatch.py already
# for this.
elif isinstance(x.sharding, PmapSharding):
return pxla.device_put(x._value, devices, replicate=True)
else:
return x._arrays


def _array_shard_arg(x, devices, indices, mode):
if mode == pxla.InputsHandlerMode.pmap:
return _array_pmap_shard_arg(x, devices, indices, mode)
else:
if dispatch.is_single_device_sharding(x.sharding):
return [buf if buf.device() == d else buf.copy_to_device(d)
for buf, d in safe_zip(x._arrays, devices)]
# If PmapSharding exists, then do a round trip via host. This will happen
# if the input Array containing PmapSharding takes the jit path
# i.e. `apply_primitive` or `xla_callable_uncached`. `jit(pmap)` is the most
# common case where this will happen.
elif isinstance(x.sharding, PmapSharding):
return pxla.device_put(x._value, devices, replicate=True)
else:
return x._arrays
return _array_rest_shard_arg(x, devices, indices, mode)
pxla.shard_arg_handlers[Array] = _array_shard_arg


Expand Down
40 changes: 26 additions & 14 deletions jax/experimental/pjit.py
Expand Up @@ -581,11 +581,13 @@ def _pjit_jaxpr(fun, out_shardings_thunk, global_in_avals, out_tree):


def pjit_check_aval_sharding(
shardings: Sequence[XLACompatibleSharding], flat_avals, what_aval: str,
allow_uneven_sharding: bool):
shardings, flat_avals, what_aval: str, allow_uneven_sharding: bool):
for aval, s in zip(flat_avals, shardings):
if _is_unspecified_or_from_gda_or_auto(s):
continue
if not isinstance(s, XLACompatibleSharding):
raise ValueError(f'One of {what_aval} got sharding {s} which is not a '
'subclass of XLACompatibleSharding.')
global_str = "" if s.is_fully_addressable else " global"
shape = aval.shape
try:
Expand Down Expand Up @@ -781,24 +783,35 @@ def _check_unique_resources(axis_resources, arg_name):


def _resolve_in_shardings(args, pjit_in_shardings, out_shardings, pjit_mesh):
arg_shardings = tuple(a.sharding if hasattr(a, 'sharding') else _UNSPECIFIED
for a in args)
arg_ndims = tuple(a.ndim if hasattr(a, 'ndim') else 0 for a in args)
committed_arg_shardings = []
for a in args:
if hasattr(a, 'sharding'):
arg_s = a.sharding
if not isinstance(arg_s, XLACompatibleSharding):
raise ValueError(f'One of the argument to pjit got sharding {arg_s} '
'which is not a subclass of XLACompatibleSharding.')
if a._committed:
committed_arg_shardings.append(arg_s)

da = _get_and_check_device_assignment(
it.chain(arg_shardings, pjit_in_shardings, out_shardings), pjit_mesh)
it.chain(
committed_arg_shardings, pjit_in_shardings, out_shardings),
pjit_mesh)

resolved_in_shardings = []
for arg_s, pjit_in_s, ndim in safe_zip(arg_shardings, pjit_in_shardings, arg_ndims):
for arg, pjit_in_s in safe_zip(args, pjit_in_shardings):
arg_s, committed = ((arg.sharding, arg._committed)
if hasattr(arg, 'sharding') else (_UNSPECIFIED, False))
if _is_unspecified(pjit_in_s):
if _is_unspecified(arg_s):
resolved_in_shardings.append(OpShardingSharding.get_replicated(da))
else:
resolved_in_shardings.append(to_op_sharding_sharding(arg_s, ndim))
resolved_in_shardings.append(to_op_sharding_sharding(arg_s, arg.ndim))
else:
if not _is_unspecified(arg_s):
if not pxla.are_op_shardings_equal(
pjit_in_s._to_xla_op_sharding(ndim),
arg_s._to_xla_op_sharding(ndim)):
if committed and not pxla.are_op_shardings_equal(
pjit_in_s._to_xla_op_sharding(arg.ndim),
arg_s._to_xla_op_sharding(arg.ndim)):
raise ValueError('Sharding passed to pjit does not match the sharding '
'on the respective arg. '
f'Got pjit sharding: {pjit_in_s},\n'
Expand Down Expand Up @@ -1400,8 +1413,8 @@ def get_array_mapping(
def to_op_sharding_sharding(s, ndim):
if isinstance(s, OpShardingSharding):
return s
op_sharding_sharding = OpShardingSharding(s._device_assignment,
s._to_xla_op_sharding(ndim))
op_sharding_sharding = OpShardingSharding(
s._device_assignment, s._to_xla_op_sharding(ndim)) # type: ignore
op_sharding_sharding._original_sharding = s
return op_sharding_sharding

Expand Down Expand Up @@ -1501,7 +1514,6 @@ def _gda_check_and_get_sharding(
return tuple(out)


@lru_cache(maxsize=4096)
def _get_and_check_device_assignment(shardings, pjit_mesh):
first_device_assignment = None
mesh_devices = list(pjit_mesh.devices.flat)
Expand Down
18 changes: 9 additions & 9 deletions tests/debugging_primitives_test.py
Expand Up @@ -738,12 +738,6 @@ def body(carry):

@jtu.skip_on_devices(*disabled_backends)
def test_unordered_print_of_pjit_of_xmap(self):
# TODO(https://github.com/google/jax/issues/12016): Make xmap work properly
# with Arrays of different
# sharding.
if config.jax_array:
raise unittest.SkipTest('Does not work with Array.')

if (jax.default_backend() in {"cpu", "gpu"}
and jaxlib.xla_extension_version < 81):
raise unittest.SkipTest("`pjit` of callback not supported.")
Expand All @@ -756,9 +750,15 @@ def foo(x):
out = maps.xmap(foo, in_axes=['foo'], out_axes=[...])(x)
debug_print("Out: {}", out)
return out
f = pjit.pjit(f, in_axis_resources=pjit.PartitionSpec('dev'),
out_axis_resources=pjit.PartitionSpec())
with maps.Mesh(np.array(jax.devices()), ['dev']):
mesh = maps.Mesh(np.array(jax.devices()), ['dev'])
if config.jax_array:
in_spec = sharding.MeshPspecSharding(mesh, pjit.PartitionSpec('dev'))
out_spec = sharding.MeshPspecSharding(mesh, pjit.PartitionSpec())
else:
in_spec = pjit.PartitionSpec('dev')
out_spec = pjit.PartitionSpec()
f = pjit.pjit(f, in_axis_resources=in_spec, out_axis_resources=out_spec)
with mesh:
with jtu.capture_stdout() as output:
f(jnp.arange(8, dtype=jnp.int32) * 2)
lines = ["0: 0", "1: 2", "2: 4", "3: 6", "4: 8", "5: 10", "6: 12",
Expand Down
41 changes: 41 additions & 0 deletions tests/pjit_test.py
Expand Up @@ -2103,6 +2103,47 @@ def test_fast_path_array(self):
self.assertArraysEqual([o.device() for o in out2._arrays], list(mesh.devices.flat))
self.assertArraysEqual(out2, inp_data)

@jax_array(True)
def test_not_xlacompatible_sharding_error(self):
shape = (8, 2)
inp_data = np.arange(prod(shape)).reshape(shape)
ts = TempSharding(jax.devices())
arr = array.make_array_from_callback(
shape, ts, lambda idx: inp_data[idx])
with self.assertRaisesRegex(
ValueError,
'One of the argument to pjit got sharding.*which is not a subclass of '
'XLACompatibleSharding.'):
pjit(lambda x: x)(arr)

with self.assertRaisesRegex(
ValueError,
'One of pjit arguments got sharding.*which is not a subclass of '
'XLACompatibleSharding.'):
pjit(lambda x: x, in_axis_resources=ts)(arr)

with self.assertRaisesRegex(
ValueError,
'One of pjit outputs got sharding.*which is not a subclass of '
'XLACompatibleSharding.'):
pjit(lambda x: x, out_axis_resources=ts)(arr)


class TempSharding(Sharding):

def __init__(self, devices):
self._devices = devices

@property
def device_set(self):
return set(self._devices)

def devices_indices_map(self, global_shape):
return {d: (slice(None),) * len(global_shape) for d in self.device_set}

def shard_shape(self, global_shape):
return global_shape


def spec_regex(s):
return str(s).replace(r"(", r"\(").replace(r")", r"\)")
Expand Down

0 comments on commit 60ec541

Please sign in to comment.