diff --git a/jax/_src/array.py b/jax/_src/array.py index 2fe4b048c8cc..41cbbe723a2a 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -15,6 +15,7 @@ from __future__ import annotations from collections import defaultdict +import enum import math import operator as op import numpy as np @@ -50,6 +51,13 @@ PRNGKeyArrayImpl = Any # TODO(jakevdp): fix cycles and import this. +# Mirror of dlpack.h enum +class DLDeviceType(enum.IntEnum): + kDLCPU = 1 + kDLCUDA = 2 + kDLROCM = 10 + + class Shard: """A single data shard of an Array. @@ -362,9 +370,40 @@ def is_fully_addressable(self) -> bool: def __array__(self, dtype=None, context=None): return np.asarray(self._value, dtype=dtype) - def __dlpack__(self): + def __dlpack__(self, stream: int | None = None): + if len(self._arrays) != 1: + raise ValueError("__dlpack__ only supported for unsharded arrays.") from jax._src.dlpack import to_dlpack # pylint: disable=g-import-not-at-top - return to_dlpack(self) + return to_dlpack(self, stream=stream) + + def __dlpack_device__(self) -> tuple[DLDeviceType, int]: + if len(self._arrays) != 1: + raise ValueError("__dlpack__ only supported for unsharded arrays.") + + if self.platform() == "cpu": + return DLDeviceType.kDLCPU, 0 + + elif self.platform() == "gpu": + platform_version = self.device().client.platform_version + if "cuda" in platform_version: + dl_device_type = DLDeviceType.kDLCUDA + elif "rocm" in platform_version: + dl_device_type = DLDeviceType.kDLROCM + else: + raise ValueError("Unknown GPU platform for __dlpack__: " + f"{platform_version}") + + local_hardware_id = self.device().local_hardware_id + if local_hardware_id is None: + raise ValueError("Couldn't get local_hardware_id for __dlpack__") + + return dl_device_type, local_hardware_id + + else: + raise ValueError( + "__dlpack__ device only supported for CPU and GPU, got platform: " + f"{self.platform()}" + ) def __reduce__(self): fun, args, arr_state = self._value.__reduce__() # type: ignore diff --git a/jax/_src/dlpack.py b/jax/_src/dlpack.py index b5013df64fb3..70e12aae926c 100644 --- a/jax/_src/dlpack.py +++ b/jax/_src/dlpack.py @@ -12,11 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from jax import numpy as jnp from jax._src import array -from jax._src.typing import Array from jax._src import xla_bridge from jax._src.lib import xla_client +from jax._src.lib import xla_extension_version +from jax._src.typing import Array SUPPORTED_DTYPES = frozenset({ @@ -25,7 +28,8 @@ jnp.float64, jnp.complex64, jnp.complex128}) -def to_dlpack(x: Array, take_ownership: bool = False): +def to_dlpack(x: Array, take_ownership: bool = False, + stream: int | None = None): """Returns a DLPack tensor that encapsulates a ``DeviceArray`` `x`. Takes ownership of the contents of ``x``; leaves `x` in an invalid/deleted @@ -38,13 +42,25 @@ def to_dlpack(x: Array, take_ownership: bool = False): it were deleted. If ``False``, JAX retains ownership of the buffer; it is undefined behavior if the DLPack consumer writes to a buffer that JAX owns. + stream: optional platform-dependent stream to wait on until the buffer is + ready. This corresponds to the `stream` argument to __dlpack__ documented + in https://dmlc.github.io/dlpack/latest/python_spec.html. """ if not isinstance(x, array.ArrayImpl): raise TypeError("Argument to to_dlpack must be a jax.Array, " f"got {type(x)}") assert len(x.devices()) == 1 - return xla_client._xla.buffer_to_dlpack_managed_tensor( - x.addressable_data(0), take_ownership=take_ownership) # type: ignore + if xla_extension_version >= 186: + return xla_client._xla.buffer_to_dlpack_managed_tensor( + x.addressable_data(0), take_ownership=take_ownership, stream=stream + ) # type: ignore + else: + if stream is not None: + raise ValueError( + "passing `stream` argument to to_dlpack requires jaxlib >= 0.4.15") + return xla_client._xla.buffer_to_dlpack_managed_tensor( + x.addressable_data(0), take_ownership=take_ownership) # type: ignore + def from_dlpack(dlpack): diff --git a/tests/pytorch_interoperability_test.py b/tests/pytorch_interoperability_test.py index ec63309ec2b9..1d170ec105d1 100644 --- a/tests/pytorch_interoperability_test.py +++ b/tests/pytorch_interoperability_test.py @@ -21,6 +21,7 @@ import jax.dlpack from jax._src import xla_bridge from jax._src.lib import xla_client +from jax._src.lib import xla_extension_version import jax.numpy as jnp from jax._src import test_util as jtu @@ -83,6 +84,31 @@ def testJaxToTorch(self, shape, dtype): else: self.assertAllClose(np, y.cpu().numpy()) + @jtu.sample_product(shape=all_shapes, dtype=torch_dtypes) + def testJaxArrayToTorch(self, shape, dtype): + if xla_extension_version < 186: + self.SkipTest("Need xla_extension_version >= 186") + + if not config.x64_enabled and dtype in [ + jnp.int64, + jnp.float64, + jnp.complex128, + ]: + self.skipTest("x64 types are disabled by jax_enable_x64") + rng = jtu.rand_default(self.rng()) + np = rng(shape, dtype) + # Test across all devices + for device in jax.local_devices(): + x = jax.device_put(np, device) + y = torch.utils.dlpack.from_dlpack(x) + if dtype == jnp.bfloat16: + # .numpy() doesn't work on Torch bfloat16 tensors. + self.assertAllClose( + np, y.cpu().view(torch.int16).numpy().view(jnp.bfloat16) + ) + else: + self.assertAllClose(np, y.cpu().numpy()) + def testTorchToJaxInt64(self): # See https://github.com/google/jax/issues/11895 x = jax.dlpack.from_dlpack( @@ -90,7 +116,6 @@ def testTorchToJaxInt64(self): dtype_expected = jnp.int64 if config.x64_enabled else jnp.int32 self.assertEqual(x.dtype, dtype_expected) - @jtu.sample_product(shape=all_shapes, dtype=torch_dtypes) def testTorchToJax(self, shape, dtype): if not config.x64_enabled and dtype in [jnp.int64, jnp.float64,