Skip to content

Commit

Permalink
[JAX] Implement the stream argument to jax.Array.__dlpack__ for CUD…
Browse files Browse the repository at this point in the history
…A GPU

Also implements jax.Array.__dlpack_device__. See
https://dmlc.github.io/dlpack/latest/python_spec.html

This requires plumbing the raw CUDA stream pointer through PJRT and
StreamExecutor (since the GPU PJRT implementation is still based on
SE). This is done via the new PJRT method
ExternalReference::WaitUntilBufferReadyOnStream.

I haven't plumbed this through the PJRT C API yet, because I'm still
debating whether this should be part of the main API or a GPU-specific
extension (plus either way it should probably be its own change).

PiperOrigin-RevId: 558245360
  • Loading branch information
skye authored and jax authors committed Aug 18, 2023
1 parent 3119b43 commit a80cbc5
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 7 deletions.
43 changes: 41 additions & 2 deletions jax/_src/array.py
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

from collections import defaultdict
import enum
import math
import operator as op
import numpy as np
Expand Down Expand Up @@ -50,6 +51,13 @@
PRNGKeyArrayImpl = Any # TODO(jakevdp): fix cycles and import this.


# Mirror of dlpack.h enum
class DLDeviceType(enum.IntEnum):
kDLCPU = 1
kDLCUDA = 2
kDLROCM = 10


class Shard:
"""A single data shard of an Array.
Expand Down Expand Up @@ -362,9 +370,40 @@ def is_fully_addressable(self) -> bool:
def __array__(self, dtype=None, context=None):
return np.asarray(self._value, dtype=dtype)

def __dlpack__(self):
def __dlpack__(self, stream: int | None = None):
if len(self._arrays) != 1:
raise ValueError("__dlpack__ only supported for unsharded arrays.")
from jax._src.dlpack import to_dlpack # pylint: disable=g-import-not-at-top
return to_dlpack(self)
return to_dlpack(self, stream=stream)

def __dlpack_device__(self) -> tuple[DLDeviceType, int]:
if len(self._arrays) != 1:
raise ValueError("__dlpack__ only supported for unsharded arrays.")

if self.platform() == "cpu":
return DLDeviceType.kDLCPU, 0

elif self.platform() == "gpu":
platform_version = self.device().client.platform_version
if "cuda" in platform_version:
dl_device_type = DLDeviceType.kDLCUDA
elif "rocm" in platform_version:
dl_device_type = DLDeviceType.kDLROCM
else:
raise ValueError("Unknown GPU platform for __dlpack__: "
f"{platform_version}")

local_hardware_id = self.device().local_hardware_id
if local_hardware_id is None:
raise ValueError("Couldn't get local_hardware_id for __dlpack__")

return dl_device_type, local_hardware_id

else:
raise ValueError(
"__dlpack__ device only supported for CPU and GPU, got platform: "
f"{self.platform()}"
)

def __reduce__(self):
fun, args, arr_state = self._value.__reduce__() # type: ignore
Expand Down
24 changes: 20 additions & 4 deletions jax/_src/dlpack.py
Expand Up @@ -12,11 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from jax import numpy as jnp
from jax._src import array
from jax._src.typing import Array
from jax._src import xla_bridge
from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version
from jax._src.typing import Array


SUPPORTED_DTYPES = frozenset({
Expand All @@ -25,7 +28,8 @@
jnp.float64, jnp.complex64, jnp.complex128})


def to_dlpack(x: Array, take_ownership: bool = False):
def to_dlpack(x: Array, take_ownership: bool = False,
stream: int | None = None):
"""Returns a DLPack tensor that encapsulates a ``DeviceArray`` `x`.
Takes ownership of the contents of ``x``; leaves `x` in an invalid/deleted
Expand All @@ -38,13 +42,25 @@ def to_dlpack(x: Array, take_ownership: bool = False):
it were deleted. If ``False``, JAX retains ownership of the buffer; it is
undefined behavior if the DLPack consumer writes to a buffer that JAX
owns.
stream: optional platform-dependent stream to wait on until the buffer is
ready. This corresponds to the `stream` argument to __dlpack__ documented
in https://dmlc.github.io/dlpack/latest/python_spec.html.
"""
if not isinstance(x, array.ArrayImpl):
raise TypeError("Argument to to_dlpack must be a jax.Array, "
f"got {type(x)}")
assert len(x.devices()) == 1
return xla_client._xla.buffer_to_dlpack_managed_tensor(
x.addressable_data(0), take_ownership=take_ownership) # type: ignore
if xla_extension_version >= 186:
return xla_client._xla.buffer_to_dlpack_managed_tensor(
x.addressable_data(0), take_ownership=take_ownership, stream=stream
) # type: ignore
else:
if stream is not None:
raise ValueError(
"passing `stream` argument to to_dlpack requires jaxlib >= 0.4.15")
return xla_client._xla.buffer_to_dlpack_managed_tensor(
x.addressable_data(0), take_ownership=take_ownership) # type: ignore



def from_dlpack(dlpack):
Expand Down
27 changes: 26 additions & 1 deletion tests/pytorch_interoperability_test.py
Expand Up @@ -21,6 +21,7 @@
import jax.dlpack
from jax._src import xla_bridge
from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version
import jax.numpy as jnp
from jax._src import test_util as jtu

Expand Down Expand Up @@ -83,14 +84,38 @@ def testJaxToTorch(self, shape, dtype):
else:
self.assertAllClose(np, y.cpu().numpy())

@jtu.sample_product(shape=all_shapes, dtype=torch_dtypes)
def testJaxArrayToTorch(self, shape, dtype):
if xla_extension_version < 186:
self.SkipTest("Need xla_extension_version >= 186")

if not config.x64_enabled and dtype in [
jnp.int64,
jnp.float64,
jnp.complex128,
]:
self.skipTest("x64 types are disabled by jax_enable_x64")
rng = jtu.rand_default(self.rng())
np = rng(shape, dtype)
# Test across all devices
for device in jax.local_devices():
x = jax.device_put(np, device)
y = torch.utils.dlpack.from_dlpack(x)
if dtype == jnp.bfloat16:
# .numpy() doesn't work on Torch bfloat16 tensors.
self.assertAllClose(
np, y.cpu().view(torch.int16).numpy().view(jnp.bfloat16)
)
else:
self.assertAllClose(np, y.cpu().numpy())

def testTorchToJaxInt64(self):
# See https://github.com/google/jax/issues/11895
x = jax.dlpack.from_dlpack(
torch.utils.dlpack.to_dlpack(torch.ones((2, 3), dtype=torch.int64)))
dtype_expected = jnp.int64 if config.x64_enabled else jnp.int32
self.assertEqual(x.dtype, dtype_expected)


@jtu.sample_product(shape=all_shapes, dtype=torch_dtypes)
def testTorchToJax(self, shape, dtype):
if not config.x64_enabled and dtype in [jnp.int64, jnp.float64,
Expand Down

0 comments on commit a80cbc5

Please sign in to comment.