New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
isinstance(numpy.zeros(1), jax.numpy.ndarray) returns True #2014
Comments
Thanks for the questions, and the positivity!
I usually
This was a conscious choice, since we thought we wanted I think this line, together with the fact that Thoughts? |
I agree, as I wrote over in #1081 (comment), this behavior surprised me. This sort of dynamic inheritance is rarely used; my guess is that it could lead to bugs. My vote would definitely be for encouraging separate/explicit |
@mattjj Thanks for the quick reply! I would've expected the following: jax_arr = jax.numpy.zeros(1)
assert isinstance(jax_arr, numpy.ndarray)
np_arr = numpy.ndarray(1)
assert not isinstance(np_arr, jax.numpy.ndarray) This is the normal inheritancish behaviour: Finally, I think it can be convenient to do something like |
These arguments are pretty convincing IMO. I'll ping our internal chat room to see if there's any dissent, and if not we should fix this. |
@honnibal did you mean for the last line to have a |
I think pretending that inheritance goes either way is surprising. |
One implementation detail, not necessarily relevant to the question of how things should behave but maybe useful as an explanation, is that |
@mattjj Oops, yes! Fixed. |
sorry for reviving a closed issue, it seems the behaviour @honnibal mentioned is still the same, right? assert isinstance(np.zeros(3), jnp.ndarray) based on what @mattjj suggested, should we be doing the following instead if we want to check if an array is indeed a jax array instead of a plain numpy array? assert not isinstance(jnp.zeros(3), jax.interpreters.xla._DeviceArray) thanks! |
Yes, we still haven't fixed this :/ I think a better check is just |
Can you say more about why you're interested in checking object types here? For example, checking object identity this way might cause JIT-compiled code to act unexpectedly, because device arrays are replaced with tracers at compile time. |
Sure. The insights you provided about JIT-compiled code is already very useful, thanks. Let's see if my use-case makes sense. Imagine I have a function def foo(array_a: Union[np.ndarray, jnp.ndarray], array_b: Union[np.ndarray, jnp.ndarray]) -> Union[np.ndarray, jnp.ndarray]):
"""When given np array only, returns np array, when given jnp array only, returns jnp array"""
# Imagine the two functions below work with both numpy and jax.numpy arrays.
a = do_something(array_a)
b = do_something_else(array_b)
# here's the type check, and for the purpose of discussion, we only check one input.
if isinstance(array_a, np.ndarray):
concat = np.concatenate
else:
concat = jnp.concatenate
return concat([a, b]) why do I need to do the type check? I want to use the same User A: gives User B: gives Does this use-case make sense or is this something that I should avoid? |
I see – One thing to be careful of in an approach like that is that JAX functions do not always return device arrays: sometimes they can return numpy arrays (because internally, numpy values are often used for constants). As a simple example: type(jax.grad(lambda x: x)(1.0))
# numpy.ndarray So if you have too much dependence on the type of value in a pipeline, you might end up with surprises in corner cases. |
This is arguably a bug though, right? |
I think the preferred option for checking for JAX arrays would be: import jax.numpy as jnp
import numpy as np
def is_jax_array(x):
return isinstance(x, jnp.ndarray) and not isinstance(x, np.ndarray) This version is also future proof: when |
Maybe? I don't think we ever promise that returned values will be DeviceArrays, and our test scaffold doesn't have any isinstance checks on outputs. So if it's a bug, it's one we've not particularly cared about previously. |
@shoyer thanks, I'll do that for now :-) |
@jakevdp We never wrote down that promise, but yes, I would consider it a bug to ever return a numpy array from a JAX function. I did a sequence of changes to make sure that |
I think the original issue has been fixed now (See Changelog: v0.2.21): In [1]: import numpy as np
...: import jax.numpy as jnp
...: isinstance(np.zeros(1), jnp.ndarray)
Out[1]: False Additionally, the other issue I highlighted above has been fixed as well; I'm not sure precisely when: In [2]: import jax
...: type(jax.grad(lambda x: x)(1.0))
Out[2]: jaxlib.xla_extension.DeviceArray |
numpy.ndarray
instances currently returnTrue
if you run anisinstance
check againstjax.numpy.ndarray
. I guess I see how this happens: I think Jax doesn't actually use that type, so it's maybe the actual one from numpy? It's a bit of a hassle when you're checking the array provenances though.Btw, what's the preferred way to convert data from Jax to numpy? I've found
jax.device_get()
by poking around, but I don't think it's documented.Thanks for the great project!
The text was updated successfully, but these errors were encountered: