From d7726e7b26a5002e7f7403237cc8a152de71671e Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 9 Sep 2022 14:24:39 -0700 Subject: [PATCH] Make `__getitem__` work for PmapSharding just like SDA works. DA is already covered with the current implementation. Added TODOs to take fast path for indices wherever it is possible to do that. If a correct index is passed during getitem and if that index exists on `Array`, then the fast path is taken (see the test in this CL). PiperOrigin-RevId: 473342504 --- jax/_src/numpy/lax_numpy.py | 6 +++-- jax/experimental/array.py | 46 +++++++++++++++++++++++++++---------- tests/array_test.py | 32 +++++++++++++++++++++++--- tests/pmap_test.py | 13 ++++++----- 4 files changed, 74 insertions(+), 23 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 16f5559986c7..3623ea9b0830 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -5107,10 +5107,12 @@ def _set_shaped_array_attributes(shaped_array): _set_shaped_array_attributes(DShapedArray) -def _set_device_array_base_attributes(device_array, include=None): +def _set_device_array_base_attributes(device_array, include=None, exclude=None): # Forward operators, methods, and properties on DeviceArray to lax_numpy # functions (with no Tracers involved; this forwarding is direct) def maybe_setattr(attr_name, target): + if exclude is not None and attr_name in exclude: + return if not include or attr_name in include: setattr(device_array, attr_name, target) @@ -5132,7 +5134,7 @@ def maybe_setattr(attr_name, target): maybe_setattr("clip", _clip) _set_device_array_base_attributes(device_array.DeviceArray) -_set_device_array_base_attributes(Array) +_set_device_array_base_attributes(Array, exclude={'__getitem__'}) def _set_device_array_attributes(device_array): diff --git a/jax/experimental/array.py b/jax/experimental/array.py index c6b19817fc81..cec2c544e77f 100644 --- a/jax/experimental/array.py +++ b/jax/experimental/array.py @@ -308,25 +308,47 @@ def __format__(self, format_spec): else: return format(self._value, format_spec) + @_use_python_method + def __getitem__(self, idx): + from jax._src.numpy import lax_numpy + self._check_if_deleted() + + # This index canonicalization only works for PmapSharding currently. + # TODO(yashkatariya): Make it work for other Shardings too wherever its + # possible to not do data movement. + if not isinstance(idx, tuple): + cidx = (idx,) + (slice(None),) * (len(self.shape) - 1) + else: + cidx = idx + (slice(None),) * (len(self.shape) - len(idx)) + if self._npy_value is None: + if self._fast_path_args is None: + indices = tuple(self.sharding.devices_indices_map(self.shape).values()) + else: + indices = tuple(self._fast_path_args.devices_indices_map.values()) + try: + buf_idx = indices.index(cidx) + except ValueError: + buf_idx = None + if buf_idx is not None: + buf = self._arrays[buf_idx] + aval = core.ShapedArray(buf.xla_shape().dimensions(), self.dtype) + return Array(aval, SingleDeviceSharding(buf.device()), [buf], + committed=False, _skip_checks=True) + return lax_numpy._rewriting_take(self, idx) + @_use_python_method def __iter__(self): if self.ndim == 0: raise TypeError("iteration over a 0-d array") # same as numpy error else: assert self.is_fully_replicated() or self.is_fully_addressable() - # TODO(yashkatariya,mattjj): Handle more cases that can hit the 1 device - # sharding case. This will be possible when __getitem__ is implemented. - # Use DA like __iter__ when sharding is SingleDeviceSharding. - if len(self.sharding.device_set) == 1: - # chunk_iter is added to Array in lax_numpy.py similar to DA. - return (sl for chunk in self._chunk_iter(100) for sl in chunk._unstack()) # type: ignore - elif isinstance(self.sharding, PmapSharding): - return (s.data for s in self.addressable_shards) + # TODO(yashkatariya): Let all shardings take this route? The else path + # is taken by DA (in the old path) so keeping it here until we generalize + # it. + if isinstance(self.sharding, PmapSharding): + return (self[i] for i in range(self.shape[0])) # type: ignore else: - # TODO(yashkatariya,mattjj): Avoid this round trip and figure out a better - # way to handle this when you have arbitrary sharding. - val = self._value - return (val[i] for i in range(val.shape[0])) + return (sl for chunk in self._chunk_iter(100) for sl in chunk._unstack()) # type: ignore @_use_python_method def item(self): diff --git a/tests/array_test.py b/tests/array_test.py index 60cf3d390291..a30af361d942 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -14,6 +14,7 @@ """Tests for GlobalDeviceArray.""" import os +import unittest from absl.testing import absltest from absl.testing import parameterized import numpy as np @@ -354,14 +355,24 @@ def test_array_iter_pmap_sharding(self): y = jax.pmap(jnp.sin)(x) self.assertArraysEqual([a.device() for a in y], [a.device() for a in y._arrays]) - for a in y: - self.assertIsInstance(a, array.Array) - self.assertEqual(a._committed, y._committed) sin_x = iter(np.sin(x)) for i, j in zip(iter(y), sin_x): + self.assertIsInstance(i, array.Array) self.assertArraysAllClose(i, j) + @jax_config.jax_array(True) + def test_array_iter_pmap_sharding_last_dim_sharded(self): + if jax.device_count() < 2: + self.skipTest('Test requires >= 2 devices.') + + x = jnp.array([[1., 0., 0.], [0., 2., 3.]]) + y = jax.pmap(jnp.sin, out_axes=1)(x) + + for i, j in zip(iter(y), iter(np.sin(x).T)): + self.assertArraysAllClose(i, j) + + @unittest.skip('After b/245667823 is fixed, this test can be enabled.') @jax_config.jax_array(True) def test_array_iter_mesh_pspec_sharding_multi_device(self): global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) @@ -372,6 +383,21 @@ def test_array_iter_mesh_pspec_sharding_multi_device(self): for i, j in zip(iter(arr), iter(input_data)): self.assertArraysEqual(i, j) + @unittest.skip('After b/245667823 is fixed, this test can be enabled.') + @jax_config.jax_array(True) + def test_array_getitem_mesh_pspec_sharding_multi_device(self): + global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + input_shape = (8, 2) + arr, _ = create_array( + input_shape, sharding.MeshPspecSharding(global_mesh, P('x', 'y'))) + + # `__getitem__` with a specific index takes the fast path. + s = arr[2:4, 0:1] + self.assertArraysEqual(s, np.array([[4], [6]])) + + # TODO(yashkatariya): Add assert equal for this when the test is enabled. + arr[:2] # doesn't crash + @jax_config.jax_array(True) def test_array_iter_mesh_pspec_sharding_single_device(self): if jax.device_count() < 2: diff --git a/tests/pmap_test.py b/tests/pmap_test.py index 1fa0fcc8f2aa..2ea8d48cfbcb 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -2697,11 +2697,6 @@ def testThreadsafeIndexing(self): self.assertAllClose(actual, expected, check_dtypes=False) def testNoCopyIndexing1D(self): - # TODO(https://github.com/google/jax/issues/12016): Implement no copy - # indexing similar to SDA. - if config.jax_array: - self.skipTest('No copy indexing is not implemented for Array yet.') - shape = (8, 4) if jax.device_count() < shape[0]: @@ -2710,8 +2705,14 @@ def testNoCopyIndexing1D(self): x = jnp.arange(prod(shape)).reshape(shape) sharded_x = pmap(lambda x: x)(x) self.assertIsNone(sharded_x._npy_value) + + if config.jax_array: + arr_type = array.Array + else: + arr_type = device_array.DeviceArray + for i in range(8): - self.assertIsInstance(sharded_x[i], device_array.DeviceArray) + self.assertIsInstance(sharded_x[i], arr_type) self.assertIsNone(sharded_x._npy_value) @parameterized.named_parameters(