Skip to content

Commit

Permalink
Expand support for __jax_array__ in jnp.array.
Browse files Browse the repository at this point in the history
This relates to the long discussion in google#4725 and google#10065.
  • Loading branch information
gnecula committed Oct 7, 2022
1 parent 6c70e4d commit 7c7c94c
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 0 deletions.
2 changes: 2 additions & 0 deletions jax/_src/numpy/lax_numpy.py
Expand Up @@ -1862,6 +1862,8 @@ def array(object: Any, dtype: Optional[DTypeLike] = None, copy: bool = True,
if isinstance(object, (bool, int, float, complex)):
_ = dtypes.coerce_to_array(object, dtype)

object = tree_map(lambda leaf: leaf.__jax_array__() if hasattr(leaf, "__jax_array__") else leaf,
object)
leaves = tree_leaves(object)
if dtype is None:
# Use lattice_result_type rather than result_type to avoid canonicalization.
Expand Down
7 changes: 7 additions & 0 deletions tests/api_test.py
Expand Up @@ -3574,6 +3574,13 @@ def __jax_array__(self):
for f in [jnp.isscalar, jnp.size, jnp.shape, jnp.dtype]:
self.assertEqual(f(x), f(a))

x = AlexArray(jnp.array(1))
a1 = jnp.array(x)
self.assertAllClose(1, a1)

a2 = jnp.array(((x, x), [x, x]))
self.assertAllClose(np.array(((1, 1), (1, 1))), a2)

def test_constant_handler_mro(self):
# https://github.com/google/jax/issues/6129

Expand Down

0 comments on commit 7c7c94c

Please sign in to comment.