From e6047d9d3cd85c7d6005563aaa4cfbf2e96edb75 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Mon, 8 Dec 2025 15:41:03 +0100 Subject: [PATCH 1/2] [mlir][py] avoid crashing on None contexts in custom `get`s Following a series of refactorings, MLIR Python bindings would crash if a dialect object requiring a context defined using mlir_attribute/type_subclass was constructed outside of the `ir.Context` context manager. The type caster for `MlirContext` would try using `ir.Context.current` when the default `None` value was provided to the `get`, which would also just return `None`. The caster would then attempt to obtain the MLIR capsule for that `None`, fail, but access it anyway without checking, leading to a C++ assertion failure or segfault. Guard against this case in nanobind adaptors. Also emit a warning to the user to clarify expectations, as the default message confusingly says that `None` is accepted as context and then fails with a type error. Using Python C API is currently recommended by nanobind in this case since the surrounding function must be marked `noexcept`. The corresponding test is in the PDL dialect since it is where I first observed the behavior. Core types are not using the `mlir_type_subclass` mechanism and are immune to the problem, so cannot be used for checking. --- .../mlir/Bindings/Python/NanobindAdaptors.h | 20 +++++++++++++------ mlir/test/python/dialects/pdl_types.py | 13 ++++++++++++ 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h index 847951ab5fd46..6594670abaaa7 100644 --- a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h @@ -181,14 +181,22 @@ struct type_caster { bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { if (src.is_none()) { // Gets the current thread-bound context. - // TODO: This raises an error of "No current context" currently. - // Update the implementation to pretty-print the helpful error that the - // core implementations print in this case. src = mlir::python::irModule().attr("Context").attr("current"); } - std::optional capsule = mlirApiObjectToCapsule(src); - value = mlirPythonCapsuleToContext(capsule->ptr()); - return !mlirContextIsNull(value); + // If there is no context, including thread-bound, emit a warning (since + // this function is not allowed to throw) and fail to cast. + if (src.is_none()) { + PyErr_Warn( + PyExc_RuntimeWarning, + "Passing None as MLIR Context is only allowed inside " + "the " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") " context manager."); + return false; + } + if (std::optional capsule = mlirApiObjectToCapsule(src)) { + value = mlirPythonCapsuleToContext(capsule->ptr()); + return !mlirContextIsNull(value); + } + return false; } }; diff --git a/mlir/test/python/dialects/pdl_types.py b/mlir/test/python/dialects/pdl_types.py index 16a41e2a4c1ce..bbc6007511721 100644 --- a/mlir/test/python/dialects/pdl_types.py +++ b/mlir/test/python/dialects/pdl_types.py @@ -148,3 +148,16 @@ def test_value_type(): print(parsedType) # CHECK: !pdl.value print(constructedType) + + +# CHECK-LABEL: TEST: test_type_without_context +@run +def test_type_without_context(): + # Constructing a type without the surrounding ir.Context context manager + # should raise an exception but not crash. + try: + constructedType = pdl.ValueType.get() + except TypeError: + pass + else: + assert False, "Expected TypeError to be raised." From deac26450350ba40b9f9357f68ec3a5e458b43d6 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Mon, 8 Dec 2025 15:50:41 +0100 Subject: [PATCH 2/2] [mlir][py] partially use mlir_type_subclass for IRTypes.cpp Port the bindings for non-shaped builtin types in IRTypes.cpp to use the `mlir_type_subclass` mechanism used by non-builtin types. This is part of a longer-term cleanup to only support one subclassing mechanism. Eventually, the `PyConcreteType` mechanism will be removed. This required a surgery in the type casters and the `mlir_type_subclass` logic to avoid circular imports of the `_mlir.ir` module that would otherwise when using `mlir_type_subclass` to define classes in the `_mlir.ir` module. Tests are updated to use the `.get_static_typeid()` function instead of the `.static_typeid` property that was specific to builtin types due to the `PyConcreteType` mechanism. The change should be NFC otherwise. --- .../mlir/Bindings/Python/NanobindAdaptors.h | 41 +- mlir/lib/Bindings/Python/IRTypes.cpp | 1029 ++++++----------- mlir/lib/Bindings/Python/MainModule.cpp | 15 + mlir/test/python/dialects/arith_dialect.py | 8 +- mlir/test/python/ir/builtin_types.py | 11 +- mlir/test/python/ir/value.py | 6 +- 6 files changed, 425 insertions(+), 685 deletions(-) diff --git a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h index 6594670abaaa7..f678f57527e97 100644 --- a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h @@ -371,16 +371,22 @@ struct type_caster { } return false; } - static handle from_cpp(MlirTypeID v, rv_policy, - cleanup_list *cleanup) noexcept { + + static handle + from_cpp_given_module(MlirTypeID v, + const nanobind::module_ &module) noexcept { if (v.ptr == nullptr) return nanobind::none(); nanobind::object capsule = nanobind::steal(mlirPythonTypeIDToCapsule(v)); - return mlir::python::irModule() - .attr("TypeID") + return module.attr("TypeID") .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) .release(); + } + + static handle from_cpp(MlirTypeID v, rv_policy, + cleanup_list *cleanup) noexcept { + return from_cpp_given_module(v, mlir::python::irModule()); }; }; @@ -602,9 +608,12 @@ class mlir_type_subclass : public pure_subclass { /// Subclasses by looking up the super-class dynamically. mlir_type_subclass(nanobind::handle scope, const char *typeClassName, IsAFunctionTy isaFunction, - GetTypeIDFunctionTy getTypeIDFunction = nullptr) - : mlir_type_subclass(scope, typeClassName, isaFunction, - irModule().attr("Type"), getTypeIDFunction) {} + GetTypeIDFunctionTy getTypeIDFunction = nullptr, + const nanobind::module_ *mlirIrModule = nullptr) + : mlir_type_subclass( + scope, typeClassName, isaFunction, + (mlirIrModule != nullptr ? *mlirIrModule : irModule()).attr("Type"), + getTypeIDFunction, mlirIrModule) {} /// Subclasses with a provided mlir.ir.Type super-class. This must /// be used if the subclass is being defined in the same extension module @@ -613,7 +622,8 @@ class mlir_type_subclass : public pure_subclass { mlir_type_subclass(nanobind::handle scope, const char *typeClassName, IsAFunctionTy isaFunction, const nanobind::object &superCls, - GetTypeIDFunctionTy getTypeIDFunction = nullptr) + GetTypeIDFunctionTy getTypeIDFunction = nullptr, + const nanobind::module_ *mlirIrModule = nullptr) : pure_subclass(scope, typeClassName, superCls) { // Casting constructor. Note that it is hard, if not impossible, to properly // call chain to parent `__init__` in nanobind due to its special handling @@ -672,9 +682,18 @@ class mlir_type_subclass : public pure_subclass { nanobind::sig("def get_static_typeid() -> " MAKE_MLIR_PYTHON_QUALNAME("ir.TypeID")) // clang-format on ); - nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) - .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)( - getTypeIDFunction())(nanobind::cpp_function( + + // Directly call the caster implementation given the "ir" module, + // otherwise it may trigger recursive import as the default caster + // attempts to import the "ir" module. + MlirTypeID typeID = getTypeIDFunction(); + mlirIrModule = mlirIrModule ? mlirIrModule : &irModule(); + nanobind::handle pyTypeID = + nanobind::detail::type_caster::from_cpp_given_module( + typeID, *mlirIrModule); + + mlirIrModule->attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(pyTypeID)( + nanobind::cpp_function( [thisClass = thisClass](const nanobind::object &mlirType) { return thisClass(mlirType); })); diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index 34c5b8dd86a66..2e4090c358c47 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -18,13 +18,13 @@ #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" #include "mlir-c/Support.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" namespace nb = nanobind; using namespace mlir; using namespace mlir::python; using llvm::SmallVector; -using llvm::Twine; namespace { @@ -34,480 +34,368 @@ static int mlirTypeIsAIntegerOrFloat(MlirType type) { mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type); } -class PyIntegerType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirIntegerTypeGetTypeID; - static constexpr const char *pyClassName = "IntegerType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get_signless", - [](unsigned width, DefaultingPyMlirContext context) { - MlirType t = mlirIntegerTypeGet(context->get(), width); - return PyIntegerType(context->getRef(), t); - }, - nb::arg("width"), nb::arg("context") = nb::none(), - "Create a signless integer type"); - c.def_static( - "get_signed", - [](unsigned width, DefaultingPyMlirContext context) { - MlirType t = mlirIntegerTypeSignedGet(context->get(), width); - return PyIntegerType(context->getRef(), t); - }, - nb::arg("width"), nb::arg("context") = nb::none(), - "Create a signed integer type"); - c.def_static( - "get_unsigned", - [](unsigned width, DefaultingPyMlirContext context) { - MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width); - return PyIntegerType(context->getRef(), t); - }, - nb::arg("width"), nb::arg("context") = nb::none(), - "Create an unsigned integer type"); - c.def_prop_ro( - "width", - [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); }, - "Returns the width of the integer type"); - c.def_prop_ro( - "is_signless", - [](PyIntegerType &self) -> bool { - return mlirIntegerTypeIsSignless(self); - }, - "Returns whether this is a signless integer"); - c.def_prop_ro( - "is_signed", - [](PyIntegerType &self) -> bool { - return mlirIntegerTypeIsSigned(self); - }, - "Returns whether this is a signed integer"); - c.def_prop_ro( - "is_unsigned", - [](PyIntegerType &self) -> bool { - return mlirIntegerTypeIsUnsigned(self); - }, - "Returns whether this is an unsigned integer"); - } -}; - -/// Index Type subclass - IndexType. -class PyIndexType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirIndexTypeGetTypeID; - static constexpr const char *pyClassName = "IndexType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirIndexTypeGet(context->get()); - return PyIndexType(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a index type."); - } -}; - -class PyFloatType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat; - static constexpr const char *pyClassName = "FloatType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_prop_ro( - "width", [](PyFloatType &self) { return mlirFloatTypeGetWidth(self); }, - "Returns the width of the floating-point type"); - } -}; - -/// Floating Point Type subclass - Float4E2M1FNType. -class PyFloat4E2M1FNType - : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat4E2M1FN; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat4E2M1FNTypeGetTypeID; - static constexpr const char *pyClassName = "Float4E2M1FNType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat4E2M1FNTypeGet(context->get()); - return PyFloat4E2M1FNType(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a float4_e2m1fn type."); - } -}; - -/// Floating Point Type subclass - Float6E2M3FNType. -class PyFloat6E2M3FNType - : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E2M3FN; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat6E2M3FNTypeGetTypeID; - static constexpr const char *pyClassName = "Float6E2M3FNType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat6E2M3FNTypeGet(context->get()); - return PyFloat6E2M3FNType(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a float6_e2m3fn type."); - } -}; - -/// Floating Point Type subclass - Float6E3M2FNType. -class PyFloat6E3M2FNType - : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E3M2FN; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat6E3M2FNTypeGetTypeID; - static constexpr const char *pyClassName = "Float6E3M2FNType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat6E3M2FNTypeGet(context->get()); - return PyFloat6E3M2FNType(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a float6_e3m2fn type."); - } -}; - -/// Floating Point Type subclass - Float8E4M3FNType. -class PyFloat8E4M3FNType - : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat8E4M3FNTypeGetTypeID; - static constexpr const char *pyClassName = "Float8E4M3FNType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat8E4M3FNTypeGet(context->get()); - return PyFloat8E4M3FNType(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a float8_e4m3fn type."); - } -}; - -/// Floating Point Type subclass - Float8E5M2Type. -class PyFloat8E5M2Type : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat8E5M2TypeGetTypeID; - static constexpr const char *pyClassName = "Float8E5M2Type"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat8E5M2TypeGet(context->get()); - return PyFloat8E5M2Type(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a float8_e5m2 type."); - } -}; - -/// Floating Point Type subclass - Float8E4M3Type. -class PyFloat8E4M3Type : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat8E4M3TypeGetTypeID; - static constexpr const char *pyClassName = "Float8E4M3Type"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat8E4M3TypeGet(context->get()); - return PyFloat8E4M3Type(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a float8_e4m3 type."); - } -}; - -/// Floating Point Type subclass - Float8E4M3FNUZ. -class PyFloat8E4M3FNUZType - : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat8E4M3FNUZTypeGetTypeID; - static constexpr const char *pyClassName = "Float8E4M3FNUZType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get()); - return PyFloat8E4M3FNUZType(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a float8_e4m3fnuz type."); - } -}; - -/// Floating Point Type subclass - Float8E4M3B11FNUZ. -class PyFloat8E4M3B11FNUZType - : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat8E4M3B11FNUZTypeGetTypeID; - static constexpr const char *pyClassName = "Float8E4M3B11FNUZType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get()); - return PyFloat8E4M3B11FNUZType(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a float8_e4m3b11fnuz type."); - } -}; - -/// Floating Point Type subclass - Float8E5M2FNUZ. -class PyFloat8E5M2FNUZType - : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat8E5M2FNUZTypeGetTypeID; - static constexpr const char *pyClassName = "Float8E5M2FNUZType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get()); - return PyFloat8E5M2FNUZType(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a float8_e5m2fnuz type."); - } -}; - -/// Floating Point Type subclass - Float8E3M4Type. -class PyFloat8E3M4Type : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E3M4; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat8E3M4TypeGetTypeID; - static constexpr const char *pyClassName = "Float8E3M4Type"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat8E3M4TypeGet(context->get()); - return PyFloat8E3M4Type(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a float8_e3m4 type."); - } -}; - -/// Floating Point Type subclass - Float8E8M0FNUType. -class PyFloat8E8M0FNUType - : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E8M0FNU; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat8E8M0FNUTypeGetTypeID; - static constexpr const char *pyClassName = "Float8E8M0FNUType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat8E8M0FNUTypeGet(context->get()); - return PyFloat8E8M0FNUType(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a float8_e8m0fnu type."); - } -}; - -/// Floating Point Type subclass - BF16Type. -class PyBF16Type : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirBFloat16TypeGetTypeID; - static constexpr const char *pyClassName = "BF16Type"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirBF16TypeGet(context->get()); - return PyBF16Type(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a bf16 type."); - } -}; - -/// Floating Point Type subclass - F16Type. -class PyF16Type : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat16TypeGetTypeID; - static constexpr const char *pyClassName = "F16Type"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirF16TypeGet(context->get()); - return PyF16Type(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a f16 type."); - } -}; - -/// Floating Point Type subclass - TF32Type. -class PyTF32Type : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsATF32; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloatTF32TypeGetTypeID; - static constexpr const char *pyClassName = "FloatTF32Type"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirTF32TypeGet(context->get()); - return PyTF32Type(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a tf32 type."); - } -}; - -/// Floating Point Type subclass - F32Type. -class PyF32Type : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat32TypeGetTypeID; - static constexpr const char *pyClassName = "F32Type"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirF32TypeGet(context->get()); - return PyF32Type(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a f32 type."); - } -}; +static void populateIRTypesModule(const nanobind::module_ &m) { + using namespace nanobind_adaptors; -/// Floating Point Type subclass - F64Type. -class PyF64Type : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat64TypeGetTypeID; - static constexpr const char *pyClassName = "F64Type"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirF64TypeGet(context->get()); - return PyF64Type(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a f64 type."); - } -}; - -/// None Type subclass - NoneType. -class PyNoneType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirNoneTypeGetTypeID; - static constexpr const char *pyClassName = "NoneType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirNoneTypeGet(context->get()); - return PyNoneType(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a none type."); - } -}; - -/// Complex Type subclass - ComplexType. -class PyComplexType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirComplexTypeGetTypeID; - static constexpr const char *pyClassName = "ComplexType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](PyType &elementType) { - // The element must be a floating point or integer scalar type. - if (mlirTypeIsAIntegerOrFloat(elementType)) { - MlirType t = mlirComplexTypeGet(elementType); - return PyComplexType(elementType.getContext(), t); - } - throw nb::value_error( - (Twine("invalid '") + - nb::cast(nb::repr(nb::cast(elementType))) + - "' and expected floating point or integer type.") - .str() - .c_str()); - }, - "Create a complex type"); - c.def_prop_ro( - "element_type", - [](PyComplexType &self) -> nb::typed { - return PyType(self.getContext(), mlirComplexTypeGetElementType(self)) - .maybeDownCast(); - }, - "Returns element type."); - } -}; + mlir_type_subclass integerType(m, "IntegerType", mlirTypeIsAInteger, + mlirIntegerTypeGetTypeID, &m); + integerType.def_classmethod( + "get_signless", + [](const nb::object &cls, unsigned width, MlirContext ctx) { + return cls(mlirIntegerTypeGet(ctx, width)); + }, + nb::arg("cls"), nb::arg("width"), nb::arg("context") = nb::none(), + "Create a signless integer type"); + integerType.def_classmethod( + "get_signed", + [](const nb::object &cls, unsigned width, MlirContext ctx) { + return cls(mlirIntegerTypeSignedGet(ctx, width)); + }, + nb::arg("cls"), nb::arg("width"), nb::arg("context") = nb::none(), + "Create a signed integer type"); + integerType.def_classmethod( + "get_unsigned", + [](const nb::object &cls, unsigned width, MlirContext ctx) { + return cls(mlirIntegerTypeUnsignedGet(ctx, width)); + }, + nb::arg("cls"), nb::arg("width"), nb::arg("context") = nb::none(), + "Create an unsigned integer type"); + integerType.def_property_readonly( + "width", [](MlirType self) { return mlirIntegerTypeGetWidth(self); }, + "Returns the width of the integer type"); + integerType.def_property_readonly( + "is_signless", + [](MlirType self) { return mlirIntegerTypeIsSignless(self); }, + "Returns whether this is a signless integer"); + integerType.def_property_readonly( + "is_signed", [](MlirType self) { return mlirIntegerTypeIsSigned(self); }, + "Returns whether this is a signed integer"); + integerType.def_property_readonly( + "is_unsigned", + [](MlirType self) { return mlirIntegerTypeIsUnsigned(self); }, + "Returns whether this is an unsigned integer"); + + // IndexType + mlir_type_subclass indexType(m, "IndexType", mlirTypeIsAIndex, + mlirIndexTypeGetTypeID, &m); + + indexType.def_classmethod( + "get", + [](const nb::object &cls, MlirContext ctx) { + return cls(mlirIndexTypeGet(ctx)); + }, + nb::arg("cls"), nb::arg("context") = nb::none(), "Create a index type."); + + // FloatType (base class for specific float types) + mlir_type_subclass floatType(m, "FloatType", mlirTypeIsAFloat, nullptr, &m); + floatType.def_property_readonly( + "width", [](MlirType self) { return mlirFloatTypeGetWidth(self); }, + "Returns the width of the floating-point type"); + + // Float4E2M1FNType + mlir_type_subclass float4E2M1FNType( + m, "Float4E2M1FNType", mlirTypeIsAFloat4E2M1FN, floatType.get_class(), + mlirFloat4E2M1FNTypeGetTypeID, &m); + float4E2M1FNType.def_classmethod( + "get", + [](const nb::object &cls, MlirContext ctx) { + return cls(mlirFloat4E2M1FNTypeGet(ctx)); + }, + nb::arg("cls"), nb::arg("context") = nb::none(), + "Create a float4_e2m1fn type."); + + // Float6E2M3FNType + mlir_type_subclass float6E2M3FNType( + m, "Float6E2M3FNType", mlirTypeIsAFloat6E2M3FN, floatType.get_class(), + mlirFloat6E2M3FNTypeGetTypeID, &m); + float6E2M3FNType.def_classmethod( + "get", + [](const nb::object &cls, MlirContext ctx) { + return cls(mlirFloat6E2M3FNTypeGet(ctx)); + }, + nb::arg("cls"), nb::arg("context") = nb::none(), + "Create a float6_e2m3fn type."); + + // Float6E3M2FNType + mlir_type_subclass float6E3M2FNType( + m, "Float6E3M2FNType", mlirTypeIsAFloat6E3M2FN, floatType.get_class(), + mlirFloat6E3M2FNTypeGetTypeID, &m); + float6E3M2FNType.def_classmethod( + "get", + [](const nb::object &cls, MlirContext ctx) { + return cls(mlirFloat6E3M2FNTypeGet(ctx)); + }, + nb::arg("cls"), nb::arg("context") = nb::none(), + "Create a float6_e3m2fn type."); + + // Float8E4M3FNType + mlir_type_subclass float8E4M3FNType( + m, "Float8E4M3FNType", mlirTypeIsAFloat8E4M3FN, floatType.get_class(), + mlirFloat8E4M3FNTypeGetTypeID, &m); + float8E4M3FNType.def_classmethod( + "get", + [](const nb::object &cls, MlirContext ctx) { + return cls(mlirFloat8E4M3FNTypeGet(ctx)); + }, + nb::arg("cls"), nb::arg("context") = nb::none(), + "Create a float8_e4m3fn type."); + + // Float8E5M2Type + mlir_type_subclass float8E5M2Type(m, "Float8E5M2Type", mlirTypeIsAFloat8E5M2, + floatType.get_class(), + mlirFloat8E5M2TypeGetTypeID, &m); + float8E5M2Type.def_classmethod( + "get", + [](const nb::object &cls, MlirContext ctx) { + return cls(mlirFloat8E5M2TypeGet(ctx)); + }, + nb::arg("cls"), nb::arg("context") = nb::none(), + "Create a float8_e5m2 type."); + + // Float8E4M3Type + mlir_type_subclass float8E4M3Type(m, "Float8E4M3Type", mlirTypeIsAFloat8E4M3, + floatType.get_class(), + mlirFloat8E4M3TypeGetTypeID, &m); + float8E4M3Type.def_classmethod( + "get", + [](const nb::object &cls, MlirContext ctx) { + return cls(mlirFloat8E4M3TypeGet(ctx)); + }, + nb::arg("cls"), nb::arg("context") = nb::none(), + "Create a float8_e4m3 type."); + + // Float8E4M3FNUZType + mlir_type_subclass float8E4M3FNUZType( + m, "Float8E4M3FNUZType", mlirTypeIsAFloat8E4M3FNUZ, floatType.get_class(), + mlirFloat8E4M3FNUZTypeGetTypeID, &m); + float8E4M3FNUZType.def_classmethod( + "get", + [](const nb::object &cls, MlirContext ctx) { + return cls(mlirFloat8E4M3FNUZTypeGet(ctx)); + }, + nb::arg("cls"), nb::arg("context") = nb::none(), + "Create a float8_e4m3fnuz type."); + + // Float8E4M3B11FNUZType + mlir_type_subclass float8E4M3B11FNUZType( + m, "Float8E4M3B11FNUZType", mlirTypeIsAFloat8E4M3B11FNUZ, + floatType.get_class(), mlirFloat8E4M3B11FNUZTypeGetTypeID, &m); + float8E4M3B11FNUZType.def_classmethod( + "get", + [](const nb::object &cls, MlirContext ctx) { + return cls(mlirFloat8E4M3B11FNUZTypeGet(ctx)); + }, + nb::arg("cls"), nb::arg("context") = nb::none(), + "Create a float8_e4m3b11fnuz type."); + + // Float8E5M2FNUZType + mlir_type_subclass float8E5M2FNUZType( + m, "Float8E5M2FNUZType", mlirTypeIsAFloat8E5M2FNUZ, floatType.get_class(), + mlirFloat8E5M2FNUZTypeGetTypeID, &m); + float8E5M2FNUZType.def_classmethod( + "get", + [](const nb::object &cls, MlirContext ctx) { + return cls(mlirFloat8E5M2FNUZTypeGet(ctx)); + }, + nb::arg("cls"), nb::arg("context") = nb::none(), + "Create a float8_e5m2fnuz type."); + + // Float8E3M4Type + mlir_type_subclass float8E3M4Type(m, "Float8E3M4Type", mlirTypeIsAFloat8E3M4, + floatType.get_class(), + mlirFloat8E3M4TypeGetTypeID, &m); + float8E3M4Type.def_classmethod( + "get", + [](const nb::object &cls, MlirContext ctx) { + return cls(mlirFloat8E3M4TypeGet(ctx)); + }, + nb::arg("cls"), nb::arg("context") = nb::none(), + "Create a float8_e3m4 type."); + + // Float8E8M0FNUType + mlir_type_subclass float8E8M0FNUType( + m, "Float8E8M0FNUType", mlirTypeIsAFloat8E8M0FNU, floatType.get_class(), + mlirFloat8E8M0FNUTypeGetTypeID, &m); + float8E8M0FNUType.def_classmethod( + "get", + [](const nb::object &cls, MlirContext ctx) { + return cls(mlirFloat8E8M0FNUTypeGet(ctx)); + }, + nb::arg("cls"), nb::arg("context") = nb::none(), + "Create a float8_e8m0fnu type."); + + // BF16Type + mlir_type_subclass bf16Type(m, "BF16Type", mlirTypeIsABF16, + floatType.get_class(), mlirBFloat16TypeGetTypeID, + &m); + bf16Type.def_classmethod( + "get", + [](const nb::object &cls, MlirContext ctx) { + return cls(mlirBF16TypeGet(ctx)); + }, + nb::arg("cls"), nb::arg("context") = nb::none(), "Create a bf16 type."); + + // F16Type + mlir_type_subclass f16Type(m, "F16Type", mlirTypeIsAF16, + floatType.get_class(), mlirFloat16TypeGetTypeID, + &m); + f16Type.def_classmethod( + "get", + [](const nb::object &cls, MlirContext ctx) { + return cls(mlirF16TypeGet(ctx)); + }, + nb::arg("cls"), nb::arg("context") = nb::none(), "Create a f16 type."); + + // FloatTF32Type + mlir_type_subclass tf32Type(m, "FloatTF32Type", mlirTypeIsATF32, + floatType.get_class(), mlirFloatTF32TypeGetTypeID, + &m); + tf32Type.def_classmethod( + "get", + [](const nb::object &cls, MlirContext ctx) { + return cls(mlirTF32TypeGet(ctx)); + }, + nb::arg("cls"), nb::arg("context") = nb::none(), "Create a tf32 type."); + + // F32Type + mlir_type_subclass f32Type(m, "F32Type", mlirTypeIsAF32, + floatType.get_class(), mlirFloat32TypeGetTypeID, + &m); + f32Type.def_classmethod( + "get", + [](const nb::object &cls, MlirContext ctx) { + return cls(mlirF32TypeGet(ctx)); + }, + nb::arg("cls"), nb::arg("context") = nb::none(), "Create a f32 type."); + + // F64Type + mlir_type_subclass f64Type(m, "F64Type", mlirTypeIsAF64, + floatType.get_class(), mlirFloat64TypeGetTypeID, + &m); + f64Type.def_classmethod( + "get", + [](const nb::object &cls, MlirContext ctx) { + return cls(mlirF64TypeGet(ctx)); + }, + nb::arg("cls"), nb::arg("context") = nb::none(), "Create a f64 type."); + + // NoneType + mlir_type_subclass noneType(m, "NoneType", mlirTypeIsANone, + mlirNoneTypeGetTypeID, &m); + noneType.def_classmethod( + "get", + [](const nb::object &cls, MlirContext ctx) { + return cls(mlirNoneTypeGet(ctx)); + }, + nb::arg("cls"), nb::arg("context") = nb::none(), "Create a none type."); + + // ComplexType + mlir_type_subclass complexType(m, "ComplexType", mlirTypeIsAComplex, + mlirComplexTypeGetTypeID, &m); + complexType.def_classmethod( + "get", + [](const nb::object &cls, MlirType elementType) { + // The element must be a floating point or integer scalar type. + if (mlirTypeIsAIntegerOrFloat(elementType)) { + return cls(mlirComplexTypeGet(elementType)); + } + throw nb::value_error("Invalid element type for ComplexType: expected " + "floating point or integer type."); + }, + "Create a complex type"); + complexType.def_property_readonly( + "element_type", + [](MlirType self) { return mlirComplexTypeGetElementType(self); }, + "Returns element type."); + + // TupleType + mlir_type_subclass tupleType(m, "TupleType", mlirTypeIsATuple, + mlirTupleTypeGetTypeID, &m); + tupleType.def_classmethod( + "get_tuple", + [](const nb::object &cls, std::vector elements, + MlirContext ctx) { + return cls(mlirTupleTypeGet(ctx, elements.size(), elements.data())); + }, + nb::arg("cls"), nb::arg("elements"), nb::arg("context") = nb::none(), + "Create a tuple type"); + tupleType.def( + "get_type", + [](MlirType self, intptr_t pos) { + return mlirTupleTypeGetType(self, pos); + }, + nb::arg("pos"), "Returns the pos-th type in the tuple type."); + tupleType.def_property_readonly( + "num_types", [](MlirType self) { return mlirTupleTypeGetNumTypes(self); }, + "Returns the number of types contained in a tuple."); + + // FunctionType + mlir_type_subclass functionType(m, "FunctionType", mlirTypeIsAFunction, + mlirFunctionTypeGetTypeID, &m); + functionType.def_classmethod( + "get", + [](const nb::object &cls, std::vector inputs, + std::vector results, MlirContext ctx) { + return cls(mlirFunctionTypeGet(ctx, inputs.size(), inputs.data(), + results.size(), results.data())); + }, + nb::arg("cls"), nb::arg("inputs"), nb::arg("results"), + nb::arg("context") = nb::none(), + "Gets a FunctionType from a list of input and result types"); + functionType.def_property_readonly( + "inputs", + [](MlirType self) { + nb::list types; + for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e; + ++i) { + types.append(mlirFunctionTypeGetInput(self, i)); + } + return types; + }, + "Returns the list of input types in the FunctionType."); + functionType.def_property_readonly( + "results", + [](MlirType self) { + nb::list types; + for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e; + ++i) { + types.append(mlirFunctionTypeGetResult(self, i)); + } + return types; + }, + "Returns the list of result types in the FunctionType."); + + // OpaqueType + mlir_type_subclass opaqueType(m, "OpaqueType", mlirTypeIsAOpaque, + mlirOpaqueTypeGetTypeID, &m); + opaqueType.def_classmethod( + "get", + [](const nb::object &cls, const std::string &dialectNamespace, + const std::string &typeData, MlirContext ctx) { + MlirStringRef dialectNs = mlirStringRefCreate(dialectNamespace.data(), + dialectNamespace.size()); + MlirStringRef data = + mlirStringRefCreate(typeData.data(), typeData.size()); + return cls(mlirOpaqueTypeGet(ctx, dialectNs, data)); + }, + nb::arg("cls"), nb::arg("dialect_namespace"), nb::arg("buffer"), + nb::arg("context") = nb::none(), + "Create an unregistered (opaque) dialect type."); + opaqueType.def_property_readonly( + "dialect_namespace", + [](MlirType self) { + MlirStringRef stringRef = mlirOpaqueTypeGetDialectNamespace(self); + return nb::str(stringRef.data, stringRef.length); + }, + "Returns the dialect namespace for the Opaque type as a string."); + opaqueType.def_property_readonly( + "data", + [](MlirType self) { + MlirStringRef stringRef = mlirOpaqueTypeGetData(self); + return nb::str(stringRef.data, stringRef.length); + }, + "Returns the data for the Opaque type as a string."); +} } // namespace @@ -977,202 +865,17 @@ class PyUnrankedMemRefType } }; -/// Tuple Type subclass - TupleType. -class PyTupleType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirTupleTypeGetTypeID; - static constexpr const char *pyClassName = "TupleType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get_tuple", - [](const std::vector &elements, - DefaultingPyMlirContext context) { - std::vector mlirElements; - mlirElements.reserve(elements.size()); - for (const auto &element : elements) - mlirElements.push_back(element.get()); - MlirType t = mlirTupleTypeGet(context->get(), elements.size(), - mlirElements.data()); - return PyTupleType(context->getRef(), t); - }, - nb::arg("elements"), nb::arg("context") = nb::none(), - "Create a tuple type"); - c.def_static( - "get_tuple", - [](std::vector elements, DefaultingPyMlirContext context) { - MlirType t = mlirTupleTypeGet(context->get(), elements.size(), - elements.data()); - return PyTupleType(context->getRef(), t); - }, - nb::arg("elements"), nb::arg("context") = nb::none(), - // clang-format off - nb::sig("def get_tuple(elements: Sequence[Type], context: Context | None = None) -> TupleType"), - // clang-format on - "Create a tuple type"); - c.def( - "get_type", - [](PyTupleType &self, intptr_t pos) -> nb::typed { - return PyType(self.getContext(), mlirTupleTypeGetType(self, pos)) - .maybeDownCast(); - }, - nb::arg("pos"), "Returns the pos-th type in the tuple type."); - c.def_prop_ro( - "num_types", - [](PyTupleType &self) -> intptr_t { - return mlirTupleTypeGetNumTypes(self); - }, - "Returns the number of types contained in a tuple."); - } -}; - -/// Function type. -class PyFunctionType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFunctionTypeGetTypeID; - static constexpr const char *pyClassName = "FunctionType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](std::vector inputs, std::vector results, - DefaultingPyMlirContext context) { - std::vector mlirInputs; - mlirInputs.reserve(inputs.size()); - for (const auto &input : inputs) - mlirInputs.push_back(input.get()); - std::vector mlirResults; - mlirResults.reserve(results.size()); - for (const auto &result : results) - mlirResults.push_back(result.get()); - - MlirType t = mlirFunctionTypeGet(context->get(), inputs.size(), - mlirInputs.data(), results.size(), - mlirResults.data()); - return PyFunctionType(context->getRef(), t); - }, - nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(), - "Gets a FunctionType from a list of input and result types"); - c.def_static( - "get", - [](std::vector inputs, std::vector results, - DefaultingPyMlirContext context) { - MlirType t = - mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(), - results.size(), results.data()); - return PyFunctionType(context->getRef(), t); - }, - nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(), - // clang-format off - nb::sig("def get(inputs: Sequence[Type], results: Sequence[Type], context: Context | None = None) -> FunctionType"), - // clang-format on - "Gets a FunctionType from a list of input and result types"); - c.def_prop_ro( - "inputs", - [](PyFunctionType &self) { - MlirType t = self; - nb::list types; - for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e; - ++i) { - types.append(mlirFunctionTypeGetInput(t, i)); - } - return types; - }, - "Returns the list of input types in the FunctionType."); - c.def_prop_ro( - "results", - [](PyFunctionType &self) { - nb::list types; - for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e; - ++i) { - types.append(mlirFunctionTypeGetResult(self, i)); - } - return types; - }, - "Returns the list of result types in the FunctionType."); - } -}; - -static MlirStringRef toMlirStringRef(const std::string &s) { - return mlirStringRefCreate(s.data(), s.size()); -} - -/// Opaque Type subclass - OpaqueType. -class PyOpaqueType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAOpaque; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirOpaqueTypeGetTypeID; - static constexpr const char *pyClassName = "OpaqueType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](const std::string &dialectNamespace, const std::string &typeData, - DefaultingPyMlirContext context) { - MlirType type = mlirOpaqueTypeGet(context->get(), - toMlirStringRef(dialectNamespace), - toMlirStringRef(typeData)); - return PyOpaqueType(context->getRef(), type); - }, - nb::arg("dialect_namespace"), nb::arg("buffer"), - nb::arg("context") = nb::none(), - "Create an unregistered (opaque) dialect type."); - c.def_prop_ro( - "dialect_namespace", - [](PyOpaqueType &self) { - MlirStringRef stringRef = mlirOpaqueTypeGetDialectNamespace(self); - return nb::str(stringRef.data, stringRef.length); - }, - "Returns the dialect namespace for the Opaque type as a string."); - c.def_prop_ro( - "data", - [](PyOpaqueType &self) { - MlirStringRef stringRef = mlirOpaqueTypeGetData(self); - return nb::str(stringRef.data, stringRef.length); - }, - "Returns the data for the Opaque type as a string."); - } -}; - } // namespace void mlir::python::populateIRTypes(nb::module_ &m) { - PyIntegerType::bind(m); - PyFloatType::bind(m); - PyIndexType::bind(m); - PyFloat4E2M1FNType::bind(m); - PyFloat6E2M3FNType::bind(m); - PyFloat6E3M2FNType::bind(m); - PyFloat8E4M3FNType::bind(m); - PyFloat8E5M2Type::bind(m); - PyFloat8E4M3Type::bind(m); - PyFloat8E4M3FNUZType::bind(m); - PyFloat8E4M3B11FNUZType::bind(m); - PyFloat8E5M2FNUZType::bind(m); - PyFloat8E3M4Type::bind(m); - PyFloat8E8M0FNUType::bind(m); - PyBF16Type::bind(m); - PyF16Type::bind(m); - PyTF32Type::bind(m); - PyF32Type::bind(m); - PyF64Type::bind(m); - PyNoneType::bind(m); - PyComplexType::bind(m); + // Populate types using mlir_type_subclass + populateIRTypesModule(m); + + // Keep PyShapedType and its subclasses that weren't replaced PyShapedType::bind(m); PyVectorType::bind(m); PyRankedTensorType::bind(m); PyUnrankedTensorType::bind(m); PyMemRefType::bind(m); PyUnrankedMemRefType::bind(m); - PyTupleType::bind(m); - PyFunctionType::bind(m); - PyOpaqueType::bind(m); } diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index ba767ad6692cf..fb73beda4cf88 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -145,6 +145,21 @@ NB_MODULE(_mlir, m) { // Define and populate IR submodule. auto irModule = m.def_submodule("ir", "MLIR IR Bindings"); + irModule.def( + MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR, + [](MlirTypeID mlirTypeID, bool replace) -> nb::object { + return nb::cpp_function([mlirTypeID, replace]( + nb::callable typeCaster) -> nb::object { + PyGlobals::get().registerTypeCaster(mlirTypeID, typeCaster, replace); + return typeCaster; + }); + }, + // clang-format off + nb::sig("def register_type_caster(typeid: _mlir.ir.TypeID, *, replace: bool = False) " + "-> typing.Callable[[typing.Callable[[T], U]], typing.Callable[[T], U]]"), + // clang-format on + "typeid"_a, nb::kw_only(), "replace"_a = false, + "Register a type caster for casting MLIR types to custom user types."); populateIRCore(irModule); populateIRAffine(irModule); populateIRAttributes(irModule); diff --git a/mlir/test/python/dialects/arith_dialect.py b/mlir/test/python/dialects/arith_dialect.py index c9af5e7b46db8..ad318238b77c6 100644 --- a/mlir/test/python/dialects/arith_dialect.py +++ b/mlir/test/python/dialects/arith_dialect.py @@ -54,10 +54,10 @@ def _binary_op(lhs, rhs, op: str) -> "ArithValue": op = getattr(arith, f"{op}Op") return op(lhs, rhs).result - @register_value_caster(F16Type.static_typeid) - @register_value_caster(F32Type.static_typeid) - @register_value_caster(F64Type.static_typeid) - @register_value_caster(IntegerType.static_typeid) + @register_value_caster(F16Type.get_static_typeid()) + @register_value_caster(F32Type.get_static_typeid()) + @register_value_caster(F64Type.get_static_typeid()) + @register_value_caster(IntegerType.get_static_typeid()) class ArithValue(Value): def __init__(self, v): super().__init__(v) diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py index 54863253fc770..20509050eda9f 100644 --- a/mlir/test/python/ir/builtin_types.py +++ b/mlir/test/python/ir/builtin_types.py @@ -185,7 +185,7 @@ def testStandardTypeCasts(): try: tillegal = IntegerType(Type.parse("f32", ctx)) except ValueError as e: - # CHECK: ValueError: Cannot cast type to IntegerType (from Type(f32)) + # CHECK: ValueError: Cannot cast type to IntegerType (from F32Type(f32)) print("ValueError:", e) else: print("Exception not produced") @@ -302,7 +302,7 @@ def testComplexType(): try: complex_invalid = ComplexType.get(index) except ValueError as e: - # CHECK: invalid 'Type(index)' and expected floating point or integer type. + # CHECK: Invalid element type for ComplexType: expected floating point or integer type. print(e) else: print("Exception not produced") @@ -714,7 +714,8 @@ def testTypeIDs(): # mlirTypeGetTypeID(self) for an instance. # CHECK: all equal for t1, t2 in types: - tid1, tid2 = t1.static_typeid, Type(t2).typeid + # TODO: remove the alternative once mlir_type_subclass transition is complete. + tid1, tid2 = t1.static_typeid if hasattr(t1, "static_typeid") else t1.get_static_typeid(), Type(t2).typeid assert tid1 == tid2 and hash(tid1) == hash( tid2 ), f"expected hash and value equality {t1} {t2}" @@ -728,7 +729,9 @@ def testTypeIDs(): # CHECK: all equal for t1, t2 in typeid_dict.items(): - assert t1.static_typeid == t2.typeid and hash(t1.static_typeid) == hash( + # TODO: remove the alternative once mlir_type_subclass transition is complete. + tid1 = t1.static_typeid if hasattr(t1, "static_typeid") else t1.get_static_typeid() + assert tid1 == t2.typeid and hash(tid1) == hash( t2.typeid ), f"expected hash and value equality {t1} {t2}" else: diff --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py index 4a241afb8e89d..9d9b6c2090974 100644 --- a/mlir/test/python/ir/value.py +++ b/mlir/test/python/ir/value.py @@ -361,7 +361,7 @@ def __init__(self, v): def __str__(self): return super().__str__().replace(Value.__name__, NOPBlockArg.__name__) - @register_value_caster(IntegerType.static_typeid) + @register_value_caster(IntegerType.get_static_typeid()) def cast_int(v) -> Value: print("in caster", v.__class__.__name__) if isinstance(v, OpResult): @@ -425,7 +425,7 @@ def reduction(arg0, arg1): try: - @register_value_caster(IntegerType.static_typeid) + @register_value_caster(IntegerType.get_static_typeid()) def dont_cast_int_shouldnt_register(v): ... @@ -433,7 +433,7 @@ def dont_cast_int_shouldnt_register(v): # CHECK: Value caster is already registered: {{.*}}cast_int print(e) - @register_value_caster(IntegerType.static_typeid, replace=True) + @register_value_caster(IntegerType.get_static_typeid(), replace=True) def dont_cast_int(v) -> OpResult: assert isinstance(v, OpResult) print("don't cast", v.result_number, v)