Skip to content

Commit

Permalink
Enable constructing C++ ShardedDeviceArray instead of the Python object.
Browse files Browse the repository at this point in the history
This also contains a cleanup for lax_numpy:
- Create functions to setup attributes, so we clearly see what is happening. I am keeping things as they are, but we could, if we wanted to, set all the attributes that we set on `_DeviceArray`/`_CppDeviceArray`/`pmap_lib.ShardedDeviceArray` directly on the `DeviceArray` base class (not sure if it's slower or faster, etc), or we could set them all on the leaf nodes (maybe it's faster).
- This also remove some temporary objects in the scope, and I am removing `operator_name`, which for me should not be exposed (it values "round" at HEAD).

PiperOrigin-RevId: 390072555
  • Loading branch information
jblespiau authored and jax authors committed Aug 11, 2021
1 parent 6ce4504 commit 7200a7a
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 65 deletions.
122 changes: 65 additions & 57 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from jax.core import UnshapedArray, ShapedArray, ConcreteArray, canonicalize_shape
from jax.config import config
from jax.interpreters.xla import DeviceArray, _DeviceArray, _CppDeviceArray
from jax.interpreters import pxla
from jax import lax
from jax._src.lax.lax import _device_put_raw
from jax import ops
Expand Down Expand Up @@ -5951,42 +5952,6 @@ def _operator_round(number, ndigits=None):
argpartition = _not_implemented(np.argpartition)
_NOT_IMPLEMENTED = ['argpartition']

# Set up operator, method, and property forwarding on Tracer instances containing
# ShapedArray avals by following the forwarding conventions for Tracer.
# Forward operators using a single-underscore-prefix naming convention:
for operator_name, function in _operators.items():
setattr(ShapedArray, "_{}".format(operator_name), staticmethod(function))
# Forward methods and properties using core.aval_method and core.aval_property:
for method_name in _nondiff_methods + _diff_methods:
setattr(ShapedArray, method_name, core.aval_method(globals()[method_name]))
setattr(ShapedArray, "reshape", core.aval_method(_reshape))
setattr(ShapedArray, "transpose", core.aval_method(_transpose))
setattr(ShapedArray, "flatten", core.aval_method(ravel))
setattr(ShapedArray, "T", core.aval_property(transpose))
setattr(ShapedArray, "real", core.aval_property(real))
setattr(ShapedArray, "imag", core.aval_property(imag))
setattr(ShapedArray, "astype", core.aval_method(_astype))
setattr(ShapedArray, "view", core.aval_method(_view))
setattr(ShapedArray, "nbytes", core.aval_property(_nbytes))


# Forward operators, methods, and properties on DeviceArray to lax_numpy
# functions (with no Tracers involved; this forwarding is direct)
for device_array in [DeviceArray]:
for operator_name, function in _operators.items():
setattr(device_array, "__{}__".format(operator_name), function)
for method_name in _nondiff_methods + _diff_methods:
setattr(device_array, method_name, globals()[method_name])
setattr(device_array, "reshape", _reshape)
setattr(device_array, "transpose", _transpose)
setattr(device_array, "flatten", ravel)
setattr(device_array, "T", property(transpose))
setattr(device_array, "real", property(real))
setattr(device_array, "imag", property(imag))
setattr(device_array, "astype", _astype)
setattr(device_array, "view", _view)
setattr(device_array, "nbytes", property(_nbytes))


# Experimental support for NumPy's module dispatch with NEP-37.
# Currently requires https://github.com/seberg/numpy-dispatch
Expand All @@ -5999,26 +5964,10 @@ def __array_module__(self, types):
else:
return NotImplemented

setattr(ShapedArray, "_array_module", staticmethod(__array_module__))
setattr(_DeviceArray, "__array_module__", __array_module__)
setattr(_CppDeviceArray, "__array_module__", __array_module__)


# Extra methods that are handy
setattr(ShapedArray, "broadcast", core.aval_method(lax.broadcast))
setattr(ShapedArray, "broadcast_in_dim", core.aval_method(lax.broadcast_in_dim))
setattr(ShapedArray, "split", core.aval_method(split))
for device_array in [_DeviceArray, _CppDeviceArray]:
setattr(device_array, "broadcast", lax.broadcast)
setattr(device_array, "broadcast_in_dim", lax.broadcast_in_dim)
setattr(device_array, "split", split)

def _compress_method(a, condition, axis=None, out=None):
return compress(condition, a, axis, out)

setattr(ShapedArray, "compress", _compress_method)
setattr(_DeviceArray, "compress", _compress_method)
setattr(_CppDeviceArray, "compress", _compress_method)

@partial(jit, static_argnums=(1,2,3))
def _multi_slice(arr,
Expand All @@ -6037,8 +5986,6 @@ def _multi_slice(arr,
sliced = lax.squeeze(sliced, removed)
results.append(sliced)
return results
setattr(_DeviceArray, "_multi_slice", _multi_slice)
setattr(_CppDeviceArray, "_multi_slice", _multi_slice)


# Syntactic sugar for scatter operations.
Expand Down Expand Up @@ -6206,6 +6153,67 @@ def max(self, values, indices_are_sorted=False, unique_indices=False):
unique_indices=unique_indices)


setattr(_DeviceArray, "at", property(_IndexUpdateHelper))
setattr(_CppDeviceArray, "at", property(_IndexUpdateHelper))
setattr(ShapedArray, "at", core.aval_property(_IndexUpdateHelper))
def _set_shaped_array_attributes(shaped_array):
# Set up operator, method, and property forwarding on Tracer instances
# containing
# ShapedArray avals by following the forwarding conventions for Tracer.
# Forward operators using a single-underscore-prefix naming convention:
for operator_name, function in _operators.items():
setattr(shaped_array, "_{}".format(operator_name), staticmethod(function))
# Forward methods and properties using core.{aval_method, aval_property}:
for method_name in _nondiff_methods + _diff_methods:
setattr(shaped_array, method_name, core.aval_method(globals()[method_name]))
setattr(shaped_array, "reshape", core.aval_method(_reshape))
setattr(shaped_array, "transpose", core.aval_method(_transpose))
setattr(shaped_array, "flatten", core.aval_method(ravel))
setattr(shaped_array, "T", core.aval_property(transpose))
setattr(shaped_array, "real", core.aval_property(real))
setattr(shaped_array, "imag", core.aval_property(imag))
setattr(shaped_array, "astype", core.aval_method(_astype))
setattr(shaped_array, "view", core.aval_method(_view))
setattr(shaped_array, "nbytes", core.aval_property(_nbytes))

setattr(shaped_array, "_array_module", staticmethod(__array_module__))
setattr(shaped_array, "broadcast", core.aval_method(lax.broadcast))
setattr(shaped_array, "broadcast_in_dim", core.aval_method(lax.broadcast_in_dim))
setattr(shaped_array, "split", core.aval_method(split))
setattr(shaped_array, "compress", _compress_method)
setattr(shaped_array, "at", core.aval_property(_IndexUpdateHelper))

_set_shaped_array_attributes(ShapedArray)


def _set_device_array_base_attributes(device_array):
# Forward operators, methods, and properties on DeviceArray to lax_numpy
# functions (with no Tracers involved; this forwarding is direct)
for operator_name, function in _operators.items():
setattr(device_array, "__{}__".format(operator_name), function)
for method_name in _nondiff_methods + _diff_methods:
setattr(device_array, method_name, globals()[method_name])
setattr(device_array, "reshape", _reshape)
setattr(device_array, "transpose", _transpose)
setattr(device_array, "flatten", ravel)
setattr(device_array, "T", property(transpose))
setattr(device_array, "real", property(real))
setattr(device_array, "imag", property(imag))
setattr(device_array, "astype", _astype)
setattr(device_array, "view", _view)
setattr(device_array, "nbytes", property(_nbytes))

_set_device_array_base_attributes(DeviceArray)


def _set_device_array_attributes(device_array):
setattr(device_array, "__array_module__", __array_module__)
# Extra methods that are handy
setattr(device_array, "broadcast", lax.broadcast)
setattr(device_array, "broadcast_in_dim", lax.broadcast_in_dim)
setattr(device_array, "split", split)
setattr(device_array, "compress", _compress_method)
setattr(device_array, "_multi_slice", _multi_slice)
setattr(device_array, "at", property(_IndexUpdateHelper))

_set_device_array_attributes(_DeviceArray)
_set_device_array_attributes(_CppDeviceArray)
_set_device_array_attributes(pxla._ShardedDeviceArray)
_set_device_array_attributes(pxla.pmap_lib.ShardedDeviceArray)
13 changes: 6 additions & 7 deletions jax/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,9 +455,8 @@ def array_result_handler(sharding_spec, indices, aval: ShapedArray):

### lazy device-memory persistence and result handling

# TODO(jblespiau): Clean all occurrences of the SDA constructor before
# switching this to True.
_USE_EXPERIMENTAL_CPP_SDA = False
# TODO(jblespiau): Remove when jaxlib 0.1.71 is the minimal version.
_USE_CPP_SDA = _xla_extension_version >= 33


def make_sharded_device_array(
Expand Down Expand Up @@ -489,7 +488,7 @@ def make_sharded_device_array(
if indices is None:
indices = spec_to_indices(aval.shape, sharding_spec)

if (_USE_EXPERIMENTAL_CPP_SDA and
if (_USE_CPP_SDA and
(not device_buffers or
isinstance(device_buffers[0], xb.xla_client.Buffer))):
return pmap_lib.ShardedDeviceArray(aval, sharding_spec, device_buffers,
Expand All @@ -498,7 +497,7 @@ def make_sharded_device_array(
return _ShardedDeviceArray(aval, sharding_spec, device_buffers, indices)


if _USE_EXPERIMENTAL_CPP_SDA:
if _USE_CPP_SDA:
ShardedDeviceArrayBase = pmap_lib.ShardedDeviceArrayBase # type: ignore
# We want the C++ SDA to extend the DeviceArrayBase. We want this both to
# benefit from its methods, and to have isinstance(x, DeviceArray) return true
Expand Down Expand Up @@ -548,7 +547,7 @@ def __init__(
# We don't use `super`, following pybind11 guidelines:
# https://pybind11.readthedocs.io/en/stable/advanced/classes.html#overriding-virtual-functions-in-python
xla.DeviceArray.__init__(self)
if _USE_EXPERIMENTAL_CPP_SDA:
if _USE_CPP_SDA:
ShardedDeviceArrayBase.__init__(self) # type: ignore

# TODO(skye): this is temporary staging while we switch users over to
Expand Down Expand Up @@ -680,7 +679,7 @@ def _sda__getitem__(self, idx):


ShardedDeviceArray: Type[object]
if _USE_EXPERIMENTAL_CPP_SDA:
if _USE_CPP_SDA:
ShardedDeviceArray = pmap_lib.ShardedDeviceArrayBase
else:
ShardedDeviceArray = _ShardedDeviceArray
Expand Down
2 changes: 1 addition & 1 deletion jax/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
nanmedian, nanpercentile, nanquantile,
nanmax, nanmean, nanmin, nanprod, nanstd, nansum, nanvar, ndarray, ndim,
negative, newaxis, nextafter, nonzero, not_equal, number,
object_, ogrid, ones, ones_like, operator_name, outer, packbits, pad, percentile,
object_, ogrid, ones, ones_like, outer, packbits, pad, percentile,
pi, piecewise, poly, polyadd, polyder, polyint, polymul, polysub, polyval, positive, power,
prod, product, promote_types, ptp, quantile,
r_, rad2deg, radians, ravel, ravel_multi_index, real, reciprocal, remainder, repeat, reshape,
Expand Down

0 comments on commit 7200a7a

Please sign in to comment.