Skip to content

Commit

Permalink
[JAX] Improve support for DLPack tensors on CPU when a GPU is available.
Browse files Browse the repository at this point in the history
#5581

Previously the user had to provide the target backend explicitly. Now we supply both CPU and GPU backends to the C++ code so it can choose based on the metadata of the DLPack tensor.

PiperOrigin-RevId: 380795192
  • Loading branch information
hawkinsp authored and jax authors committed Jun 22, 2021
1 parent 8e14d53 commit f885366
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 15 deletions.
22 changes: 16 additions & 6 deletions jax/_src/dlpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from jax import core
from jax import numpy as jnp
from jax.interpreters import xla
import jax.lib
from jax.lib import xla_client
from jax.lib import xla_bridge

Expand Down Expand Up @@ -50,13 +51,22 @@ def from_dlpack(dlpack, backend=None):
Args:
dlpack: a DLPack tensor, on either CPU or GPU.
backend: experimental, optional: the platform on which `dlpack` lives.
backend: deprecated, do not use.
"""
# TODO(phawkins): ideally the user wouldn't need to provide a backend and we
# would be able to figure it out from the DLPack.
backend = backend or xla_bridge.get_backend()
client = getattr(backend, "client", backend)
buf = xla_client._xla.dlpack_managed_tensor_to_buffer(dlpack, client)
if jax.lib._xla_extension_version >= 25:
cpu_backend = xla_bridge.get_backend("cpu")
try:
gpu_backend = xla_bridge.get_backend("gpu")
except RuntimeError:
gpu_backend = None
buf = xla_client._xla.dlpack_managed_tensor_to_buffer(
dlpack, cpu_backend, gpu_backend)
else:
# TODO(phawkins): drop the backend argument after deleting this case.
backend = backend or xla_bridge.get_backend()
client = getattr(backend, "client", backend)
buf = xla_client._xla.dlpack_managed_tensor_to_buffer(dlpack, client)

xla_shape = buf.xla_shape()
assert not xla_shape.is_tuple()
aval = core.ShapedArray(xla_shape.dimensions(), xla_shape.numpy_dtype())
Expand Down
4 changes: 1 addition & 3 deletions jax/experimental/jax2tf/call_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from jax import tree_util
from jax._src import util
from jax.interpreters import xla
from jax.lib import xla_bridge
from jax.lib import xla_client
from . import jax2tf as jax2tf_internal

Expand Down Expand Up @@ -200,8 +199,7 @@ def _res_tf_to_jax(res_tf: TfVal, out_aval: core.AbstractValue):
res_jax_platform = res_tf_platform.lower()
if res_jax_platform in _DLPACK_PLATFORMS:
res_dlpack = tf.experimental.dlpack.to_dlpack(res_tf)
return jax.dlpack.from_dlpack(
res_dlpack, backend=xla_bridge.get_backend(res_jax_platform))
return jax.dlpack.from_dlpack(res_dlpack)

return jnp.asarray(np.asarray(res_tf))

Expand Down
22 changes: 16 additions & 6 deletions tests/array_interoperability_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,20 +66,30 @@ def setUp(self):
self.skipTest("DLPack not supported on TPU")

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_take_ownership={}".format(
{"testcase_name": "_{}_take_ownership={}_gpu={}".format(
jtu.format_shape_dtype_string(shape, dtype),
take_ownership),
"shape": shape, "dtype": dtype, "take_ownership": take_ownership}
take_ownership, gpu),
"shape": shape, "dtype": dtype, "take_ownership": take_ownership,
"gpu": gpu}
for shape in all_shapes
for dtype in dlpack_dtypes
for take_ownership in [False, True]))
def testJaxRoundTrip(self, shape, dtype, take_ownership):
for take_ownership in [False, True]
for gpu in [False, True]))
def testJaxRoundTrip(self, shape, dtype, take_ownership, gpu):
rng = jtu.rand_default(self.rng())
np = rng(shape, dtype)
x = jnp.array(np)
if gpu and jax.default_backend() == "cpu":
raise unittest.SkipTest("Skipping GPU test case on CPU")
if (not gpu and jax.default_backend() == "gpu" and
jax.lib._xla_extension_version < 25):
raise unittest.SkipTest("Mixed CPU/GPU dlpack support requires jaxlib "
"0.1.68 or newer")
device = jax.devices("gpu" if gpu else "cpu")[0]
x = jax.device_put(np, device)
dlpack = jax.dlpack.to_dlpack(x, take_ownership=take_ownership)
self.assertEqual(take_ownership, x.device_buffer.is_deleted())
y = jax.dlpack.from_dlpack(dlpack)
self.assertEqual(y.device(), device)
self.assertAllClose(np.astype(x.dtype), y)

self.assertRaisesRegex(RuntimeError,
Expand Down

0 comments on commit f885366

Please sign in to comment.