Skip to content

Commit

Permalink
Pmap should output SDA like Arrays to maintain the current behavior…
Browse files Browse the repository at this point in the history
… exactly. Split the shard_arg_handler for `Array` based on whether the mode is pmap or pjit. Why do this? The doc below explains more about the context.

PiperOrigin-RevId: 466849614
  • Loading branch information
yashk2810 authored and jax authors committed Aug 11, 2022
1 parent 0a783ca commit 33c4fc4
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 94 deletions.
49 changes: 23 additions & 26 deletions jax/_src/api.py
Expand Up @@ -1897,39 +1897,41 @@ class PmapCallInfo(NamedTuple):


def _check_in_pmap_sharding_with_arrays(args, in_axes_flat, in_devices):
from jax.experimental import sharding
from jax.experimental.sharding import PmapSharding
from jax.experimental.array import Array

if not args:
return

if in_devices is not None:
in_devices = np.array(in_devices)

first_arr_devices = args[0].sharding.devices
first_device_assignment = None
for a, i in safe_zip(args, in_axes_flat):
assert isinstance(a.sharding, sharding.PmapSharding)
if not isinstance(a, Array):
continue
if not isinstance(a.sharding, PmapSharding):
raise NotImplementedError('pmap only works with PmapSharding.')
if first_device_assignment is None:
first_device_assignment = a.sharding._device_assignment
arr_sharding = a.sharding.sharded_dim
arr_devices = a.sharding.devices
arr_device_assignment = a.sharding._device_assignment
if arr_sharding != i:
raise ValueError('Array and pmap sharding does not match. Got pmap '
f'sharding: {i}, Array sharding: {arr_sharding} for '
f'arg: {a}')
if (in_devices is not None and
arr_devices is not None and
not np.array_equal(arr_devices, in_devices)):
arr_device_assignment is not None and
arr_device_assignment != in_devices):
raise ValueError('Devices passed to pmap and Array should be equal. '
f'Got pmap devices: {devices}, Array devices: '
f'{arr_devices} for arg: {a}')
f'Got pmap devices: {in_devices}, Array devices: '
f'{arr_device_assignment} for arg: {a}')
if (in_devices is None and
not np.array_equal(arr_devices, first_arr_devices)):
arr_device_assignment != first_device_assignment):
raise ValueError('Devices of all `Array` inputs should be the same. '
f'Got array device: {arr_devices}, '
f'another array device: {first_arr_devices}')
return first_arr_devices
f'Got array device: {arr_device_assignment}, '
f'another array device: {first_device_assignment}')


def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
donate_tuple, global_arg_shapes, devices, args, kwargs):
donate_tuple, global_arg_shapes, in_devices, args, kwargs):
f = lu.wrap_init(fun)
if static_broadcasted_tuple:
if max(static_broadcasted_tuple) >= len(args):
Expand Down Expand Up @@ -1971,13 +1973,7 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
flat_fun, out_tree = flatten_fun(f, in_tree)

if config.jax_array:
from jax.experimental.array import Array
if any(not isinstance(a, Array) for a in args):
raise ValueError('All arguments to pmap when `config.jax_array` is '
'enabled should be `Array`s.')
arr_devices = _check_in_pmap_sharding_with_arrays(args, in_axes_flat, devices)
if devices is None and arr_devices is not None:
devices = arr_devices
_check_in_pmap_sharding_with_arrays(args, in_axes_flat, in_devices)

if any(out_axis is None for out_axis in tree_flatten(out_axes)):
raise NotImplementedError("None out_axes in pmap are not supported yet")
Expand Down Expand Up @@ -2011,7 +2007,7 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
local_axis_size=local_axis_size,
global_arg_shapes_flat=global_arg_shapes_flat,
out_axes_thunk=out_axes_thunk,
devices=None if devices is None else tuple(devices))
devices=None if in_devices is None else tuple(in_devices))


def _get_f_mapped(
Expand Down Expand Up @@ -2199,8 +2195,9 @@ def cache_miss(*args, **kwargs):

return out, fastpath_data

cpp_mapped_f = pmap_lib.pmap(fun, cache_miss,
static_broadcasted_tuple, pxla._shard_arg)
cpp_mapped_f = pmap_lib.pmap(
fun, cache_miss, static_broadcasted_tuple,
partial(pxla._shard_arg, mode=pxla.InputsHandlerMode.pmap))

pmap_f = wraps(fun)(cpp_mapped_f)

Expand Down
26 changes: 20 additions & 6 deletions jax/experimental/array.py
Expand Up @@ -21,7 +21,7 @@
from jax._src import api_util
from jax._src import dispatch
from jax._src.config import config
from jax._src.util import prod
from jax._src.util import prod, safe_zip
from jax._src.lib import xla_client as xc
from jax._src.api import device_put
from jax.interpreters import pxla, xla
Expand Down Expand Up @@ -261,12 +261,26 @@ def _device_put_array(x, device: Optional[Device]):
dispatch.device_put_handlers[Array] = _device_put_array


def _array_shard_arg(x, devices, indices):
return x._arrays
def _array_shard_arg(x, devices, indices, mode):
# TODO(yashkatariya): Remove the `mode` handling and try to consolidate the
# code paths.
if mode == pxla.InputsHandlerMode.pmap:
# sharding mismatch between `Array` and pmap sharding is checked in api.py's
# `_check_in_pmap_sharding_with_arrays` function.
return [buf if buf.device() == d else buf.copy_to_device(d)
for buf, d in safe_zip(x._arrays, devices)]
else:
return x._arrays
pxla.shard_arg_handlers[Array] = _array_shard_arg


def _array_result_handler(global_aval, out_sharding):
def _array_global_result_handler(global_aval, out_sharding):
return lambda bufs: Array(global_aval.shape, out_sharding, bufs, committed=True)
pxla.global_result_handlers[(core.ShapedArray, pxla.OutputType.Array)] = _array_result_handler
pxla.global_result_handlers[(core.ConcreteArray, pxla.OutputType.Array)] = _array_result_handler
pxla.global_result_handlers[(core.ShapedArray, pxla.OutputType.Array)] = _array_global_result_handler
pxla.global_result_handlers[(core.ConcreteArray, pxla.OutputType.Array)] = _array_global_result_handler


def _array_local_result_handler(aval, sharding, indices):
return lambda bufs: Array(aval.shape, sharding, bufs, committed=True)
pxla.local_result_handlers[(core.ShapedArray, pxla.OutputType.Array)] = _array_local_result_handler
pxla.local_result_handlers[(core.ConcreteArray, pxla.OutputType.Array)] = _array_local_result_handler
4 changes: 3 additions & 1 deletion jax/experimental/global_device_array.py
Expand Up @@ -561,7 +561,9 @@ def from_batched_callback_with_devices(
api_util._shaped_abstractify_handlers[GlobalDeviceArray] = \
lambda x: core.ShapedArray(x.shape, x.dtype)

def _gda_shard_arg(x, devices, indices):
def _gda_shard_arg(x, devices, indices, mode):
if mode == pxla.InputsHandlerMode.pmap:
raise RuntimeError('GDA is not supported with pmap.')
return x._device_buffers
pxla.shard_arg_handlers[GlobalDeviceArray] = _gda_shard_arg

Expand Down

0 comments on commit 33c4fc4

Please sign in to comment.