Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH : np.argmax is unusually time consuming for multidimensional array #25846

Open
penzaijun opened this issue Feb 18, 2024 · 3 comments
Open
Labels

Comments

@penzaijun
Copy link

Describe the issue:

Both np.amax and np.argmax are expected to have a time complexity of O(n),so they should have similar computational times.
However, they only exhibit comparable performance on 1D arrays.
For 2D or higher dimensional arrays, np.amax consistently outperforms np.argmax by a factor of 8x or more. It's strange.

Reproduce the code example:

import timeit

stmt1 = "np.argmax(a, axis=0)"
stmt2 = "np.amax(a, axis=0)"
setup_1d = "import numpy as np; a = np.random.rand(3*768*768)"
setup_2d = "import numpy as np; a = np.random.rand(3,768*768)"

execution_time1 = timeit.timeit(stmt1, setup=setup_2d, number=1000)
print(f"Execution time for np.argmax on 2d array: {execution_time1} seconds")

execution_time2 = timeit.timeit(stmt2, setup=setup_2d, number=1000)
print(f"Execution time for np.amax on 2d array: {execution_time2} seconds")

execution_time1 = timeit.timeit(stmt1, setup=setup_1d, number=1000)
print(f"Execution time for np.argmax on 1d array: {execution_time1} seconds")

execution_time2 = timeit.timeit(stmt2, setup=setup_1d, number=1000)
print(f"Execution time for np.amax on 1d array: {execution_time2} seconds")

Error message:

Execution time for np.argmax on 2d array: 16.13085489999503 seconds
Execution time for np.amax on 2d array: 2.400201399810612 seconds
Execution time for np.argmax on 1d array: 0.6763406000100076 seconds
Execution time for np.amax on 1d array: 0.4886799002997577 seconds

Python and NumPy Versions:

Python: 3.10.13
Numpy: 1.26.4

Runtime Environment:

No response

Context for the issue:

No response

@amentee
Copy link

amentee commented Feb 18, 2024

@penzaijun - It may be due to the fact that argmax will flatten the array if its more than one dimensional first and then return the indices of the maximum value. The relevant documentation - https://numpy.org/doc/stable/reference/generated/numpy.argmax.html.

@penzaijun
Copy link
Author

@amentee

Thank you for the reply !

According to the documentation, argmax will flatten the array only when axis is None. But for all tests above, I have set axis=0.

Additionally, I tested the time consumption of argmax on 2d arrays without setting axis=0. The result is 0.680 seconds, very close to the time consumption of directly performing argmax on 1d arrays with the same size. It seems that flattening the array does not cost too much time.

I should correct my statement: the time consumption becomes abnormally high only when argmax is performed on multidimensional arrays along a specified axis.

Additional Test:
execution_time3 = timeit.timeit("np.argmax(a)", setup=setup_2d, number=1000)
print(f"Execution time for np.argmax on 2d array with axis==None: {execution_time3} seconds")

Test Result:
Execution time for np.argmax on 2d array with axis==None: 0.6801762999966741 seconds

@bencwallace
Copy link

bencwallace commented Feb 22, 2024

It seems that they have quite different implementations. The list of array method implementations in C can be found in methods.c. From this, you can see that argmax uses a custom implementation (the main part of which can be found in calculation.c), whereas amax (which is the same as max) is defined as the reduce method of the maximum ufunc. This method has a general definition for the ufunc base class in ufunc_object.c.

In either case you have to loop over all elements along the given axis so I'd guess the discrepancy is due to the fact that ufuncs like maximum have highly optimized implementations.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants