You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Is there a best practice to verify that an object is an array, e.g. during input checks?
Comparing types seems unfeasable, since the function will be be called with abstract array types during tracing.
But also consider all jax types that might occur during tracing.
If it doesn't already exist, such an is_array function would be nice to have, especially if the abstract types used during racing should change in the future
The text was updated successfully, but these errors were encountered:
I think isinstance(x, jax.numpy.ndarray) does what you want (it'll return True for JAX arrays, including abstract ones, as well as NumPy ndarrays and subtypes of ndarrays)
Is there a best practice to verify that an object is an array, e.g. during input checks?
Comparing types seems unfeasable, since the function will be be called with abstract array types during tracing.
Essentially i want to do something like
But also consider all jax types that might occur during tracing.
If it doesn't already exist, such an
is_array
function would be nice to have, especially if the abstract types used during racing should change in the futureThe text was updated successfully, but these errors were encountered: