-
Notifications
You must be signed in to change notification settings - Fork 14.8k
[mlir][python] Cache import of ir module in type casters. #160000
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: Peter Hawkins (hawkinsp) ChangesIn a JAX benchmark that traces a large language model, this change reduces the time spent in nanobind::module::import_ from 1.2s to 10ms. Full diff: https://github.com/llvm/llvm-project/pull/160000.diff 1 Files Affected:
diff --git a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h
index 8744d8d0e4bca..aeb51542f9b6d 100644
--- a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h
+++ b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h
@@ -19,7 +19,9 @@
#ifndef MLIR_BINDINGS_PYTHON_NANOBINDADAPTORS_H
#define MLIR_BINDINGS_PYTHON_NANOBINDADAPTORS_H
+#include <atomic>
#include <cstdint>
+#include <memory>
#include <optional>
#include "mlir-c/Diagnostics.h"
@@ -30,6 +32,56 @@
// clang-format on
#include "llvm/ADT/Twine.h"
+namespace mlir {
+namespace python {
+namespace {
+
+// Safely calls Python initialization code on first use, avoiding deadlocks.
+template <typename T> class SafeInit {
+public:
+ typedef std::unique_ptr<T> (*F)();
+
+ explicit SafeInit(F init_fn) : init_fn_(init_fn) {}
+
+ T &Get() {
+ if (T *result = output_.load()) {
+ return *result;
+ }
+
+ // Note: init_fn() may be called multiple times if, for example, the GIL is
+ // released during its execution. The intended use case is for module
+ // imports which are safe to perform multiple times. We are careful not to
+ // hold a lock across init_fn() to avoid lock ordering problems.
+ std::unique_ptr<T> m = init_fn_();
+ {
+ nanobind::ft_lock_guard lock(mu_);
+ if (T *result = output_.load()) {
+ return *result;
+ }
+ T *p = m.release();
+ output_.store(p);
+ return *p;
+ }
+ }
+
+private:
+ nanobind::ft_mutex mu_;
+ std::atomic<T *> output_{nullptr};
+ F init_fn_;
+};
+
+nanobind::module_ &IrModule() {
+ static SafeInit<nanobind::module_> init([]() {
+ return std::make_unique<nanobind::module_>(
+ nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")));
+ });
+ return init.Get();
+}
+
+} // namespace
+} // namespace python
+} // namespace mlir
+
// Raw CAPI type casters need to be declared before use, so always include them
// first.
namespace nanobind {
@@ -75,7 +127,7 @@ struct type_caster<MlirAffineMap> {
cleanup_list *cleanup) noexcept {
nanobind::object capsule =
nanobind::steal<nanobind::object>(mlirPythonAffineMapToCapsule(v));
- return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ return mlir::python::IrModule()
.attr("AffineMap")
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
.release();
@@ -97,7 +149,7 @@ struct type_caster<MlirAttribute> {
cleanup_list *cleanup) noexcept {
nanobind::object capsule =
nanobind::steal<nanobind::object>(mlirPythonAttributeToCapsule(v));
- return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ return mlir::python::IrModule()
.attr("Attribute")
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
.attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)()
@@ -128,9 +180,7 @@ struct type_caster<MlirContext> {
// TODO: This raises an error of "No current context" currently.
// Update the implementation to pretty-print the helpful error that the
// core implementations print in this case.
- src = nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
- .attr("Context")
- .attr("current");
+ src = mlir::python::IrModule().attr("Context").attr("current");
}
std::optional<nanobind::object> capsule = mlirApiObjectToCapsule(src);
value = mlirPythonCapsuleToContext(capsule->ptr());
@@ -153,7 +203,7 @@ struct type_caster<MlirDialectRegistry> {
cleanup_list *cleanup) noexcept {
nanobind::object capsule = nanobind::steal<nanobind::object>(
mlirPythonDialectRegistryToCapsule(v));
- return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ return mlir::python::IrModule()
.attr("DialectRegistry")
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
.release();
@@ -167,9 +217,7 @@ struct type_caster<MlirLocation> {
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
if (src.is_none()) {
// Gets the current thread-bound context.
- src = nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
- .attr("Location")
- .attr("current");
+ src = mlir::python::IrModule().attr("Location").attr("current");
}
if (auto capsule = mlirApiObjectToCapsule(src)) {
value = mlirPythonCapsuleToLocation(capsule->ptr());
@@ -181,7 +229,7 @@ struct type_caster<MlirLocation> {
cleanup_list *cleanup) noexcept {
nanobind::object capsule =
nanobind::steal<nanobind::object>(mlirPythonLocationToCapsule(v));
- return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ return mlir::python::IrModule()
.attr("Location")
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
.release();
@@ -203,7 +251,7 @@ struct type_caster<MlirModule> {
cleanup_list *cleanup) noexcept {
nanobind::object capsule =
nanobind::steal<nanobind::object>(mlirPythonModuleToCapsule(v));
- return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ return mlir::python::IrModule()
.attr("Module")
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
.release();
@@ -250,7 +298,7 @@ struct type_caster<MlirOperation> {
return nanobind::none();
nanobind::object capsule =
nanobind::steal<nanobind::object>(mlirPythonOperationToCapsule(v));
- return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ return mlir::python::IrModule()
.attr("Operation")
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
.release();
@@ -274,7 +322,7 @@ struct type_caster<MlirValue> {
return nanobind::none();
nanobind::object capsule =
nanobind::steal<nanobind::object>(mlirPythonValueToCapsule(v));
- return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ return mlir::python::IrModule()
.attr("Value")
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
.attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)()
@@ -312,7 +360,7 @@ struct type_caster<MlirTypeID> {
return nanobind::none();
nanobind::object capsule =
nanobind::steal<nanobind::object>(mlirPythonTypeIDToCapsule(v));
- return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ return mlir::python::IrModule()
.attr("TypeID")
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
.release();
@@ -334,7 +382,7 @@ struct type_caster<MlirType> {
cleanup_list *cleanup) noexcept {
nanobind::object capsule =
nanobind::steal<nanobind::object>(mlirPythonTypeToCapsule(t));
- return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ return mlir::python::IrModule()
.attr("Type")
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
.attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)()
@@ -453,11 +501,9 @@ class mlir_attribute_subclass : public pure_subclass {
mlir_attribute_subclass(nanobind::handle scope, const char *attrClassName,
IsAFunctionTy isaFunction,
GetTypeIDFunctionTy getTypeIDFunction = nullptr)
- : mlir_attribute_subclass(
- scope, attrClassName, isaFunction,
- nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
- .attr("Attribute"),
- getTypeIDFunction) {}
+ : mlir_attribute_subclass(scope, attrClassName, isaFunction,
+ IrModule().attr("Attribute"),
+ getTypeIDFunction) {}
/// Subclasses with a provided mlir.ir.Attribute super-class. This must
/// be used if the subclass is being defined in the same extension module
@@ -540,11 +586,8 @@ class mlir_type_subclass : public pure_subclass {
mlir_type_subclass(nanobind::handle scope, const char *typeClassName,
IsAFunctionTy isaFunction,
GetTypeIDFunctionTy getTypeIDFunction = nullptr)
- : mlir_type_subclass(
- scope, typeClassName, isaFunction,
- nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
- .attr("Type"),
- getTypeIDFunction) {}
+ : mlir_type_subclass(scope, typeClassName, isaFunction,
+ IrModule().attr("Type"), getTypeIDFunction) {}
/// Subclasses with a provided mlir.ir.Type super-class. This must
/// be used if the subclass is being defined in the same extension module
@@ -631,10 +674,8 @@ class mlir_value_subclass : public pure_subclass {
/// Subclasses by looking up the super-class dynamically.
mlir_value_subclass(nanobind::handle scope, const char *valueClassName,
IsAFunctionTy isaFunction)
- : mlir_value_subclass(
- scope, valueClassName, isaFunction,
- nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
- .attr("Value")) {}
+ : mlir_value_subclass(scope, valueClassName, isaFunction,
+ IrModule().attr("Value")) {}
/// Subclasses with a provided mlir.ir.Value super-class. This must
/// be used if the subclass is being defined in the same extension module
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM modulo LLVM style nits
In a JAX benchmark that traces a large language model, this change reduces the time spent in nanobind::module::import_ from 1.2s to 10ms.
All done. Would you please merge? Thanks! |
@hawkinsp I am trying to understand the issue that required this commit. Shouldn't importing a module basically be a no-op once it's imported? Are there things that should be changed in nanobind? |
@wjakob All that is going on here is these type casters are called many times, perhaps 10^6 times or more in the benchmark. Here's a CPU flame graph from a sampling profiler under CPython 3.12 that shows the problem (not the exact same benchmark as the original, hence the different timing): It might make sense for nanobind to cache imports perhaps, or to simply make the CPython import logic faster in the case that we repeatedly make the same import. |
I would like to find out two things:
|
In a JAX benchmark that traces a large language model, this change reduces the time spent in nanobind::module::import_ from 1.2s to 10ms.