New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Should JAX deprecate indexing with lists? #4564
Comments
👍 we should definitely deprecate this. We can probably be even more aggressive about removing support for this given the experimental nature of JAX. In the long term I would be inclined to make all indexing with lists in JAX an error. Even the case where lists are treated like an array is at odds with how JAX disallows lists as arguments to functions like |
+1. Let's just disallow this now. |
That messed up some things in TensorFlow... |
Oh, gcucurull/jax-gcn#1 is about that kinda |
Since numpy 1.16, indexing with a list in place of a tuple has led to a
FutureWarning
(See numpy/numpy#9686 for a discussion of the rationale for this):As mentioned in the warning, the current behavior treats the indices as identical to a tuple:
while in the future, the indices will be treated as an array:
JAX currently implements the old, deprecated behavior, without any warning:
This is setting us up for a future where numpy and JAX have different indexing semantics for lists of indices. I would propose that we follow numpy and start warning about this behavior now, so that when a numpy release finally does deprecate this indexing behavior, jax will be ready to immediately follow suit.
Thoughts?
The text was updated successfully, but these errors were encountered: