Skip to content

Commit

Permalink
Add support for max_version, dl_device, copy kwargs in __dlpack__
Browse files Browse the repository at this point in the history
  • Loading branch information
Micky774 authored and rajasekharporeddy committed Apr 12, 2024
1 parent ec81efe commit e094a8b
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 42 deletions.
24 changes: 19 additions & 5 deletions jax/_src/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
SingleDeviceSharding, XLACompatibleSharding, PmapSharding,
device_replica_id_map, hashed_index)
from jax._src.layout import DeviceLocalLayout, Layout
from jax._src.typing import ArrayLike
from jax._src.typing import ArrayLike, DLDeviceType
from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method


Expand Down Expand Up @@ -404,11 +404,25 @@ def __array__(self, dtype=None, context=None, copy=None):
kwds = {} if copy is None else {'copy': copy}
return np.asarray(self._value, dtype=dtype, **kwds)

def __dlpack__(self, *, stream: int | Any | None = None):
if len(self._arrays) != 1:
raise BufferError("__dlpack__ only supported for unsharded arrays.")
def __dlpack__(self, *, stream: int | Any | None = None,
max_version: tuple[int, int] | None = None,
dl_device: tuple[DLDeviceType, int] | None = None,
copy: bool | None = None):
from jax._src.dlpack import to_dlpack # pylint: disable=g-import-not-at-top
return to_dlpack(self, stream=stream)

device_set = self.sharding.device_set
if len(device_set) > 1:
raise BufferError(
"to_dlpack can only pack a dlpack tensor from an array on a singular "
f"device, but an array with a Sharding over {len(device_set)} devices "
"was provided."
)
device, = device_set
return to_dlpack(self, stream=stream,
max_version=max_version,
src_device=device,
dl_device=dl_device,
copy=copy)

def __dlpack_device__(self) -> tuple[enum.Enum, int]:
if len(self._arrays) != 1:
Expand Down
117 changes: 93 additions & 24 deletions jax/_src/dlpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,21 @@

from __future__ import annotations

import enum
from typing import Any
import warnings

from jax._src.api import device_put
from jax import numpy as jnp
from jax._src import array
from jax._src import xla_bridge
from jax._src.lax.lax import _array_copy
from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version
from jax._src.typing import Array
from jax._src.typing import Array, DLDeviceType
from jax._src.sharding import Sharding

DLPACK_VERSION = (0, 8)
MIN_DLPACK_VERSION = (0, 5)

# A set of dtypes that dlpack supports.
# Note: Make sure to use a "type", not a dtype instance, when looking up this set
# because their hashes are different.
Expand All @@ -43,45 +45,112 @@
SUPPORTED_DTYPES = SUPPORTED_DTYPES | frozenset({jnp.bool_})


# Mirror of dlpack.h enum
class DLDeviceType(enum.IntEnum):
kDLCPU = 1
kDLCUDA = 2
kDLROCM = 10
def _to_dlpack(x: Array, stream: int | Any | None,
src_device: xla_client.Device | None = None,
device: xla_client.Device | None = None,
copy: bool | None = None):

if src_device is None:
src_device, = x.devices()
if device and (src_device is None or device != src_device):
if copy is not None and not copy:
raise ValueError(
f"Specified {device=} which requires a copy since the source device "
f"is {repr(src_device)}, however copy=False. Set copy=True or "
"copy=None to perform the requested operation."
)
else:
arr = device_put(x, device)
else:
arr = _array_copy(x) if copy else x
return xla_client._xla.buffer_to_dlpack_managed_tensor(
arr.addressable_data(0), stream=stream
)

def to_dlpack(x: Array, take_ownership: bool = False,
stream: int | Any | None = None):
def to_dlpack(x: Array, stream: int | Any | None = None,
src_device: xla_client.Device | None = None,
dl_device: tuple[DLDeviceType, int] | None = None,
max_version: tuple[int, int] | None = None,
copy : bool | None = None):
"""Returns a DLPack tensor that encapsulates a :class:`~jax.Array` ``x``.
Args:
x: a :class:`~jax.Array`, on either CPU or GPU.
take_ownership: Deprecated. It is a no-op to set take_ownership. Will be
deleted in 01/2024.
stream: optional platform-dependent stream to wait on until the buffer is
ready. This corresponds to the `stream` argument to ``__dlpack__``
documented in https://dmlc.github.io/dlpack/latest/python_spec.html.
src_device: either a CPU or GPU :class:`~jax.Device`.
dl_device: a tuple of ``(dl_device_type, local_hardware_id)`` in DLPack
format e.g. as produced by ``__dlpack_device__``.
max_version: the maximum DLPack version that the consumer (i.e. caller of
``__dlpack__``) supports in the form of a 2-tuple of ``(major, minor)``.
This function is not guaranteed to return a capsule of version
``max_version``.
copy: a boolean indicating whether or not to copy the input. If
``copy=True`` then the function must always copy. When
``copy=False`` then the function must never copy, and must raise an error
when a copy is deemed necessary. If ``copy=None`` then the function must
avoid a copy if possible but may copy if needed.
Returns:
A dlpack PyCapsule object.
A DLPack PyCapsule object.
Note:
While JAX arrays are always immutable, dlpack buffers cannot be marked as
immutable, and it is possible for processes external to JAX to mutate them
in-place. If a dlpack buffer derived from a JAX array is mutated, it may
lead to undefined behavior when using the associated JAX array.
While JAX arrays are always immutable, ``DLPackManagedTensor`` buffers
cannot be marked as immutable, and it is possible for processes external
to JAX to mutate them in-place. If a DLPack buffer derived from a JAX array
is mutated, it may lead to undefined behavior when using the associated JAX
array. When JAX eventually supports ``DLManagedTensorVersioned``
(DLPack 1.0), it will be possible to specify that a buffer is read-only.
"""
if not isinstance(x, array.ArrayImpl):
raise TypeError("Argument to to_dlpack must be a jax.Array, "
f"got {type(x)}")
assert len(x.devices()) == 1
if take_ownership:
warnings.warn(
"take_ownership in to_dlpack is deprecated and it is a no-op."

device = None
dl_device_type, local_hardware_id = dl_device if dl_device else (None, None)
if dl_device_type:
try:
dl_device_platform = {
DLDeviceType.kDLCPU: "cpu",
DLDeviceType.kDLCUDA: "cuda",
DLDeviceType.kDLROCM: "rocm",
}[dl_device_type]
backend = xla_bridge.get_backend(dl_device_platform)
device = backend.device_from_local_hardware_id(local_hardware_id)
except TypeError:
# https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html
# recommends using BufferError.
raise BufferError(
"The device specification passed to to_dlpack contains an unsupported "
f"device type (DLDeviceType: {dl_device_type})")

# As new versions are adopted over time, we can maintain some legacy paths
# for compatability mediated through the max_version parameter.
# TODO(micky774): Deprecate default usage of DLPackManagedTensor when XLA
# supports DLManagedTensorVersioned (DLPack version 1.0) and repurpose the
# current _to_dlpack as a legacy path for (0,5) <= max_version < (1,0).
if max_version is None or max_version >= DLPACK_VERSION:
# Latest
return _to_dlpack(
x, stream=stream,
src_device=src_device,
device=device,
copy=copy
)
elif max_version >= MIN_DLPACK_VERSION:
# Oldest supported
return _to_dlpack(
x, stream=stream,
src_device=src_device,
device=device,
copy=copy
)
else:
raise BufferError(
f"JAX does not support any version below {MIN_DLPACK_VERSION} but "
f"version ({max_version}) was requested."
)
return xla_client._xla.buffer_to_dlpack_managed_tensor(
x.addressable_data(0), stream=stream
) # type: ignore

def _place_array(_arr, device, dlpack_device, copy):
if device and dlpack_device != device:
Expand Down
7 changes: 7 additions & 0 deletions jax/_src/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from collections.abc import Sequence
from typing import Any, Protocol, Union
import numpy as np
import enum

from jax._src.basearray import (
Array as Array,
Expand Down Expand Up @@ -83,3 +84,9 @@ def shape(self) -> Shape: ...
class DeprecatedArg:
def __repr__(self):
return "Deprecated"

# Mirror of dlpack.h enum
class DLDeviceType(enum.IntEnum):
kDLCPU = 1
kDLCUDA = 2
kDLROCM = 10
2 changes: 1 addition & 1 deletion jax/experimental/jax2tf/call_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def _arg_jax_to_tf(arg_jax):
if (isinstance(arg_jax, jax.Array) and
list(arg_jax.devices())[0].platform in _DLPACK_PLATFORMS and
arg_jax.dtype.type in dlpack.SUPPORTED_DTYPES):
arg_dlpack = jax.dlpack.to_dlpack(arg_jax, take_ownership=False)
arg_dlpack = jax.dlpack.to_dlpack(arg_jax)
return tf.experimental.dlpack.from_dlpack(arg_dlpack)
# The following avoids copies to the host on CPU, always for Array
# and even for ndarray if they are sufficiently aligned.
Expand Down
49 changes: 37 additions & 12 deletions tests/array_interoperability_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,23 +73,48 @@ def setUp(self):
@jtu.sample_product(
shape=all_shapes,
dtype=dlpack_dtypes,
gpu=[False, True],
copy=[False, True, None]
)
def testJaxRoundTrip(self, shape, dtype, gpu):
@jtu.run_on_devices("gpu")
def testJaxRoundTrip(self, shape, dtype, copy):
if xb.using_pjrt_c_api():
self.skipTest("DLPack support is incomplete in the PJRT C API") # TODO(skyewm)
rng = jtu.rand_default(self.rng())
np = rng(shape, dtype)
if gpu and jtu.test_device_matches(["cpu"]):
raise unittest.SkipTest("Skipping GPU test case on CPU")
device = jax.devices("gpu" if gpu else "cpu")[0]
x = jax.device_put(np, device)
dlpack = jax.dlpack.to_dlpack(x)
y = jax.dlpack.from_dlpack(dlpack)
self.assertEqual(y.devices(), {device})
self.assertAllClose(np.astype(x.dtype), y)

def _check_copy(x: jax.Array, y: jax.Array, expect_copy):
copied = x.unsafe_buffer_pointer() != y.unsafe_buffer_pointer()
assert copied == expect_copy, f"Expected {'a' if expect_copy else 'no'} copy"

# Check if the source device is preserved
x = jax.device_put(np, jax.devices("cpu")[0])
device = jax.devices("gpu")[0]
y = jax.device_put(x, device)
dl_device = y.__dlpack_device__()
dlpack = jax.dlpack.to_dlpack(y, copy=copy)
z = jax.dlpack.from_dlpack(dlpack)

self.assertEqual(z.devices(), {device})
self.assertAllClose(np.astype(x.dtype), z)
self.assertRaisesRegex(RuntimeError,
"DLPack tensor may be consumed at most once",
lambda: jax.dlpack.from_dlpack(dlpack))
"DLPack tensor may be consumed at most once",
lambda: jax.dlpack.from_dlpack(dlpack))

if shape in nonempty_array_shapes:
_check_copy(y, z, bool(copy))

# Check if the destination device can be specified
make_dlpack = lambda: x.__dlpack__(dl_device=dl_device, copy=copy)
if copy == False:
self.assertRaisesRegex(ValueError, "copy=False", make_dlpack)
return

z = jax.dlpack.from_dlpack(make_dlpack())
self.assertEqual(z.devices(), {device})
self.assertAllClose(x, z)

if shape in nonempty_array_shapes:
_check_copy(x, z, True)

@jtu.sample_product(
shape=all_shapes,
Expand Down

0 comments on commit e094a8b

Please sign in to comment.