Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

isinstance(numpy.zeros(1), jax.numpy.ndarray) returns True #2014

Closed
honnibal opened this issue Jan 16, 2020 · 20 comments
Closed

isinstance(numpy.zeros(1), jax.numpy.ndarray) returns True #2014

honnibal opened this issue Jan 16, 2020 · 20 comments
Labels
bug Something isn't working question Questions for the JAX team

Comments

@honnibal
Copy link

numpy.ndarray instances currently return True if you run an isinstance check against jax.numpy.ndarray. I guess I see how this happens: I think Jax doesn't actually use that type, so it's maybe the actual one from numpy? It's a bit of a hassle when you're checking the array provenances though.

def test_numpy_ndarray_is_not_instance_of_jax_numpy_ndarray():
    assert not isinstance(numpy.zeros(1), jax.numpy.ndarray)

Btw, what's the preferred way to convert data from Jax to numpy? I've found jax.device_get() by poking around, but I don't think it's documented.

Thanks for the great project!

@mattjj
Copy link
Member

mattjj commented Jan 16, 2020

Thanks for the questions, and the positivity!

Btw, what's the preferred way to convert data from Jax to numpy?

I usually import numpy as onp and then just use onp.array(jax_array) (or asarray), i.e. the usual way you use NumPy to turn something into an ndarray. That's as efficient as possible.

numpy.ndarray instances currently return True if you run an isinstance check against jax.numpy.ndarray.

This was a conscious choice, since we thought we wanted jax.numpy to be as close as possible to a drop-in replacement for numpy, where you just import jax.numpy instead of numpy and everything still works. But maybe it's more common, and more in line with our philosophy of explicitness, to imagine that users will import both numpy and jax.numpy and want to keep the two straight, e.g. with the kinds of isinstance checks you mention.

I think this line, together with the fact that onp.ndarray is included in _arraylike_types, controls this behavior. (See also the comment above those lines.) Maybe we should consider changing this.

Thoughts?

@mattjj mattjj added the question Questions for the JAX team label Jan 16, 2020
@mattjj
Copy link
Member

mattjj commented Jan 16, 2020

@shoyer @hawkinsp I'm keen to get your thoughts in particular!

@shoyer
Copy link
Member

shoyer commented Jan 16, 2020

I agree, as I wrote over in #1081 (comment), this behavior surprised me. This sort of dynamic inheritance is rarely used; my guess is that it could lead to bugs. My vote would definitely be for encouraging separate/explicit jax.numpy and numpy imports.

@honnibal
Copy link
Author

honnibal commented Jan 16, 2020

@mattjj Thanks for the quick reply!

I would've expected the following:

jax_arr = jax.numpy.zeros(1)
assert isinstance(jax_arr, numpy.ndarray)
np_arr = numpy.ndarray(1)
assert not isinstance(np_arr, jax.numpy.ndarray)

This is the normal inheritancish behaviour: jax.numpy.ndarray is the new impostor, and it can claim to be a type of numpy.ndarray (even when that's not literally true). But it's kind of weird to trick numpy.ndarray into believing it's a type of jax.numpy.ndarray, since that's really not at all true.

Finally, I think it can be convenient to do something like from jax import numpy in a quick script, and it's nice for that to work fine when the program only has to deal with jax arrays or numpy arrays, but not both. But if a user writes that import in a context where they'll have a mixture of the two types, their code will have all sorts of bugs, and I think that's not really jax's fault? So I think having isinstance(arr, numpy.ndarray) return False will be the least of their problems. Like, yeah, they might have written code expecting that to return True --- but the actual truth is False, and they're better off knowing it.

@mattjj
Copy link
Member

mattjj commented Jan 16, 2020

These arguments are pretty convincing IMO. I'll ping our internal chat room to see if there's any dissent, and if not we should fix this.

@mattjj
Copy link
Member

mattjj commented Jan 16, 2020

@honnibal did you mean for the last line to have a not in it?

@shoyer
Copy link
Member

shoyer commented Jan 16, 2020

I think pretending that inheritance goes either way is surprising.

@mattjj
Copy link
Member

mattjj commented Jan 16, 2020

One implementation detail, not necessarily relevant to the question of how things should behave but maybe useful as an explanation, is that jax.numpy.ndarray is not actually our array type; our array type is jax.interpreters.DeviceArray. The jax.numpy.ndarray value is there only for isinstance checks (which we can configure to act however we want, as we should figure out in this thread).

@mattjj mattjj added the bug Something isn't working label Jan 16, 2020
@honnibal
Copy link
Author

@mattjj Oops, yes! Fixed.

@yunlongxu-artemis
Copy link

sorry for reviving a closed issue, it seems the behaviour @honnibal mentioned is still the same, right?

assert isinstance(np.zeros(3), jnp.ndarray)

based on what @mattjj suggested, should we be doing the following instead if we want to check if an array is indeed a jax array instead of a plain numpy array?

assert not isinstance(jnp.zeros(3), jax.interpreters.xla._DeviceArray)

thanks!

@mattjj
Copy link
Member

mattjj commented Mar 28, 2021

Yes, we still haven't fixed this :/

I think a better check is just isinstance(x, np.ndarray) to see if x is a regular numpy.ndarray. That is, while isinstatnce(numpy_ndarray, jnp.ndarray) is surprisingly True, it's not the case that isinstance(jax_array, np.ndarray) is True.

@jakevdp
Copy link
Collaborator

jakevdp commented Mar 28, 2021

should we be doing the following instead if we want to check if an array is indeed a jax array instead of a plain numpy array?

assert not isinstance(jnp.zeros(3), jax.interpreters.xla._DeviceArray)

Can you say more about why you're interested in checking object types here? For example, checking object identity this way might cause JIT-compiled code to act unexpectedly, because device arrays are replaced with tracers at compile time.

@yunlongxu-artemis
Copy link

should we be doing the following instead if we want to check if an array is indeed a jax array instead of a plain numpy array?

assert not isinstance(jnp.zeros(3), jax.interpreters.xla._DeviceArray)

Can you say more about why you're interested in checking object types here? For example, checking object identity this way might cause JIT-compiled code to act unexpectedly, because device arrays are replaced with tracers at compile time.

Sure. The insights you provided about JIT-compiled code is already very useful, thanks. Let's see if my use-case makes sense.

Imagine I have a function foo:

def foo(array_a: Union[np.ndarray, jnp.ndarray], array_b: Union[np.ndarray, jnp.ndarray]) -> Union[np.ndarray, jnp.ndarray]):
     """When given np array only, returns np array, when given jnp array only, returns jnp array"""

     # Imagine the two functions below work with both numpy and jax.numpy arrays.
     a = do_something(array_a)
     b = do_something_else(array_b)

    # here's the type check, and for the purpose of discussion, we only check one input.
    if isinstance(array_a, np.ndarray):
        concat = np.concatenate
    else:
        concat = jnp.concatenate

    return concat([a, b])

why do I need to do the type check? I want to use the same foo for two users:

User A: gives foo numpy array only, expects numpy array back, and does an in-place array element update with the returned value (which breaks if we use jnp.concatenate in foo)

User B: gives foo jax numpy array only, expects jax numpy array back, and then does autograd on the function (which breaks if we use np.concatenate in foo)

Does this use-case make sense or is this something that I should avoid?

@jakevdp
Copy link
Collaborator

jakevdp commented Mar 28, 2021

I see – One thing to be careful of in an approach like that is that JAX functions do not always return device arrays: sometimes they can return numpy arrays (because internally, numpy values are often used for constants). As a simple example:

type(jax.grad(lambda x: x)(1.0))
# numpy.ndarray

So if you have too much dependence on the type of value in a pipeline, you might end up with surprises in corner cases.

@shoyer
Copy link
Member

shoyer commented Mar 28, 2021

I see – One thing to be careful of in an approach like that is that JAX functions do not always return device arrays: sometimes they can return numpy arrays (because internally, numpy values are often used for constants). As a simple example:

This is arguably a bug though, right?

@shoyer
Copy link
Member

shoyer commented Mar 28, 2021

I think the preferred option for checking for JAX arrays would be:

import jax.numpy as jnp
import numpy as np

def is_jax_array(x):
  return isinstance(x, jnp.ndarray) and not isinstance(x, np.ndarray)

This version is also future proof: when jnp.ndarray no longer pretends to super-class numpy arrays (which hopefully will happen soon!), the second check can simply be removed.

@jakevdp
Copy link
Collaborator

jakevdp commented Mar 28, 2021

This is arguably a bug though, right?

Maybe? I don't think we ever promise that returned values will be DeviceArrays, and our test scaffold doesn't have any isinstance checks on outputs. So if it's a bug, it's one we've not particularly cared about previously.

@yunlongxu-artemis
Copy link

@shoyer thanks, I'll do that for now :-)

@hawkinsp
Copy link
Member

@jakevdp We never wrote down that promise, but yes, I would consider it a bug to ever return a numpy array from a JAX function. I did a sequence of changes to make sure that jax.numpy never does that. The main thing that convinced me was that promotion semantics are different.

@jakevdp
Copy link
Collaborator

jakevdp commented Jun 21, 2022

I think the original issue has been fixed now (See Changelog: v0.2.21):

In [1]: import numpy as np 
   ...: import jax.numpy as jnp 
   ...: isinstance(np.zeros(1), jnp.ndarray)                                                                                                                                                                                                                                                                      
Out[1]: False

Additionally, the other issue I highlighted above has been fixed as well; I'm not sure precisely when:

In [2]: import jax
   ...: type(jax.grad(lambda x: x)(1.0))                                                                                                                                                                                                                                                                         
Out[2]: jaxlib.xla_extension.DeviceArray

@jakevdp jakevdp closed this as completed Jun 21, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working question Questions for the JAX team
Projects
None yet
Development

No branches or pull requests

6 participants