-
-
Notifications
You must be signed in to change notification settings - Fork 12.2k
unravel_index fails to restore dimensions of indices arguments #28439
Copy link
Copy link
Closed
Labels
Description
Describe the issue:
I have a (N,M,M) array on which I would like to call argmax(array, axis=(1,2)). To circumvent the fact that argmax only takes one axis, I flatten the corresponding dimensions to call argmax on an (N,M*M) array and unravel the result to recover the initial dimension.
The resulting unravelled indices don't correspond to what is expected (see code below)
Reproduce the code example:
import numpy as np
a = np.array([[[1, 1, 1], [1, 1, 1], [1, 4, 1]], [[5, 1, 1],[1, 1, 1],[1, 1, 1]]])
flat_indices = np.argmax(a.reshape(2, -1), axis=-1)
indices = np.array(np.unravel_index(flat_indices, a.shape))
print(indices)
assert not np.any(indices - np.array([[0, 1], [2, 0], [1, 0]])), "Wrong unravelled indices"Error message:
Python and NumPy Versions:
2.0.2
3.9.21 (main, Dec 11 2024, 16:24:11)
[GCC 11.2.0]
Runtime Environment:
No response
Context for the issue:
No response
Reactions are currently unavailable