diff --git a/nx/guides/advanced/backend_comparison.livemd b/nx/guides/advanced/backend_comparison.livemd new file mode 100644 index 0000000000..4d6f62de95 --- /dev/null +++ b/nx/guides/advanced/backend_comparison.livemd @@ -0,0 +1,257 @@ +# Backend Comparison with Evaluator + +```elixir +Mix.install([ + # {:nx, "~> 0.7"} + {:nx, path: Path.join(__DIR__, "../..")}, + {:mimic, "~> 1.7"} +]) +``` + +## Introduction + +This guide demonstrates how to use `Nx.Defn.Evaluator` to compare the outputs of different backends. This is particularly useful for: + +* **Testing backend implementations** - Ensure different backends produce consistent results +* **Debugging numerical differences** - Identify where backends diverge +* **Validating optimizations** - Confirm that optimized backends match reference implementations + +The evaluator's `debug_options` feature saves each node's computation as an executable `.exs` file, making it easy to reconstruct and compare tensors across backends. + +## How It Works + +When you enable `debug_options` with a `save_path`, the evaluator: + +1. Saves each computation node as a separate `.exs` file +2. Serializes tensors as executable `Nx.from_binary()` calls +3. Preserves backend information, shape, type, and names +4. Creates files that can be directly executed to reconstruct tensors + +This allows you to: + +* Run the same computation with different backends +* Compare corresponding node outputs — in this guide we'll be using `Nx.all_close/2` +* Identify exactly where backends differ + +## Simulating Backend Differences with Mimic + +For demonstration purposes, instead of defining an new incorrect backend, we can use `Mimic.stub/3` to override individual callbacks on `Nx.BinaryBackend`. We use `Mimic.copy(Nx.BinaryBackend)` so it can be stubbed correctly. Then `add`, `multiply`, and `divide` are swapped to force a divergence in implementation. + +```elixir + +Mimic.copy(Nx.BinaryBackend) + +defmodule BackendSwaps do + def enable! do + Mimic.stub(Nx.BinaryBackend, :add, fn out, left, right -> + Nx.BinaryBackend.subtract(out, left, right) + end) + + Mimic.stub(Nx.BinaryBackend, :multiply, fn out, left, right -> + Nx.BinaryBackend.add(out, left, right) + end) + + Mimic.stub(Nx.BinaryBackend, :divide, fn out, left, right -> + Nx.BinaryBackend.add(out, left, right) + end) + end + + def restore! do + for fun <- [:add, :multiply, :divide] do + Mimic.stub(Nx.BinaryBackend, fun, fn out, left, right -> + Mimic.call_original(Nx.BinaryBackend, fun, [out, left, right]) + end) + end + end +end +``` + +## Example: Simple Computation + +Let's define a simple computation to compare across backends: + +```elixir +defmodule SimpleComputation do + import Nx.Defn + + defn compute(x, y) do + a = Nx.add(x, y) + b = Nx.multiply(a, 2) + Nx.divide(b, 3) + end +end +``` + +### Prepare Test Data + +```elixir +# Create some test input +x = Nx.tensor([1.0, 2.0, 3.0, 4.0]) +y = Nx.tensor([0.5, 1.5, 2.5, 3.5]) + +IO.puts("Input tensors:") +IO.inspect(x, label: "x") +IO.inspect(y, label: "y") +``` + +## Preparing the function for comparing + +In order to ensure the same `id` for each node in the graph while our function traverses it on both backends, we need to use `Nx.Defn.debug_expr/1` to pre-compile `SimpleComputation.compute/2`. + +This is a trick to make sure the same expression is passed on both `Nx.Defn.jit/2` calls and should not be used liberally elsewhere. + +```elixir +expr = Nx.Defn.debug_expr(&SimpleComputation.compute/2).(x, y) + +precompiled = fn _x, _y -> expr end +``` + +## Running with Backend A + +Let's run our computation with the first backend (BinaryBackend in this example, but could be any backend): + +```elixir +# Clean up and create output directory +File.rm_rf!("/tmp/backend_a") +File.mkdir_p!("/tmp/backend_a") + +# Run computation with debug output enabled +result_a = Nx.Defn.jit( + precompiled, + compiler: Nx.Defn.Evaluator, + debug_options: [save_path: "/tmp/backend_a"] +).(x, y) + +IO.puts("\n✅ Backend A completed") +IO.inspect(result_a, label: "Result A") +IO.puts("Backend: #{inspect(result_a.data.__struct__)}") + +# Show what files were generated +files_a = File.ls!("/tmp/backend_a") +IO.puts("\nGenerated #{length(files_a)} node files:") +Enum.each(files_a, &IO.puts(" - #{&1}")) +``` + +## Examining the Output Files + +Let's look at what the `.exs` files contain: + +```elixir +# Read and display one of the generated files +example_file = File.ls!("/tmp/backend_a") |> List.last() +content = File.read!(Path.join("/tmp/backend_a", example_file)) + +IO.puts("=== Content of #{example_file} ===\n") +IO.puts(content) +``` + +Notice the format: + +* **Node ID** - Unique identifier for this computation node +* **Operation** - The operation being performed (e.g., `:add`, `:multiply`, `:parameter`) +* **Arguments** - List containing parameters and tensors as strings +* **Result** - Executable code that reconstructs the output tensor from binary + +### Verifying Executability + +Each `.exs` file is a self-contained Elixir script, so you can execute it directly: + +```elixir +example_path = Path.join("/tmp/backend_a", example_file) +Code.eval_file(example_path) + +``` + +## Running with Backend B + +Now let's run the same computation with the swapped operations. We leave `Nx` on its default backend, but temporarily enable the Mimic stubs so the evaluator will capture the modified behaviour. + +In practice, this would be where we call a totally different backend to compare against the reference. + +```elixir +# Clean up and create output directory for backend B +File.rm_rf!("/tmp/backend_b") +File.mkdir_p!("/tmp/backend_b") + +BackendSwaps.enable!() + +result_b = + Nx.Defn.jit( + precompiled, + compiler: Nx.Defn.Evaluator, + debug_options: [save_path: "/tmp/backend_b"] + ).(x, y) + +IO.puts("✅ Backend B completed") +IO.inspect(result_b, label: "Result B") +IO.puts("Backend: #{inspect(result_b.data.__struct__)}") + +files_b = File.ls!("/tmp/backend_b") +IO.puts("\nGenerated #{length(files_b)} node files") + +BackendSwaps.restore!() +``` + +## Comparing the Outputs + +Now we inspect the generated `.exs` files, compare every node, and then summarise matches and mismatches. + +```elixir +IO.puts("Comparing outputs from .exs files") +IO.puts(String.duplicate("-", 60)) + +files_a = File.ls!("/tmp/backend_a") |> Enum.sort() +files_b = File.ls!("/tmp/backend_b") |> Enum.sort() + +IO.puts("Backend A generated #{length(files_a)} files") +IO.puts("Backend B generated #{length(files_b)} files") + +comparison = + Enum.zip_with(files_a, files_b, fn file_a, file_b -> + {tensor_a, bindings_a} = Code.eval_file(Path.join("/tmp/backend_a", file_a)) + {tensor_b, _bindings_b} = Code.eval_file(Path.join("/tmp/backend_b", file_b)) + + op = Keyword.get(bindings_a, :operation) + match? = Nx.all_close(tensor_a, tensor_b, atol: 1.0e-6) |> Nx.to_number() == 1 + + %{ + operation: op, + tensor_a: tensor_a, + tensor_b: tensor_b, + match?: match?, + file_a: file_a, + file_b: file_b + } + end) + +{matches, mismatches} = Enum.split_with(comparison, & &1.match?) + +IO.puts("\n Summary:") +IO.puts(String.duplicate("-", 60)) + +if Enum.any?(matches) do + IO.puts("✅ Matched nodes (#{length(matches)}):") + + Enum.each(matches, fn match -> + IO.puts(" - #{match.operation} (#{match.file_a})") + end) +else + IO.puts("\n❌ No nodes match!") +end + +if Enum.any?(mismatches) do + IO.puts("\n❌ Mismatched nodes (#{length(mismatches)}):") + + Enum.each(mismatches, fn mismatch -> + IO.puts("- #{mismatch.operation} (#{mismatch.file_a})") + IO.puts("Backend A") + IO.inspect(mismatch.tensor_a) + IO.puts("Backend B") + IO.inspect(mismatch.tensor_b) + end) +else + IO.puts("\n✅ All nodes match!") +end +``` + +With Mimic stubs in place, the evaluator’s debug artifacts clearly show where the divergence starts, making it straightforward to pinpoint inconsistent nodes between implementations, while the summary highlights both the matching and mismatching nodes. diff --git a/nx/lib/nx/defn/evaluator.ex b/nx/lib/nx/defn/evaluator.ex index d028ec6a63..936ec1f4b0 100644 --- a/nx/lib/nx/defn/evaluator.ex +++ b/nx/lib/nx/defn/evaluator.ex @@ -535,36 +535,112 @@ defmodule Nx.Defn.Evaluator do end defp format_node_info(%Expr{id: id, op: op, args: args}, res, inspect_limit) do - args = - Enum.map( - args, - &inspect(&1, custom_options: [print_id: true], limit: inspect_limit) - ) + id_str = ref_to_string(id) - result_str = inspect(res, limit: inspect_limit) + inspect_opts = + case inspect_limit do + nil -> [custom_options: [print_id: true]] + limit -> [custom_options: [print_id: true], limit: limit] + end + + args_code = + args + |> Enum.map(fn arg -> + inspected = + arg + |> inspect(inspect_opts) + |> String.trim() + + " #{inspect(inspected)}" + end) + |> Enum.join(",\n") - import Inspect.Algebra + # Format result as serialized tensor + result_code = "result = #{serialize_tensor(res)}" - id = :erlang.ref_to_list(id) |> List.to_string() |> String.replace(["#Ref<", ">"], "") + """ + node_id = "#{id_str}" + operation = #{inspect(op)} + + args = [ + #{args_code} + ] + + # Result: + #{result_code} + """ + end - concat([ - "Node ID: #{id}", - line(), - "Operation: #{inspect(op)}", - line(), - concat(Enum.intersperse(["Args:" | args], line())), - line(), - "Result:", - line(), - result_str - ]) - |> format(100) - |> IO.iodata_to_binary() + defp serialize_tensor(%Nx.Tensor{data: %Expr{id: id}} = _tensor) when is_reference(id) do + # This is an unevaluated expression, not a concrete tensor + # Show the Node ID so users can find which file contains this tensor + id_str = :erlang.ref_to_list(id) |> List.to_string() |> String.replace(["#Ref<", ">"], "") + "# See Node ID: #{id_str}" + end + + defp serialize_tensor(%Nx.Tensor{data: %Expr{}} = _tensor) do + # Expression without a valid reference ID + "# " + end + + defp serialize_tensor(%Nx.Tensor{} = tensor) do + # Get the backend information from the tensor's data + {backend, backend_opts} = + case tensor.data do + %backend_mod{} -> {backend_mod, []} + _ -> Nx.default_backend() + end + + # Convert tensor to binary and get metadata + binary = Nx.to_binary(tensor) + type = tensor.type + shape = tensor.shape + names = tensor.names + + # Format the binary as a binary literal + binary_str = inspect(binary, limit: :infinity) + + # Build the executable Nx code + backend_str = "{#{inspect(backend)}, #{inspect(backend_opts)}}" + + code = "Nx.from_binary(#{binary_str}, #{inspect(type)}, backend: #{backend_str})" + + # Add reshape if needed (non-scalar) + code = + if shape != {} do + shape_str = inspect(shape) + code <> " |> Nx.reshape(#{shape_str})" + else + code + end + + # Add rename if there are non-nil names + code = + if Enum.any?(names, fn name -> not is_nil(name) end) do + names_list = inspect(names) + code <> " |> Nx.rename(#{names_list})" + else + code + end + + code + end + + defp serialize_tensor(other) do + # For non-tensor values (numbers, tuples, etc.) + inspect(other) end defp save_node_info_to_file(save_path, id, node_info) do - sanitized_id = inspect(id) |> String.replace(~r/[^a-zA-Z0-9_]/, "_") - file = Path.join(save_path, "node_#{sanitized_id}.txt") + sanitized_id = id |> ref_to_string() |> String.replace(".", "_") + file = Path.join(save_path, "node_#{sanitized_id}.exs") File.write!(file, node_info) end + + defp ref_to_string(id) when is_reference(id) do + id + |> :erlang.ref_to_list() + |> List.to_string() + |> String.replace(["#Ref<", ">"], "") + end end diff --git a/nx/mix.exs b/nx/mix.exs index 4bfd81217b..95b4b395e7 100644 --- a/nx/mix.exs +++ b/nx/mix.exs @@ -68,6 +68,8 @@ defmodule Nx.MixProject do "guides/advanced/vectorization.livemd", "guides/advanced/aggregation.livemd", "guides/advanced/automatic_differentiation.livemd", + "guides/advanced/backend_comparison.livemd", + "guides/advanced/complex_fft.livemd", "guides/exercises/exercises-1-20.livemd" ], skip_undefined_reference_warnings_on: ["CHANGELOG.md"], diff --git a/nx/test/nx/defn/evaluator_test.exs b/nx/test/nx/defn/evaluator_test.exs index 97e6313d19..511d0d3a64 100644 --- a/nx/test/nx/defn/evaluator_test.exs +++ b/nx/test/nx/defn/evaluator_test.exs @@ -775,7 +775,7 @@ defmodule Nx.Defn.EvaluatorTest do opts = [compiler: Nx.Defn.Evaluator, debug_options: [inspect_limit: 5]] output = ExUnit.CaptureIO.capture_io(fn -> debug_test_fun(x, y, opts) end) - node_id_regex = ~r/Node ID: (.*)/ + node_id_regex = ~r/node_id = \"(.*)\"/ assert [id0, id1, id2, id3, id4] = Regex.scan(node_id_regex, output, capture: :all_but_first) @@ -788,90 +788,60 @@ defmodule Nx.Defn.EvaluatorTest do |> String.replace(id3, "::id3::") |> String.replace(id4, "::id4::") - assert output == """ - Node ID: ::id0:: - Operation: :parameter - Args: - 0 - Result: - #Nx.Tensor< - s32[2] - [1, 2] - > - Node ID: ::id1:: - Operation: :parameter - Args: - 1 - Result: - #Nx.Tensor< - s32[2] - [3, 4] - > - Node ID: ::id2:: - Operation: :add - Args: - #Nx.Tensor< - s32[2] - \s\s - Nx.Defn.Expr<::id0::> - parameter a:0 s32[2] - > - #Nx.Tensor< - s32[2] - \s\s - Nx.Defn.Expr<::id1::> - parameter a:1 s32[2] - > - Result: - #Nx.Tensor< - s32[2] - [4, 6] - > - Node ID: ::id3:: - Operation: :multiply - Args: - #Nx.Tensor< - s32 - \s\s - Nx.Defn.Expr - 2 - > - #Nx.Tensor< - s32[2] - \s\s - Nx.Defn.Expr<::id2::> - parameter a:0 s32[2] - parameter b:1 s32[2] - c = add a, b s32[2] - > - Result: - #Nx.Tensor< - s32[2] - [8, 12] - > - Node ID: ::id4:: - Operation: :subtract - Args: - #Nx.Tensor< - s32[2] - \s\s - Nx.Defn.Expr<::id3::> - parameter a:0 s32[2] - parameter b:1 s32[2] - c = add a, b s32[2] - d = multiply 2, c s32[2] - > - #Nx.Tensor< - s32 - \s\s - Nx.Defn.Expr - 1 - > - Result: - #Nx.Tensor< - s32[2] - [7, 11] - > + assert output == ~S""" + node_id = "::id0::" + operation = :parameter + + args = [ + "0" + ] + + # Result: + result = Nx.from_binary(<<1, 0, 0, 0, 2, 0, 0, 0>>, {:s, 32}, backend: {Nx.BinaryBackend, []}) |> Nx.reshape({2}) + + node_id = "::id1::" + operation = :parameter + + args = [ + "1" + ] + + # Result: + result = Nx.from_binary(<<3, 0, 0, 0, 4, 0, 0, 0>>, {:s, 32}, backend: {Nx.BinaryBackend, []}) |> Nx.reshape({2}) + + node_id = "::id2::" + operation = :add + + args = [ + "#Nx.Tensor<\n s32[2]\n \n Nx.Defn.Expr<::id0::>\n parameter a:0 s32[2]\n>", + "#Nx.Tensor<\n s32[2]\n \n Nx.Defn.Expr<::id1::>\n parameter a:1 s32[2]\n>" + ] + + # Result: + result = Nx.from_binary(<<4, 0, 0, 0, 6, 0, 0, 0>>, {:s, 32}, backend: {Nx.BinaryBackend, []}) |> Nx.reshape({2}) + + node_id = "::id3::" + operation = :multiply + + args = [ + "#Nx.Tensor<\n s32\n \n Nx.Defn.Expr\n 2\n>", + "#Nx.Tensor<\n s32[2]\n \n Nx.Defn.Expr<::id2::>\n parameter a:0 s32[2]\n parameter b:1 s32[2]\n c = add a, b s32[2]\n>" + ] + + # Result: + result = Nx.from_binary(<<8, 0, 0, 0, 12, 0, 0, 0>>, {:s, 32}, backend: {Nx.BinaryBackend, []}) |> Nx.reshape({2}) + + node_id = "::id4::" + operation = :subtract + + args = [ + "#Nx.Tensor<\n s32[2]\n \n Nx.Defn.Expr<::id3::>\n parameter a:0 s32[2]\n parameter b:1 s32[2]\n c = add a, b s32[2]\n d = multiply 2, c s32[2]\n>", + "#Nx.Tensor<\n s32\n \n Nx.Defn.Expr\n 1\n>" + ] + + # Result: + result = Nx.from_binary(<<7, 0, 0, 0, 11, 0, 0, 0>>, {:s, 32}, backend: {Nx.BinaryBackend, []}) |> Nx.reshape({2}) + """ end @@ -886,9 +856,9 @@ defmodule Nx.Defn.EvaluatorTest do files = File.ls!(tmp_dir) assert Enum.any?(files, &String.starts_with?(&1, "node_")) contents = Enum.map(files, &File.read!(Path.join(tmp_dir, &1))) - assert {[_], rest} = Enum.split_with(contents, &(&1 =~ "Operation: :add")) - assert {[_], rest} = Enum.split_with(rest, &(&1 =~ "Operation: :multiply")) - assert {[_], rest} = Enum.split_with(rest, &(&1 =~ "Operation: :subtract")) + assert {[_], rest} = Enum.split_with(contents, &(&1 =~ "operation = :add")) + assert {[_], rest} = Enum.split_with(rest, &(&1 =~ "operation = :multiply")) + assert {[_], rest} = Enum.split_with(rest, &(&1 =~ "operation = :subtract")) assert length(rest) == 2 end @@ -897,7 +867,7 @@ defmodule Nx.Defn.EvaluatorTest do opts = [compiler: Nx.Defn.Evaluator, debug_options: [inspect_limit: 5]] output = ExUnit.CaptureIO.capture_io(fn -> reuse_fun(x, opts) end) - node_id_regex = ~r/Node ID: (.*)/ + node_id_regex = ~r/node_id = (.*)/ assert [id0, id1, id2, id3] = Regex.scan(node_id_regex, output, capture: :all_but_first) @@ -910,10 +880,10 @@ defmodule Nx.Defn.EvaluatorTest do |> String.replace(id3, "::id3::") # ensure that each node id is printed exactly once - assert output =~ ~r/.*(?:Node ID: ::id0::){1}.*/ - assert output =~ ~r/.*(?:Node ID: ::id1::){1}.*/ - assert output =~ ~r/.*(?:Node ID: ::id2::){1}.*/ - assert output =~ ~r/.*(?:Node ID: ::id3::){1}.*/ + assert output =~ ~r/.*(?:node_id = ::id0::){1}.*/ + assert output =~ ~r/.*(?:node_id = ::id1::){1}.*/ + assert output =~ ~r/.*(?:node_id = ::id2::){1}.*/ + assert output =~ ~r/.*(?:node_id = ::id3::){1}.*/ end test "respects inspect_limit" do