Skip to content

Conversation

hawkinsp
Copy link
Contributor

In a JAX benchmark that traces a large language model, this change reduces the time spent in nanobind::module::import_ from 1.2s to 10ms.

@llvmbot llvmbot added the mlir label Sep 21, 2025
@llvmbot
Copy link
Member

llvmbot commented Sep 21, 2025

@llvm/pr-subscribers-mlir

Author: Peter Hawkins (hawkinsp)

Changes

In a JAX benchmark that traces a large language model, this change reduces the time spent in nanobind::module::import_ from 1.2s to 10ms.


Full diff: https://github.com/llvm/llvm-project/pull/160000.diff

1 Files Affected:

  • (modified) mlir/include/mlir/Bindings/Python/NanobindAdaptors.h (+70-29)
diff --git a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h
index 8744d8d0e4bca..aeb51542f9b6d 100644
--- a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h
+++ b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h
@@ -19,7 +19,9 @@
 #ifndef MLIR_BINDINGS_PYTHON_NANOBINDADAPTORS_H
 #define MLIR_BINDINGS_PYTHON_NANOBINDADAPTORS_H
 
+#include <atomic>
 #include <cstdint>
+#include <memory>
 #include <optional>
 
 #include "mlir-c/Diagnostics.h"
@@ -30,6 +32,56 @@
 // clang-format on
 #include "llvm/ADT/Twine.h"
 
+namespace mlir {
+namespace python {
+namespace {
+
+// Safely calls Python initialization code on first use, avoiding deadlocks.
+template <typename T> class SafeInit {
+public:
+  typedef std::unique_ptr<T> (*F)();
+
+  explicit SafeInit(F init_fn) : init_fn_(init_fn) {}
+
+  T &Get() {
+    if (T *result = output_.load()) {
+      return *result;
+    }
+
+    // Note: init_fn() may be called multiple times if, for example, the GIL is
+    // released during its execution. The intended use case is for module
+    // imports which are safe to perform multiple times. We are careful not to
+    // hold a lock across init_fn() to avoid lock ordering problems.
+    std::unique_ptr<T> m = init_fn_();
+    {
+      nanobind::ft_lock_guard lock(mu_);
+      if (T *result = output_.load()) {
+        return *result;
+      }
+      T *p = m.release();
+      output_.store(p);
+      return *p;
+    }
+  }
+
+private:
+  nanobind::ft_mutex mu_;
+  std::atomic<T *> output_{nullptr};
+  F init_fn_;
+};
+
+nanobind::module_ &IrModule() {
+  static SafeInit<nanobind::module_> init([]() {
+    return std::make_unique<nanobind::module_>(
+        nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")));
+  });
+  return init.Get();
+}
+
+} // namespace
+} // namespace python
+} // namespace mlir
+
 // Raw CAPI type casters need to be declared before use, so always include them
 // first.
 namespace nanobind {
@@ -75,7 +127,7 @@ struct type_caster<MlirAffineMap> {
                          cleanup_list *cleanup) noexcept {
     nanobind::object capsule =
         nanobind::steal<nanobind::object>(mlirPythonAffineMapToCapsule(v));
-    return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+    return mlir::python::IrModule()
         .attr("AffineMap")
         .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
         .release();
@@ -97,7 +149,7 @@ struct type_caster<MlirAttribute> {
                          cleanup_list *cleanup) noexcept {
     nanobind::object capsule =
         nanobind::steal<nanobind::object>(mlirPythonAttributeToCapsule(v));
-    return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+    return mlir::python::IrModule()
         .attr("Attribute")
         .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
         .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)()
@@ -128,9 +180,7 @@ struct type_caster<MlirContext> {
       // TODO: This raises an error of "No current context" currently.
       // Update the implementation to pretty-print the helpful error that the
       // core implementations print in this case.
-      src = nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
-                .attr("Context")
-                .attr("current");
+      src = mlir::python::IrModule().attr("Context").attr("current");
     }
     std::optional<nanobind::object> capsule = mlirApiObjectToCapsule(src);
     value = mlirPythonCapsuleToContext(capsule->ptr());
@@ -153,7 +203,7 @@ struct type_caster<MlirDialectRegistry> {
                          cleanup_list *cleanup) noexcept {
     nanobind::object capsule = nanobind::steal<nanobind::object>(
         mlirPythonDialectRegistryToCapsule(v));
-    return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+    return mlir::python::IrModule()
         .attr("DialectRegistry")
         .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
         .release();
@@ -167,9 +217,7 @@ struct type_caster<MlirLocation> {
   bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
     if (src.is_none()) {
       // Gets the current thread-bound context.
-      src = nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
-                .attr("Location")
-                .attr("current");
+      src = mlir::python::IrModule().attr("Location").attr("current");
     }
     if (auto capsule = mlirApiObjectToCapsule(src)) {
       value = mlirPythonCapsuleToLocation(capsule->ptr());
@@ -181,7 +229,7 @@ struct type_caster<MlirLocation> {
                          cleanup_list *cleanup) noexcept {
     nanobind::object capsule =
         nanobind::steal<nanobind::object>(mlirPythonLocationToCapsule(v));
-    return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+    return mlir::python::IrModule()
         .attr("Location")
         .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
         .release();
@@ -203,7 +251,7 @@ struct type_caster<MlirModule> {
                          cleanup_list *cleanup) noexcept {
     nanobind::object capsule =
         nanobind::steal<nanobind::object>(mlirPythonModuleToCapsule(v));
-    return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+    return mlir::python::IrModule()
         .attr("Module")
         .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
         .release();
@@ -250,7 +298,7 @@ struct type_caster<MlirOperation> {
       return nanobind::none();
     nanobind::object capsule =
         nanobind::steal<nanobind::object>(mlirPythonOperationToCapsule(v));
-    return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+    return mlir::python::IrModule()
         .attr("Operation")
         .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
         .release();
@@ -274,7 +322,7 @@ struct type_caster<MlirValue> {
       return nanobind::none();
     nanobind::object capsule =
         nanobind::steal<nanobind::object>(mlirPythonValueToCapsule(v));
-    return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+    return mlir::python::IrModule()
         .attr("Value")
         .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
         .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)()
@@ -312,7 +360,7 @@ struct type_caster<MlirTypeID> {
       return nanobind::none();
     nanobind::object capsule =
         nanobind::steal<nanobind::object>(mlirPythonTypeIDToCapsule(v));
-    return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+    return mlir::python::IrModule()
         .attr("TypeID")
         .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
         .release();
@@ -334,7 +382,7 @@ struct type_caster<MlirType> {
                          cleanup_list *cleanup) noexcept {
     nanobind::object capsule =
         nanobind::steal<nanobind::object>(mlirPythonTypeToCapsule(t));
-    return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+    return mlir::python::IrModule()
         .attr("Type")
         .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
         .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)()
@@ -453,11 +501,9 @@ class mlir_attribute_subclass : public pure_subclass {
   mlir_attribute_subclass(nanobind::handle scope, const char *attrClassName,
                           IsAFunctionTy isaFunction,
                           GetTypeIDFunctionTy getTypeIDFunction = nullptr)
-      : mlir_attribute_subclass(
-            scope, attrClassName, isaFunction,
-            nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
-                .attr("Attribute"),
-            getTypeIDFunction) {}
+      : mlir_attribute_subclass(scope, attrClassName, isaFunction,
+                                IrModule().attr("Attribute"),
+                                getTypeIDFunction) {}
 
   /// Subclasses with a provided mlir.ir.Attribute super-class. This must
   /// be used if the subclass is being defined in the same extension module
@@ -540,11 +586,8 @@ class mlir_type_subclass : public pure_subclass {
   mlir_type_subclass(nanobind::handle scope, const char *typeClassName,
                      IsAFunctionTy isaFunction,
                      GetTypeIDFunctionTy getTypeIDFunction = nullptr)
-      : mlir_type_subclass(
-            scope, typeClassName, isaFunction,
-            nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
-                .attr("Type"),
-            getTypeIDFunction) {}
+      : mlir_type_subclass(scope, typeClassName, isaFunction,
+                           IrModule().attr("Type"), getTypeIDFunction) {}
 
   /// Subclasses with a provided mlir.ir.Type super-class. This must
   /// be used if the subclass is being defined in the same extension module
@@ -631,10 +674,8 @@ class mlir_value_subclass : public pure_subclass {
   /// Subclasses by looking up the super-class dynamically.
   mlir_value_subclass(nanobind::handle scope, const char *valueClassName,
                       IsAFunctionTy isaFunction)
-      : mlir_value_subclass(
-            scope, valueClassName, isaFunction,
-            nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
-                .attr("Value")) {}
+      : mlir_value_subclass(scope, valueClassName, isaFunction,
+                            IrModule().attr("Value")) {}
 
   /// Subclasses with a provided mlir.ir.Value super-class. This must
   /// be used if the subclass is being defined in the same extension module

Copy link

github-actions bot commented Sep 21, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Member

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM modulo LLVM style nits

In a JAX benchmark that traces a large language model, this change
reduces the time spent in nanobind::module::import_ from 1.2s to 10ms.
@hawkinsp
Copy link
Contributor Author

All done. Would you please merge? Thanks!

@jpienaar jpienaar merged commit b1e00f6 into llvm:main Sep 24, 2025
9 checks passed
@wjakob
Copy link
Contributor

wjakob commented Oct 13, 2025

@hawkinsp I am trying to understand the issue that required this commit. Shouldn't importing a module basically be a no-op once it's imported? Are there things that should be changed in nanobind?

@hawkinsp
Copy link
Contributor Author

hawkinsp commented Oct 13, 2025

@wjakob All that is going on here is these type casters are called many times, perhaps 10^6 times or more in the benchmark.

Here's a CPU flame graph from a sampling profiler under CPython 3.12 that shows the problem (not the exact same benchmark as the original, hence the different timing):
image

It might make sense for nanobind to cache imports perhaps, or to simply make the CPython import logic faster in the case that we repeatedly make the same import.

@wjakob
Copy link
Contributor

wjakob commented Oct 14, 2025

I would like to find out two things:

  1. Who is making those expensive nb::module_::import_() calls? Is it a specific type caster in nanobind that is calling nb::module::import_? Is it ndarray? Anything else? Or are the import from MLIR code?

  2. If nanobind was to offer a cached nb::import variant, how should it be implemented. Just keep a nb::dict in the internals to potentially avoid the call to the CPython module import function? Actually, CPython does not really import a module if it is already imported. I would assume that CPython likewise does a dictionary lookup and turns this into a no-op, so we would be duplicating existing functionality. So how can it take so long?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants