Skip to content

Commit

Permalink
Make it possible to return a C++ ShardedDeviceArray.
Browse files Browse the repository at this point in the history
This **will** be a **breaking** change, as pxla.ShardedDeviceArray constructor won't be valid anymore:
- for the next Jax release
- on the condition _USE_EXPERIMENTAL_CPP_SDA is switch to `_xla_extension_version > xx` and with the associated jaxlib release.

I am already adding the impact for the users in the CHANGELOG, we can still move it to the next version depending on when it's shipped.

Similarly to JAX.jit, for which we have a C++ `DeviceArray` and a Python `_DeviceArray`, we will introduce 2 objects for ShardedDeviceArray, with the Python object only for JAX extensions not compatible with the C++ object (e.g. Cloud TPU).

- Add `make_sharded_device_array` to be used within JAX and for hackers that need to construct SDA objects.
- Make sure the C++ object is valid by
  (a) extending `DeviceArrayBase` (done in Python), as it brings a bunch of methods and enable `isinstance(x, DeviceArray)`
  (b) Adding the same methods as the Python SDA.

NOTE: mypy has troubled with the " -> pxla.ShardedDeviceArray` function return type annotation, I had to remove 2.
PiperOrigin-RevId: 389876734
  • Loading branch information
jblespiau authored and jax authors committed Aug 10, 2021
1 parent beddf59 commit 45aaf8a
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 72 deletions.
8 changes: 4 additions & 4 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2400,7 +2400,7 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]):
raise ValueError(f"len(shards) = {len(shards)} must equal "
f"len(devices) = {len(devices)}.")

def _device_put_sharded(*xs) -> pxla.ShardedDeviceArray:
def _device_put_sharded(*xs):
avals = [core.raise_to_shaped(core.get_aval(x)) for x in xs]
if not all(a1 == a2 for a1, a2 in zip(avals[:-1], avals[1:])):
a1, a2 = next((a1, a2) for a1, a2 in zip(avals[:-1], avals[1:])
Expand All @@ -2409,7 +2409,7 @@ def _device_put_sharded(*xs) -> pxla.ShardedDeviceArray:
f"consistent shape and dtype, but got {a1} and {a2}.")
stacked_aval = avals[0].update(shape=(len(devices),) + avals[0].shape)
buffers = [buf for x, d in zip(xs, devices) for buf in xla.device_put(x, d)]
return pxla.ShardedDeviceArray(stacked_aval, buffers)
return pxla.make_sharded_device_array(stacked_aval, None, buffers)

return tree_multimap(_device_put_sharded, *shards)

Expand Down Expand Up @@ -2446,13 +2446,13 @@ def device_put_replicated(x: Any, devices: Sequence[xc.Device]):
if not isinstance(devices, Sequence) or not devices:
raise ValueError("`devices` argument to `device_put_replicated must be "
"a non-empty sequence.")
def _device_put_replicated(x) -> pxla.ShardedDeviceArray:
def _device_put_replicated(x):
aval = core.unmapped_aval(len(devices), 0,
core.raise_to_shaped(core.get_aval(x)))
assert isinstance(aval, core.ShapedArray) and aval._num_buffers == 1
buf, = xla.device_put(x, devices[0])
rest_bufs = [buf.copy_to_device(d) for d in devices[1:]]
return pxla.ShardedDeviceArray(aval, [buf, *rest_bufs])
return pxla.make_sharded_device_array(aval, None, [buf, *rest_bufs])
return tree_map(_device_put_replicated, x)


Expand Down
233 changes: 167 additions & 66 deletions jax/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
import threading
from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Tuple,
Type, Union, Iterable, NamedTuple, TYPE_CHECKING)

import warnings
from absl import logging
import numpy as np

Expand Down Expand Up @@ -445,14 +445,71 @@ def aval_to_result_handler(sharding_spec: Optional[ShardingSpec],
pxla_result_handlers: Dict[Type[core.AbstractValue], PxlaResultHandler] = {}
pxla_result_handlers[core.AbstractUnit] = lambda *_: lambda _: core.unit
def array_result_handler(sharding_spec, indices, aval: ShapedArray):
return lambda bufs: ShardedDeviceArray(aval, sharding_spec, bufs, indices)
return lambda bufs: make_sharded_device_array(aval, sharding_spec, bufs,
indices)


pxla_result_handlers[ShapedArray] = array_result_handler
pxla_result_handlers[ConcreteArray] = array_result_handler


### lazy device-memory persistence and result handling

class ShardedDeviceArray(xla.DeviceArray): # type: ignore
# TODO(jblespiau): Clean all occurrences of the SDA constructor before
# switching this to True.
_USE_EXPERIMENTAL_CPP_SDA = False


def make_sharded_device_array(
aval: ShapedArray,
sharding_spec: Optional[ShardingSpec],
# Any is for JAX extensions implementing their own buffer.
device_buffers: List[Union[Any, xb.xla_client.Buffer]],
indices: Optional[Tuple[Index, ...]] = None,
):
"""Returns a ShardedDeviceArray implementation based on arguments.
Returns either a C++ SDA or a Python DeviceArray when the buffers are not
JAX buffers.
Args:
aval: The `ShapedArray` for this array.
sharding_spec: If `None`, assumes a pmap-style ShardedDeviceArrays over the
first dimension.
device_buffers: If a list of Jax `Buffer` objects, a C++ SDA will be
returned (if the version is high enough). Otherwise, a Python object will
be returned, for JAX extensions not implementing the C++ API.
indices: For caching purposes, will be computed if `None`.
"""
if sharding_spec is None:
sharded_aval = aval.update(shape=aval.shape[1:])
sharding_spec = _pmap_sharding_spec(aval.shape[0], aval.shape[0],
1, None, sharded_aval, 0)

if indices is None:
indices = spec_to_indices(aval.shape, sharding_spec)

if (_USE_EXPERIMENTAL_CPP_SDA and
(not device_buffers or
isinstance(device_buffers[0], xb.xla_client.Buffer))):
return pmap_lib.ShardedDeviceArray(aval, sharding_spec, device_buffers,
indices)

return _ShardedDeviceArray(aval, sharding_spec, device_buffers, indices)


if _USE_EXPERIMENTAL_CPP_SDA:
ShardedDeviceArrayBase = pmap_lib.ShardedDeviceArrayBase # type: ignore
# We want the C++ SDA to extend the DeviceArrayBase. We want this both to
# benefit from its methods, and to have isinstance(x, DeviceArray) return true
ShardedDeviceArrayBase.__bases__ = ((xla.DeviceArray,) + # type: ignore
ShardedDeviceArrayBase.__bases__)
base_class = pmap_lib.ShardedDeviceArrayBase # type: ignore
else:
base_class: Type[xla.DeviceArray] = xla.DeviceArray # type: ignore


class _ShardedDeviceArray(base_class): # type: ignore
"""A ShardedDeviceArray is an ndarray sharded across devices.
The purpose of a ShardedDeviceArray is to reduce the number of transfers when
Expand Down Expand Up @@ -482,27 +539,37 @@ class ShardedDeviceArray(xla.DeviceArray): # type: ignore
"_one_replica_buffer_indices", "_npy_value"
]


def __init__(
self,
aval: ShapedArray,
sharding_spec, # TODO(skye): add type annotation back, see below
device_buffers: Optional[List[xb.xla_client.Buffer]] = None,
indices: Optional[Tuple[Index, ...]] = None):
# We don't use `super`, following pybind11 guidelines:
# https://pybind11.readthedocs.io/en/stable/advanced/classes.html#overriding-virtual-functions-in-python
xla.DeviceArray.__init__(self)
if _USE_EXPERIMENTAL_CPP_SDA:
ShardedDeviceArrayBase.__init__(self) # type: ignore

# TODO(skye): this is temporary staging while we switch users over to
# providing sharding_spec. It assumes that any pre-existing callers are
# creating pmap-style ShardedDeviceArrays over the first dimension.
if device_buffers is None:
warnings.warn(
"The constructor of ShardedDeviceArray has changed and expects a "
"ShardingSpec object as the second argument. For a no-op fix, "
"replace SDA(aval, buffers) with SDA(aval, None, Buffers).")
device_buffers = sharding_spec
sharding_spec = None
if sharding_spec is None:
sharded_aval = aval.update(shape=aval.shape[1:])
sharding_spec = _pmap_sharding_spec(aval.shape[0], aval.shape[0],
1, None, sharded_aval, 0)

# TODO(skye): assert invariants. Keep performance in mind though.
if indices is None:
indices = spec_to_indices(aval.shape, sharding_spec)

self.aval = aval
self.device_buffers = device_buffers
self.sharding_spec = sharding_spec
Expand All @@ -512,20 +579,6 @@ def __init__(
if config.jax_enable_checks:
assert type(aval) is ShapedArray

@property
def one_replica_buffer_indices(self):
"""Indices of buffers containing one complete copy of the array data."""
if self._one_replica_buffer_indices is None:
one_replica_indices = []
seen_index_hashes = set()
for i, index in enumerate(self.indices):
hashed_index = _hashable_index(index)
if hashed_index not in seen_index_hashes:
one_replica_indices.append(i)
seen_index_hashes.add(hashed_index)
self._one_replica_buffer_indices = one_replica_indices
return self._one_replica_buffer_indices

@property
def shape(self):
return self.aval.shape
Expand All @@ -542,51 +595,91 @@ def size(self):
def ndim(self):
return len(self.aval.shape)

def copy_to_host_async(self):
for buffer_index in self.one_replica_buffer_indices:
self.device_buffers[buffer_index].copy_to_host_async()

def delete(self):
for buf in self.device_buffers:
buf.delete()
self.device_buffers = None
self._npy_value = None
def _sda_one_replica_buffer_indices(self):
"""Indices of buffers containing one complete copy of the array data."""
if self._one_replica_buffer_indices is None:
one_replica_indices = []
seen_index_hashes = set()
for i, index in enumerate(self.indices):
hashed_index = _hashable_index(index)
if hashed_index not in seen_index_hashes:
one_replica_indices.append(i)
seen_index_hashes.add(hashed_index)
self._one_replica_buffer_indices = one_replica_indices
return self._one_replica_buffer_indices

def _check_if_deleted(self):
if self.device_buffers is None:
raise ValueError("ShardedDeviceArray has been deleted.")

def block_until_ready(self):
self._check_if_deleted()
for buf in self.device_buffers:
buf.block_host_until_ready()
return self
def _sda_copy_to_host_async(self):
for buffer_index in self.one_replica_buffer_indices:
self.device_buffers[buffer_index].copy_to_host_async()

@property
def _value(self):
if self._npy_value is None:
self.copy_to_host_async()
npy_value = np.empty(self.aval.shape, self.aval.dtype)
for i in self.one_replica_buffer_indices:
npy_value[self.indices[i]] = self.device_buffers[i].to_py()
self._npy_value = npy_value
return self._npy_value

def __getitem__(self, idx):
if not isinstance(idx, tuple):
cidx = (idx,) + (slice(None),) * (len(self.aval.shape) - 1)
else:
cidx = idx + (slice(None),) * (len(self.aval.shape) - len(idx))
try:
buf_idx = self.indices.index(cidx)
except ValueError:
# NOTE: Slow path, this will materialize the sharded array on a single
# device and use XLA's Gather to index into the resulting array.
return xla.DeviceArray.__getitem__(self, idx)
else:
buf = self.device_buffers[buf_idx]
aval = ShapedArray(buf.xla_shape().dimensions(), self.aval.dtype)
return xla.make_device_array(aval, None, buf)

def _sda_delete(self):
for buf in self.device_buffers:
buf.delete()
self.device_buffers = None
self._npy_value = None


def _sda_check_if_deleted(self):
if self.device_buffers is None:
raise ValueError("ShardedDeviceArray has been deleted.")


def _sda_block_until_ready(self):
self._check_if_deleted()
for buf in self.device_buffers:
buf.block_host_until_ready()
return self


def _sda_value(self):
if self._npy_value is None:
self.copy_to_host_async()
npy_value = np.empty(self.aval.shape, self.aval.dtype)
for i in self.one_replica_buffer_indices:
npy_value[self.indices[i]] = self.device_buffers[i].to_py()
self._npy_value = npy_value
return self._npy_value


def _sda__getitem__(self, idx):
if not isinstance(idx, tuple):
cidx = (idx,) + (slice(None),) * (len(self.aval.shape) - 1)
else:
cidx = idx + (slice(None),) * (len(self.aval.shape) - len(idx))
try:
buf_idx = self.indices.index(cidx)
except ValueError:
# NOTE: Slow path, this will materialize the sharded array on a single
# device and use XLA's Gather to index into the resulting array.
return xla.DeviceArray.__getitem__(self, idx)
else:
buf = self.device_buffers[buf_idx]
aval = ShapedArray(buf.xla_shape().dimensions(), self.aval.dtype)
return xla.make_device_array(aval, None, buf)


for sda in [_ShardedDeviceArray, pmap_lib.ShardedDeviceArray]:
setattr(sda, "one_replica_buffer_indices",
property(_sda_one_replica_buffer_indices))
setattr(sda, "copy_to_host_async", _sda_copy_to_host_async)
setattr(sda, "delete", _sda_delete)
setattr(sda, "_check_if_deleted", _sda_check_if_deleted)
setattr(sda, "block_until_ready", _sda_block_until_ready)
setattr(sda, "_value", property(_sda_value))
setattr(sda, "__getitem__", _sda__getitem__)

del (_sda_one_replica_buffer_indices, _sda_copy_to_host_async, _sda_delete,
_sda_check_if_deleted, _sda_block_until_ready, _sda_value, _sda__getitem__)


ShardedDeviceArray: Type[object]
if _USE_EXPERIMENTAL_CPP_SDA:
ShardedDeviceArray = pmap_lib.ShardedDeviceArrayBase
else:
ShardedDeviceArray = _ShardedDeviceArray


def _hashable_index(idx):
Expand Down Expand Up @@ -617,17 +710,23 @@ def _shard_sharded_device_array_slow_path(x, devices, indices):
else:
bufs.append(buf.copy_to_device(device))
return bufs
shard_arg_handlers[ShardedDeviceArray] = _shard_sharded_device_array_slow_path


def _sharded_device_array_constant_handler(c, val, canonicalize_types=True):
return xb.constant_general(c, np.asarray(val), canonicalize_types=canonicalize_types)
xb.register_constant_handler(ShardedDeviceArray, _sharded_device_array_constant_handler)

core.pytype_aval_mappings[ShardedDeviceArray] = ConcreteArray
xla.device_put_handlers[ShardedDeviceArray] = xla._device_put_array
xla.pytype_aval_mappings[ShardedDeviceArray] = op.attrgetter('aval')
xla.canonicalize_dtype_handlers[ShardedDeviceArray] = identity

def _register_handlers_for_sharded_device_array(sda):
shard_arg_handlers[sda] = _shard_sharded_device_array_slow_path
xb.register_constant_handler(sda, _sharded_device_array_constant_handler)

core.pytype_aval_mappings[sda] = ConcreteArray
xla.device_put_handlers[sda] = xla._device_put_array
xla.pytype_aval_mappings[sda] = op.attrgetter("aval")
xla.canonicalize_dtype_handlers[sda] = identity

_register_handlers_for_sharded_device_array(_ShardedDeviceArray)
_register_handlers_for_sharded_device_array(pmap_lib.ShardedDeviceArray)

### the xla_pmap primitive and its rules are comparable to xla_call in xla.py

Expand Down Expand Up @@ -1114,7 +1213,9 @@ def replicate(val, axis_size, nrep, devices=None, backend=None, in_axis=0):
# TODO(skye): figure out how partitioning should work here
sharding_spec = _pmap_sharding_spec(nrep, axis_size, 1, None, aval, in_axis)
device_buffers = device_put(val, devices, replicate=True)
return ShardedDeviceArray(replicated_aval, sharding_spec, device_buffers)
return make_sharded_device_array(replicated_aval, sharding_spec,
device_buffers)


def _pmap_sharding_spec(nrep, axis_size, npart, parts, sharded_aval, map_axis: Optional[int]):
"""Sharding spec for arguments or results of a pmap.
Expand Down
5 changes: 3 additions & 2 deletions tests/pmap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,8 @@ def testShardedDeviceArrays(self):
self.assertAllClose(z, 2 * 2 * x, check_dtypes=False)

# test that we can handle device movement on dispatch
y = pxla.ShardedDeviceArray(y.aval, y.sharding_spec, y.device_buffers[::-1])
y = pxla.make_sharded_device_array(y.aval, y.sharding_spec,
y.device_buffers[::-1])
z = f(y)
self.assertAllClose(z, 2 * 2 * x[::-1], check_dtypes=False)

Expand Down Expand Up @@ -1263,7 +1264,7 @@ def testReshardInput(self):
sharding_spec = pxla.ShardingSpec(
sharding=map(pxla.Chunked, ([2], [2])),
mesh_mapping=map(pxla.ShardedAxis, (0, 1)))
arr = pxla.ShardedDeviceArray(aval, sharding_spec, bufs)
arr = pxla.make_sharded_device_array(aval, sharding_spec, bufs)

r = self.pmap(lambda x: x + 1)(arr)
self.assertAllClose(r, arr + 1)
Expand Down

0 comments on commit 45aaf8a

Please sign in to comment.