diff --git a/tests/test_torch.py b/tests/test_torch.py index f661a272..c8619565 100644 --- a/tests/test_torch.py +++ b/tests/test_torch.py @@ -130,3 +130,13 @@ def test_meshgrid(): assert Y.shape == Y_ij.shape assert xp.all(Y == Y_ij) + +def test_argsort_stable(): + """Verify that argsort defaults to a stable sort.""" + # Bare pytorch defaults to an unstable sort, and the array_api_compat wrapper + # enforces the stable=True default. + # cf https://github.com/data-apis/array-api-compat/pull/356 and + # https://github.com/data-apis/array-api-tests/pull/390#issuecomment-3452868329 + + t = xp.zeros(50) # should be >16 + assert xp.all(xp.argsort(t) == xp.arange(50))