Skip to content

Commit

Permalink
Make __getitem__ work for PmapSharding just like SDA works. DA is a…
Browse files Browse the repository at this point in the history
…lready covered with the current implementation.

Added TODOs to take fast path for indices wherever it is possible to do that. If a correct index is passed during getitem and if that index exists on `Array`, then the fast path is taken (see the test in this CL).

PiperOrigin-RevId: 473342504
  • Loading branch information
yashk2810 authored and jax authors committed Sep 9, 2022
1 parent 635eebf commit d7726e7
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 23 deletions.
6 changes: 4 additions & 2 deletions jax/_src/numpy/lax_numpy.py
Expand Up @@ -5107,10 +5107,12 @@ def _set_shaped_array_attributes(shaped_array):
_set_shaped_array_attributes(DShapedArray)


def _set_device_array_base_attributes(device_array, include=None):
def _set_device_array_base_attributes(device_array, include=None, exclude=None):
# Forward operators, methods, and properties on DeviceArray to lax_numpy
# functions (with no Tracers involved; this forwarding is direct)
def maybe_setattr(attr_name, target):
if exclude is not None and attr_name in exclude:
return
if not include or attr_name in include:
setattr(device_array, attr_name, target)

Expand All @@ -5132,7 +5134,7 @@ def maybe_setattr(attr_name, target):
maybe_setattr("clip", _clip)

_set_device_array_base_attributes(device_array.DeviceArray)
_set_device_array_base_attributes(Array)
_set_device_array_base_attributes(Array, exclude={'__getitem__'})


def _set_device_array_attributes(device_array):
Expand Down
46 changes: 34 additions & 12 deletions jax/experimental/array.py
Expand Up @@ -308,25 +308,47 @@ def __format__(self, format_spec):
else:
return format(self._value, format_spec)

@_use_python_method
def __getitem__(self, idx):
from jax._src.numpy import lax_numpy
self._check_if_deleted()

# This index canonicalization only works for PmapSharding currently.
# TODO(yashkatariya): Make it work for other Shardings too wherever its
# possible to not do data movement.
if not isinstance(idx, tuple):
cidx = (idx,) + (slice(None),) * (len(self.shape) - 1)
else:
cidx = idx + (slice(None),) * (len(self.shape) - len(idx))
if self._npy_value is None:
if self._fast_path_args is None:
indices = tuple(self.sharding.devices_indices_map(self.shape).values())
else:
indices = tuple(self._fast_path_args.devices_indices_map.values())
try:
buf_idx = indices.index(cidx)
except ValueError:
buf_idx = None
if buf_idx is not None:
buf = self._arrays[buf_idx]
aval = core.ShapedArray(buf.xla_shape().dimensions(), self.dtype)
return Array(aval, SingleDeviceSharding(buf.device()), [buf],
committed=False, _skip_checks=True)
return lax_numpy._rewriting_take(self, idx)

@_use_python_method
def __iter__(self):
if self.ndim == 0:
raise TypeError("iteration over a 0-d array") # same as numpy error
else:
assert self.is_fully_replicated() or self.is_fully_addressable()
# TODO(yashkatariya,mattjj): Handle more cases that can hit the 1 device
# sharding case. This will be possible when __getitem__ is implemented.
# Use DA like __iter__ when sharding is SingleDeviceSharding.
if len(self.sharding.device_set) == 1:
# chunk_iter is added to Array in lax_numpy.py similar to DA.
return (sl for chunk in self._chunk_iter(100) for sl in chunk._unstack()) # type: ignore
elif isinstance(self.sharding, PmapSharding):
return (s.data for s in self.addressable_shards)
# TODO(yashkatariya): Let all shardings take this route? The else path
# is taken by DA (in the old path) so keeping it here until we generalize
# it.
if isinstance(self.sharding, PmapSharding):
return (self[i] for i in range(self.shape[0])) # type: ignore
else:
# TODO(yashkatariya,mattjj): Avoid this round trip and figure out a better
# way to handle this when you have arbitrary sharding.
val = self._value
return (val[i] for i in range(val.shape[0]))
return (sl for chunk in self._chunk_iter(100) for sl in chunk._unstack()) # type: ignore

@_use_python_method
def item(self):
Expand Down
32 changes: 29 additions & 3 deletions tests/array_test.py
Expand Up @@ -14,6 +14,7 @@
"""Tests for GlobalDeviceArray."""

import os
import unittest
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
Expand Down Expand Up @@ -354,14 +355,24 @@ def test_array_iter_pmap_sharding(self):
y = jax.pmap(jnp.sin)(x)
self.assertArraysEqual([a.device() for a in y],
[a.device() for a in y._arrays])
for a in y:
self.assertIsInstance(a, array.Array)
self.assertEqual(a._committed, y._committed)

sin_x = iter(np.sin(x))
for i, j in zip(iter(y), sin_x):
self.assertIsInstance(i, array.Array)
self.assertArraysAllClose(i, j)

@jax_config.jax_array(True)
def test_array_iter_pmap_sharding_last_dim_sharded(self):
if jax.device_count() < 2:
self.skipTest('Test requires >= 2 devices.')

x = jnp.array([[1., 0., 0.], [0., 2., 3.]])
y = jax.pmap(jnp.sin, out_axes=1)(x)

for i, j in zip(iter(y), iter(np.sin(x).T)):
self.assertArraysAllClose(i, j)

@unittest.skip('After b/245667823 is fixed, this test can be enabled.')
@jax_config.jax_array(True)
def test_array_iter_mesh_pspec_sharding_multi_device(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
Expand All @@ -372,6 +383,21 @@ def test_array_iter_mesh_pspec_sharding_multi_device(self):
for i, j in zip(iter(arr), iter(input_data)):
self.assertArraysEqual(i, j)

@unittest.skip('After b/245667823 is fixed, this test can be enabled.')
@jax_config.jax_array(True)
def test_array_getitem_mesh_pspec_sharding_multi_device(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
arr, _ = create_array(
input_shape, sharding.MeshPspecSharding(global_mesh, P('x', 'y')))

# `__getitem__` with a specific index takes the fast path.
s = arr[2:4, 0:1]
self.assertArraysEqual(s, np.array([[4], [6]]))

# TODO(yashkatariya): Add assert equal for this when the test is enabled.
arr[:2] # doesn't crash

@jax_config.jax_array(True)
def test_array_iter_mesh_pspec_sharding_single_device(self):
if jax.device_count() < 2:
Expand Down
13 changes: 7 additions & 6 deletions tests/pmap_test.py
Expand Up @@ -2697,11 +2697,6 @@ def testThreadsafeIndexing(self):
self.assertAllClose(actual, expected, check_dtypes=False)

def testNoCopyIndexing1D(self):
# TODO(https://github.com/google/jax/issues/12016): Implement no copy
# indexing similar to SDA.
if config.jax_array:
self.skipTest('No copy indexing is not implemented for Array yet.')

shape = (8, 4)

if jax.device_count() < shape[0]:
Expand All @@ -2710,8 +2705,14 @@ def testNoCopyIndexing1D(self):
x = jnp.arange(prod(shape)).reshape(shape)
sharded_x = pmap(lambda x: x)(x)
self.assertIsNone(sharded_x._npy_value)

if config.jax_array:
arr_type = array.Array
else:
arr_type = device_array.DeviceArray

for i in range(8):
self.assertIsInstance(sharded_x[i], device_array.DeviceArray)
self.assertIsInstance(sharded_x[i], arr_type)
self.assertIsNone(sharded_x._npy_value)

@parameterized.named_parameters(
Expand Down

0 comments on commit d7726e7

Please sign in to comment.