Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Vectorize argsort and argselect with AVX2 #25610

Merged
merged 3 commits into from
Jan 24, 2024

Conversation

r-devulap
Copy link
Member

@r-devulap r-devulap commented Jan 17, 2024

Add AVX2 version of argsort and argselect. Benchmark numbers:

| Change   | Before [174ac7bc] <main>   | After [680b6823] <avx2_arg>   |   Ratio | Benchmark (Parameter)                                                                    |
|----------|----------------------------|-------------------------------|---------|------------------------------------------------------------------------------------------|
| +        | 63.8±0.5μs                 | 260±1μs                       |    4.07 | bench_function_base.Sort.time_argsort('quick', 'int64', ('ordered',))                    |
| +        | 69.6±0.07μs                | 210±3μs                       |    3.03 | bench_function_base.Sort.time_argsort('quick', 'int32', ('ordered',))                    |
| +        | 73.7±0.3μs                 | 220±3μs                       |    2.99 | bench_function_base.Sort.time_argsort('quick', 'float64', ('ordered',))                  |
| +        | 72.2±0.03μs                | 214±0.2μs                     |    2.96 | bench_function_base.Sort.time_argsort('quick', 'uint32', ('ordered',))                   |
| +        | 79.1±0.07μs                | 227±0.1μs                     |    2.87 | bench_function_base.Sort.time_argsort('quick', 'float32', ('ordered',))                  |
| +        | 122±0.09μs                 | 250±5μs                       |    2.05 | bench_function_base.Sort.time_argsort('quick', 'int64', ('reversed',))                   |
| +        | 232±0.3μs                  | 468±0.4μs                     |    2.01 | bench_function_base.Partition.time_argpartition('float32', ('sorted_block', 1000), 10)   |
| +        | 245±1μs                    | 492±2μs                       |    2.01 | bench_function_base.Partition.time_argpartition('int64', ('sorted_block', 1000), 100)    |
| +        | 234±0.3μs                  | 467±0.4μs                     |    2    | bench_function_base.Partition.time_argpartition('float32', ('sorted_block', 1000), 100)  |
| +        | 247±0.6μs                  | 492±1μs                       |    2    | bench_function_base.Partition.time_argpartition('int64', ('sorted_block', 1000), 10)     |
| +        | 105±0.2μs                  | 206±0.3μs                     |    1.95 | bench_function_base.Sort.time_argsort('quick', 'int32', ('reversed',))                   |
| +        | 163±1μs                    | 303±2μs                       |    1.86 | bench_function_base.Partition.time_argpartition('float32', ('ordered',), 100)            |
| +        | 163±1μs                    | 302±2μs                       |    1.85 | bench_function_base.Partition.time_argpartition('float32', ('ordered',), 1000)           |
| +        | 163±1μs                    | 300±3μs                       |    1.84 | bench_function_base.Partition.time_argpartition('float32', ('ordered',), 10)             |
| +        | 255±0.6μs                  | 469±1μs                       |    1.84 | bench_function_base.Partition.time_argpartition('float64', ('sorted_block', 1000), 10)   |
| +        | 113±0.08μs                 | 206±3μs                       |    1.83 | bench_function_base.Sort.time_argsort('quick', 'uint32', ('reversed',))                  |
| +        | 256±0.9μs                  | 467±1μs                       |    1.82 | bench_function_base.Partition.time_argpartition('float64', ('sorted_block', 1000), 100)  |
| +        | 172±0.8μs                  | 301±0.2μs                     |    1.75 | bench_function_base.Partition.time_argpartition('int64', ('ordered',), 100)              |
| +        | 172±0.5μs                  | 300±0.4μs                     |    1.75 | bench_function_base.Partition.time_argpartition('int64', ('ordered',), 1000)             |
| +        | 123±0.08μs                 | 214±0.4μs                     |    1.75 | bench_function_base.Sort.time_argsort('quick', 'float64', ('reversed',))                 |
| +        | 261±0.1μs                  | 452±0.3μs                     |    1.73 | bench_function_base.Partition.time_argpartition('int32', ('sorted_block', 1000), 10)     |
| +        | 261±0.4μs                  | 452±0.5μs                     |    1.73 | bench_function_base.Partition.time_argpartition('int32', ('sorted_block', 1000), 100)    |
| +        | 172±0.2μs                  | 297±0.6μs                     |    1.73 | bench_function_base.Partition.time_argpartition('int64', ('ordered',), 10)               |
| +        | 127±0.2μs                  | 214±0.1μs                     |    1.69 | bench_function_base.Sort.time_argsort('quick', 'float32', ('reversed',))                 |
| +        | 184±1μs                    | 298±0.2μs                     |    1.63 | bench_function_base.Partition.time_argpartition('float64', ('ordered',), 100)            |
| +        | 183±0.7μs                  | 298±0.5μs                     |    1.63 | bench_function_base.Partition.time_argpartition('float64', ('ordered',), 1000)           |
| +        | 183±0.2μs                  | 296±0.4μs                     |    1.62 | bench_function_base.Partition.time_argpartition('float64', ('ordered',), 10)             |
| +        | 261±2μs                    | 416±6μs                       |    1.6  | bench_function_base.Partition.time_argpartition('int64', ('sorted_block', 100), 10)      |
| +        | 259±2μs                    | 415±6μs                       |    1.6  | bench_function_base.Partition.time_argpartition('int64', ('sorted_block', 100), 1000)    |
| +        | 262±2μs                    | 416±6μs                       |    1.59 | bench_function_base.Partition.time_argpartition('int64', ('sorted_block', 100), 100)     |
| +        | 264±2μs                    | 400±0.7μs                     |    1.51 | bench_function_base.Partition.time_argpartition('float32', ('sorted_block', 100), 1000)  |
| +        | 271±0.8μs                  | 401±0.5μs                     |    1.48 | bench_function_base.Partition.time_argpartition('float32', ('sorted_block', 100), 10)    |
| +        | 271±0.8μs                  | 400±0.8μs                     |    1.48 | bench_function_base.Partition.time_argpartition('float32', ('sorted_block', 100), 100)   |
| +        | 279±3μs                    | 405±3μs                       |    1.45 | bench_function_base.Partition.time_argpartition('float64', ('sorted_block', 100), 1000)  |
| +        | 281±3μs                    | 404±7μs                       |    1.44 | bench_function_base.Partition.time_argpartition('float64', ('sorted_block', 100), 10)    |
| +        | 281±2μs                    | 403±4μs                       |    1.44 | bench_function_base.Partition.time_argpartition('float64', ('sorted_block', 100), 100)   |
| +        | 194±2μs                    | 273±2μs                       |    1.41 | bench_function_base.Partition.time_argpartition('int32', ('ordered',), 100)              |
| +        | 193±1μs                    | 272±2μs                       |    1.41 | bench_function_base.Partition.time_argpartition('int32', ('ordered',), 1000)             |
| +        | 193±1μs                    | 270±2μs                       |    1.4  | bench_function_base.Partition.time_argpartition('int32', ('ordered',), 10)               |
| +        | 280±1μs                    | 383±0.4μs                     |    1.37 | bench_function_base.Partition.time_argpartition('int32', ('sorted_block', 100), 100)     |
| +        | 279±0.9μs                  | 383±0.2μs                     |    1.37 | bench_function_base.Partition.time_argpartition('int32', ('sorted_block', 100), 1000)    |
| +        | 280±1μs                    | 383±0.3μs                     |    1.36 | bench_function_base.Partition.time_argpartition('int32', ('sorted_block', 100), 10)      |
| +        | 199±2μs                    | 212±0.8μs                     |    1.06 | bench_function_base.Sort.time_argsort('merge', 'float32', ('sorted_block', 10))          |
| -        | 496±0.9μs                  | 462±0.5μs                     |    0.93 | bench_function_base.Partition.time_argpartition('float32', ('sorted_block', 1000), 1000) |
| -        | 170±0.7μs                  | 149±0.1μs                     |    0.88 | bench_function_base.Sort.time_argsort('merge', 'uint32', ('sorted_block', 10))           |
| -        | 533±0.6μs                  | 463±2μs                       |    0.87 | bench_function_base.Partition.time_argpartition('float64', ('sorted_block', 1000), 1000) |
| -        | 532±0.3μs                  | 445±0.5μs                     |    0.84 | bench_function_base.Partition.time_argpartition('int32', ('sorted_block', 1000), 1000)   |
| -        | 360±0.3μs                  | 282±0.4μs                     |    0.78 | bench_function_base.Sort.time_argsort('quick', 'int32', ('sorted_block', 1000))          |
| -        | 364±0.6μs                  | 275±0.7μs                     |    0.76 | bench_function_base.Sort.time_argsort('quick', 'uint32', ('sorted_block', 1000))         |
| -        | 389±0.2μs                  | 293±0.6μs                     |    0.75 | bench_function_base.Sort.time_argsort('quick', 'float32', ('sorted_block', 1000))        |
| -        | 476±2μs                    | 341±3μs                       |    0.72 | bench_function_base.Partition.time_argpartition('int64', ('sorted_block', 10), 100)      |
| -        | 391±0.4μs                  | 280±0.8μs                     |    0.72 | bench_function_base.Sort.time_argsort('quick', 'float64', ('sorted_block', 1000))        |
| -        | 476±3μs                    | 340±2μs                       |    0.71 | bench_function_base.Partition.time_argpartition('int64', ('sorted_block', 10), 10)       |
| -        | 474±3μs                    | 334±3μs                       |    0.7  | bench_function_base.Partition.time_argpartition('float64', ('sorted_block', 10), 10)     |
| -        | 477±3μs                    | 336±2μs                       |    0.7  | bench_function_base.Partition.time_argpartition('float64', ('sorted_block', 10), 100)    |
| -        | 482±2μs                    | 335±3μs                       |    0.7  | bench_function_base.Partition.time_argpartition('int64', ('sorted_block', 10), 1000)     |
| -        | 474±4μs                    | 326±0.3μs                     |    0.69 | bench_function_base.Partition.time_argpartition('float32', ('sorted_block', 10), 100)    |
| -        | 477±7μs                    | 326±0.2μs                     |    0.68 | bench_function_base.Partition.time_argpartition('float32', ('sorted_block', 10), 10)     |
| -        | 485±3μs                    | 329±5μs                       |    0.68 | bench_function_base.Partition.time_argpartition('float64', ('sorted_block', 10), 1000)   |
| -        | 490±2μs                    | 323±0.2μs                     |    0.66 | bench_function_base.Partition.time_argpartition('float32', ('sorted_block', 10), 1000)   |
| -        | 476±2μs                    | 302±0.2μs                     |    0.64 | bench_function_base.Partition.time_argpartition('int32', ('sorted_block', 10), 10)       |
| -        | 478±2μs                    | 301±0.4μs                     |    0.63 | bench_function_base.Partition.time_argpartition('int32', ('sorted_block', 10), 100)      |
| -        | 483±2μs                    | 301±0.4μs                     |    0.62 | bench_function_base.Partition.time_argpartition('int32', ('sorted_block', 10), 1000)     |
| -        | 461±0.3μs                  | 284±5μs                       |    0.62 | bench_function_base.Sort.time_argsort('quick', 'int64', ('sorted_block', 10))            |
| -        | 445±0.5μs                  | 273±0.1μs                     |    0.61 | bench_function_base.Sort.time_argsort('quick', 'int64', ('sorted_block', 100))           |
| -        | 315±0.2μs                  | 169±1μs                       |    0.54 | bench_function_base.Partition.time_argpartition('float64', ('uniform',), 10)             |
| -        | 315±0.4μs                  | 169±1μs                       |    0.54 | bench_function_base.Partition.time_argpartition('float64', ('uniform',), 100)            |
| -        | 315±0.4μs                  | 169±2μs                       |    0.54 | bench_function_base.Partition.time_argpartition('float64', ('uniform',), 1000)           |
| -        | 284±0.2μs                  | 149±2μs                       |    0.53 | bench_function_base.Partition.time_argpartition('int64', ('uniform',), 1000)             |
| -        | 328±6μs                    | 169±1μs                       |    0.52 | bench_function_base.Partition.time_argpartition('float32', ('uniform',), 10)             |
| -        | 285±0.3μs                  | 149±2μs                       |    0.52 | bench_function_base.Partition.time_argpartition('int64', ('uniform',), 10)               |
| -        | 284±0.2μs                  | 149±0.2μs                     |    0.52 | bench_function_base.Partition.time_argpartition('int64', ('uniform',), 100)              |
| -        | 329±6μs                    | 169±1μs                       |    0.51 | bench_function_base.Partition.time_argpartition('float32', ('uniform',), 100)            |
| -        | 332±4μs                    | 169±1μs                       |    0.51 | bench_function_base.Partition.time_argpartition('float32', ('uniform',), 1000)           |
| -        | 471±0.1μs                  | 242±0.4μs                     |    0.51 | bench_function_base.Sort.time_argsort('quick', 'int32', ('sorted_block', 100))           |
| -        | 547±0.5μs                  | 276±5μs                       |    0.5  | bench_function_base.Sort.time_argsort('quick', 'int64', ('random',))                     |
| -        | 488±0.5μs                  | 238±2μs                       |    0.49 | bench_function_base.Sort.time_argsort('quick', 'int32', ('sorted_block', 10))            |
| -        | 479±0.4μs                  | 234±2μs                       |    0.49 | bench_function_base.Sort.time_argsort('quick', 'uint32', ('sorted_block', 100))          |
| -        | 486±0.4μs                  | 233±3μs                       |    0.48 | bench_function_base.Sort.time_argsort('quick', 'uint32', ('sorted_block', 10))           |
| -        | 283±2μs                    | 134±1μs                       |    0.47 | bench_function_base.Partition.time_argpartition('int32', ('uniform',), 10)               |
| -        | 286±2μs                    | 133±1μs                       |    0.47 | bench_function_base.Partition.time_argpartition('int32', ('uniform',), 100)              |
| -        | 284±2μs                    | 134±1μs                       |    0.47 | bench_function_base.Partition.time_argpartition('int32', ('uniform',), 1000)             |
| -        | 531±0.2μs                  | 252±4μs                       |    0.47 | bench_function_base.Sort.time_argsort('quick', 'float32', ('sorted_block', 10))          |
| -        | 514±0.2μs                  | 240±0.2μs                     |    0.47 | bench_function_base.Sort.time_argsort('quick', 'float32', ('sorted_block', 100))         |
| -        | 524±0.3μs                  | 232±5μs                       |    0.44 | bench_function_base.Sort.time_argsort('quick', 'float64', ('sorted_block', 100))         |
| -        | 542±0.2μs                  | 232±0.7μs                     |    0.43 | bench_function_base.Sort.time_argsort('quick', 'float64', ('sorted_block', 10))          |
| -        | 576±0.3μs                  | 235±0.2μs                     |    0.41 | bench_function_base.Sort.time_argsort('quick', 'int32', ('random',))                     |
| -        | 572±0.4μs                  | 226±2μs                       |    0.4  | bench_function_base.Sort.time_argsort('quick', 'uint32', ('random',))                    |
| -        | 631±0.4μs                  | 240±6μs                       |    0.38 | bench_function_base.Sort.time_argsort('quick', 'float32', ('random',))                   |
| -        | 1.06±0ms                   | 390±2μs                       |    0.37 | bench_function_base.Partition.time_argpartition('int64', ('random',), 1000)              |
| -        | 1.06±0ms                   | 378±1μs                       |    0.36 | bench_function_base.Partition.time_argpartition('int64', ('random',), 10)                |
| -        | 1.06±0ms                   | 378±2μs                       |    0.36 | bench_function_base.Partition.time_argpartition('int64', ('random',), 100)               |
| -        | 655±0.2μs                  | 232±6μs                       |    0.35 | bench_function_base.Sort.time_argsort('quick', 'float64', ('random',))                   |
| -        | 1.12±0.01ms                | 379±3μs                       |    0.34 | bench_function_base.Partition.time_argpartition('float32', ('random',), 1000)            |
| -        | 1.12±0.01ms                | 368±3μs                       |    0.33 | bench_function_base.Partition.time_argpartition('float32', ('random',), 10)              |
| -        | 1.12±0.01ms                | 368±3μs                       |    0.33 | bench_function_base.Partition.time_argpartition('float32', ('random',), 100)             |
| -        | 1.18±0ms                   | 388±1μs                       |    0.33 | bench_function_base.Partition.time_argpartition('float64', ('random',), 1000)            |
| -        | 1.06±0.01ms                | 347±3μs                       |    0.33 | bench_function_base.Partition.time_argpartition('int32', ('random',), 10)                |
| -        | 1.06±0.01ms                | 346±3μs                       |    0.33 | bench_function_base.Partition.time_argpartition('int32', ('random',), 100)               |
| -        | 1.07±0.01ms                | 356±3μs                       |    0.33 | bench_function_base.Partition.time_argpartition('int32', ('random',), 1000)              |
| -        | 1.17±0ms                   | 377±1μs                       |    0.32 | bench_function_base.Partition.time_argpartition('float64', ('random',), 10)              |
| -        | 1.17±0ms                   | 377±2μs                       |    0.32 | bench_function_base.Partition.time_argpartition('float64', ('random',), 100)             |
| -        | 75.4±0.05μs                | 15.7±0.4μs                    |    0.21 | bench_function_base.Sort.time_argsort('quick', 'int64', ('uniform',))                    |
| -        | 1.37±0ms                   | 258±0.9μs                     |    0.19 | bench_function_base.Partition.time_argpartition('float64', ('reversed',), 10)            |
| -        | 1.38±0.01ms                | 258±0.8μs                     |    0.19 | bench_function_base.Partition.time_argpartition('float64', ('reversed',), 100)           |
| -        | 1.37±0.01ms                | 258±3μs                       |    0.19 | bench_function_base.Partition.time_argpartition('float64', ('reversed',), 1000)          |
| -        | 1.56±0.01ms                | 261±2μs                       |    0.17 | bench_function_base.Partition.time_argpartition('float32', ('reversed',), 10)            |
| -        | 1.57±0.01ms                | 260±2μs                       |    0.17 | bench_function_base.Partition.time_argpartition('float32', ('reversed',), 100)           |
| -        | 1.57±0.01ms                | 260±2μs                       |    0.17 | bench_function_base.Partition.time_argpartition('float32', ('reversed',), 1000)          |
| -        | 1.48±0ms                   | 253±0.5μs                     |    0.17 | bench_function_base.Partition.time_argpartition('int64', ('reversed',), 10)              |
| -        | 1.48±0.01ms                | 253±0.6μs                     |    0.17 | bench_function_base.Partition.time_argpartition('int64', ('reversed',), 100)             |
| -        | 1.48±0.01ms                | 253±1μs                       |    0.17 | bench_function_base.Partition.time_argpartition('int64', ('reversed',), 1000)            |
| -        | 1.49±0.01ms                | 228±2μs                       |    0.15 | bench_function_base.Partition.time_argpartition('int32', ('reversed',), 10)              |
| -        | 1.49±0.01ms                | 228±2μs                       |    0.15 | bench_function_base.Partition.time_argpartition('int32', ('reversed',), 100)             |
| -        | 1.49±0.01ms                | 228±2μs                       |    0.15 | bench_function_base.Partition.time_argpartition('int32', ('reversed',), 1000)            |
| -        | 94.4±0.6μs                 | 14.3±0.08μs                   |    0.15 | bench_function_base.Sort.time_argsort('quick', 'int32', ('uniform',))                    |
| -        | 94.5±0.3μs                 | 14.1±0.1μs                    |    0.15 | bench_function_base.Sort.time_argsort('quick', 'uint32', ('uniform',))                   |
| -        | 131±0.1μs                  | 17.8±0.09μs                   |    0.14 | bench_function_base.Sort.time_argsort('quick', 'float32', ('uniform',))                  |
| -        | 143±0.08μs                 | 18.0±0.4μs                    |    0.13 | bench_function_base.Sort.time_argsort('quick', 'float64', ('uniform',))                  |

@seiko2plus seiko2plus added the component: SIMD Issues in SIMD (fast instruction sets) code or machinery label Jan 18, 2024
@seiko2plus seiko2plus merged commit 221427b into numpy:main Jan 24, 2024
63 checks passed
@seiko2plus
Copy link
Member

Thank you @r-devulap!

@lesteve
Copy link
Contributor

lesteve commented Feb 1, 2024

There was a new failure in scikit-learn when testing against numpy dev because of this change, see scikit-learn/scikit-learn#28326 for more details. The np.argsort changes enough that the number of clusters in HDBSCAN is 3 with numpy dev, and 2 with numpy 1.26.3.

The scikit-learn tests pass in 0a4b2b8 (previous merge commit in main) and fail in 221427b (the merge commit for this PR).

I could not find any mention of this change in the changelog, but maybe I missed it?

I think this would be worth to adding a changelog entry about this change to indicate that np.argsort and np.argselect results may change in numpy 2.0.

@rgommers
Copy link
Member

rgommers commented Mar 5, 2024

I think this would be worth to adding a changelog entry about this change to indicate that np.argsort and np.argselect results may change in numpy 2.0.

I verified that this is indeed common when sorting integers.

>>> import hashlib
>>> import numpy as np

>>> rng = np.random.RandomState(seed=123098)  # minor: reproducers across versions should use RandomState
>>> x = rng.randint(100, size=10_000)
>>> hashlib.sha256(np.argsort(x).tobytes()).hexdigest()

I'll include this note in my next release notes PR:

Minor changes in behavior of sorting functions
----------------------------------------------

Due to algorithmic changes and use of SIMD code, sorting functions with methods
that aren't stable may return slightly different results in 2.0.0 compared to
1.26.x. This includes the default method of `~numpy.sort` and `~numpy.argsort`.

rgommers added a commit to rgommers/numpy that referenced this pull request Mar 5, 2024
As asked for in numpygh-25610

[skip actions] [skip azp] [skip cirrus]
rgommers added a commit to rgommers/numpy that referenced this pull request Mar 6, 2024
As asked for in numpygh-25610

[skip actions] [skip azp] [skip cirrus]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
01 - Enhancement component: SIMD Issues in SIMD (fast instruction sets) code or machinery
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants