Skip to content

Commit

Permalink
[XLA:Python] Fix __cuda_array_interface__.
Browse files Browse the repository at this point in the history
Adds a test for __cuda_array_interface__ that does not depend on cupy.

Fixes #16440

PiperOrigin-RevId: 541965361
  • Loading branch information
hawkinsp authored and jax authors committed Jun 20, 2023
1 parent afcd1a7 commit 0ec03db
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Expand Up @@ -37,6 +37,10 @@ Remember to align the itemized text with the first line of an item within a list

## jaxlib 0.4.13

* Bug fixes
* `__cuda_array_interface__` was broken in previous jaxlib versions and is now
fixed ({jax-issue}`16440`).

## jax 0.4.12 (June 8, 2023)

* Changes
Expand Down
27 changes: 24 additions & 3 deletions tests/array_interoperability_test.py
Expand Up @@ -21,6 +21,7 @@
import jax.dlpack
import jax.numpy as jnp
from jax._src import test_util as jtu
from jax._src.lib import xla_extension_version

import numpy as np

Expand Down Expand Up @@ -48,6 +49,8 @@
[dt for dt in jax.dlpack.SUPPORTED_DTYPES if dt != jnp.bfloat16],
key=lambda x: x.__name__)

cuda_array_interface_dtypes = [dt for dt in dlpack_dtypes if dt != jnp.bfloat16]

nonempty_nonscalar_array_shapes = [(4,), (3, 4), (2, 3, 4)]
empty_array_shapes = []
empty_array_shapes += [(0,), (0, 4), (3, 0),]
Expand Down Expand Up @@ -162,6 +165,7 @@ def testJaxToNumpy(self, shape, dtype):
self.assertAllClose(x_np, x_jax)


@unittest.skipIf(xla_extension_version < 163, "Test requires jaxlib 0.4.13")
class CudaArrayInterfaceTest(jtu.JaxTestCase):

def setUp(self):
Expand All @@ -171,12 +175,29 @@ def setUp(self):

@jtu.sample_product(
shape=all_shapes,
dtype=dlpack_dtypes,
dtype=cuda_array_interface_dtypes,
)
def testCudaArrayInterfaceWorks(self, shape, dtype):
rng = jtu.rand_default(self.rng())
x = rng(shape, dtype)
y = jnp.array(x)
a = y.__cuda_array_interface__
self.assertEqual(shape, a["shape"])
self.assertEqual(x.__array_interface__["typestr"], a["typestr"])

def testCudaArrayInterfaceBfloat16Fails(self):
rng = jtu.rand_default(self.rng())
x = rng((2, 2), jnp.bfloat16)
y = jnp.array(x)
with self.assertRaisesRegex(RuntimeError, ".*not supported for bfloat16.*"):
_ = y.__cuda_array_interface__

@jtu.sample_product(
shape=all_shapes,
dtype=cuda_array_interface_dtypes,
)
@unittest.skipIf(not cupy, "Test requires CuPy")
def testJaxToCuPy(self, shape, dtype):
if dtype == jnp.bfloat16:
raise unittest.SkipTest("cupy does not support bfloat16")
rng = jtu.rand_default(self.rng())
x = rng(shape, dtype)
y = jnp.array(x)
Expand Down

0 comments on commit 0ec03db

Please sign in to comment.