diff --git a/jax/_src/api.py b/jax/_src/api.py index c7afadb2fb02..ba10b5d67669 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -1904,42 +1904,6 @@ class PmapCallInfo(NamedTuple): devices: Optional[Sequence[xc.Device]] -def _check_in_pmap_sharding_with_arrays(args, in_axes_flat, in_devices): - from jax.experimental.sharding import PmapSharding, SingleDeviceSharding - from jax.experimental.array import Array - - if not args: - return - - first_device_assignment = None - for a, i in safe_zip(args, in_axes_flat): - if not isinstance(a, Array): - continue - if isinstance(a.sharding, SingleDeviceSharding): - continue - if not isinstance(a.sharding, PmapSharding): - raise NotImplementedError('pmap only works with PmapSharding.') - if first_device_assignment is None: - first_device_assignment = a.sharding._device_assignment - arr_sharding = a.sharding.sharded_dim - arr_device_assignment = a.sharding._device_assignment - if arr_sharding != i: - raise ValueError('Array and pmap sharding does not match. Got pmap ' - f'sharding: {i}, Array sharding: {arr_sharding} for ' - f'arg: {a}') - if (in_devices is not None and - arr_device_assignment is not None and - arr_device_assignment != in_devices): - raise ValueError('Devices passed to pmap and Array should be equal. ' - f'Got pmap devices: {in_devices}, Array devices: ' - f'{arr_device_assignment} for arg: {a}') - if (in_devices is None and - arr_device_assignment != first_device_assignment): - raise ValueError('Devices of all `Array` inputs should be the same. ' - f'Got array device: {arr_device_assignment}, ' - f'another array device: {first_device_assignment}') - - def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple, global_arg_shapes, in_devices, args, kwargs): f = lu.wrap_init(fun) @@ -1982,9 +1946,6 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple, flat_fun, out_tree = flatten_fun(f, in_tree) - if config.jax_array: - _check_in_pmap_sharding_with_arrays(args, in_axes_flat, in_devices) - if any(out_axis is None for out_axis in tree_flatten(out_axes)): raise NotImplementedError("None out_axes in pmap are not supported yet") # NOTE: We don't put out_tree() in the closure, because it's (1) non-hashable, diff --git a/jax/experimental/array.py b/jax/experimental/array.py index ce652f930e97..1c9b8b31a90d 100644 --- a/jax/experimental/array.py +++ b/jax/experimental/array.py @@ -120,7 +120,7 @@ def __init__(self, aval: core.ShapedArray, sharding: Sharding, if config.jax_enable_checks: assert all(db.dtype == self.dtype for db in self._arrays), ( "Input arrays to `Array` must have matching dtypes, " - f"got: {[db.dtype for db in self._arrays]}") + f"got: {[db.dtype for db in self._arrays]}, aval type: {self.dtype}") # Rearrange arrays based on the device assignment. if isinstance(sharding, XLACompatibleSharding): @@ -378,16 +378,23 @@ def _device_put_array(x, device: Optional[Device]): dispatch.device_put_handlers[Array] = _device_put_array -def _array_shard_arg(x, devices, indices, mode): - # TODO(yashkatariya): Remove the `mode` handling and try to consolidate the - # code paths. - if mode == pxla.InputsHandlerMode.pmap: - # sharding mismatch between `Array` and pmap sharding is checked in api.py's - # `_check_in_pmap_sharding_with_arrays` function. - if isinstance(x.sharding, SingleDeviceSharding): - return pxla._shard_device_array(x, devices, indices, mode) +def _array_pmap_shard_arg(x, devices, indices, mode): + if isinstance(x.sharding, SingleDeviceSharding): + return pxla._shard_device_array(x, devices, indices, mode) + + # If the sharding of Array does not match pmap's sharding then take the slow + # path which is similar to what SDA does. This slow path reroute only happens + # for `pmap`. + if indices == tuple(x.sharding.devices_indices_map(x.shape).values()): return [buf if buf.device() == d else buf.copy_to_device(d) for buf, d in safe_zip(x._arrays, devices)] + else: + return pxla._shard_sharded_device_array_slow_path(x, devices, indices, mode) + + +def _array_shard_arg(x, devices, indices, mode): + if mode == pxla.InputsHandlerMode.pmap: + return _array_pmap_shard_arg(x, devices, indices, mode) else: return x._arrays pxla.shard_arg_handlers[Array] = _array_shard_arg diff --git a/jax/experimental/sharding.py b/jax/experimental/sharding.py index 66c7aed8c58d..fff17a9bbd08 100644 --- a/jax/experimental/sharding.py +++ b/jax/experimental/sharding.py @@ -258,13 +258,6 @@ def __init__(self, devices: np.ndarray, sharding_spec: pxla.ShardingSpec): def device_set(self) -> Set[Device]: return set(self.devices.flat) - @pxla.maybe_cached_property - def sharded_dim(self): - for i, s in enumerate(self.sharding_spec.sharding): - if isinstance(s, pxla.Unstacked): - return i - return None - @functools.lru_cache(maxsize=4096) def devices_indices_map( self, global_shape: Shape) -> Mapping[Device, Optional[Index]]: diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 0cef4ffa784e..715b3a24bdd9 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -853,8 +853,16 @@ def _hashable_index(idx): # The fast path is handled directly in shard_args(). # TODO(skye): is there a simpler way to rewrite this using sharding_spec? def _shard_sharded_device_array_slow_path(x, devices, indices, mode): + from jax.experimental.array import Array + candidates = defaultdict(list) - for buf, idx in safe_zip(x.device_buffers, x.indices): + if isinstance(x, Array): + bufs = x._arrays + arr_indices = tuple(x.sharding.devices_indices_map(x.shape).values()) + else: + bufs = x.device_buffers + arr_indices = x.indices + for buf, idx in safe_zip(bufs, arr_indices): candidates[_hashable_index(idx)].append(buf) bufs = [] @@ -977,10 +985,13 @@ def _emap_impl(fun: lu.WrappedFun, *args, if isinstance(outval, (ShardedDeviceArray, jax.experimental.array.Array)): # We don't want to donate if it's already sharded. donate_argnums_ = () - out = jax.pmap(lambda _, x: x, in_axes=(0, out_axis_src.get(axis_name)), - out_axes=out_axis, devices=devices, backend=backend, - donate_argnums=donate_argnums_)( - np.arange(axis_size), outval) + out = jax.pmap( + lambda _, x: x, + in_axes=(0, out_axis_src.get(axis_name)), + out_axes=out_axis, + devices=(None if devices is None else list(devices)), + backend=backend, + donate_argnums=donate_argnums_)(np.arange(axis_size), outval) new_outvals.append(out) return new_outvals @@ -1000,8 +1011,13 @@ def _multi_pmap(f: Callable, info: EmapInfo, names: List[core.AxisName], for i, name in reversed(list(enumerate(names))): in_axes = tuple(arg_axis[i] for arg_axis in all_axes) if any(in_axis is not None for in_axis in in_axes): - f = jax.pmap(f, in_axes=in_axes, axis_name=name, out_axes=0, - backend=info.backend, devices=info.devices) + f = jax.pmap( + f, + in_axes=in_axes, + axis_name=name, + out_axes=0, + backend=info.backend, + devices=(None if info.devices is None else list(info.devices))) used_names.append(name) out_shard_axes = {name: i for i, name in enumerate(reversed(used_names))} return f, out_shard_axes diff --git a/tests/BUILD b/tests/BUILD index 2c6bde79df00..88a8de6dcabf 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -528,8 +528,6 @@ jax_test( jax_test( name = "pmap_test", srcs = ["pmap_test.py"], - # pmap already has array tests inside it. - disable_configs = ["cpu_jax_array"], shard_count = { "cpu": 15, "gpu": 30, diff --git a/tests/pmap_test.py b/tests/pmap_test.py index 1f8988d15d67..65b7266bf3f3 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -2707,6 +2707,11 @@ def testThreadsafeIndexing(self): self.assertAllClose(actual, expected, check_dtypes=False) def testNoCopyIndexing1D(self): + # TODO(https://github.com/google/jax/issues/12016): Implement no copy + # indexing similar to SDA. + if config.jax_array: + self.skipTest('No copy indexing is not implemented for Array yet.') + shape = (8, 4) if jax.device_count() < shape[0]: @@ -2798,16 +2803,24 @@ def test_device_put_replicated_pytree(self, is_jax_array, array_type, buffer_att def test_repr(self): x = jax.device_put_replicated(1, jax.devices()) - self.assertStartsWith(repr(x), 'ShardedDeviceArray') + if config.jax_array: + arr = 'Array' + else: + arr = 'ShardedDeviceArray' + self.assertStartsWith(repr(x), arr) def test_delete_is_idempotent(self): x = jax.device_put_replicated(1, jax.devices()) x.delete() x.delete() - with self.assertRaisesRegex(ValueError, - 'ShardedDeviceArray has been deleted.'): - _ = x[0] + if config.jax_array: + with self.assertRaisesRegex(RuntimeError, 'Array has been deleted.'): + _ = x[0] + else: + with self.assertRaisesRegex(ValueError, + 'ShardedDeviceArray has been deleted.'): + _ = x[0] class SpecToIndicesTest(jtu.JaxTestCase): @@ -3075,66 +3088,33 @@ def test_pmap_array_sharding_mismatch(self): f = jax.pmap(lambda x: x, in_axes=0, out_axes=0) with jax._src.config.jax_array(True): - with self.assertRaisesRegex( - ValueError, - ("Array and pmap sharding does not match. Got pmap sharding: 0, " - "Array sharding: None")): - f(a1) + out_array = f(a1) - def test_pmap_array_devices_mismatch(self): - if jax.device_count() <= 1: - raise unittest.SkipTest('Skipping because this test needs more than ' - '1 device.') - input_shape = (jax.device_count(), 2) - a1, _ = create_input_array_for_pmap(input_shape) + with jax._src.config.jax_array(False): + out_sda = f(a1) - f = jax.pmap(lambda x: x, devices=jax.devices()[::-1]) - with jax._src.config.jax_array(True): - with self.assertRaisesRegex( - ValueError, "Devices passed to pmap and Array should be equal."): - f(a1) + self.assertEqual(out_array.sharding.sharding_spec, out_sda.sharding_spec) + self.assertArraysEqual(out_array.sharding.devices, + [d.device() for d in out_sda.device_buffers]) - def test_pmap_array_devices_mismatch_between_arrays(self): + def test_pmap_array_devices_mismatch(self): if jax.device_count() <= 1: raise unittest.SkipTest('Skipping because this test needs more than ' '1 device.') input_shape = (jax.device_count(), 2) a1, _ = create_input_array_for_pmap(input_shape) - a2, _ = create_input_array_for_pmap(input_shape, devices=jax.devices()[::-1]) - f = jax.pmap(lambda x, y: (x, y)) + f = jax.pmap(lambda x: x, devices=jax.devices()[::-1]) with jax._src.config.jax_array(True): - with self.assertRaisesRegex( - ValueError, "Devices of all `Array` inputs should be the same."): - f(a1, a2) + out_array = f(a1) + with jax._src.config.jax_array(False): + out_sda = f(a1) -class ArrayPmapMixin: + self.assertEqual(out_array.sharding.sharding_spec, out_sda.sharding_spec) + self.assertArraysEqual(out_array.sharding.devices, + [d.device() for d in out_sda.device_buffers]) - def setUp(self): - super().setUp() - self.array_enabled = config.jax_array - config.update('jax_array', True) - - def tearDown(self): - config.update('jax_array', self.array_enabled) - super().tearDown() - - -class ArrayPythonPmapTest(ArrayPmapMixin, PythonPmapTest): - pass - -class ArrayCppPmapTest(ArrayPmapMixin, CppPmapTest): - pass - -class ArrayVmapOfPmapTest(ArrayPmapMixin, VmapOfPmapTest): - pass - -class ArrayVmapPmapCollectivesTest(ArrayPmapMixin, VmapPmapCollectivesTest): - pass - -class ArrayPmapWithDevicesTest(ArrayPmapMixin, PmapWithDevicesTest): - pass class EagerPmapMixin: