-
Couldn't load subscription status.
- Fork 38
FIX: Wrap torch.argsort to set stable=True by default #356
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR wraps torch.argsort to set stable=True by default, aligning it with the array API specification and matching the behavior of the existing sort wrapper.
- Adds a new
argsortfunction wrapper that defaultsstableparameter toTrue
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
Remove the empty line with trailing whitespace inside the function body. This line serves no purpose and should be deleted. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
|
In pytorch, both Which looks correct and wanted indeed. |
|
Need to also add $ git diff
diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py
index 715182a..fc1688a 100644
--- a/array_api_compat/torch/_aliases.py
+++ b/array_api_compat/torch/_aliases.py
@@ -851,7 +851,8 @@ __all__ = ['asarray', 'result_type', 'can_cast',
'equal', 'floor_divide', 'greater', 'greater_equal', 'hypot',
'less', 'less_equal', 'logaddexp', 'maximum', 'minimum',
'multiply', 'not_equal', 'pow', 'remainder', 'subtract', 'max',
- 'min', 'clip', 'unstack', 'cumulative_sum', 'cumulative_prod', 'sort', 'prod', 'sum',
+ 'min', 'clip', 'unstack', 'cumulative_sum', 'cumulative_prod',
+ 'argsort', 'sort', 'prod', 'sum',
'any', 'all', 'mean', 'std', 'var', 'concat', 'squeeze',
'broadcast_to', 'flip', 'roll', 'nonzero', 'where', 'reshape',
'arange', 'eye', 'linspace', 'full', 'ones', 'zeros', 'empty',To verify: run data-apis/array-api-tests#390 with |
|
Also cross-ref data-apis/array-api-tests#390 (comment). |
Thanks for the tip 👍 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's merge this.
Not sure if we need to wait for the -tests PR to be merged first.
|
Okay, I'll send a quick follow-up PR with $ git diff
diff --git a/tests/test_torch.py b/tests/test_torch.py
index 7adb4ab..a367c7b 100644
--- a/tests/test_torch.py
+++ b/tests/test_torch.py
@@ -117,3 +117,14 @@ def test_meshgrid():
assert Y.shape == Y_xy.shape
assert xp.all(Y == Y_xy)
+
+
+def test_argsort_stable():
+ """Verify that argsort defaults to a stable sort."""
+ # Bare pytorch defaults to an unstable sort, and the array_api_compat wrapper
+ # enforces the stable=True default.
+ # cf https://github.com/data-apis/array-api-compat/pull/356 and
+ # https://github.com/data-apis/array-api-tests/pull/390#issuecomment-3452868329
+
+ t = xp.zeros(50) # should be >16
+ assert xp.all(xp.argsort(t) == xp.arange(50)) |
cross-ref data-apis#356 which wrapped torch.argsort to fix the default, and data-apis/array-api-tests#390 which made a matching change in the array-api-test suite.
cross-ref data-apis#356 which wrapped torch.argsort to fix the default, and data-apis/array-api-tests#390 which made a matching change in the array-api-test suite.
|
A follow-up in #358 |
Everything is in the title.
Fixes #354