diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 557a47acbbf2..163a5da9b415 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -437,6 +437,9 @@ def convert_element_type(operand: Array, new_dtype: DType = None, msg = "Casting complex values to real discards the imaginary part" warnings.warn(msg, np.ComplexWarning, stacklevel=2) + if hasattr(operand, '__jax_array__'): + operand = operand.__jax_array__() + if not isinstance(operand, (core.Tracer, xla.DeviceArray)): return _device_put_raw(np.asarray(operand, dtype=new_dtype), weak_type=new_weak_type) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 05876db30497..cf8fba4a3263 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -300,9 +300,11 @@ def _result_dtype(op, *args): return _dtype(op(*args)) -def _arraylike(x): return isinstance(x, ndarray) or isscalar(x) +def _arraylike(x): + return isinstance(x, ndarray) or isscalar(x) or hasattr(x, '__jax_array__') + def _check_arraylike(fun_name, *args): - """Check if all args fit JAX's definition of arraylike (ndarray or scalar).""" + """Check if all args fit JAX's definition of arraylike.""" assert isinstance(fun_name, str), f"fun_name must be a string. Got {fun_name}" if _any(not _arraylike(arg) for arg in args): pos, arg = next((i, arg) for i, arg in enumerate(args) diff --git a/jax/core.py b/jax/core.py index 903ad5ad228c..4ae5040ec320 100644 --- a/jax/core.py +++ b/jax/core.py @@ -859,6 +859,8 @@ def concrete_aval(x): for typ in type(x).mro(): handler = pytype_aval_mappings.get(typ) if handler: return handler(x) + if hasattr(x, '__jax_array__'): + return concrete_aval(x.__jax_array__()) raise TypeError(f"{type(x)} is not a valid JAX type") diff --git a/jax/dtypes.py b/jax/dtypes.py index 11aff4d3dab6..8fc0fdafa9f2 100644 --- a/jax/dtypes.py +++ b/jax/dtypes.py @@ -280,7 +280,10 @@ def is_python_scalar(x): try: return x.aval.weak_type and np.ndim(x) == 0 except AttributeError: - return type(x) in python_scalar_dtypes + if hasattr(x, '__jax_array__'): + return is_python_scalar(x.__jax_array__()) + else: + return type(x) in python_scalar_dtypes def dtype(x): if type(x) in python_scalar_dtypes: diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index 4be8b3c223ca..e829f6134926 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -122,11 +122,14 @@ def array_result_handler(device: Optional[Device], aval: core.ShapedArray): } def device_put(x, device: Optional[Device] = None) -> Tuple[Any]: - x = canonicalize_dtype(x) - try: - return device_put_handlers[type(x)](x, device) - except KeyError as err: - raise TypeError(f"No device_put handler for type: {type(x)}") from err + handler = device_put_handlers.get(type(x)) + if handler: + x = canonicalize_dtype(x) + return handler(x, device) + elif hasattr(x, '__jax_array__'): + return device_put(x.__jax_array__(), device) + else: + raise TypeError(f"No device_put handler for type: {type(x)}") def _device_put_array(x, device: Optional[Device]): backend = xb.get_device_backend(device) @@ -151,6 +154,8 @@ def canonicalize_dtype(x): for typ in typ.mro(): handler = canonicalize_dtype_handlers.get(typ) if handler: return handler(x) + if hasattr(x, '__jax_array__'): + return canonicalize_dtype(x.__jax_array__()) raise TypeError(f"No canonicalize_dtype handler for type: {type(x)}") def _canonicalize_ndarray_dtype(x): @@ -173,6 +178,8 @@ def abstractify(x) -> core.AbstractValue: for typ in typ.mro(): aval_fn = pytype_aval_mappings.get(typ) if aval_fn: return aval_fn(x) + if hasattr(x, '__jax_array__'): + return abstractify(x.__jax_array__()) raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type") def _make_abstract_python_scalar(typ, _): diff --git a/tests/api_test.py b/tests/api_test.py index d92dd96a9583..1dc59755068d 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -2068,6 +2068,32 @@ def test_linearize_aval_error(self): with self.assertRaisesRegex(ValueError, "tangent values inconsistent"): f_jvp(np.ones(2, np.int32)) + def test_dunder_jax_array(self): + # https://github.com/google/jax/pull/4725 + + class AlexArray: + def __init__(self, jax_val): + self.jax_val = jax_val + def __jax_array__(self): + return self.jax_val + dtype = property(lambda self: self.jax_val.dtype) + shape = property(lambda self: self.jax_val.shape) + + x = AlexArray(jnp.array([1., 2., 3.])) + y = jnp.sin(x) + self.assertAllClose(y, jnp.sin(jnp.array([1., 2., 3.]))) + y = api.grad(api.jit(lambda x: jnp.sin(x).sum()))(x) + self.assertAllClose(y, jnp.cos(jnp.array([1., 2., 3.]))) + + x = AlexArray(jnp.array([[1., 2., 3.]])) + y = api.pmap(jnp.sin)(x) + self.assertAllClose(y, jnp.sin(jnp.array([[1., 2., 3.]]))) + + x = jnp.array(1) + a = AlexArray(x) + for f in [jnp.isscalar, jnp.size, jnp.shape, jnp.dtype]: + self.assertEqual(f(x), f(a)) + class RematTest(jtu.JaxTestCase):