From fffc9f1da0a3b78730efb4047d650c1870fe01e9 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Wed, 22 Oct 2025 15:00:00 -0300 Subject: [PATCH 1/5] feat: argsort nans --- example.exs | 26 ++++++++++++++++++++ example_debug.exs | 37 ++++++++++++++++++++++++++++ example_no_stream.exs | 33 +++++++++++++++++++++++++ lib/emlx/backend.ex | 46 ++++++++++++++++++++++++----------- scripts/test_nifcall.exs | 28 +++++++++++++++++++++ scripts/test_simple.exs | 20 +++++++++++++++ test/emlx/nx_doctest_test.exs | 4 +-- 7 files changed, 177 insertions(+), 17 deletions(-) create mode 100644 example.exs create mode 100644 example_debug.exs create mode 100644 example_no_stream.exs create mode 100644 scripts/test_nifcall.exs create mode 100644 scripts/test_simple.exs diff --git a/example.exs b/example.exs new file mode 100644 index 0000000..bc47a8d --- /dev/null +++ b/example.exs @@ -0,0 +1,26 @@ +Mix.install([ + {:bumblebee, github: "elixir-nx/bumblebee", override: true}, + {:emlx, path: __DIR__} +], system_env: %{"LIBMLX_ENABLE_DEBUG" => "true"}, force: true) + + Nx.global_default_backend({EMLX.Backend, device: :gpu}) + + Nx.Defn.default_options(compiler: EMLX) + + {:ok, model_info} = Bumblebee.load_model({:hf, "openai-community/gpt2"}) + {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "openai-community/gpt2"}) + + {:ok, generation_config} = + Bumblebee.load_generation_config({:hf, "openai-community/gpt2"}) + + generation_config = Bumblebee.configure(generation_config, max_new_tokens: 20) + + serving = + Bumblebee.Text.generation(model_info, tokenizer, generation_config, + compile: [batch_size: 1, sequence_length: 100], + stream: true + ) + +Nx.Serving.run(serving, "What is the capital of queensland?") +|> Enum.to_list() +|> IO.puts() diff --git a/example_debug.exs b/example_debug.exs new file mode 100644 index 0000000..4b7f594 --- /dev/null +++ b/example_debug.exs @@ -0,0 +1,37 @@ +Mix.install([ + {:bumblebee, github: "elixir-nx/bumblebee", override: true}, + {:emlx, path: __DIR__} +], system_env: %{"LIBMLX_ENABLE_DEBUG" => "true"}) + +IO.puts("1. Setting backend...") +Nx.global_default_backend({EMLX.Backend, device: :gpu}) +Nx.Defn.default_options(compiler: EMLX) + +IO.puts("2. Loading model...") +{:ok, model_info} = Bumblebee.load_model({:hf, "openai-community/gpt2"}) + +IO.puts("3. Loading tokenizer...") +{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "openai-community/gpt2"}) + +IO.puts("4. Loading generation config...") +{:ok, generation_config} = + Bumblebee.load_generation_config({:hf, "openai-community/gpt2"}) + +generation_config = Bumblebee.configure(generation_config, max_new_tokens: 20) + +IO.puts("5. Creating serving...") +serving = + Bumblebee.Text.generation(model_info, tokenizer, generation_config, + compile: [batch_size: 1, sequence_length: 100], + stream: true + ) + +IO.puts("6. Running serving...") +result = Nx.Serving.run(serving, "What is the capital of queensland?") + +dbg(result) + +IO.puts("7. Collecting results...") +result +|> Enum.to_list() +|> IO.inspect(label: "Final result") diff --git a/example_no_stream.exs b/example_no_stream.exs new file mode 100644 index 0000000..7074a17 --- /dev/null +++ b/example_no_stream.exs @@ -0,0 +1,33 @@ +Mix.install([ + {:bumblebee, github: "elixir-nx/bumblebee", override: true}, + {:emlx, path: __DIR__} +], system_env: %{"LIBMLX_ENABLE_DEBUG" => "true"}, force: true) + +IO.puts("1. Setting backend...") +Nx.global_default_backend({EMLX.Backend, device: :gpu}) +Nx.Defn.default_options(compiler: EMLX) + +IO.puts("2. Loading model...") +{:ok, model_info} = Bumblebee.load_model({:hf, "openai-community/gpt2"}) + +IO.puts("3. Loading tokenizer...") +{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "openai-community/gpt2"}) + +IO.puts("4. Loading generation config...") +{:ok, generation_config} = + Bumblebee.load_generation_config({:hf, "openai-community/gpt2"}) + +generation_config = Bumblebee.configure(generation_config, max_new_tokens: 20) + +IO.puts("5. Creating serving (NO STREAM)...") +serving = + Bumblebee.Text.generation(model_info, tokenizer, generation_config, + compile: [batch_size: 1, sequence_length: 100], + stream: false + ) + +IO.puts("6. Running serving...") +result = Nx.Serving.run(serving, "What is the capital of queensland?") + +IO.puts("7. Got result!") +IO.inspect(result, label: "Final result") diff --git a/lib/emlx/backend.ex b/lib/emlx/backend.ex index 5566ed2..e33f569 100644 --- a/lib/emlx/backend.ex +++ b/lib/emlx/backend.ex @@ -1212,20 +1212,38 @@ defmodule EMLX.Backend do axis = opts[:axis] asc? = opts[:direction] == :asc - if asc? do - tensor - |> from_nx() - |> EMLX.argsort(axis) - |> EMLX.astype(to_mlx_type(out.type)) - |> to_nx(out) - else - tensor - |> from_nx() - |> EMLX.negate() - |> EMLX.argsort(axis) - |> EMLX.astype(to_mlx_type(out.type)) - |> to_nx(out) - end + t_mx = from_nx(tensor) + # Get the initial sorting indices + sort_mx = + if asc? do + EMLX.argsort(t_mx, axis) + else + t_mx + |> EMLX.negate() + |> EMLX.argsort(axis) + end + + # Gather values at sorted positions to identify NaNs + sorted_values_mx = EMLX.take_along_axis(t_mx, sort_mx, axis) + is_nan_mx = EMLX.is_nan(sorted_values_mx) + + # Partition indices to place NaNs correctly (NaNs are treated as highest): + # - For ascending: NaNs (highest) go to end: sort by is_nan (0 < 1) + # - For descending: NaNs (highest) go to beginning: sort by !is_nan (1 < 0) + partition_indices_mx = + if asc? do + EMLX.argsort(is_nan_mx, axis) + else + is_nan_mx + |> EMLX.logical_not() + |> EMLX.argsort(axis) + end + + # Reorder the sorting indices to move NaN indices to the correct position + sort_mx + |> EMLX.take_along_axis(partition_indices_mx, axis) + |> EMLX.astype(to_mlx_type(out.type)) + |> to_nx(out) end defp maybe_upcast(%T{type: t} = left, %T{type: t} = right), diff --git a/scripts/test_nifcall.exs b/scripts/test_nifcall.exs new file mode 100644 index 0000000..648809d --- /dev/null +++ b/scripts/test_nifcall.exs @@ -0,0 +1,28 @@ +IO.puts("1. Setting backend...") +Nx.global_default_backend({EMLX.Backend, device: :cpu}) +Nx.Defn.default_options(compiler: EMLX) + +IO.puts("2. Loading model...") +{:ok, model_info} = Bumblebee.load_model({:hf, "openai-community/gpt2"}) + +IO.puts("3. Loading tokenizer...") +{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "openai-community/gpt2"}) + +IO.puts("4. Loading generation config...") +{:ok, generation_config} = + Bumblebee.load_generation_config({:hf, "openai-community/gpt2"}) + +generation_config = Bumblebee.configure(generation_config, max_new_tokens: 20) + +IO.puts("5. Creating serving (NO STREAM)...") +serving = + Bumblebee.Text.generation(model_info, tokenizer, generation_config, + compile: [batch_size: 1, sequence_length: 100], + stream: false + ) + +IO.puts("6. Running serving...") +result = Nx.Serving.run(serving, "What is the capital of queensland?") + +IO.puts("7. Got result!") +IO.inspect(result, label: "Final result") diff --git a/scripts/test_simple.exs b/scripts/test_simple.exs new file mode 100644 index 0000000..6ba9429 --- /dev/null +++ b/scripts/test_simple.exs @@ -0,0 +1,20 @@ +IO.puts("Setting backend...") +Nx.global_default_backend({EMLX.Backend, device: :cpu}) +Nx.Defn.default_options(compiler: EMLX) + +IO.puts("Defining simple defn...") +defmodule SimpleTest do + import Nx.Defn + + defn add_one(x) do + Nx.add(x, 1) + end +end + +IO.puts("Creating tensor...") +x = Nx.tensor([1, 2, 3]) + +IO.puts("Calling defn (this will trigger compilation)...") +result = SimpleTest.add_one(x) + +IO.puts("Success! Result: #{inspect(result)}") diff --git a/test/emlx/nx_doctest_test.exs b/test/emlx/nx_doctest_test.exs index f46f058..0dcc473 100644 --- a/test/emlx/nx_doctest_test.exs +++ b/test/emlx/nx_doctest_test.exs @@ -46,9 +46,7 @@ defmodule EMLX.Nx.DoctestTest do ] @to_be_fixed [ - :moduledoc, - # MLX sorts NaNs lowest, Nx sorts them highest - argsort: 2 + :moduledoc ] @not_supported [ From ea0478ea4c109763a2ea99858bd33e7407ccaeb7 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Wed, 22 Oct 2025 15:03:19 -0300 Subject: [PATCH 2/5] feat: sort nans --- lib/emlx/backend.ex | 40 ++++++++++++++++++++++++++++------- test/emlx/nx_doctest_test.exs | 2 -- 2 files changed, 32 insertions(+), 10 deletions(-) diff --git a/lib/emlx/backend.ex b/lib/emlx/backend.ex index e33f569..618d2e9 100644 --- a/lib/emlx/backend.ex +++ b/lib/emlx/backend.ex @@ -1196,15 +1196,39 @@ defmodule EMLX.Backend do axis = opts[:axis] asc? = opts[:direction] == :asc - t = tensor |> from_nx() |> EMLX.sort(axis) + t_mx = from_nx(tensor) - if asc? do - to_nx(t, out) - else - t - |> to_nx(out) - |> Nx.reverse(axes: [axis]) - end + # Get the sorting indices + sort_mx = + if asc? do + EMLX.argsort(t_mx, axis) + else + t_mx + |> EMLX.negate() + |> EMLX.argsort(axis) + end + + # Gather values at sorted positions to identify NaNs + sorted_values_mx = EMLX.take_along_axis(t_mx, sort_mx, axis) + is_nan_mx = EMLX.is_nan(sorted_values_mx) + + # Partition indices to place NaNs correctly (NaNs are treated as highest): + # - For ascending: NaNs (highest) go to end: sort by is_nan (0 < 1) + # - For descending: NaNs (highest) go to beginning: sort by !is_nan (1 < 0) + partition_indices_mx = + if asc? do + EMLX.argsort(is_nan_mx, axis) + else + is_nan_mx + |> EMLX.logical_not() + |> EMLX.argsort(axis) + end + + # Reorder the sorted values to move NaNs to the correct position + sorted_values_mx + |> EMLX.take_along_axis(partition_indices_mx, axis) + |> EMLX.astype(to_mlx_type(out.type)) + |> to_nx(out) end @impl true diff --git a/test/emlx/nx_doctest_test.exs b/test/emlx/nx_doctest_test.exs index 0dcc473..c3dfa32 100644 --- a/test/emlx/nx_doctest_test.exs +++ b/test/emlx/nx_doctest_test.exs @@ -54,8 +54,6 @@ defmodule EMLX.Nx.DoctestTest do window_reduce: 5, population_count: 1, count_leading_zeros: 1, - sort: 2, - # We do not support the same ordering for NaNs as Nx argmin: 2, argmax: 2 ] From 0674e7c318bf9487ba9127bf60d71b509480462e Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Wed, 22 Oct 2025 15:16:13 -0300 Subject: [PATCH 3/5] feat: support nans in argmin argmax --- lib/emlx/backend.ex | 29 +++++++++++++++++++++++++++++ test/emlx/nx_doctest_test.exs | 4 +--- 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/lib/emlx/backend.ex b/lib/emlx/backend.ex index 618d2e9..3d7f39b 100644 --- a/lib/emlx/backend.ex +++ b/lib/emlx/backend.ex @@ -590,6 +590,32 @@ defmodule EMLX.Backend do t_mx = from_nx(tensor) + # Check for NaNs in the original tensor (before any reversal) + is_nan_mx = EMLX.is_nan(t_mx) + + nan_index_mx = + if axis do + EMLX.argmax(is_nan_mx, axis, keep_axis) + else + EMLX.argmax(is_nan_mx, keep_axis) + end + + # Check if any NaN exists along the axis + has_nan_mx = + cond do + axis -> + EMLX.any(is_nan_mx, [axis], keep_axis) + + tuple_size(tensor.shape) == 0 -> + # For scalar input, is_nan_mx is already a scalar boolean + is_nan_mx + + true -> + # For full reduction over non-scalar tensors + EMLX.any(is_nan_mx, Nx.axes(tensor), keep_axis) + end + + # Apply reversal for tie_break after NaN check t_mx = if opts[:tie_break] == :high do reverse_mlx(t_mx, tensor.shape, [axis] || Nx.axes(tensor)) @@ -623,6 +649,9 @@ defmodule EMLX.Backend do result end + # Use NaN index if any NaN exists, otherwise use regular result + result = EMLX.where(has_nan_mx, nan_index_mx, result) + result |> EMLX.astype(to_mlx_type(out.type)) |> to_nx(out) diff --git a/test/emlx/nx_doctest_test.exs b/test/emlx/nx_doctest_test.exs index c3dfa32..bd7ea38 100644 --- a/test/emlx/nx_doctest_test.exs +++ b/test/emlx/nx_doctest_test.exs @@ -53,9 +53,7 @@ defmodule EMLX.Nx.DoctestTest do reduce: 4, window_reduce: 5, population_count: 1, - count_leading_zeros: 1, - argmin: 2, - argmax: 2 + count_leading_zeros: 1 ] doctest Nx, From d6669c6459d47dbf8ca192eaea41bb7eeecddb7d Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Wed, 22 Oct 2025 15:22:56 -0300 Subject: [PATCH 4/5] chore: revert stray files --- example.exs | 26 -------------------------- example_debug.exs | 37 ------------------------------------- example_no_stream.exs | 33 --------------------------------- 3 files changed, 96 deletions(-) delete mode 100644 example.exs delete mode 100644 example_debug.exs delete mode 100644 example_no_stream.exs diff --git a/example.exs b/example.exs deleted file mode 100644 index bc47a8d..0000000 --- a/example.exs +++ /dev/null @@ -1,26 +0,0 @@ -Mix.install([ - {:bumblebee, github: "elixir-nx/bumblebee", override: true}, - {:emlx, path: __DIR__} -], system_env: %{"LIBMLX_ENABLE_DEBUG" => "true"}, force: true) - - Nx.global_default_backend({EMLX.Backend, device: :gpu}) - - Nx.Defn.default_options(compiler: EMLX) - - {:ok, model_info} = Bumblebee.load_model({:hf, "openai-community/gpt2"}) - {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "openai-community/gpt2"}) - - {:ok, generation_config} = - Bumblebee.load_generation_config({:hf, "openai-community/gpt2"}) - - generation_config = Bumblebee.configure(generation_config, max_new_tokens: 20) - - serving = - Bumblebee.Text.generation(model_info, tokenizer, generation_config, - compile: [batch_size: 1, sequence_length: 100], - stream: true - ) - -Nx.Serving.run(serving, "What is the capital of queensland?") -|> Enum.to_list() -|> IO.puts() diff --git a/example_debug.exs b/example_debug.exs deleted file mode 100644 index 4b7f594..0000000 --- a/example_debug.exs +++ /dev/null @@ -1,37 +0,0 @@ -Mix.install([ - {:bumblebee, github: "elixir-nx/bumblebee", override: true}, - {:emlx, path: __DIR__} -], system_env: %{"LIBMLX_ENABLE_DEBUG" => "true"}) - -IO.puts("1. Setting backend...") -Nx.global_default_backend({EMLX.Backend, device: :gpu}) -Nx.Defn.default_options(compiler: EMLX) - -IO.puts("2. Loading model...") -{:ok, model_info} = Bumblebee.load_model({:hf, "openai-community/gpt2"}) - -IO.puts("3. Loading tokenizer...") -{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "openai-community/gpt2"}) - -IO.puts("4. Loading generation config...") -{:ok, generation_config} = - Bumblebee.load_generation_config({:hf, "openai-community/gpt2"}) - -generation_config = Bumblebee.configure(generation_config, max_new_tokens: 20) - -IO.puts("5. Creating serving...") -serving = - Bumblebee.Text.generation(model_info, tokenizer, generation_config, - compile: [batch_size: 1, sequence_length: 100], - stream: true - ) - -IO.puts("6. Running serving...") -result = Nx.Serving.run(serving, "What is the capital of queensland?") - -dbg(result) - -IO.puts("7. Collecting results...") -result -|> Enum.to_list() -|> IO.inspect(label: "Final result") diff --git a/example_no_stream.exs b/example_no_stream.exs deleted file mode 100644 index 7074a17..0000000 --- a/example_no_stream.exs +++ /dev/null @@ -1,33 +0,0 @@ -Mix.install([ - {:bumblebee, github: "elixir-nx/bumblebee", override: true}, - {:emlx, path: __DIR__} -], system_env: %{"LIBMLX_ENABLE_DEBUG" => "true"}, force: true) - -IO.puts("1. Setting backend...") -Nx.global_default_backend({EMLX.Backend, device: :gpu}) -Nx.Defn.default_options(compiler: EMLX) - -IO.puts("2. Loading model...") -{:ok, model_info} = Bumblebee.load_model({:hf, "openai-community/gpt2"}) - -IO.puts("3. Loading tokenizer...") -{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "openai-community/gpt2"}) - -IO.puts("4. Loading generation config...") -{:ok, generation_config} = - Bumblebee.load_generation_config({:hf, "openai-community/gpt2"}) - -generation_config = Bumblebee.configure(generation_config, max_new_tokens: 20) - -IO.puts("5. Creating serving (NO STREAM)...") -serving = - Bumblebee.Text.generation(model_info, tokenizer, generation_config, - compile: [batch_size: 1, sequence_length: 100], - stream: false - ) - -IO.puts("6. Running serving...") -result = Nx.Serving.run(serving, "What is the capital of queensland?") - -IO.puts("7. Got result!") -IO.inspect(result, label: "Final result") From ad606ead8677059cc442e6d02efd94190335f2e0 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Wed, 22 Oct 2025 15:23:11 -0300 Subject: [PATCH 5/5] --version --- scripts/test_nifcall.exs | 28 ---------------------------- scripts/test_simple.exs | 20 -------------------- 2 files changed, 48 deletions(-) delete mode 100644 scripts/test_nifcall.exs delete mode 100644 scripts/test_simple.exs diff --git a/scripts/test_nifcall.exs b/scripts/test_nifcall.exs deleted file mode 100644 index 648809d..0000000 --- a/scripts/test_nifcall.exs +++ /dev/null @@ -1,28 +0,0 @@ -IO.puts("1. Setting backend...") -Nx.global_default_backend({EMLX.Backend, device: :cpu}) -Nx.Defn.default_options(compiler: EMLX) - -IO.puts("2. Loading model...") -{:ok, model_info} = Bumblebee.load_model({:hf, "openai-community/gpt2"}) - -IO.puts("3. Loading tokenizer...") -{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "openai-community/gpt2"}) - -IO.puts("4. Loading generation config...") -{:ok, generation_config} = - Bumblebee.load_generation_config({:hf, "openai-community/gpt2"}) - -generation_config = Bumblebee.configure(generation_config, max_new_tokens: 20) - -IO.puts("5. Creating serving (NO STREAM)...") -serving = - Bumblebee.Text.generation(model_info, tokenizer, generation_config, - compile: [batch_size: 1, sequence_length: 100], - stream: false - ) - -IO.puts("6. Running serving...") -result = Nx.Serving.run(serving, "What is the capital of queensland?") - -IO.puts("7. Got result!") -IO.inspect(result, label: "Final result") diff --git a/scripts/test_simple.exs b/scripts/test_simple.exs deleted file mode 100644 index 6ba9429..0000000 --- a/scripts/test_simple.exs +++ /dev/null @@ -1,20 +0,0 @@ -IO.puts("Setting backend...") -Nx.global_default_backend({EMLX.Backend, device: :cpu}) -Nx.Defn.default_options(compiler: EMLX) - -IO.puts("Defining simple defn...") -defmodule SimpleTest do - import Nx.Defn - - defn add_one(x) do - Nx.add(x, 1) - end -end - -IO.puts("Creating tensor...") -x = Nx.tensor([1, 2, 3]) - -IO.puts("Calling defn (this will trigger compilation)...") -result = SimpleTest.add_one(x) - -IO.puts("Success! Result: #{inspect(result)}")