Skip to content

Commit

Permalink
Make eager pmap tests pass with Array. Also add a slow path for Arr…
Browse files Browse the repository at this point in the history
…ay in `pmap` similar to what SDA has. This is required for eager pmap. Adding a slow path removes the need for doing sharding checks in api.py because SDA doesn't do those checks and if the sharding does not match with pmap sharding, then it just defaults to the slow path (exactly like SDA).

PiperOrigin-RevId: 468843310
  • Loading branch information
yashk2810 authored and jax authors committed Aug 20, 2022
1 parent 3a2f25f commit f905d98
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 115 deletions.
39 changes: 0 additions & 39 deletions jax/_src/api.py
Expand Up @@ -1904,42 +1904,6 @@ class PmapCallInfo(NamedTuple):
devices: Optional[Sequence[xc.Device]]


def _check_in_pmap_sharding_with_arrays(args, in_axes_flat, in_devices):
from jax.experimental.sharding import PmapSharding, SingleDeviceSharding
from jax.experimental.array import Array

if not args:
return

first_device_assignment = None
for a, i in safe_zip(args, in_axes_flat):
if not isinstance(a, Array):
continue
if isinstance(a.sharding, SingleDeviceSharding):
continue
if not isinstance(a.sharding, PmapSharding):
raise NotImplementedError('pmap only works with PmapSharding.')
if first_device_assignment is None:
first_device_assignment = a.sharding._device_assignment
arr_sharding = a.sharding.sharded_dim
arr_device_assignment = a.sharding._device_assignment
if arr_sharding != i:
raise ValueError('Array and pmap sharding does not match. Got pmap '
f'sharding: {i}, Array sharding: {arr_sharding} for '
f'arg: {a}')
if (in_devices is not None and
arr_device_assignment is not None and
arr_device_assignment != in_devices):
raise ValueError('Devices passed to pmap and Array should be equal. '
f'Got pmap devices: {in_devices}, Array devices: '
f'{arr_device_assignment} for arg: {a}')
if (in_devices is None and
arr_device_assignment != first_device_assignment):
raise ValueError('Devices of all `Array` inputs should be the same. '
f'Got array device: {arr_device_assignment}, '
f'another array device: {first_device_assignment}')


def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
donate_tuple, global_arg_shapes, in_devices, args, kwargs):
f = lu.wrap_init(fun)
Expand Down Expand Up @@ -1982,9 +1946,6 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,

flat_fun, out_tree = flatten_fun(f, in_tree)

if config.jax_array:
_check_in_pmap_sharding_with_arrays(args, in_axes_flat, in_devices)

if any(out_axis is None for out_axis in tree_flatten(out_axes)):
raise NotImplementedError("None out_axes in pmap are not supported yet")
# NOTE: We don't put out_tree() in the closure, because it's (1) non-hashable,
Expand Down
25 changes: 16 additions & 9 deletions jax/experimental/array.py
Expand Up @@ -120,7 +120,7 @@ def __init__(self, aval: core.ShapedArray, sharding: Sharding,
if config.jax_enable_checks:
assert all(db.dtype == self.dtype for db in self._arrays), (
"Input arrays to `Array` must have matching dtypes, "
f"got: {[db.dtype for db in self._arrays]}")
f"got: {[db.dtype for db in self._arrays]}, aval type: {self.dtype}")

# Rearrange arrays based on the device assignment.
if isinstance(sharding, XLACompatibleSharding):
Expand Down Expand Up @@ -378,16 +378,23 @@ def _device_put_array(x, device: Optional[Device]):
dispatch.device_put_handlers[Array] = _device_put_array


def _array_shard_arg(x, devices, indices, mode):
# TODO(yashkatariya): Remove the `mode` handling and try to consolidate the
# code paths.
if mode == pxla.InputsHandlerMode.pmap:
# sharding mismatch between `Array` and pmap sharding is checked in api.py's
# `_check_in_pmap_sharding_with_arrays` function.
if isinstance(x.sharding, SingleDeviceSharding):
return pxla._shard_device_array(x, devices, indices, mode)
def _array_pmap_shard_arg(x, devices, indices, mode):
if isinstance(x.sharding, SingleDeviceSharding):
return pxla._shard_device_array(x, devices, indices, mode)

# If the sharding of Array does not match pmap's sharding then take the slow
# path which is similar to what SDA does. This slow path reroute only happens
# for `pmap`.
if indices == tuple(x.sharding.devices_indices_map(x.shape).values()):
return [buf if buf.device() == d else buf.copy_to_device(d)
for buf, d in safe_zip(x._arrays, devices)]
else:
return pxla._shard_sharded_device_array_slow_path(x, devices, indices, mode)


def _array_shard_arg(x, devices, indices, mode):
if mode == pxla.InputsHandlerMode.pmap:
return _array_pmap_shard_arg(x, devices, indices, mode)
else:
return x._arrays
pxla.shard_arg_handlers[Array] = _array_shard_arg
Expand Down
7 changes: 0 additions & 7 deletions jax/experimental/sharding.py
Expand Up @@ -258,13 +258,6 @@ def __init__(self, devices: np.ndarray, sharding_spec: pxla.ShardingSpec):
def device_set(self) -> Set[Device]:
return set(self.devices.flat)

@pxla.maybe_cached_property
def sharded_dim(self):
for i, s in enumerate(self.sharding_spec.sharding):
if isinstance(s, pxla.Unstacked):
return i
return None

@functools.lru_cache(maxsize=4096)
def devices_indices_map(
self, global_shape: Shape) -> Mapping[Device, Optional[Index]]:
Expand Down
30 changes: 23 additions & 7 deletions jax/interpreters/pxla.py
Expand Up @@ -853,8 +853,16 @@ def _hashable_index(idx):
# The fast path is handled directly in shard_args().
# TODO(skye): is there a simpler way to rewrite this using sharding_spec?
def _shard_sharded_device_array_slow_path(x, devices, indices, mode):
from jax.experimental.array import Array

candidates = defaultdict(list)
for buf, idx in safe_zip(x.device_buffers, x.indices):
if isinstance(x, Array):
bufs = x._arrays
arr_indices = tuple(x.sharding.devices_indices_map(x.shape).values())
else:
bufs = x.device_buffers
arr_indices = x.indices
for buf, idx in safe_zip(bufs, arr_indices):
candidates[_hashable_index(idx)].append(buf)

bufs = []
Expand Down Expand Up @@ -977,10 +985,13 @@ def _emap_impl(fun: lu.WrappedFun, *args,
if isinstance(outval, (ShardedDeviceArray, jax.experimental.array.Array)):
# We don't want to donate if it's already sharded.
donate_argnums_ = ()
out = jax.pmap(lambda _, x: x, in_axes=(0, out_axis_src.get(axis_name)),
out_axes=out_axis, devices=devices, backend=backend,
donate_argnums=donate_argnums_)(
np.arange(axis_size), outval)
out = jax.pmap(
lambda _, x: x,
in_axes=(0, out_axis_src.get(axis_name)),
out_axes=out_axis,
devices=(None if devices is None else list(devices)),
backend=backend,
donate_argnums=donate_argnums_)(np.arange(axis_size), outval)
new_outvals.append(out)
return new_outvals

Expand All @@ -1000,8 +1011,13 @@ def _multi_pmap(f: Callable, info: EmapInfo, names: List[core.AxisName],
for i, name in reversed(list(enumerate(names))):
in_axes = tuple(arg_axis[i] for arg_axis in all_axes)
if any(in_axis is not None for in_axis in in_axes):
f = jax.pmap(f, in_axes=in_axes, axis_name=name, out_axes=0,
backend=info.backend, devices=info.devices)
f = jax.pmap(
f,
in_axes=in_axes,
axis_name=name,
out_axes=0,
backend=info.backend,
devices=(None if info.devices is None else list(info.devices)))
used_names.append(name)
out_shard_axes = {name: i for i, name in enumerate(reversed(used_names))}
return f, out_shard_axes
Expand Down
2 changes: 0 additions & 2 deletions tests/BUILD
Expand Up @@ -528,8 +528,6 @@ jax_test(
jax_test(
name = "pmap_test",
srcs = ["pmap_test.py"],
# pmap already has array tests inside it.
disable_configs = ["cpu_jax_array"],
shard_count = {
"cpu": 15,
"gpu": 30,
Expand Down
82 changes: 31 additions & 51 deletions tests/pmap_test.py
Expand Up @@ -2707,6 +2707,11 @@ def testThreadsafeIndexing(self):
self.assertAllClose(actual, expected, check_dtypes=False)

def testNoCopyIndexing1D(self):
# TODO(https://github.com/google/jax/issues/12016): Implement no copy
# indexing similar to SDA.
if config.jax_array:
self.skipTest('No copy indexing is not implemented for Array yet.')

shape = (8, 4)

if jax.device_count() < shape[0]:
Expand Down Expand Up @@ -2798,16 +2803,24 @@ def test_device_put_replicated_pytree(self, is_jax_array, array_type, buffer_att

def test_repr(self):
x = jax.device_put_replicated(1, jax.devices())
self.assertStartsWith(repr(x), 'ShardedDeviceArray')
if config.jax_array:
arr = 'Array'
else:
arr = 'ShardedDeviceArray'
self.assertStartsWith(repr(x), arr)

def test_delete_is_idempotent(self):
x = jax.device_put_replicated(1, jax.devices())
x.delete()
x.delete()

with self.assertRaisesRegex(ValueError,
'ShardedDeviceArray has been deleted.'):
_ = x[0]
if config.jax_array:
with self.assertRaisesRegex(RuntimeError, 'Array has been deleted.'):
_ = x[0]
else:
with self.assertRaisesRegex(ValueError,
'ShardedDeviceArray has been deleted.'):
_ = x[0]


class SpecToIndicesTest(jtu.JaxTestCase):
Expand Down Expand Up @@ -3075,66 +3088,33 @@ def test_pmap_array_sharding_mismatch(self):

f = jax.pmap(lambda x: x, in_axes=0, out_axes=0)
with jax._src.config.jax_array(True):
with self.assertRaisesRegex(
ValueError,
("Array and pmap sharding does not match. Got pmap sharding: 0, "
"Array sharding: None")):
f(a1)
out_array = f(a1)

def test_pmap_array_devices_mismatch(self):
if jax.device_count() <= 1:
raise unittest.SkipTest('Skipping because this test needs more than '
'1 device.')
input_shape = (jax.device_count(), 2)
a1, _ = create_input_array_for_pmap(input_shape)
with jax._src.config.jax_array(False):
out_sda = f(a1)

f = jax.pmap(lambda x: x, devices=jax.devices()[::-1])
with jax._src.config.jax_array(True):
with self.assertRaisesRegex(
ValueError, "Devices passed to pmap and Array should be equal."):
f(a1)
self.assertEqual(out_array.sharding.sharding_spec, out_sda.sharding_spec)
self.assertArraysEqual(out_array.sharding.devices,
[d.device() for d in out_sda.device_buffers])

def test_pmap_array_devices_mismatch_between_arrays(self):
def test_pmap_array_devices_mismatch(self):
if jax.device_count() <= 1:
raise unittest.SkipTest('Skipping because this test needs more than '
'1 device.')
input_shape = (jax.device_count(), 2)
a1, _ = create_input_array_for_pmap(input_shape)
a2, _ = create_input_array_for_pmap(input_shape, devices=jax.devices()[::-1])

f = jax.pmap(lambda x, y: (x, y))
f = jax.pmap(lambda x: x, devices=jax.devices()[::-1])
with jax._src.config.jax_array(True):
with self.assertRaisesRegex(
ValueError, "Devices of all `Array` inputs should be the same."):
f(a1, a2)
out_array = f(a1)

with jax._src.config.jax_array(False):
out_sda = f(a1)

class ArrayPmapMixin:
self.assertEqual(out_array.sharding.sharding_spec, out_sda.sharding_spec)
self.assertArraysEqual(out_array.sharding.devices,
[d.device() for d in out_sda.device_buffers])

def setUp(self):
super().setUp()
self.array_enabled = config.jax_array
config.update('jax_array', True)

def tearDown(self):
config.update('jax_array', self.array_enabled)
super().tearDown()


class ArrayPythonPmapTest(ArrayPmapMixin, PythonPmapTest):
pass

class ArrayCppPmapTest(ArrayPmapMixin, CppPmapTest):
pass

class ArrayVmapOfPmapTest(ArrayPmapMixin, VmapOfPmapTest):
pass

class ArrayVmapPmapCollectivesTest(ArrayPmapMixin, VmapPmapCollectivesTest):
pass

class ArrayPmapWithDevicesTest(ArrayPmapMixin, PmapWithDevicesTest):
pass

class EagerPmapMixin:

Expand Down

0 comments on commit f905d98

Please sign in to comment.