-
Notifications
You must be signed in to change notification settings - Fork 190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add Nx.argsort/2 #367
Conversation
This isn't ready because EXLA doesn't implement the option yet. I read the XLA docs on sort (https://www.tensorflow.org/xla/operation_semantics#sort) and it seems that if we provide 2 arguments, they will be sorted simultaneously. However, I had some trouble on how to build the corresponding index tensor inside the NIF. I'm leaning towards using something akin to the code below for build the second argument: iex(7)> tensor = Nx.tensor([[[[4,3,2],[1,0,-1]], [[1,2,3],[4,5,2]]]])
#Nx.Tensor<
s64[1][2][2][3]
[
[
[
[4, 3, 2],
[1, 0, -1]
],
[
[1, 2, 3],
[4, 5, 2]
]
]
]
>
iex(8)> Nx.iota(tensor, axis: Nx.axes(tensor) |> Enum.at(-1))
#Nx.Tensor<
s64[1][2][2][3]
[
[
[
[0, 1, 2],
[0, 1, 2]
],
[
[0, 1, 2],
[0, 1, 2]
]
]
]
>
However, I'm not sure how it will behave, if it'll actually need a arity-4 function for sorting (as implied by the documentation linked above). |
Awesome job @polvalente! Some notes:
|
@polvalente I looked into implementing argsort in EXLA a little bit ago. What I would do is change the sort NIF to always accept a list of Tensors. Then you can implement argsort in EXLA.Lib and sort the input tensor side by side with an iota operation. EDIT: Actually I like the variadic sort idea better, I would go with that :) |
exla/lib/exla/defn.ex
Outdated
path = [Access.key!(:data), Access.key!(:args)] | ||
[[arg_1, arg_2], expr, fun] = get_in(comparator, path) | ||
|
||
comparator_4 = | ||
put_in(comparator, path, [ | ||
[arg_1, arg_2, arg_1, arg_2], | ||
expr, | ||
fn x, y, _, _ -> fun.(x, y) end | ||
]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if there's a better way to achieve this. Basically, the comparison function needed for variadic_sort always has 2*num_args arity, so I needed to extend the arity-2 comparator to an arity-4 one which only uses the first 2 args (which are the current comparison elements for the first of the args passed to sort)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can wrap that function inside the existing function. Something like this:
# Grow the comparator to arity 4 because argsort uses
# variadic_sort underneath
[[arg1, arg2], _expr, fun] = comparator.data.args
comparator_4 = Expr.fun([arg1, arg2, arg1, arg2], fn x, y, _, _ -> fun.(x, y) end)
%T{shape: shape, names: names} = tensor = to_tensor(tensor) | ||
axis = Nx.Shape.normalize_axis(shape, opts[:axis], names) | ||
|
||
impl!(tensor).argsort(%{tensor | type: {:s, 64}}, tensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wasn't able to make argsort work with {:u, 64}
type on EXLA, which would be ideal, so I defaulted to {:s, 64}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Which error do you get? Looking at the code it feels like it should work... on the other, numpy and jax both return int64 too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Torchx also return int64, so let's keep this as is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error was something along the lines of "The comparator function expected 0-th argument to have shape s64[], got u64[] instead"
nx/lib/nx/binary_backend.ex
Outdated
fn a, b -> | ||
a = binary_to_number(a, type) | ||
b = binary_to_number(b, type) | ||
a <= b | ||
end | ||
|
||
:asc -> | ||
&>/2 | ||
fn a, b -> | ||
a = binary_to_number(a, type) | ||
b = binary_to_number(b, type) | ||
a >= b | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These changes are two-fold. First because the original comparators didn't work with negative numbers, and secondly I've used<=
and >=
so the sort is stable
@@ -1634,10 +1657,20 @@ defmodule Nx.BinaryBackend do | |||
end | |||
end | |||
|
|||
IO.iodata_to_binary(Enum.sort(data, comparator)) | |||
sorted = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if the match_types
call on line 1654 is needed. Perhaps we should remove it or refactor the for
-loop somehow to use it correctly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It isn't, you can remove it.
nx/lib/nx.ex
Outdated
""" | ||
@doc type: :ndim | ||
def argsort(tensor, opts \\ []) do | ||
opts = keyword!(opts, axis: 0, comparator: :desc) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default here seems to be :desc
. We probably want to change both sort and argsort to be asc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that the comparator definitions were swapped between desc and asc. My latest commit changed this
Co-authored-by: José Valim <jose.valim@dashbit.co>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two minor comments and ship it! 🚢
Co-authored-by: José Valim <jose.valim@dashbit.co>
This PR aims to add the sort_indices option so we can support "argsort" functionality
closes #342