From 7c7c94c8ddba1737931f3f95575682f20f37f0d2 Mon Sep 17 00:00:00 2001 From: George Necula Date: Fri, 7 Oct 2022 14:13:20 +0300 Subject: [PATCH] Expand support for __jax_array__ in jnp.array. This relates to the long discussion in #4725 and #10065. --- jax/_src/numpy/lax_numpy.py | 2 ++ tests/api_test.py | 7 +++++++ 2 files changed, 9 insertions(+) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index c3af0270b598..d36e3e88258a 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -1862,6 +1862,8 @@ def array(object: Any, dtype: Optional[DTypeLike] = None, copy: bool = True, if isinstance(object, (bool, int, float, complex)): _ = dtypes.coerce_to_array(object, dtype) + object = tree_map(lambda leaf: leaf.__jax_array__() if hasattr(leaf, "__jax_array__") else leaf, + object) leaves = tree_leaves(object) if dtype is None: # Use lattice_result_type rather than result_type to avoid canonicalization. diff --git a/tests/api_test.py b/tests/api_test.py index 4241428a46d4..c9f590f0e78e 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -3574,6 +3574,13 @@ def __jax_array__(self): for f in [jnp.isscalar, jnp.size, jnp.shape, jnp.dtype]: self.assertEqual(f(x), f(a)) + x = AlexArray(jnp.array(1)) + a1 = jnp.array(x) + self.assertAllClose(1, a1) + + a2 = jnp.array(((x, x), [x, x])) + self.assertAllClose(np.array(((1, 1), (1, 1))), a2) + def test_constant_handler_mro(self): # https://github.com/google/jax/issues/6129