From a80cbc5626b6ca4be64fc2806ff9263c93035316 Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Fri, 18 Aug 2023 14:19:49 -0700 Subject: [PATCH] [JAX] Implement the `stream` argument to jax.Array.__dlpack__ for CUDA GPU Also implements jax.Array.__dlpack_device__. See https://dmlc.github.io/dlpack/latest/python_spec.html This requires plumbing the raw CUDA stream pointer through PJRT and StreamExecutor (since the GPU PJRT implementation is still based on SE). This is done via the new PJRT method ExternalReference::WaitUntilBufferReadyOnStream. I haven't plumbed this through the PJRT C API yet, because I'm still debating whether this should be part of the main API or a GPU-specific extension (plus either way it should probably be its own change). PiperOrigin-RevId: 558245360 --- jax/_src/array.py | 43 ++++++++++++++++++++++++-- jax/_src/dlpack.py | 24 +++++++++++--- tests/pytorch_interoperability_test.py | 27 +++++++++++++++- 3 files changed, 87 insertions(+), 7 deletions(-) diff --git a/jax/_src/array.py b/jax/_src/array.py index 2fe4b048c8cc..41cbbe723a2a 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -15,6 +15,7 @@ from __future__ import annotations from collections import defaultdict +import enum import math import operator as op import numpy as np @@ -50,6 +51,13 @@ PRNGKeyArrayImpl = Any # TODO(jakevdp): fix cycles and import this. +# Mirror of dlpack.h enum +class DLDeviceType(enum.IntEnum): + kDLCPU = 1 + kDLCUDA = 2 + kDLROCM = 10 + + class Shard: """A single data shard of an Array. @@ -362,9 +370,40 @@ def is_fully_addressable(self) -> bool: def __array__(self, dtype=None, context=None): return np.asarray(self._value, dtype=dtype) - def __dlpack__(self): + def __dlpack__(self, stream: int | None = None): + if len(self._arrays) != 1: + raise ValueError("__dlpack__ only supported for unsharded arrays.") from jax._src.dlpack import to_dlpack # pylint: disable=g-import-not-at-top - return to_dlpack(self) + return to_dlpack(self, stream=stream) + + def __dlpack_device__(self) -> tuple[DLDeviceType, int]: + if len(self._arrays) != 1: + raise ValueError("__dlpack__ only supported for unsharded arrays.") + + if self.platform() == "cpu": + return DLDeviceType.kDLCPU, 0 + + elif self.platform() == "gpu": + platform_version = self.device().client.platform_version + if "cuda" in platform_version: + dl_device_type = DLDeviceType.kDLCUDA + elif "rocm" in platform_version: + dl_device_type = DLDeviceType.kDLROCM + else: + raise ValueError("Unknown GPU platform for __dlpack__: " + f"{platform_version}") + + local_hardware_id = self.device().local_hardware_id + if local_hardware_id is None: + raise ValueError("Couldn't get local_hardware_id for __dlpack__") + + return dl_device_type, local_hardware_id + + else: + raise ValueError( + "__dlpack__ device only supported for CPU and GPU, got platform: " + f"{self.platform()}" + ) def __reduce__(self): fun, args, arr_state = self._value.__reduce__() # type: ignore diff --git a/jax/_src/dlpack.py b/jax/_src/dlpack.py index b5013df64fb3..70e12aae926c 100644 --- a/jax/_src/dlpack.py +++ b/jax/_src/dlpack.py @@ -12,11 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from jax import numpy as jnp from jax._src import array -from jax._src.typing import Array from jax._src import xla_bridge from jax._src.lib import xla_client +from jax._src.lib import xla_extension_version +from jax._src.typing import Array SUPPORTED_DTYPES = frozenset({ @@ -25,7 +28,8 @@ jnp.float64, jnp.complex64, jnp.complex128}) -def to_dlpack(x: Array, take_ownership: bool = False): +def to_dlpack(x: Array, take_ownership: bool = False, + stream: int | None = None): """Returns a DLPack tensor that encapsulates a ``DeviceArray`` `x`. Takes ownership of the contents of ``x``; leaves `x` in an invalid/deleted @@ -38,13 +42,25 @@ def to_dlpack(x: Array, take_ownership: bool = False): it were deleted. If ``False``, JAX retains ownership of the buffer; it is undefined behavior if the DLPack consumer writes to a buffer that JAX owns. + stream: optional platform-dependent stream to wait on until the buffer is + ready. This corresponds to the `stream` argument to __dlpack__ documented + in https://dmlc.github.io/dlpack/latest/python_spec.html. """ if not isinstance(x, array.ArrayImpl): raise TypeError("Argument to to_dlpack must be a jax.Array, " f"got {type(x)}") assert len(x.devices()) == 1 - return xla_client._xla.buffer_to_dlpack_managed_tensor( - x.addressable_data(0), take_ownership=take_ownership) # type: ignore + if xla_extension_version >= 186: + return xla_client._xla.buffer_to_dlpack_managed_tensor( + x.addressable_data(0), take_ownership=take_ownership, stream=stream + ) # type: ignore + else: + if stream is not None: + raise ValueError( + "passing `stream` argument to to_dlpack requires jaxlib >= 0.4.15") + return xla_client._xla.buffer_to_dlpack_managed_tensor( + x.addressable_data(0), take_ownership=take_ownership) # type: ignore + def from_dlpack(dlpack): diff --git a/tests/pytorch_interoperability_test.py b/tests/pytorch_interoperability_test.py index ec63309ec2b9..1d170ec105d1 100644 --- a/tests/pytorch_interoperability_test.py +++ b/tests/pytorch_interoperability_test.py @@ -21,6 +21,7 @@ import jax.dlpack from jax._src import xla_bridge from jax._src.lib import xla_client +from jax._src.lib import xla_extension_version import jax.numpy as jnp from jax._src import test_util as jtu @@ -83,6 +84,31 @@ def testJaxToTorch(self, shape, dtype): else: self.assertAllClose(np, y.cpu().numpy()) + @jtu.sample_product(shape=all_shapes, dtype=torch_dtypes) + def testJaxArrayToTorch(self, shape, dtype): + if xla_extension_version < 186: + self.SkipTest("Need xla_extension_version >= 186") + + if not config.x64_enabled and dtype in [ + jnp.int64, + jnp.float64, + jnp.complex128, + ]: + self.skipTest("x64 types are disabled by jax_enable_x64") + rng = jtu.rand_default(self.rng()) + np = rng(shape, dtype) + # Test across all devices + for device in jax.local_devices(): + x = jax.device_put(np, device) + y = torch.utils.dlpack.from_dlpack(x) + if dtype == jnp.bfloat16: + # .numpy() doesn't work on Torch bfloat16 tensors. + self.assertAllClose( + np, y.cpu().view(torch.int16).numpy().view(jnp.bfloat16) + ) + else: + self.assertAllClose(np, y.cpu().numpy()) + def testTorchToJaxInt64(self): # See https://github.com/google/jax/issues/11895 x = jax.dlpack.from_dlpack( @@ -90,7 +116,6 @@ def testTorchToJaxInt64(self): dtype_expected = jnp.int64 if config.x64_enabled else jnp.int32 self.assertEqual(x.dtype, dtype_expected) - @jtu.sample_product(shape=all_shapes, dtype=torch_dtypes) def testTorchToJax(self, shape, dtype): if not config.x64_enabled and dtype in [jnp.int64, jnp.float64,