Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add Nx.argsort/2 #367

Merged
merged 19 commits into from
Apr 6, 2021
Merged

feat: add Nx.argsort/2 #367

merged 19 commits into from
Apr 6, 2021

Conversation

polvalente
Copy link
Contributor

This PR aims to add the sort_indices option so we can support "argsort" functionality

closes #342

@polvalente
Copy link
Contributor Author

This isn't ready because EXLA doesn't implement the option yet.
@seanmor5 any pointers on how we can achieve this?

I read the XLA docs on sort (https://www.tensorflow.org/xla/operation_semantics#sort) and it seems that if we provide 2 arguments, they will be sorted simultaneously. However, I had some trouble on how to build the corresponding index tensor inside the NIF.

I'm leaning towards using something akin to the code below for build the second argument:

iex(7)> tensor = Nx.tensor([[[[4,3,2],[1,0,-1]], [[1,2,3],[4,5,2]]]])
#Nx.Tensor<
  s64[1][2][2][3]
  [
    [
      [
        [4, 3, 2],
        [1, 0, -1]
      ],
      [
        [1, 2, 3],
        [4, 5, 2]
      ]
    ]
  ]
>
iex(8)> Nx.iota(tensor, axis: Nx.axes(tensor) |> Enum.at(-1))        
#Nx.Tensor<
  s64[1][2][2][3]
  [
    [
      [
        [0, 1, 2],
        [0, 1, 2]
      ],
      [
        [0, 1, 2],
        [0, 1, 2]
      ]
    ]
  ]
>

However, I'm not sure how it will behave, if it'll actually need a arity-4 function for sorting (as implied by the documentation linked above).

@polvalente polvalente requested a review from seanmor5 April 5, 2021 04:00
@josevalim
Copy link
Collaborator

Awesome job @polvalente!

Some notes:

  1. Generally speaking, we prefer to expose the EXLA operations as close to "native" as possible. I.e. EXLA.Op should expose the C API instead of providing conveniences on top of it. For this purpose, it is probably better to implement variadic_sort, pretty much like we implemented variadic_reduce.

  2. At the NX level, I would rather introduce an argsort function instead of having options inside sort. :)

@seanmor5
Copy link
Collaborator

seanmor5 commented Apr 5, 2021

@polvalente I looked into implementing argsort in EXLA a little bit ago. What I would do is change the sort NIF to always accept a list of Tensors. Then you can implement argsort in EXLA.Lib and sort the input tensor side by side with an iota operation.

EDIT: Actually I like the variadic sort idea better, I would go with that :)

@polvalente polvalente marked this pull request as ready for review April 6, 2021 07:36
Comment on lines 710 to 718
path = [Access.key!(:data), Access.key!(:args)]
[[arg_1, arg_2], expr, fun] = get_in(comparator, path)

comparator_4 =
put_in(comparator, path, [
[arg_1, arg_2, arg_1, arg_2],
expr,
fn x, y, _, _ -> fun.(x, y) end
])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if there's a better way to achieve this. Basically, the comparison function needed for variadic_sort always has 2*num_args arity, so I needed to extend the arity-2 comparator to an arity-4 one which only uses the first 2 args (which are the current comparison elements for the first of the args passed to sort)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can wrap that function inside the existing function. Something like this:

    # Grow the comparator to arity 4 because argsort uses
    # variadic_sort underneath
    [[arg1, arg2], _expr, fun] = comparator.data.args

    comparator_4 = Expr.fun([arg1, arg2, arg1, arg2], fn x, y, _, _ -> fun.(x, y) end)

nx/lib/nx.ex Outdated Show resolved Hide resolved
%T{shape: shape, names: names} = tensor = to_tensor(tensor)
axis = Nx.Shape.normalize_axis(shape, opts[:axis], names)

impl!(tensor).argsort(%{tensor | type: {:s, 64}}, tensor,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't able to make argsort work with {:u, 64} type on EXLA, which would be ideal, so I defaulted to {:s, 64}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which error do you get? Looking at the code it feels like it should work... on the other, numpy and jax both return int64 too.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Torchx also return int64, so let's keep this as is.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error was something along the lines of "The comparator function expected 0-th argument to have shape s64[], got u64[] instead"

Comment on lines 1590 to 1601
fn a, b ->
a = binary_to_number(a, type)
b = binary_to_number(b, type)
a <= b
end

:asc ->
&>/2
fn a, b ->
a = binary_to_number(a, type)
b = binary_to_number(b, type)
a >= b
end
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes are two-fold. First because the original comparators didn't work with negative numbers, and secondly I've used<= and >= so the sort is stable

@@ -1634,10 +1657,20 @@ defmodule Nx.BinaryBackend do
end
end

IO.iodata_to_binary(Enum.sort(data, comparator))
sorted =
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if the match_types call on line 1654 is needed. Perhaps we should remove it or refactor the for-loop somehow to use it correctly

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It isn't, you can remove it.

nx/lib/nx.ex Outdated Show resolved Hide resolved
nx/lib/nx.ex Outdated
"""
@doc type: :ndim
def argsort(tensor, opts \\ []) do
opts = keyword!(opts, axis: 0, comparator: :desc)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default here seems to be :desc. We probably want to change both sort and argsort to be asc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that the comparator definitions were swapped between desc and asc. My latest commit changed this

nx/lib/nx.ex Outdated Show resolved Hide resolved
@polvalente polvalente changed the title feat: add sort_indices option to Nx.sort/2 feat: add Nx.argsort/2 Apr 6, 2021
nx/lib/nx.ex Outdated Show resolved Hide resolved
Copy link
Collaborator

@josevalim josevalim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two minor comments and ship it! 🚢

@polvalente polvalente merged commit b350578 into main Apr 6, 2021
@polvalente polvalente deleted the feat/add-argsort branch April 6, 2021 16:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Returns the indices after sorting a tensor
3 participants