Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

What is the correct way to check that something is array-like? #8701

Closed
NeilGirdhar opened this issue Nov 25, 2021 · 4 comments
Closed

What is the correct way to check that something is array-like? #8701

NeilGirdhar opened this issue Nov 25, 2021 · 4 comments
Labels
bug Something isn't working

Comments

@NeilGirdhar
Copy link
Contributor

I guess that the new version of Jax has it so that jax.numpy.ndarray no longer inherits from np.ndarray, which is great. Unfortunately, other packages like Seaborn and Pandas are relying on isinstance(x, np.ndarray) in a few places. This causes their behavior to change, and forces me to cast to np.ndarray.

I was just wondering, what should they be doing instead? Is it looking for the array interface? But jax.numpy.ndarray doesn't expose that. If it's something else, I'll post the appropriate issues on their trackers. If there's no way to check yet, that would be good to know.

Thanks.

@NeilGirdhar NeilGirdhar added the bug Something isn't working label Nov 25, 2021
@jakevdp
Copy link
Collaborator

jakevdp commented Nov 25, 2021

i don't believe JAX DeviceArray object would ever have satisfied isinstance(x, np.ndarray) (there was a change recently that made it so that numpy array objects would not satisfy isinstance(x, jnp.ndarray), but that seems tangential to your question).

That said, the most permissive way for libraries to handle this, I think, would be to call np.asarray(x) on inputs, and let numpy decide if it can convert the result.

If you want something a bit more restrictive but that would still accept explicitly array-like objects that aren't instances of np.ndarray, perhaps something like this would be sufficient:

def is_arraylike(x):
  return hasattr(x, '__array__') or hasattr(x, '__array_interface__')

@NeilGirdhar
Copy link
Contributor Author

NeilGirdhar commented Nov 25, 2021

@jakevdp Thank you. I only spent about 15 minutes on this, but when I updated to Jax master, I started getting crashes in pandas/seaborn that were fixed by converting to numpy array. It's possible that certain operations that would previously have forced a numpy array are now supported by Jax, and so arrays remain Jax arrays longer than they used to?

If I run into this again, I'll submit an issue to them with a link here.

@jakevdp
Copy link
Collaborator

jakevdp commented Nov 25, 2021

One recent change that may be relevant is that iteration over jax arrays now returns jax arrays, where it used to return numpy arrays. So if you had code like this:

import jax.numpy as jnp
import numpy as np
x = list(jnp.ones((2, 3)))
isinstance(x[0], np.ndarray)

it would previously have returned True, and now returns False (see #8043).

@NeilGirdhar
Copy link
Contributor Author

NeilGirdhar commented Nov 25, 2021

Good thinking. That could be it. When I looked at their code, they were doing things like iterating and checking types. That's probably their fallback when the object is not an ndarray.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants