Skip to content

Commit

Permalink
TST: Add tests for np.argsort (#23846)
Browse files Browse the repository at this point in the history
Contributing a few tests I had used when developing AVX-512 based argsort.

Co-authored-by: Robert Kern <robert.kern@gmail.com>
  • Loading branch information
2 people authored and charris committed Jun 5, 2023
1 parent eaffdf3 commit 35d23ba
Showing 1 changed file with 42 additions and 0 deletions.
42 changes: 42 additions & 0 deletions numpy/core/tests/test_multiarray.py
Expand Up @@ -41,6 +41,12 @@
from datetime import timedelta, datetime


def assert_arg_sorted(arr, arg):
# resulting array should be sorted and arg values should be unique
assert_equal(arr[arg], np.sort(arr))
assert_equal(np.sort(arg), np.arange(len(arg)))


def _aligned_zeros(shape, dtype=float, order="C", align=None):
"""
Allocate a new ndarray with aligned memory.
Expand Down Expand Up @@ -9989,3 +9995,39 @@ def test_sort_uint():

def test_private_get_ndarray_c_version():
assert isinstance(_get_ndarray_c_version(), int)


@pytest.mark.parametrize("N", np.arange(1, 512))
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_argsort_float(N, dtype):
rnd = np.random.RandomState(116112)
# (1) Regular data with a few nan: doesn't use vectorized sort
arr = -0.5 + rnd.random(N).astype(dtype)
arr[rnd.choice(arr.shape[0], 3)] = np.nan
assert_arg_sorted(arr, np.argsort(arr, kind='quick'))

# (2) Random data with inf at the end of array
# See: https://github.com/intel/x86-simd-sort/pull/39
arr = -0.5 + rnd.rand(N).astype(dtype)
arr[N-1] = np.inf
assert_arg_sorted(arr, np.argsort(arr, kind='quick'))


@pytest.mark.parametrize("N", np.arange(2, 512))
@pytest.mark.parametrize("dtype", [np.int32, np.uint32, np.int64, np.uint64])
def test_argsort_int(N, dtype):
rnd = np.random.RandomState(1100710816)
# (1) random data with min and max values
minv = np.iinfo(dtype).min
maxv = np.iinfo(dtype).max
arr = rnd.randint(low=minv, high=maxv, size=N, dtype=dtype)
i, j = rnd.choice(N, 2, replace=False)
arr[i] = minv
arr[j] = maxv
assert_arg_sorted(arr, np.argsort(arr, kind='quick'))

# (2) random data with max value at the end of array
# See: https://github.com/intel/x86-simd-sort/pull/39
arr = rnd.randint(low=minv, high=maxv, size=N, dtype=dtype)
arr[N-1] = maxv
assert_arg_sorted(arr, np.argsort(arr, kind='quick'))

0 comments on commit 35d23ba

Please sign in to comment.