diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 196c2dc13783..514d271f0f6b 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -41,6 +41,12 @@ from datetime import timedelta, datetime +def assert_arg_sorted(arr, arg): + # resulting array should be sorted and arg values should be unique + assert_equal(arr[arg], np.sort(arr)) + assert_equal(np.sort(arg), np.arange(len(arg))) + + def _aligned_zeros(shape, dtype=float, order="C", align=None): """ Allocate a new ndarray with aligned memory. @@ -9989,3 +9995,39 @@ def test_sort_uint(): def test_private_get_ndarray_c_version(): assert isinstance(_get_ndarray_c_version(), int) + + +@pytest.mark.parametrize("N", np.arange(1, 512)) +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_argsort_float(N, dtype): + rnd = np.random.RandomState(116112) + # (1) Regular data with a few nan: doesn't use vectorized sort + arr = -0.5 + rnd.random(N).astype(dtype) + arr[rnd.choice(arr.shape[0], 3)] = np.nan + assert_arg_sorted(arr, np.argsort(arr, kind='quick')) + + # (2) Random data with inf at the end of array + # See: https://github.com/intel/x86-simd-sort/pull/39 + arr = -0.5 + rnd.rand(N).astype(dtype) + arr[N-1] = np.inf + assert_arg_sorted(arr, np.argsort(arr, kind='quick')) + + +@pytest.mark.parametrize("N", np.arange(2, 512)) +@pytest.mark.parametrize("dtype", [np.int32, np.uint32, np.int64, np.uint64]) +def test_argsort_int(N, dtype): + rnd = np.random.RandomState(1100710816) + # (1) random data with min and max values + minv = np.iinfo(dtype).min + maxv = np.iinfo(dtype).max + arr = rnd.randint(low=minv, high=maxv, size=N, dtype=dtype) + i, j = rnd.choice(N, 2, replace=False) + arr[i] = minv + arr[j] = maxv + assert_arg_sorted(arr, np.argsort(arr, kind='quick')) + + # (2) random data with max value at the end of array + # See: https://github.com/intel/x86-simd-sort/pull/39 + arr = rnd.randint(low=minv, high=maxv, size=N, dtype=dtype) + arr[N-1] = maxv + assert_arg_sorted(arr, np.argsort(arr, kind='quick'))