Skip to content

Commit

Permalink
[JAX] Enable C++ device arrays by default.
Browse files Browse the repository at this point in the history
[XLA:Python] Relax constraints on .aval and ._device attributes on C++ buffer objects. The constraints cause more problems than they solve. Switch _device to be a C++ attribute rather than a Python attribute. This avoids some unnecessary Python attribute parsing in the JIT dispatch path.

Change PyBuffer objects to call themselves `DeviceArray` in Python so as not to surprise JAX users.

PiperOrigin-RevId: 362969997
  • Loading branch information
hawkinsp authored and jax authors committed Mar 15, 2021
1 parent 80966fe commit 9a2a1ad
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
15 changes: 10 additions & 5 deletions jax/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from jax._src.pprint_util import pp
from .._src.util import (partial, partialmethod, cache, prod, unzip2,
extend_name_stack, wrap_name, safe_zip, safe_map)
from .. import lib
from ..lib import xla_bridge as xb
from ..lib import xla_client as xc
from . import partial_eval as pe
Expand Down Expand Up @@ -1026,7 +1027,7 @@ def _forward_method(attrname, self, fun, *args):

_CppDeviceArray: DeviceArrayProtocol = xc.Buffer

_EXPERIMENTAL_CPP_DEVICE_ARRAY = False
_EXPERIMENTAL_CPP_DEVICE_ARRAY = lib._xla_extension_version >= 7


def make_device_array(
Expand All @@ -1039,9 +1040,13 @@ def make_device_array(
This is to be used only within JAX. It will return either a PythonDeviceArray
or a C++ equivalent implementation.
"""
if _EXPERIMENTAL_CPP_DEVICE_ARRAY:
assert isinstance(device_buffer, _CppDeviceArray)
device_buffer._device = device # pylint: disable=protected-access
if (_EXPERIMENTAL_CPP_DEVICE_ARRAY and
isinstance(device_buffer, _CppDeviceArray)):

if device_buffer.aval == aval and device_buffer._device == device:
return device_buffer
device_buffer = device_buffer.clone()
device_buffer._device = device
device_buffer.aval = aval
return device_buffer

Expand Down Expand Up @@ -1313,7 +1318,7 @@ def _copy_device_array_to_device(x: Union[DeviceArrayProtocol, _DeviceArray], de
# buffers from different XLA backends are passed through the host.
backend = xb.get_device_backend(device)
moved_buf = backend.buffer_from_pyval(x.device_buffer.to_py(), device)
return _DeviceArray(x.aval, device, moved_buf)
return make_device_array(x.aval, device, moved_buf)


def _device_put_impl(x, device: Optional[Device] = None):
Expand Down
2 changes: 1 addition & 1 deletion tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1868,7 +1868,7 @@ def test_device_array_repr(self):

def test_device_array_hash(self):
rep = jnp.ones(()) + 1.
self.assertIsInstance(rep, jax.interpreters.xla._DeviceArray)
self.assertIsInstance(rep, jax.interpreters.xla.DeviceArray)
msg = "JAX DeviceArray, like numpy.ndarray, is not hashable."
with self.assertRaisesRegex(TypeError, msg):
hash(rep)
Expand Down

0 comments on commit 9a2a1ad

Please sign in to comment.