From 84812a54b88475da9d19f0aec15cffe4694be781 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Sun, 21 Sep 2025 19:17:23 +0000 Subject: [PATCH] [mlir][python] Cache import of ir module in type casters. In a JAX benchmark that traces a large language model, this change reduces the time spent in nanobind::module::import_ from 1.2s to 10ms. --- .../mlir/Bindings/Python/NanobindAdaptors.h | 100 +++++++++++++----- 1 file changed, 71 insertions(+), 29 deletions(-) diff --git a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h index 8744d8d0e4bca..b5f985f803de6 100644 --- a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h @@ -19,7 +19,9 @@ #ifndef MLIR_BINDINGS_PYTHON_NANOBINDADAPTORS_H #define MLIR_BINDINGS_PYTHON_NANOBINDADAPTORS_H +#include #include +#include #include #include "mlir-c/Diagnostics.h" @@ -30,6 +32,57 @@ // clang-format on #include "llvm/ADT/Twine.h" +namespace mlir { +namespace python { +namespace { + +// Safely calls Python initialization code on first use, avoiding deadlocks. +template +class SafeInit { +public: + typedef std::unique_ptr (*F)(); + + explicit SafeInit(F init_fn) : initFn(init_fn) {} + + T &get() { + if (T *result = output.load()) { + return *result; + } + + // Note: init_fn() may be called multiple times if, for example, the GIL is + // released during its execution. The intended use case is for module + // imports which are safe to perform multiple times. We are careful not to + // hold a lock across init_fn() to avoid lock ordering problems. + std::unique_ptr m = initFn(); + { + nanobind::ft_lock_guard lock(mu); + if (T *result = output.load()) { + return *result; + } + T *p = m.release(); + output.store(p); + return *p; + } + } + +private: + nanobind::ft_mutex mu; + std::atomic output{nullptr}; + F initFn; +}; + +nanobind::module_ &irModule() { + static SafeInit init([]() { + return std::make_unique( + nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))); + }); + return init.get(); +} + +} // namespace +} // namespace python +} // namespace mlir + // Raw CAPI type casters need to be declared before use, so always include them // first. namespace nanobind { @@ -75,7 +128,7 @@ struct type_caster { cleanup_list *cleanup) noexcept { nanobind::object capsule = nanobind::steal(mlirPythonAffineMapToCapsule(v)); - return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + return mlir::python::irModule() .attr("AffineMap") .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) .release(); @@ -97,7 +150,7 @@ struct type_caster { cleanup_list *cleanup) noexcept { nanobind::object capsule = nanobind::steal(mlirPythonAttributeToCapsule(v)); - return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + return mlir::python::irModule() .attr("Attribute") .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)() @@ -128,9 +181,7 @@ struct type_caster { // TODO: This raises an error of "No current context" currently. // Update the implementation to pretty-print the helpful error that the // core implementations print in this case. - src = nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) - .attr("Context") - .attr("current"); + src = mlir::python::irModule().attr("Context").attr("current"); } std::optional capsule = mlirApiObjectToCapsule(src); value = mlirPythonCapsuleToContext(capsule->ptr()); @@ -153,7 +204,7 @@ struct type_caster { cleanup_list *cleanup) noexcept { nanobind::object capsule = nanobind::steal( mlirPythonDialectRegistryToCapsule(v)); - return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + return mlir::python::irModule() .attr("DialectRegistry") .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) .release(); @@ -167,9 +218,7 @@ struct type_caster { bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { if (src.is_none()) { // Gets the current thread-bound context. - src = nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) - .attr("Location") - .attr("current"); + src = mlir::python::irModule().attr("Location").attr("current"); } if (auto capsule = mlirApiObjectToCapsule(src)) { value = mlirPythonCapsuleToLocation(capsule->ptr()); @@ -181,7 +230,7 @@ struct type_caster { cleanup_list *cleanup) noexcept { nanobind::object capsule = nanobind::steal(mlirPythonLocationToCapsule(v)); - return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + return mlir::python::irModule() .attr("Location") .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) .release(); @@ -203,7 +252,7 @@ struct type_caster { cleanup_list *cleanup) noexcept { nanobind::object capsule = nanobind::steal(mlirPythonModuleToCapsule(v)); - return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + return mlir::python::irModule() .attr("Module") .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) .release(); @@ -250,7 +299,7 @@ struct type_caster { return nanobind::none(); nanobind::object capsule = nanobind::steal(mlirPythonOperationToCapsule(v)); - return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + return mlir::python::irModule() .attr("Operation") .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) .release(); @@ -274,7 +323,7 @@ struct type_caster { return nanobind::none(); nanobind::object capsule = nanobind::steal(mlirPythonValueToCapsule(v)); - return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + return mlir::python::irModule() .attr("Value") .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)() @@ -312,7 +361,7 @@ struct type_caster { return nanobind::none(); nanobind::object capsule = nanobind::steal(mlirPythonTypeIDToCapsule(v)); - return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + return mlir::python::irModule() .attr("TypeID") .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) .release(); @@ -334,7 +383,7 @@ struct type_caster { cleanup_list *cleanup) noexcept { nanobind::object capsule = nanobind::steal(mlirPythonTypeToCapsule(t)); - return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + return mlir::python::irModule() .attr("Type") .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)() @@ -453,11 +502,9 @@ class mlir_attribute_subclass : public pure_subclass { mlir_attribute_subclass(nanobind::handle scope, const char *attrClassName, IsAFunctionTy isaFunction, GetTypeIDFunctionTy getTypeIDFunction = nullptr) - : mlir_attribute_subclass( - scope, attrClassName, isaFunction, - nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) - .attr("Attribute"), - getTypeIDFunction) {} + : mlir_attribute_subclass(scope, attrClassName, isaFunction, + irModule().attr("Attribute"), + getTypeIDFunction) {} /// Subclasses with a provided mlir.ir.Attribute super-class. This must /// be used if the subclass is being defined in the same extension module @@ -540,11 +587,8 @@ class mlir_type_subclass : public pure_subclass { mlir_type_subclass(nanobind::handle scope, const char *typeClassName, IsAFunctionTy isaFunction, GetTypeIDFunctionTy getTypeIDFunction = nullptr) - : mlir_type_subclass( - scope, typeClassName, isaFunction, - nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) - .attr("Type"), - getTypeIDFunction) {} + : mlir_type_subclass(scope, typeClassName, isaFunction, + irModule().attr("Type"), getTypeIDFunction) {} /// Subclasses with a provided mlir.ir.Type super-class. This must /// be used if the subclass is being defined in the same extension module @@ -631,10 +675,8 @@ class mlir_value_subclass : public pure_subclass { /// Subclasses by looking up the super-class dynamically. mlir_value_subclass(nanobind::handle scope, const char *valueClassName, IsAFunctionTy isaFunction) - : mlir_value_subclass( - scope, valueClassName, isaFunction, - nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) - .attr("Value")) {} + : mlir_value_subclass(scope, valueClassName, isaFunction, + irModule().attr("Value")) {} /// Subclasses with a provided mlir.ir.Value super-class. This must /// be used if the subclass is being defined in the same extension module