Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 71 additions & 29 deletions mlir/include/mlir/Bindings/Python/NanobindAdaptors.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
#ifndef MLIR_BINDINGS_PYTHON_NANOBINDADAPTORS_H
#define MLIR_BINDINGS_PYTHON_NANOBINDADAPTORS_H

#include <atomic>
#include <cstdint>
#include <memory>
#include <optional>

#include "mlir-c/Diagnostics.h"
Expand All @@ -30,6 +32,57 @@
// clang-format on
#include "llvm/ADT/Twine.h"

namespace mlir {
namespace python {
namespace {

// Safely calls Python initialization code on first use, avoiding deadlocks.
template <typename T>
class SafeInit {
public:
typedef std::unique_ptr<T> (*F)();

explicit SafeInit(F init_fn) : initFn(init_fn) {}

T &get() {
if (T *result = output.load()) {
return *result;
}

// Note: init_fn() may be called multiple times if, for example, the GIL is
// released during its execution. The intended use case is for module
// imports which are safe to perform multiple times. We are careful not to
// hold a lock across init_fn() to avoid lock ordering problems.
std::unique_ptr<T> m = initFn();
{
nanobind::ft_lock_guard lock(mu);
if (T *result = output.load()) {
return *result;
}
T *p = m.release();
output.store(p);
return *p;
}
}

private:
nanobind::ft_mutex mu;
std::atomic<T *> output{nullptr};
F initFn;
};

nanobind::module_ &irModule() {
static SafeInit<nanobind::module_> init([]() {
return std::make_unique<nanobind::module_>(
nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")));
});
return init.get();
}

} // namespace
} // namespace python
} // namespace mlir

// Raw CAPI type casters need to be declared before use, so always include them
// first.
namespace nanobind {
Expand Down Expand Up @@ -75,7 +128,7 @@ struct type_caster<MlirAffineMap> {
cleanup_list *cleanup) noexcept {
nanobind::object capsule =
nanobind::steal<nanobind::object>(mlirPythonAffineMapToCapsule(v));
return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
return mlir::python::irModule()
.attr("AffineMap")
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
.release();
Expand All @@ -97,7 +150,7 @@ struct type_caster<MlirAttribute> {
cleanup_list *cleanup) noexcept {
nanobind::object capsule =
nanobind::steal<nanobind::object>(mlirPythonAttributeToCapsule(v));
return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
return mlir::python::irModule()
.attr("Attribute")
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
.attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)()
Expand Down Expand Up @@ -128,9 +181,7 @@ struct type_caster<MlirContext> {
// TODO: This raises an error of "No current context" currently.
// Update the implementation to pretty-print the helpful error that the
// core implementations print in this case.
src = nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
.attr("Context")
.attr("current");
src = mlir::python::irModule().attr("Context").attr("current");
}
std::optional<nanobind::object> capsule = mlirApiObjectToCapsule(src);
value = mlirPythonCapsuleToContext(capsule->ptr());
Expand All @@ -153,7 +204,7 @@ struct type_caster<MlirDialectRegistry> {
cleanup_list *cleanup) noexcept {
nanobind::object capsule = nanobind::steal<nanobind::object>(
mlirPythonDialectRegistryToCapsule(v));
return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
return mlir::python::irModule()
.attr("DialectRegistry")
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
.release();
Expand All @@ -167,9 +218,7 @@ struct type_caster<MlirLocation> {
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
if (src.is_none()) {
// Gets the current thread-bound context.
src = nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
.attr("Location")
.attr("current");
src = mlir::python::irModule().attr("Location").attr("current");
}
if (auto capsule = mlirApiObjectToCapsule(src)) {
value = mlirPythonCapsuleToLocation(capsule->ptr());
Expand All @@ -181,7 +230,7 @@ struct type_caster<MlirLocation> {
cleanup_list *cleanup) noexcept {
nanobind::object capsule =
nanobind::steal<nanobind::object>(mlirPythonLocationToCapsule(v));
return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
return mlir::python::irModule()
.attr("Location")
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
.release();
Expand All @@ -203,7 +252,7 @@ struct type_caster<MlirModule> {
cleanup_list *cleanup) noexcept {
nanobind::object capsule =
nanobind::steal<nanobind::object>(mlirPythonModuleToCapsule(v));
return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
return mlir::python::irModule()
.attr("Module")
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
.release();
Expand Down Expand Up @@ -250,7 +299,7 @@ struct type_caster<MlirOperation> {
return nanobind::none();
nanobind::object capsule =
nanobind::steal<nanobind::object>(mlirPythonOperationToCapsule(v));
return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
return mlir::python::irModule()
.attr("Operation")
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
.release();
Expand All @@ -274,7 +323,7 @@ struct type_caster<MlirValue> {
return nanobind::none();
nanobind::object capsule =
nanobind::steal<nanobind::object>(mlirPythonValueToCapsule(v));
return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
return mlir::python::irModule()
.attr("Value")
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
.attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)()
Expand Down Expand Up @@ -312,7 +361,7 @@ struct type_caster<MlirTypeID> {
return nanobind::none();
nanobind::object capsule =
nanobind::steal<nanobind::object>(mlirPythonTypeIDToCapsule(v));
return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
return mlir::python::irModule()
.attr("TypeID")
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
.release();
Expand All @@ -334,7 +383,7 @@ struct type_caster<MlirType> {
cleanup_list *cleanup) noexcept {
nanobind::object capsule =
nanobind::steal<nanobind::object>(mlirPythonTypeToCapsule(t));
return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
return mlir::python::irModule()
.attr("Type")
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
.attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)()
Expand Down Expand Up @@ -453,11 +502,9 @@ class mlir_attribute_subclass : public pure_subclass {
mlir_attribute_subclass(nanobind::handle scope, const char *attrClassName,
IsAFunctionTy isaFunction,
GetTypeIDFunctionTy getTypeIDFunction = nullptr)
: mlir_attribute_subclass(
scope, attrClassName, isaFunction,
nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
.attr("Attribute"),
getTypeIDFunction) {}
: mlir_attribute_subclass(scope, attrClassName, isaFunction,
irModule().attr("Attribute"),
getTypeIDFunction) {}

/// Subclasses with a provided mlir.ir.Attribute super-class. This must
/// be used if the subclass is being defined in the same extension module
Expand Down Expand Up @@ -540,11 +587,8 @@ class mlir_type_subclass : public pure_subclass {
mlir_type_subclass(nanobind::handle scope, const char *typeClassName,
IsAFunctionTy isaFunction,
GetTypeIDFunctionTy getTypeIDFunction = nullptr)
: mlir_type_subclass(
scope, typeClassName, isaFunction,
nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
.attr("Type"),
getTypeIDFunction) {}
: mlir_type_subclass(scope, typeClassName, isaFunction,
irModule().attr("Type"), getTypeIDFunction) {}

/// Subclasses with a provided mlir.ir.Type super-class. This must
/// be used if the subclass is being defined in the same extension module
Expand Down Expand Up @@ -631,10 +675,8 @@ class mlir_value_subclass : public pure_subclass {
/// Subclasses by looking up the super-class dynamically.
mlir_value_subclass(nanobind::handle scope, const char *valueClassName,
IsAFunctionTy isaFunction)
: mlir_value_subclass(
scope, valueClassName, isaFunction,
nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
.attr("Value")) {}
: mlir_value_subclass(scope, valueClassName, isaFunction,
irModule().attr("Value")) {}

/// Subclasses with a provided mlir.ir.Value super-class. This must
/// be used if the subclass is being defined in the same extension module
Expand Down