From f493b65ba5ede9a4a24d618e6fe791c66bcce7f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Thu, 4 Dec 2025 08:51:38 +0000 Subject: [PATCH 1/2] [mlir:python] Add manual typing annotations to `mlir.register_*` functions. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR adds a manual typing annotations to the `register_operation` and `register_(type|value)_caster` functions in the main `mlir` module. Since those functions return the result `nb::cpp_function`, which is of type `nb::object`, the automatic typing annocations are of the form `def f() -> object`. This isn't particularly precise and leads to type checking errors when the functions are used. Manually defining the annotation with `nb::sig` solves the problem. Signed-off-by: Ingo Müller --- mlir/include/mlir/Bindings/Python/Nanobind.h | 1 + mlir/lib/Bindings/Python/MainModule.cpp | 12 ++++++++++++ 2 files changed, 13 insertions(+) diff --git a/mlir/include/mlir/Bindings/Python/Nanobind.h b/mlir/include/mlir/Bindings/Python/Nanobind.h index ca942c83d3e2f..8dc8a0d063d70 100644 --- a/mlir/include/mlir/Bindings/Python/Nanobind.h +++ b/mlir/include/mlir/Bindings/Python/Nanobind.h @@ -30,6 +30,7 @@ #include #include #include +#include #if defined(__clang__) || defined(__GNUC__) #pragma GCC diagnostic pop #endif diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index a14f09f77d2c3..915d52fb7bfde 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -24,6 +24,8 @@ using namespace mlir::python; NB_MODULE(_mlir, m) { m.doc() = "MLIR Python Native Extension"; + m.attr("T") = nb::type_var("T"); + m.attr("U") = nb::type_var("U"); nb::class_(m, "_Globals") .def_prop_rw("dialect_search_modules", @@ -102,6 +104,8 @@ NB_MODULE(_mlir, m) { return opClass; }); }, + nb::sig("def register_operation(dialect_class: type, *, " + "replace: bool = False) -> typing.Callable[[type[T]], type[T]]"), "dialect_class"_a, nb::kw_only(), "replace"_a = false, "Produce a class decorator for registering an Operation class as part of " "a dialect"); @@ -114,6 +118,10 @@ NB_MODULE(_mlir, m) { return typeCaster; }); }, + nb::sig("def register_type_caster(typeid: _mlir.ir.TypeID, *, " + "replace: bool = False) " + "-> typing.Callable[[typing.Callable[[T], U]], " + "typing.Callable[[T], U]]"), "typeid"_a, nb::kw_only(), "replace"_a = false, "Register a type caster for casting MLIR types to custom user types."); m.def( @@ -126,6 +134,10 @@ NB_MODULE(_mlir, m) { return valueCaster; }); }, + nb::sig("def register_value_caster(typeid: _mlir.ir.TypeID, *, " + "replace: bool = False) " + "-> typing.Callable[[typing.Callable[[T], U]], " + "typing.Callable[[T], U]]"), "typeid"_a, nb::kw_only(), "replace"_a = false, "Register a value caster for casting MLIR values to custom user values."); From b3ffc4859a56d55db82078dc71a4b894648f5dd3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Fri, 5 Dec 2025 08:19:47 +0000 Subject: [PATCH 2/2] Reformat as suggested by @makslevental. (NFC) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ingo Müller --- mlir/lib/Bindings/Python/MainModule.cpp | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index 915d52fb7bfde..ba767ad6692cf 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -104,8 +104,10 @@ NB_MODULE(_mlir, m) { return opClass; }); }, - nb::sig("def register_operation(dialect_class: type, *, " - "replace: bool = False) -> typing.Callable[[type[T]], type[T]]"), + // clang-format off + nb::sig("def register_operation(dialect_class: type, *, replace: bool = False) " + "-> typing.Callable[[type[T]], type[T]]"), + // clang-format on "dialect_class"_a, nb::kw_only(), "replace"_a = false, "Produce a class decorator for registering an Operation class as part of " "a dialect"); @@ -118,10 +120,10 @@ NB_MODULE(_mlir, m) { return typeCaster; }); }, - nb::sig("def register_type_caster(typeid: _mlir.ir.TypeID, *, " - "replace: bool = False) " - "-> typing.Callable[[typing.Callable[[T], U]], " - "typing.Callable[[T], U]]"), + // clang-format off + nb::sig("def register_type_caster(typeid: _mlir.ir.TypeID, *, replace: bool = False) " + "-> typing.Callable[[typing.Callable[[T], U]], typing.Callable[[T], U]]"), + // clang-format on "typeid"_a, nb::kw_only(), "replace"_a = false, "Register a type caster for casting MLIR types to custom user types."); m.def( @@ -134,10 +136,10 @@ NB_MODULE(_mlir, m) { return valueCaster; }); }, - nb::sig("def register_value_caster(typeid: _mlir.ir.TypeID, *, " - "replace: bool = False) " - "-> typing.Callable[[typing.Callable[[T], U]], " - "typing.Callable[[T], U]]"), + // clang-format off + nb::sig("def register_value_caster(typeid: _mlir.ir.TypeID, *, replace: bool = False) " + "-> typing.Callable[[typing.Callable[[T], U]], typing.Callable[[T], U]]"), + // clang-format on "typeid"_a, nb::kw_only(), "replace"_a = false, "Register a value caster for casting MLIR values to custom user values.");