diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 16f5559986c7..3623ea9b0830 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -5107,10 +5107,12 @@ def _set_shaped_array_attributes(shaped_array): _set_shaped_array_attributes(DShapedArray) -def _set_device_array_base_attributes(device_array, include=None): +def _set_device_array_base_attributes(device_array, include=None, exclude=None): # Forward operators, methods, and properties on DeviceArray to lax_numpy # functions (with no Tracers involved; this forwarding is direct) def maybe_setattr(attr_name, target): + if exclude is not None and attr_name in exclude: + return if not include or attr_name in include: setattr(device_array, attr_name, target) @@ -5132,7 +5134,7 @@ def maybe_setattr(attr_name, target): maybe_setattr("clip", _clip) _set_device_array_base_attributes(device_array.DeviceArray) -_set_device_array_base_attributes(Array) +_set_device_array_base_attributes(Array, exclude={'__getitem__'}) def _set_device_array_attributes(device_array): diff --git a/jax/experimental/array.py b/jax/experimental/array.py index c6b19817fc81..cec2c544e77f 100644 --- a/jax/experimental/array.py +++ b/jax/experimental/array.py @@ -308,25 +308,47 @@ def __format__(self, format_spec): else: return format(self._value, format_spec) + @_use_python_method + def __getitem__(self, idx): + from jax._src.numpy import lax_numpy + self._check_if_deleted() + + # This index canonicalization only works for PmapSharding currently. + # TODO(yashkatariya): Make it work for other Shardings too wherever its + # possible to not do data movement. + if not isinstance(idx, tuple): + cidx = (idx,) + (slice(None),) * (len(self.shape) - 1) + else: + cidx = idx + (slice(None),) * (len(self.shape) - len(idx)) + if self._npy_value is None: + if self._fast_path_args is None: + indices = tuple(self.sharding.devices_indices_map(self.shape).values()) + else: + indices = tuple(self._fast_path_args.devices_indices_map.values()) + try: + buf_idx = indices.index(cidx) + except ValueError: + buf_idx = None + if buf_idx is not None: + buf = self._arrays[buf_idx] + aval = core.ShapedArray(buf.xla_shape().dimensions(), self.dtype) + return Array(aval, SingleDeviceSharding(buf.device()), [buf], + committed=False, _skip_checks=True) + return lax_numpy._rewriting_take(self, idx) + @_use_python_method def __iter__(self): if self.ndim == 0: raise TypeError("iteration over a 0-d array") # same as numpy error else: assert self.is_fully_replicated() or self.is_fully_addressable() - # TODO(yashkatariya,mattjj): Handle more cases that can hit the 1 device - # sharding case. This will be possible when __getitem__ is implemented. - # Use DA like __iter__ when sharding is SingleDeviceSharding. - if len(self.sharding.device_set) == 1: - # chunk_iter is added to Array in lax_numpy.py similar to DA. - return (sl for chunk in self._chunk_iter(100) for sl in chunk._unstack()) # type: ignore - elif isinstance(self.sharding, PmapSharding): - return (s.data for s in self.addressable_shards) + # TODO(yashkatariya): Let all shardings take this route? The else path + # is taken by DA (in the old path) so keeping it here until we generalize + # it. + if isinstance(self.sharding, PmapSharding): + return (self[i] for i in range(self.shape[0])) # type: ignore else: - # TODO(yashkatariya,mattjj): Avoid this round trip and figure out a better - # way to handle this when you have arbitrary sharding. - val = self._value - return (val[i] for i in range(val.shape[0])) + return (sl for chunk in self._chunk_iter(100) for sl in chunk._unstack()) # type: ignore @_use_python_method def item(self): diff --git a/tests/array_test.py b/tests/array_test.py index 60cf3d390291..a30af361d942 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -14,6 +14,7 @@ """Tests for GlobalDeviceArray.""" import os +import unittest from absl.testing import absltest from absl.testing import parameterized import numpy as np @@ -354,14 +355,24 @@ def test_array_iter_pmap_sharding(self): y = jax.pmap(jnp.sin)(x) self.assertArraysEqual([a.device() for a in y], [a.device() for a in y._arrays]) - for a in y: - self.assertIsInstance(a, array.Array) - self.assertEqual(a._committed, y._committed) sin_x = iter(np.sin(x)) for i, j in zip(iter(y), sin_x): + self.assertIsInstance(i, array.Array) self.assertArraysAllClose(i, j) + @jax_config.jax_array(True) + def test_array_iter_pmap_sharding_last_dim_sharded(self): + if jax.device_count() < 2: + self.skipTest('Test requires >= 2 devices.') + + x = jnp.array([[1., 0., 0.], [0., 2., 3.]]) + y = jax.pmap(jnp.sin, out_axes=1)(x) + + for i, j in zip(iter(y), iter(np.sin(x).T)): + self.assertArraysAllClose(i, j) + + @unittest.skip('After b/245667823 is fixed, this test can be enabled.') @jax_config.jax_array(True) def test_array_iter_mesh_pspec_sharding_multi_device(self): global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) @@ -372,6 +383,21 @@ def test_array_iter_mesh_pspec_sharding_multi_device(self): for i, j in zip(iter(arr), iter(input_data)): self.assertArraysEqual(i, j) + @unittest.skip('After b/245667823 is fixed, this test can be enabled.') + @jax_config.jax_array(True) + def test_array_getitem_mesh_pspec_sharding_multi_device(self): + global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + input_shape = (8, 2) + arr, _ = create_array( + input_shape, sharding.MeshPspecSharding(global_mesh, P('x', 'y'))) + + # `__getitem__` with a specific index takes the fast path. + s = arr[2:4, 0:1] + self.assertArraysEqual(s, np.array([[4], [6]])) + + # TODO(yashkatariya): Add assert equal for this when the test is enabled. + arr[:2] # doesn't crash + @jax_config.jax_array(True) def test_array_iter_mesh_pspec_sharding_single_device(self): if jax.device_count() < 2: diff --git a/tests/pmap_test.py b/tests/pmap_test.py index 1fa0fcc8f2aa..2ea8d48cfbcb 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -2697,11 +2697,6 @@ def testThreadsafeIndexing(self): self.assertAllClose(actual, expected, check_dtypes=False) def testNoCopyIndexing1D(self): - # TODO(https://github.com/google/jax/issues/12016): Implement no copy - # indexing similar to SDA. - if config.jax_array: - self.skipTest('No copy indexing is not implemented for Array yet.') - shape = (8, 4) if jax.device_count() < shape[0]: @@ -2710,8 +2705,14 @@ def testNoCopyIndexing1D(self): x = jnp.arange(prod(shape)).reshape(shape) sharded_x = pmap(lambda x: x)(x) self.assertIsNone(sharded_x._npy_value) + + if config.jax_array: + arr_type = array.Array + else: + arr_type = device_array.DeviceArray + for i in range(8): - self.assertIsInstance(sharded_x[i], device_array.DeviceArray) + self.assertIsInstance(sharded_x[i], arr_type) self.assertIsNone(sharded_x._npy_value) @parameterized.named_parameters(