diff --git a/lib/emlx/backend.ex b/lib/emlx/backend.ex index 5566ed2..3d7f39b 100644 --- a/lib/emlx/backend.ex +++ b/lib/emlx/backend.ex @@ -590,6 +590,32 @@ defmodule EMLX.Backend do t_mx = from_nx(tensor) + # Check for NaNs in the original tensor (before any reversal) + is_nan_mx = EMLX.is_nan(t_mx) + + nan_index_mx = + if axis do + EMLX.argmax(is_nan_mx, axis, keep_axis) + else + EMLX.argmax(is_nan_mx, keep_axis) + end + + # Check if any NaN exists along the axis + has_nan_mx = + cond do + axis -> + EMLX.any(is_nan_mx, [axis], keep_axis) + + tuple_size(tensor.shape) == 0 -> + # For scalar input, is_nan_mx is already a scalar boolean + is_nan_mx + + true -> + # For full reduction over non-scalar tensors + EMLX.any(is_nan_mx, Nx.axes(tensor), keep_axis) + end + + # Apply reversal for tie_break after NaN check t_mx = if opts[:tie_break] == :high do reverse_mlx(t_mx, tensor.shape, [axis] || Nx.axes(tensor)) @@ -623,6 +649,9 @@ defmodule EMLX.Backend do result end + # Use NaN index if any NaN exists, otherwise use regular result + result = EMLX.where(has_nan_mx, nan_index_mx, result) + result |> EMLX.astype(to_mlx_type(out.type)) |> to_nx(out) @@ -1196,15 +1225,39 @@ defmodule EMLX.Backend do axis = opts[:axis] asc? = opts[:direction] == :asc - t = tensor |> from_nx() |> EMLX.sort(axis) + t_mx = from_nx(tensor) - if asc? do - to_nx(t, out) - else - t - |> to_nx(out) - |> Nx.reverse(axes: [axis]) - end + # Get the sorting indices + sort_mx = + if asc? do + EMLX.argsort(t_mx, axis) + else + t_mx + |> EMLX.negate() + |> EMLX.argsort(axis) + end + + # Gather values at sorted positions to identify NaNs + sorted_values_mx = EMLX.take_along_axis(t_mx, sort_mx, axis) + is_nan_mx = EMLX.is_nan(sorted_values_mx) + + # Partition indices to place NaNs correctly (NaNs are treated as highest): + # - For ascending: NaNs (highest) go to end: sort by is_nan (0 < 1) + # - For descending: NaNs (highest) go to beginning: sort by !is_nan (1 < 0) + partition_indices_mx = + if asc? do + EMLX.argsort(is_nan_mx, axis) + else + is_nan_mx + |> EMLX.logical_not() + |> EMLX.argsort(axis) + end + + # Reorder the sorted values to move NaNs to the correct position + sorted_values_mx + |> EMLX.take_along_axis(partition_indices_mx, axis) + |> EMLX.astype(to_mlx_type(out.type)) + |> to_nx(out) end @impl true @@ -1212,20 +1265,38 @@ defmodule EMLX.Backend do axis = opts[:axis] asc? = opts[:direction] == :asc - if asc? do - tensor - |> from_nx() - |> EMLX.argsort(axis) - |> EMLX.astype(to_mlx_type(out.type)) - |> to_nx(out) - else - tensor - |> from_nx() - |> EMLX.negate() - |> EMLX.argsort(axis) - |> EMLX.astype(to_mlx_type(out.type)) - |> to_nx(out) - end + t_mx = from_nx(tensor) + # Get the initial sorting indices + sort_mx = + if asc? do + EMLX.argsort(t_mx, axis) + else + t_mx + |> EMLX.negate() + |> EMLX.argsort(axis) + end + + # Gather values at sorted positions to identify NaNs + sorted_values_mx = EMLX.take_along_axis(t_mx, sort_mx, axis) + is_nan_mx = EMLX.is_nan(sorted_values_mx) + + # Partition indices to place NaNs correctly (NaNs are treated as highest): + # - For ascending: NaNs (highest) go to end: sort by is_nan (0 < 1) + # - For descending: NaNs (highest) go to beginning: sort by !is_nan (1 < 0) + partition_indices_mx = + if asc? do + EMLX.argsort(is_nan_mx, axis) + else + is_nan_mx + |> EMLX.logical_not() + |> EMLX.argsort(axis) + end + + # Reorder the sorting indices to move NaN indices to the correct position + sort_mx + |> EMLX.take_along_axis(partition_indices_mx, axis) + |> EMLX.astype(to_mlx_type(out.type)) + |> to_nx(out) end defp maybe_upcast(%T{type: t} = left, %T{type: t} = right), diff --git a/test/emlx/nx_doctest_test.exs b/test/emlx/nx_doctest_test.exs index f46f058..bd7ea38 100644 --- a/test/emlx/nx_doctest_test.exs +++ b/test/emlx/nx_doctest_test.exs @@ -46,20 +46,14 @@ defmodule EMLX.Nx.DoctestTest do ] @to_be_fixed [ - :moduledoc, - # MLX sorts NaNs lowest, Nx sorts them highest - argsort: 2 + :moduledoc ] @not_supported [ reduce: 4, window_reduce: 5, population_count: 1, - count_leading_zeros: 1, - sort: 2, - # We do not support the same ordering for NaNs as Nx - argmin: 2, - argmax: 2 + count_leading_zeros: 1 ] doctest Nx,