From 500adbef0b52e2a10dc853226af6d140429ec050 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Tue, 28 Jan 2025 15:51:44 -0300 Subject: [PATCH 1/3] feat: add better errors for invalid compiler configurations --- nx/lib/nx/defn/compiler.ex | 21 +++++++++++++ nx/lib/nx/serving.ex | 11 +++++++ nx/test/nx/defn/compiler_test.exs | 49 +++++++++++++++++++++++++++++++ 3 files changed, 81 insertions(+) create mode 100644 nx/test/nx/defn/compiler_test.exs diff --git a/nx/lib/nx/defn/compiler.ex b/nx/lib/nx/defn/compiler.ex index bbbd3ddcd4..ff5f7215a7 100644 --- a/nx/lib/nx/defn/compiler.ex +++ b/nx/lib/nx/defn/compiler.ex @@ -112,6 +112,9 @@ defmodule Nx.Defn.Compiler do def __to_backend__(opts) do {compiler, opts} = Keyword.pop(opts, :compiler, Nx.Defn.Evaluator) compiler.__to_backend__(opts) + rescue + e in [UndefinedFunctionError] -> + raise_missing_callback(e, :__to_backend__, 1, __STACKTRACE__) end ## JIT/Stream @@ -120,12 +123,30 @@ defmodule Nx.Defn.Compiler do def __compile__(fun, params, opts) do {compiler, runtime_fun, opts} = prepare_options(fun, opts) compiler.__compile__(fun, params, runtime_fun, opts) + rescue + e in [UndefinedFunctionError] -> + raise_missing_callback(e, :__compile__, 4, __STACKTRACE__) end @doc false def __jit__(fun, params, args_list, opts) do {compiler, runtime_fun, opts} = prepare_options(fun, opts) compiler.__jit__(fun, params, runtime_fun, args_list, opts) + rescue + e in [UndefinedFunctionError] -> + raise_missing_callback(e, :__jit__, 5, __STACKTRACE__) + end + + defp raise_missing_callback(exception, name, arity, stacktrace) do + case exception do + %UndefinedFunctionError{module: compiler, function: ^name, arity: ^arity} -> + raise ArgumentError, + "The expected compiler callback #{name}/#{arity} is missing. Please check that the module #{inspect(compiler)} is an Nx.Defn.Compiler." + + _ -> + # This is not an error that should've been caught by this function, so we pass the exception along + reraise exception, stacktrace + end end defp prepare_options(fun, opts) do diff --git a/nx/lib/nx/serving.ex b/nx/lib/nx/serving.ex index ab0a524bd7..06f2f6cd9e 100644 --- a/nx/lib/nx/serving.ex +++ b/nx/lib/nx/serving.ex @@ -1359,6 +1359,17 @@ defmodule Nx.Serving do defp serving_partitions(%Nx.Serving{defn_options: defn_options}, true) do compiler = Keyword.get(defn_options, :compiler, Nx.Defn.Evaluator) compiler.__partitions_options__(defn_options) + rescue + e in [UndefinedFunctionError] -> + case e do + %UndefinedFunctionError{module: compiler, function: :__partitions_options__, arity: 1} -> + raise ArgumentError, + "The expected compiler callback __partitions_options__/1 is missing. Please check that the module #{inspect(compiler)} is an Nx.Defn.Compiler." + + _ -> + # This is not an error that should've been caught by this function, so we pass the exception along + reraise e, __STACKTRACE__ + end end defp serving_partitions(%Nx.Serving{defn_options: defn_options}, false) do diff --git a/nx/test/nx/defn/compiler_test.exs b/nx/test/nx/defn/compiler_test.exs new file mode 100644 index 0000000000..7da2640de3 --- /dev/null +++ b/nx/test/nx/defn/compiler_test.exs @@ -0,0 +1,49 @@ +defmodule Nx.Defn.CompilerTest do + use ExUnit.Case, async: true + + defmodule SomeInvalidServing do + def init(_, _, _) do + :ok + end + end + + test "it raises an error if the __compile__ callback is missing" do + msg = + "The expected compiler callback __compile__/4 is missing. Please check that the module SomeInvalidCompiler is an Nx.Defn.Compiler." + + assert_raise ArgumentError, msg, fn -> + Nx.Defn.compile(&Function.identity/1, [Nx.template({}, :f32)], + compiler: SomeInvalidCompiler + ) + end + end + + test "it raises an error if the __jit__ callback is missing" do + msg = + "The expected compiler callback __jit__/5 is missing. Please check that the module SomeInvalidCompiler is an Nx.Defn.Compiler." + + assert_raise ArgumentError, msg, fn -> + Nx.Defn.jit(&Function.identity/1, compiler: SomeInvalidCompiler).(1) + end + end + + test "it raises an error if the __partitions_options__ callback is missing" do + msg = + "The expected compiler callback __partitions_options__/1 is missing. Please check that the module SomeInvalidCompiler is an Nx.Defn.Compiler." + + serving = Nx.Serving.new(SomeInvalidServing, [], compiler: SomeInvalidCompiler) + + assert_raise ArgumentError, msg, fn -> + Nx.Serving.init({MyName, serving, true, [1], 10, 1000, nil, 1}) + end + end + + test "it raises an error if the __to_backend__ callback is missing" do + msg = + "The expected compiler callback __to_backend__/1 is missing. Please check that the module SomeInvalidCompiler is an Nx.Defn.Compiler." + + assert_raise ArgumentError, msg, fn -> + Nx.Defn.to_backend(compiler: SomeInvalidCompiler) + end + end +end From f9f35f73ab8679e945ac8dd453ffdae5e05f05bf Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Thu, 30 Jan 2025 13:25:54 -0300 Subject: [PATCH 2/3] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: José Valim --- nx/lib/nx/defn/compiler.ex | 2 +- nx/lib/nx/serving.ex | 2 +- nx/test/nx/defn/compiler_test.exs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/nx/lib/nx/defn/compiler.ex b/nx/lib/nx/defn/compiler.ex index ff5f7215a7..c25e691470 100644 --- a/nx/lib/nx/defn/compiler.ex +++ b/nx/lib/nx/defn/compiler.ex @@ -141,7 +141,7 @@ defmodule Nx.Defn.Compiler do case exception do %UndefinedFunctionError{module: compiler, function: ^name, arity: ^arity} -> raise ArgumentError, - "The expected compiler callback #{name}/#{arity} is missing. Please check that the module #{inspect(compiler)} is an Nx.Defn.Compiler." + "the expected compiler callback #{name}/#{arity} is missing. Please check that the module #{inspect(compiler)} is an Nx.Defn.Compiler." _ -> # This is not an error that should've been caught by this function, so we pass the exception along diff --git a/nx/lib/nx/serving.ex b/nx/lib/nx/serving.ex index 06f2f6cd9e..6a4d41ac99 100644 --- a/nx/lib/nx/serving.ex +++ b/nx/lib/nx/serving.ex @@ -1364,7 +1364,7 @@ defmodule Nx.Serving do case e do %UndefinedFunctionError{module: compiler, function: :__partitions_options__, arity: 1} -> raise ArgumentError, - "The expected compiler callback __partitions_options__/1 is missing. Please check that the module #{inspect(compiler)} is an Nx.Defn.Compiler." + "the expected compiler callback __partitions_options__/1 is missing. Please check that the module #{inspect(compiler)} is an Nx.Defn.Compiler." _ -> # This is not an error that should've been caught by this function, so we pass the exception along diff --git a/nx/test/nx/defn/compiler_test.exs b/nx/test/nx/defn/compiler_test.exs index 7da2640de3..12304a4205 100644 --- a/nx/test/nx/defn/compiler_test.exs +++ b/nx/test/nx/defn/compiler_test.exs @@ -7,7 +7,7 @@ defmodule Nx.Defn.CompilerTest do end end - test "it raises an error if the __compile__ callback is missing" do + test "raises an error if the __compile__ callback is missing" do msg = "The expected compiler callback __compile__/4 is missing. Please check that the module SomeInvalidCompiler is an Nx.Defn.Compiler." From 8e0b464eb65d7f26c6c246d9c7f95e32554ae522 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Thu, 30 Jan 2025 13:28:06 -0300 Subject: [PATCH 3/3] chore: comply with code review --- nx/test/nx/defn/compiler_test.exs | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/nx/test/nx/defn/compiler_test.exs b/nx/test/nx/defn/compiler_test.exs index 12304a4205..c35f69b075 100644 --- a/nx/test/nx/defn/compiler_test.exs +++ b/nx/test/nx/defn/compiler_test.exs @@ -9,7 +9,7 @@ defmodule Nx.Defn.CompilerTest do test "raises an error if the __compile__ callback is missing" do msg = - "The expected compiler callback __compile__/4 is missing. Please check that the module SomeInvalidCompiler is an Nx.Defn.Compiler." + "the expected compiler callback __compile__/4 is missing. Please check that the module SomeInvalidCompiler is an Nx.Defn.Compiler." assert_raise ArgumentError, msg, fn -> Nx.Defn.compile(&Function.identity/1, [Nx.template({}, :f32)], @@ -18,18 +18,18 @@ defmodule Nx.Defn.CompilerTest do end end - test "it raises an error if the __jit__ callback is missing" do + test "raises an error if the __jit__ callback is missing" do msg = - "The expected compiler callback __jit__/5 is missing. Please check that the module SomeInvalidCompiler is an Nx.Defn.Compiler." + "the expected compiler callback __jit__/5 is missing. Please check that the module SomeInvalidCompiler is an Nx.Defn.Compiler." assert_raise ArgumentError, msg, fn -> Nx.Defn.jit(&Function.identity/1, compiler: SomeInvalidCompiler).(1) end end - test "it raises an error if the __partitions_options__ callback is missing" do + test "raises an error if the __partitions_options__ callback is missing" do msg = - "The expected compiler callback __partitions_options__/1 is missing. Please check that the module SomeInvalidCompiler is an Nx.Defn.Compiler." + "the expected compiler callback __partitions_options__/1 is missing. Please check that the module SomeInvalidCompiler is an Nx.Defn.Compiler." serving = Nx.Serving.new(SomeInvalidServing, [], compiler: SomeInvalidCompiler) @@ -38,9 +38,9 @@ defmodule Nx.Defn.CompilerTest do end end - test "it raises an error if the __to_backend__ callback is missing" do + test "raises an error if the __to_backend__ callback is missing" do msg = - "The expected compiler callback __to_backend__/1 is missing. Please check that the module SomeInvalidCompiler is an Nx.Defn.Compiler." + "the expected compiler callback __to_backend__/1 is missing. Please check that the module SomeInvalidCompiler is an Nx.Defn.Compiler." assert_raise ArgumentError, msg, fn -> Nx.Defn.to_backend(compiler: SomeInvalidCompiler)