Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Should JAX deprecate indexing with lists? #4564

Closed
jakevdp opened this issue Oct 13, 2020 · 5 comments · Fixed by #4641
Closed

Should JAX deprecate indexing with lists? #4564

jakevdp opened this issue Oct 13, 2020 · 5 comments · Fixed by #4641
Assignees
Labels
enhancement New feature or request

Comments

@jakevdp
Copy link
Collaborator

jakevdp commented Oct 13, 2020

Since numpy 1.16, indexing with a list in place of a tuple has led to a FutureWarning (See numpy/numpy#9686 for a discussion of the rationale for this):

>>> import numpy as np
>>> x = np.arange(6).reshape(2, 3)
>>> idx = [[0], [1]]
>>> x[idx]
FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be interpreted as an array index, `arr[np.array(seq)]`, which will result either in an error or a different result.
array([1])

As mentioned in the warning, the current behavior treats the indices as identical to a tuple:

>>> x[tuple(idx)]
array([1])

while in the future, the indices will be treated as an array:

>>> x[np.array(idx)]
array([[[0, 1, 2]],
       [[3, 4, 5]]])

JAX currently implements the old, deprecated behavior, without any warning:

>>> import jax.numpy as jnp
>>> jnp.array(x)[idx]
DeviceArray([1], dtype=int32)

This is setting us up for a future where numpy and JAX have different indexing semantics for lists of indices. I would propose that we follow numpy and start warning about this behavior now, so that when a numpy release finally does deprecate this indexing behavior, jax will be ready to immediately follow suit.

Thoughts?

@shoyer
Copy link
Member

shoyer commented Oct 14, 2020

👍 we should definitely deprecate this. We can probably be even more aggressive about removing support for this given the experimental nature of JAX.

In the long term I would be inclined to make all indexing with lists in JAX an error. Even the case where lists are treated like an array is at odds with how JAX disallows lists as arguments to functions like jnp.sum(). It's not much more painful to require inserting array() in expressions like x[:, jnp.array([1, 2])] and has the advantage of much more explicit conversion.

@jakevdp jakevdp self-assigned this Oct 14, 2020
@hawkinsp
Copy link
Member

+1. Let's just disallow this now.

@jakevdp
Copy link
Collaborator Author

jakevdp commented Nov 19, 2020

Update: #4641 added a warning for this behavior in JAX, mirroring the warning in numpy (part of jax v0.2.4)

#4867 turns this warning into a TypeError. This will most likely be part of jax v0.2.7.

@donno2048
Copy link

That messed up some things in TensorFlow...

@donno2048
Copy link

Oh, gcucurull/jax-gcn#1 is about that kinda

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants