Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 93 additions & 22 deletions lib/emlx/backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,32 @@ defmodule EMLX.Backend do

t_mx = from_nx(tensor)

# Check for NaNs in the original tensor (before any reversal)
is_nan_mx = EMLX.is_nan(t_mx)

nan_index_mx =
if axis do
EMLX.argmax(is_nan_mx, axis, keep_axis)
else
EMLX.argmax(is_nan_mx, keep_axis)
end

# Check if any NaN exists along the axis
has_nan_mx =
cond do
axis ->
EMLX.any(is_nan_mx, [axis], keep_axis)

tuple_size(tensor.shape) == 0 ->
# For scalar input, is_nan_mx is already a scalar boolean
is_nan_mx

true ->
# For full reduction over non-scalar tensors
EMLX.any(is_nan_mx, Nx.axes(tensor), keep_axis)
end

# Apply reversal for tie_break after NaN check
t_mx =
if opts[:tie_break] == :high do
reverse_mlx(t_mx, tensor.shape, [axis] || Nx.axes(tensor))
Expand Down Expand Up @@ -623,6 +649,9 @@ defmodule EMLX.Backend do
result
end

# Use NaN index if any NaN exists, otherwise use regular result
result = EMLX.where(has_nan_mx, nan_index_mx, result)

result
|> EMLX.astype(to_mlx_type(out.type))
|> to_nx(out)
Expand Down Expand Up @@ -1196,36 +1225,78 @@ defmodule EMLX.Backend do
axis = opts[:axis]
asc? = opts[:direction] == :asc

t = tensor |> from_nx() |> EMLX.sort(axis)
t_mx = from_nx(tensor)

if asc? do
to_nx(t, out)
else
t
|> to_nx(out)
|> Nx.reverse(axes: [axis])
end
# Get the sorting indices
sort_mx =
if asc? do
EMLX.argsort(t_mx, axis)
else
t_mx
|> EMLX.negate()
|> EMLX.argsort(axis)
end

# Gather values at sorted positions to identify NaNs
sorted_values_mx = EMLX.take_along_axis(t_mx, sort_mx, axis)
is_nan_mx = EMLX.is_nan(sorted_values_mx)

# Partition indices to place NaNs correctly (NaNs are treated as highest):
# - For ascending: NaNs (highest) go to end: sort by is_nan (0 < 1)
# - For descending: NaNs (highest) go to beginning: sort by !is_nan (1 < 0)
partition_indices_mx =
if asc? do
EMLX.argsort(is_nan_mx, axis)
else
is_nan_mx
|> EMLX.logical_not()
|> EMLX.argsort(axis)
end

# Reorder the sorted values to move NaNs to the correct position
sorted_values_mx
|> EMLX.take_along_axis(partition_indices_mx, axis)
|> EMLX.astype(to_mlx_type(out.type))
|> to_nx(out)
end

@impl true
def argsort(out, tensor, opts) do
axis = opts[:axis]
asc? = opts[:direction] == :asc

if asc? do
tensor
|> from_nx()
|> EMLX.argsort(axis)
|> EMLX.astype(to_mlx_type(out.type))
|> to_nx(out)
else
tensor
|> from_nx()
|> EMLX.negate()
|> EMLX.argsort(axis)
|> EMLX.astype(to_mlx_type(out.type))
|> to_nx(out)
end
t_mx = from_nx(tensor)
# Get the initial sorting indices
sort_mx =
if asc? do
EMLX.argsort(t_mx, axis)
else
t_mx
|> EMLX.negate()
|> EMLX.argsort(axis)
end

# Gather values at sorted positions to identify NaNs
sorted_values_mx = EMLX.take_along_axis(t_mx, sort_mx, axis)
is_nan_mx = EMLX.is_nan(sorted_values_mx)

# Partition indices to place NaNs correctly (NaNs are treated as highest):
# - For ascending: NaNs (highest) go to end: sort by is_nan (0 < 1)
# - For descending: NaNs (highest) go to beginning: sort by !is_nan (1 < 0)
partition_indices_mx =
if asc? do
EMLX.argsort(is_nan_mx, axis)
else
is_nan_mx
|> EMLX.logical_not()
|> EMLX.argsort(axis)
end

# Reorder the sorting indices to move NaN indices to the correct position
sort_mx
|> EMLX.take_along_axis(partition_indices_mx, axis)
|> EMLX.astype(to_mlx_type(out.type))
|> to_nx(out)
end

defp maybe_upcast(%T{type: t} = left, %T{type: t} = right),
Expand Down
10 changes: 2 additions & 8 deletions test/emlx/nx_doctest_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -46,20 +46,14 @@ defmodule EMLX.Nx.DoctestTest do
]

@to_be_fixed [
:moduledoc,
# MLX sorts NaNs lowest, Nx sorts them highest
argsort: 2
:moduledoc
]

@not_supported [
reduce: 4,
window_reduce: 5,
population_count: 1,
count_leading_zeros: 1,
sort: 2,
# We do not support the same ordering for NaNs as Nx
argmin: 2,
argmax: 2
count_leading_zeros: 1
]

doctest Nx,
Expand Down