Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Checking if an object is an array #2115

Closed
Jakob-Unfried opened this issue Jan 29, 2020 · 2 comments
Closed

Checking if an object is an array #2115

Jakob-Unfried opened this issue Jan 29, 2020 · 2 comments
Labels
question Questions for the JAX team

Comments

@Jakob-Unfried
Copy link
Contributor

Jakob-Unfried commented Jan 29, 2020

Is there a best practice to verify that an object is an array, e.g. during input checks?
Comparing types seems unfeasable, since the function will be be called with abstract array types during tracing.

Essentially i want to do something like

import jax.numpy as np

def is_array(obj):
   return type(obj) == np.ndarray

But also consider all jax types that might occur during tracing.

If it doesn't already exist, such an is_array function would be nice to have, especially if the abstract types used during racing should change in the future

@jekbradbury
Copy link
Contributor

I think isinstance(x, jax.numpy.ndarray) does what you want (it'll return True for JAX arrays, including abstract ones, as well as NumPy ndarrays and subtypes of ndarrays)

@mattjj
Copy link
Member

mattjj commented Feb 14, 2020

@jekbradbury 's answer is the right one; though we might revise it (see #2014) right now you should check isinstance(x, jax.numpy.ndarray).

Hope that answers your question!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Questions for the JAX team
Projects
None yet
Development

No branches or pull requests

3 participants