Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jnp.moveaxis: fix bug when axes are integer dtype #4371

Merged
merged 1 commit into from Sep 22, 2020

Conversation

jakevdp
Copy link
Collaborator

@jakevdp jakevdp commented Sep 21, 2020

Before:

In [1]: import jax.numpy as jnp                                                                                                          

In [2]: x = jnp.arange(6).reshape(2, 3) 

In [3]: jnp.moveaxis(x, jnp.int32(0), jnp.int32(1))                                                                                      
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-3-53fba72727a1> in <module>
----> 1 jnp.moveaxis(x, jnp.int32(0), jnp.int32(1))

~/github/google/jax/jax/numpy/lax_numpy.py in moveaxis(a, source, destination)
   1261   if isinstance(destination, int):
   1262     destination = (destination,)
-> 1263   source = tuple(_canonicalize_axis(i, ndim(a)) for i in source)
   1264   destination = tuple(_canonicalize_axis(i, ndim(a)) for i in destination)
   1265   if len(source) != len(destination):

~/github/google/jax/jax/interpreters/xla.py in __iter__(self)
   1060   def __iter__(self):
   1061     if self.ndim == 0:
-> 1062       raise TypeError("iteration over a 0-d array")  # same as numpy error
   1063     else:
   1064       return self._value.__iter__()

TypeError: iteration over a 0-d array

After:

In [1]: import jax.numpy as jnp                                                                                                          

In [2]: x = jnp.arange(6).reshape(2, 3)

In [3]: jnp.moveaxis(x, jnp.int32(0), jnp.int32(1))                                                                                      
Out[3]: 
DeviceArray([[0, 3],
             [1, 4],
             [2, 5]], dtype=int32)

@google-cla google-cla bot added the cla: yes label Sep 21, 2020
@jakevdp jakevdp added the pull ready Ready for copybara import and testing label Sep 22, 2020
@copybara-service copybara-service bot merged commit d43d5d9 into google:master Sep 22, 2020
@jakevdp jakevdp deleted the moveaxis-fix branch September 22, 2020 03:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants