From 60ec5414ce4c4e0867ba3ce4f10e23833b8beedd Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 15 Sep 2022 10:33:31 -0700 Subject: [PATCH] Enable the debugging_primitives pjit(xmap) case. Also don't check for sharding mismatch when the array is not committed. Check the device assignment only for committed arrays. PiperOrigin-RevId: 474598597 --- jax/experimental/array.py | 45 ++++++++++++++++++++++-------- jax/experimental/pjit.py | 40 ++++++++++++++++---------- tests/debugging_primitives_test.py | 18 ++++++------ tests/pjit_test.py | 41 +++++++++++++++++++++++++++ 4 files changed, 110 insertions(+), 34 deletions(-) diff --git a/jax/experimental/array.py b/jax/experimental/array.py index f4f3f572a80f..ac87a020c92d 100644 --- a/jax/experimental/array.py +++ b/jax/experimental/array.py @@ -524,21 +524,44 @@ def _array_pmap_shard_arg(x, devices, indices, mode): return pxla._shard_sharded_device_array_slow_path(x, devices, indices, mode) +def _array_rest_shard_arg(x, devices, indices, mode): + if not x._committed: + if dispatch.is_single_device_sharding(x.sharding): + # This condition is to break the recursion that happens when only + # `pxla._shard_device_array` is used since it has `_multi_slice` in the + # implementation which is jitted. Eventually it calls back here and the + # recursion happens. + if len(devices) == 1: + return [buf if buf.device() == d else buf.copy_to_device(d) + for buf, d in safe_zip(x._arrays, devices)] + return pxla._shard_device_array(x, devices, indices, mode) + else: + raise NotImplementedError('Resharding uncommitted arrays sharded over ' + 'multiple devices is not supported.') + # TODO(yashkatariya): Remove the special case here and don't move to another + # device if its already committed. There is a TODO in dispatch.py already + # for this. + if dispatch.is_single_device_sharding(x.sharding): + return [buf if buf.device() == d else buf.copy_to_device(d) + for buf, d in safe_zip(x._arrays, devices)] + # If PmapSharding exists, then do a round trip via host. This will happen + # if the input Array containing PmapSharding takes the jit path + # i.e. `apply_primitive` or `xla_callable_uncached`. `jit(pmap)` is the most + # common case where this will happen. + # TODO(yashkatariya): Remove the special case here and don't move to another + # device if its already committed. There is a TODO in dispatch.py already + # for this. + elif isinstance(x.sharding, PmapSharding): + return pxla.device_put(x._value, devices, replicate=True) + else: + return x._arrays + + def _array_shard_arg(x, devices, indices, mode): if mode == pxla.InputsHandlerMode.pmap: return _array_pmap_shard_arg(x, devices, indices, mode) else: - if dispatch.is_single_device_sharding(x.sharding): - return [buf if buf.device() == d else buf.copy_to_device(d) - for buf, d in safe_zip(x._arrays, devices)] - # If PmapSharding exists, then do a round trip via host. This will happen - # if the input Array containing PmapSharding takes the jit path - # i.e. `apply_primitive` or `xla_callable_uncached`. `jit(pmap)` is the most - # common case where this will happen. - elif isinstance(x.sharding, PmapSharding): - return pxla.device_put(x._value, devices, replicate=True) - else: - return x._arrays + return _array_rest_shard_arg(x, devices, indices, mode) pxla.shard_arg_handlers[Array] = _array_shard_arg diff --git a/jax/experimental/pjit.py b/jax/experimental/pjit.py index 7182f99d442c..e042a424ca91 100644 --- a/jax/experimental/pjit.py +++ b/jax/experimental/pjit.py @@ -581,11 +581,13 @@ def _pjit_jaxpr(fun, out_shardings_thunk, global_in_avals, out_tree): def pjit_check_aval_sharding( - shardings: Sequence[XLACompatibleSharding], flat_avals, what_aval: str, - allow_uneven_sharding: bool): + shardings, flat_avals, what_aval: str, allow_uneven_sharding: bool): for aval, s in zip(flat_avals, shardings): if _is_unspecified_or_from_gda_or_auto(s): continue + if not isinstance(s, XLACompatibleSharding): + raise ValueError(f'One of {what_aval} got sharding {s} which is not a ' + 'subclass of XLACompatibleSharding.') global_str = "" if s.is_fully_addressable else " global" shape = aval.shape try: @@ -781,24 +783,35 @@ def _check_unique_resources(axis_resources, arg_name): def _resolve_in_shardings(args, pjit_in_shardings, out_shardings, pjit_mesh): - arg_shardings = tuple(a.sharding if hasattr(a, 'sharding') else _UNSPECIFIED - for a in args) - arg_ndims = tuple(a.ndim if hasattr(a, 'ndim') else 0 for a in args) + committed_arg_shardings = [] + for a in args: + if hasattr(a, 'sharding'): + arg_s = a.sharding + if not isinstance(arg_s, XLACompatibleSharding): + raise ValueError(f'One of the argument to pjit got sharding {arg_s} ' + 'which is not a subclass of XLACompatibleSharding.') + if a._committed: + committed_arg_shardings.append(arg_s) + da = _get_and_check_device_assignment( - it.chain(arg_shardings, pjit_in_shardings, out_shardings), pjit_mesh) + it.chain( + committed_arg_shardings, pjit_in_shardings, out_shardings), + pjit_mesh) resolved_in_shardings = [] - for arg_s, pjit_in_s, ndim in safe_zip(arg_shardings, pjit_in_shardings, arg_ndims): + for arg, pjit_in_s in safe_zip(args, pjit_in_shardings): + arg_s, committed = ((arg.sharding, arg._committed) + if hasattr(arg, 'sharding') else (_UNSPECIFIED, False)) if _is_unspecified(pjit_in_s): if _is_unspecified(arg_s): resolved_in_shardings.append(OpShardingSharding.get_replicated(da)) else: - resolved_in_shardings.append(to_op_sharding_sharding(arg_s, ndim)) + resolved_in_shardings.append(to_op_sharding_sharding(arg_s, arg.ndim)) else: if not _is_unspecified(arg_s): - if not pxla.are_op_shardings_equal( - pjit_in_s._to_xla_op_sharding(ndim), - arg_s._to_xla_op_sharding(ndim)): + if committed and not pxla.are_op_shardings_equal( + pjit_in_s._to_xla_op_sharding(arg.ndim), + arg_s._to_xla_op_sharding(arg.ndim)): raise ValueError('Sharding passed to pjit does not match the sharding ' 'on the respective arg. ' f'Got pjit sharding: {pjit_in_s},\n' @@ -1400,8 +1413,8 @@ def get_array_mapping( def to_op_sharding_sharding(s, ndim): if isinstance(s, OpShardingSharding): return s - op_sharding_sharding = OpShardingSharding(s._device_assignment, - s._to_xla_op_sharding(ndim)) + op_sharding_sharding = OpShardingSharding( + s._device_assignment, s._to_xla_op_sharding(ndim)) # type: ignore op_sharding_sharding._original_sharding = s return op_sharding_sharding @@ -1501,7 +1514,6 @@ def _gda_check_and_get_sharding( return tuple(out) -@lru_cache(maxsize=4096) def _get_and_check_device_assignment(shardings, pjit_mesh): first_device_assignment = None mesh_devices = list(pjit_mesh.devices.flat) diff --git a/tests/debugging_primitives_test.py b/tests/debugging_primitives_test.py index 2893dd069fba..edbd9cf51a12 100644 --- a/tests/debugging_primitives_test.py +++ b/tests/debugging_primitives_test.py @@ -738,12 +738,6 @@ def body(carry): @jtu.skip_on_devices(*disabled_backends) def test_unordered_print_of_pjit_of_xmap(self): - # TODO(https://github.com/google/jax/issues/12016): Make xmap work properly - # with Arrays of different - # sharding. - if config.jax_array: - raise unittest.SkipTest('Does not work with Array.') - if (jax.default_backend() in {"cpu", "gpu"} and jaxlib.xla_extension_version < 81): raise unittest.SkipTest("`pjit` of callback not supported.") @@ -756,9 +750,15 @@ def foo(x): out = maps.xmap(foo, in_axes=['foo'], out_axes=[...])(x) debug_print("Out: {}", out) return out - f = pjit.pjit(f, in_axis_resources=pjit.PartitionSpec('dev'), - out_axis_resources=pjit.PartitionSpec()) - with maps.Mesh(np.array(jax.devices()), ['dev']): + mesh = maps.Mesh(np.array(jax.devices()), ['dev']) + if config.jax_array: + in_spec = sharding.MeshPspecSharding(mesh, pjit.PartitionSpec('dev')) + out_spec = sharding.MeshPspecSharding(mesh, pjit.PartitionSpec()) + else: + in_spec = pjit.PartitionSpec('dev') + out_spec = pjit.PartitionSpec() + f = pjit.pjit(f, in_axis_resources=in_spec, out_axis_resources=out_spec) + with mesh: with jtu.capture_stdout() as output: f(jnp.arange(8, dtype=jnp.int32) * 2) lines = ["0: 0", "1: 2", "2: 4", "3: 6", "4: 8", "5: 10", "6: 12", diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 5bd2defe92cf..2acc0b42e7e1 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -2103,6 +2103,47 @@ def test_fast_path_array(self): self.assertArraysEqual([o.device() for o in out2._arrays], list(mesh.devices.flat)) self.assertArraysEqual(out2, inp_data) + @jax_array(True) + def test_not_xlacompatible_sharding_error(self): + shape = (8, 2) + inp_data = np.arange(prod(shape)).reshape(shape) + ts = TempSharding(jax.devices()) + arr = array.make_array_from_callback( + shape, ts, lambda idx: inp_data[idx]) + with self.assertRaisesRegex( + ValueError, + 'One of the argument to pjit got sharding.*which is not a subclass of ' + 'XLACompatibleSharding.'): + pjit(lambda x: x)(arr) + + with self.assertRaisesRegex( + ValueError, + 'One of pjit arguments got sharding.*which is not a subclass of ' + 'XLACompatibleSharding.'): + pjit(lambda x: x, in_axis_resources=ts)(arr) + + with self.assertRaisesRegex( + ValueError, + 'One of pjit outputs got sharding.*which is not a subclass of ' + 'XLACompatibleSharding.'): + pjit(lambda x: x, out_axis_resources=ts)(arr) + + +class TempSharding(Sharding): + + def __init__(self, devices): + self._devices = devices + + @property + def device_set(self): + return set(self._devices) + + def devices_indices_map(self, global_shape): + return {d: (slice(None),) * len(global_shape) for d in self.device_set} + + def shard_shape(self, global_shape): + return global_shape + def spec_regex(s): return str(s).replace(r"(", r"\(").replace(r")", r"\)")