diff --git a/mlir/include/mlir/Bindings/Python/Nanobind.h b/mlir/include/mlir/Bindings/Python/Nanobind.h index ca942c83d3e2f..8dc8a0d063d70 100644 --- a/mlir/include/mlir/Bindings/Python/Nanobind.h +++ b/mlir/include/mlir/Bindings/Python/Nanobind.h @@ -30,6 +30,7 @@ #include #include #include +#include #if defined(__clang__) || defined(__GNUC__) #pragma GCC diagnostic pop #endif diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index a14f09f77d2c3..ba767ad6692cf 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -24,6 +24,8 @@ using namespace mlir::python; NB_MODULE(_mlir, m) { m.doc() = "MLIR Python Native Extension"; + m.attr("T") = nb::type_var("T"); + m.attr("U") = nb::type_var("U"); nb::class_(m, "_Globals") .def_prop_rw("dialect_search_modules", @@ -102,6 +104,10 @@ NB_MODULE(_mlir, m) { return opClass; }); }, + // clang-format off + nb::sig("def register_operation(dialect_class: type, *, replace: bool = False) " + "-> typing.Callable[[type[T]], type[T]]"), + // clang-format on "dialect_class"_a, nb::kw_only(), "replace"_a = false, "Produce a class decorator for registering an Operation class as part of " "a dialect"); @@ -114,6 +120,10 @@ NB_MODULE(_mlir, m) { return typeCaster; }); }, + // clang-format off + nb::sig("def register_type_caster(typeid: _mlir.ir.TypeID, *, replace: bool = False) " + "-> typing.Callable[[typing.Callable[[T], U]], typing.Callable[[T], U]]"), + // clang-format on "typeid"_a, nb::kw_only(), "replace"_a = false, "Register a type caster for casting MLIR types to custom user types."); m.def( @@ -126,6 +136,10 @@ NB_MODULE(_mlir, m) { return valueCaster; }); }, + // clang-format off + nb::sig("def register_value_caster(typeid: _mlir.ir.TypeID, *, replace: bool = False) " + "-> typing.Callable[[typing.Callable[[T], U]], typing.Callable[[T], U]]"), + // clang-format on "typeid"_a, nb::kw_only(), "replace"_a = false, "Register a value caster for casting MLIR values to custom user values.");