Skip to content

Commit

Permalink
Raise a runtime error when trying to convert the jax.Array wrapped …
Browse files Browse the repository at this point in the history
…by `jax.core.Token` to a numpy array, as it is an internal implementation detail and the buffer has XLA token shape.

PiperOrigin-RevId: 632682906
  • Loading branch information
yueshengys authored and jax authors committed May 11, 2024
1 parent 20646eb commit 3b03e54
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from jax._src.interpreters import partial_eval as pe
from jax._src.lib import xla_client
from jax._src.lib import xla_extension
from jax._src.lib import xla_extension_version
import jax._src.util as jax_util
from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint
import jax.custom_batching
Expand Down Expand Up @@ -702,6 +703,14 @@ def test_trivial_computations(self):
self.assertNotEqual(z3.unsafe_buffer_pointer(), x1.unsafe_buffer_pointer())
self.assertEqual(z2, 1)

@unittest.skipIf(xla_extension_version < 264, "jaxlib version too old")
def test_print_token_buffer_error(self):
token = jax.lax.create_token()
with self.assertRaisesRegex(
RuntimeError, "Cannot convert a token-shape buffer to a numpy array."
):
token._buf._value

def test_trivial_computations_with_tokens(self):
@jit
def noop(arr, token):
Expand Down

0 comments on commit 3b03e54

Please sign in to comment.