-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[mlir][py] partially use mlir_type_subclass for IRTypes.cpp #171143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Following a series of refactorings, MLIR Python bindings would crash if a dialect object requiring a context defined using mlir_attribute/type_subclass was constructed outside of the `ir.Context` context manager. The type caster for `MlirContext` would try using `ir.Context.current` when the default `None` value was provided to the `get`, which would also just return `None`. The caster would then attempt to obtain the MLIR capsule for that `None`, fail, but access it anyway without checking, leading to a C++ assertion failure or segfault. Guard against this case in nanobind adaptors. Also emit a warning to the user to clarify expectations, as the default message confusingly says that `None` is accepted as context and then fails with a type error. Using Python C API is currently recommended by nanobind in this case since the surrounding function must be marked `noexcept`. The corresponding test is in the PDL dialect since it is where I first observed the behavior. Core types are not using the `mlir_type_subclass` mechanism and are immune to the problem, so cannot be used for checking.
Port the bindings for non-shaped builtin types in IRTypes.cpp to use the `mlir_type_subclass` mechanism used by non-builtin types. This is part of a longer-term cleanup to only support one subclassing mechanism. Eventually, the `PyConcreteType` mechanism will be removed. This required a surgery in the type casters and the `mlir_type_subclass` logic to avoid circular imports of the `_mlir.ir` module that would otherwise when using `mlir_type_subclass` to define classes in the `_mlir.ir` module. Tests are updated to use the `.get_static_typeid()` function instead of the `.static_typeid` property that was specific to builtin types due to the `PyConcreteType` mechanism. The change should be NFC otherwise.
|
@llvm/pr-subscribers-mlir Author: Oleksandr "Alex" Zinenko (ftynse) ChangesPort the bindings for non-shaped builtin types in IRTypes.cpp to use the This required a surgery in the type casters and the Tests are updated to use the Patch is 49.00 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/171143.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h
index 6594670abaaa7..f678f57527e97 100644
--- a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h
+++ b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h
@@ -371,16 +371,22 @@ struct type_caster<MlirTypeID> {
}
return false;
}
- static handle from_cpp(MlirTypeID v, rv_policy,
- cleanup_list *cleanup) noexcept {
+
+ static handle
+ from_cpp_given_module(MlirTypeID v,
+ const nanobind::module_ &module) noexcept {
if (v.ptr == nullptr)
return nanobind::none();
nanobind::object capsule =
nanobind::steal<nanobind::object>(mlirPythonTypeIDToCapsule(v));
- return mlir::python::irModule()
- .attr("TypeID")
+ return module.attr("TypeID")
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
.release();
+ }
+
+ static handle from_cpp(MlirTypeID v, rv_policy,
+ cleanup_list *cleanup) noexcept {
+ return from_cpp_given_module(v, mlir::python::irModule());
};
};
@@ -602,9 +608,12 @@ class mlir_type_subclass : public pure_subclass {
/// Subclasses by looking up the super-class dynamically.
mlir_type_subclass(nanobind::handle scope, const char *typeClassName,
IsAFunctionTy isaFunction,
- GetTypeIDFunctionTy getTypeIDFunction = nullptr)
- : mlir_type_subclass(scope, typeClassName, isaFunction,
- irModule().attr("Type"), getTypeIDFunction) {}
+ GetTypeIDFunctionTy getTypeIDFunction = nullptr,
+ const nanobind::module_ *mlirIrModule = nullptr)
+ : mlir_type_subclass(
+ scope, typeClassName, isaFunction,
+ (mlirIrModule != nullptr ? *mlirIrModule : irModule()).attr("Type"),
+ getTypeIDFunction, mlirIrModule) {}
/// Subclasses with a provided mlir.ir.Type super-class. This must
/// be used if the subclass is being defined in the same extension module
@@ -613,7 +622,8 @@ class mlir_type_subclass : public pure_subclass {
mlir_type_subclass(nanobind::handle scope, const char *typeClassName,
IsAFunctionTy isaFunction,
const nanobind::object &superCls,
- GetTypeIDFunctionTy getTypeIDFunction = nullptr)
+ GetTypeIDFunctionTy getTypeIDFunction = nullptr,
+ const nanobind::module_ *mlirIrModule = nullptr)
: pure_subclass(scope, typeClassName, superCls) {
// Casting constructor. Note that it is hard, if not impossible, to properly
// call chain to parent `__init__` in nanobind due to its special handling
@@ -672,9 +682,18 @@ class mlir_type_subclass : public pure_subclass {
nanobind::sig("def get_static_typeid() -> " MAKE_MLIR_PYTHON_QUALNAME("ir.TypeID"))
// clang-format on
);
- nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
- .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
- getTypeIDFunction())(nanobind::cpp_function(
+
+ // Directly call the caster implementation given the "ir" module,
+ // otherwise it may trigger recursive import as the default caster
+ // attempts to import the "ir" module.
+ MlirTypeID typeID = getTypeIDFunction();
+ mlirIrModule = mlirIrModule ? mlirIrModule : &irModule();
+ nanobind::handle pyTypeID =
+ nanobind::detail::type_caster<MlirTypeID>::from_cpp_given_module(
+ typeID, *mlirIrModule);
+
+ mlirIrModule->attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(pyTypeID)(
+ nanobind::cpp_function(
[thisClass = thisClass](const nanobind::object &mlirType) {
return thisClass(mlirType);
}));
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 34c5b8dd86a66..2e4090c358c47 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -18,13 +18,13 @@
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
using namespace mlir;
using namespace mlir::python;
using llvm::SmallVector;
-using llvm::Twine;
namespace {
@@ -34,480 +34,368 @@ static int mlirTypeIsAIntegerOrFloat(MlirType type) {
mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type);
}
-class PyIntegerType : public PyConcreteType<PyIntegerType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirIntegerTypeGetTypeID;
- static constexpr const char *pyClassName = "IntegerType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get_signless",
- [](unsigned width, DefaultingPyMlirContext context) {
- MlirType t = mlirIntegerTypeGet(context->get(), width);
- return PyIntegerType(context->getRef(), t);
- },
- nb::arg("width"), nb::arg("context") = nb::none(),
- "Create a signless integer type");
- c.def_static(
- "get_signed",
- [](unsigned width, DefaultingPyMlirContext context) {
- MlirType t = mlirIntegerTypeSignedGet(context->get(), width);
- return PyIntegerType(context->getRef(), t);
- },
- nb::arg("width"), nb::arg("context") = nb::none(),
- "Create a signed integer type");
- c.def_static(
- "get_unsigned",
- [](unsigned width, DefaultingPyMlirContext context) {
- MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width);
- return PyIntegerType(context->getRef(), t);
- },
- nb::arg("width"), nb::arg("context") = nb::none(),
- "Create an unsigned integer type");
- c.def_prop_ro(
- "width",
- [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); },
- "Returns the width of the integer type");
- c.def_prop_ro(
- "is_signless",
- [](PyIntegerType &self) -> bool {
- return mlirIntegerTypeIsSignless(self);
- },
- "Returns whether this is a signless integer");
- c.def_prop_ro(
- "is_signed",
- [](PyIntegerType &self) -> bool {
- return mlirIntegerTypeIsSigned(self);
- },
- "Returns whether this is a signed integer");
- c.def_prop_ro(
- "is_unsigned",
- [](PyIntegerType &self) -> bool {
- return mlirIntegerTypeIsUnsigned(self);
- },
- "Returns whether this is an unsigned integer");
- }
-};
-
-/// Index Type subclass - IndexType.
-class PyIndexType : public PyConcreteType<PyIndexType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirIndexTypeGetTypeID;
- static constexpr const char *pyClassName = "IndexType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirIndexTypeGet(context->get());
- return PyIndexType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a index type.");
- }
-};
-
-class PyFloatType : public PyConcreteType<PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat;
- static constexpr const char *pyClassName = "FloatType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_prop_ro(
- "width", [](PyFloatType &self) { return mlirFloatTypeGetWidth(self); },
- "Returns the width of the floating-point type");
- }
-};
-
-/// Floating Point Type subclass - Float4E2M1FNType.
-class PyFloat4E2M1FNType
- : public PyConcreteType<PyFloat4E2M1FNType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat4E2M1FN;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat4E2M1FNTypeGetTypeID;
- static constexpr const char *pyClassName = "Float4E2M1FNType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat4E2M1FNTypeGet(context->get());
- return PyFloat4E2M1FNType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float4_e2m1fn type.");
- }
-};
-
-/// Floating Point Type subclass - Float6E2M3FNType.
-class PyFloat6E2M3FNType
- : public PyConcreteType<PyFloat6E2M3FNType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E2M3FN;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat6E2M3FNTypeGetTypeID;
- static constexpr const char *pyClassName = "Float6E2M3FNType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat6E2M3FNTypeGet(context->get());
- return PyFloat6E2M3FNType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float6_e2m3fn type.");
- }
-};
-
-/// Floating Point Type subclass - Float6E3M2FNType.
-class PyFloat6E3M2FNType
- : public PyConcreteType<PyFloat6E3M2FNType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E3M2FN;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat6E3M2FNTypeGetTypeID;
- static constexpr const char *pyClassName = "Float6E3M2FNType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat6E3M2FNTypeGet(context->get());
- return PyFloat6E3M2FNType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float6_e3m2fn type.");
- }
-};
-
-/// Floating Point Type subclass - Float8E4M3FNType.
-class PyFloat8E4M3FNType
- : public PyConcreteType<PyFloat8E4M3FNType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E4M3FNTypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E4M3FNType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E4M3FNTypeGet(context->get());
- return PyFloat8E4M3FNType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float8_e4m3fn type.");
- }
-};
-
-/// Floating Point Type subclass - Float8E5M2Type.
-class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E5M2TypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E5M2Type";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E5M2TypeGet(context->get());
- return PyFloat8E5M2Type(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float8_e5m2 type.");
- }
-};
-
-/// Floating Point Type subclass - Float8E4M3Type.
-class PyFloat8E4M3Type : public PyConcreteType<PyFloat8E4M3Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E4M3TypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E4M3Type";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E4M3TypeGet(context->get());
- return PyFloat8E4M3Type(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float8_e4m3 type.");
- }
-};
-
-/// Floating Point Type subclass - Float8E4M3FNUZ.
-class PyFloat8E4M3FNUZType
- : public PyConcreteType<PyFloat8E4M3FNUZType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E4M3FNUZTypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E4M3FNUZType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get());
- return PyFloat8E4M3FNUZType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float8_e4m3fnuz type.");
- }
-};
-
-/// Floating Point Type subclass - Float8E4M3B11FNUZ.
-class PyFloat8E4M3B11FNUZType
- : public PyConcreteType<PyFloat8E4M3B11FNUZType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E4M3B11FNUZTypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E4M3B11FNUZType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get());
- return PyFloat8E4M3B11FNUZType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float8_e4m3b11fnuz type.");
- }
-};
-
-/// Floating Point Type subclass - Float8E5M2FNUZ.
-class PyFloat8E5M2FNUZType
- : public PyConcreteType<PyFloat8E5M2FNUZType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E5M2FNUZTypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E5M2FNUZType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get());
- return PyFloat8E5M2FNUZType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float8_e5m2fnuz type.");
- }
-};
-
-/// Floating Point Type subclass - Float8E3M4Type.
-class PyFloat8E3M4Type : public PyConcreteType<PyFloat8E3M4Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E3M4;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E3M4TypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E3M4Type";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E3M4TypeGet(context->get());
- return PyFloat8E3M4Type(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float8_e3m4 type.");
- }
-};
-
-/// Floating Point Type subclass - Float8E8M0FNUType.
-class PyFloat8E8M0FNUType
- : public PyConcreteType<PyFloat8E8M0FNUType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E8M0FNU;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E8M0FNUTypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E8M0FNUType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E8M0FNUTypeGet(context->get());
- return PyFloat8E8M0FNUType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float8_e8m0fnu type.");
- }
-};
-
-/// Floating Point Type subclass - BF16Type.
-class PyBF16Type : public PyConcreteType<PyBF16Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirBFloat16TypeGetTypeID;
- static constexpr const char *pyClassName = "BF16Type";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirBF16TypeGet(context->get());
- return PyBF16Type(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a bf16 type.");
- }
-};
-
-/// Floating Point Type subclass - F16Type.
-class PyF16Type : public PyConcreteType<PyF16Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat16TypeGetTypeID;
- static constexpr const char *pyClassName = "F16Type";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirF16TypeGet(context->get());
- return PyF16Type(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a f16 type.");
- }
-};
-
-/// Floating Point Type subclass - TF32Type.
-class PyTF32Type : public PyConcreteType<PyTF32Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsATF32;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloatTF32TypeGetTypeID;
- static constexpr const char *pyClassName = "FloatTF32Type";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirTF32TypeGet(context->get());
- return PyTF32Type(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a tf32 type.");
- }
-};
-
-/// Floating Point Type subclass - F32Type.
-class PyF32Type : public PyConcreteType<PyF32Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat32TypeGetTypeID;
- static constexpr const char *pyClassName = "F32Type";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirF32TypeGet(context->get());
- return PyF32Type(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a f32 type.");
- }
-};
+static void populateIRTypesModule(const nanobind::module_ &m) {
+ using namespace nanobind_adaptors;
-/// Floating Point Type subclass - F64Type.
-class PyF64Type : public PyConcreteType<PyF64Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat64TypeGetTypeID;
- static constexpr const char *pyClassName = "F64Type";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirF64TypeGet(context->get());
- return PyF64Type(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a f64 type.");
- }
-};
-
-/// None Type subclass - NoneType.
-class PyNoneType : public PyConcreteType<PyNoneType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirNoneTypeGetTypeID;
- static constexpr const char *pyClassName = "NoneType";
- using PyConcreteType::PyConcreteType;
-
- ...
[truncated]
|
You can test this locally with the following command:darker --check --diff -r origin/main...HEAD mlir/test/python/dialects/arith_dialect.py mlir/test/python/ir/builtin_types.py mlir/test/python/ir/value.py
View the diff from darker here.--- ir/builtin_types.py 2025-12-08 14:50:41.000000 +0000
+++ ir/builtin_types.py 2025-12-08 14:58:38.795311 +0000
@@ -713,11 +713,16 @@
# Test getTypeIdFunction agrees with
# mlirTypeGetTypeID(self) for an instance.
# CHECK: all equal
for t1, t2 in types:
# TODO: remove the alternative once mlir_type_subclass transition is complete.
- tid1, tid2 = t1.static_typeid if hasattr(t1, "static_typeid") else t1.get_static_typeid(), Type(t2).typeid
+ tid1, tid2 = (
+ t1.static_typeid
+ if hasattr(t1, "static_typeid")
+ else t1.get_static_typeid(),
+ Type(t2).typeid,
+ )
assert tid1 == tid2 and hash(tid1) == hash(
tid2
), f"expected hash and value equality {t1} {t2}"
else:
print("all equal")
@@ -728,11 +733,15 @@
assert len(typeid_dict)
# CHECK: all equal
for t1, t2 in typeid_dict.items():
# TODO: remove the alternative once mlir_type_subclass transition is complete.
- tid1 = t1.static_typeid if hasattr(t1, "static_typeid") else t1.get_static_typeid()
+ tid1 = (
+ t1.static_typeid
+ if hasattr(t1, "static_typeid")
+ else t1.get_static_typeid()
+ )
assert tid1 == t2.typeid and hash(tid1) == hash(
t2.typeid
), f"expected hash and value equality {t1} {t2}"
else:
print("all equal")
|
makslevental
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't a hard block - just a request for discussion.
This is part of a longer-term cleanup to only support one subclassing mechanism.
I had the same idea but I think we should go in the opposite direction - remove all of the mlir_*_subclasses and unify on PyConcrete*. I started that here #156575 but de-prioritized. If this unification is a priority, I can finish that PR this week.
I don't think mlir_*_subclass is the way to go - they're not "real" in the sense that you're giving up everything useful about nanobind by using them - i.e., all of the convenience def*, type signature generation, etc.
Why do you think mlir_type_subclass should be it instead of PyConcreteType?
e6047d9 to
68db47a
Compare
|
Just to summarize an offline discussion here for public knowledge: I went with |
Port the bindings for non-shaped builtin types in IRTypes.cpp to use the
mlir_type_subclassmechanism used by non-builtin types. This is part of a longer-term cleanup to only support one subclassing mechanism. Eventually, thePyConcreteTypemechanism will be removed.This required a surgery in the type casters and the
mlir_type_subclasslogic to avoid circular imports of the_mlir.irmodule that would otherwise when usingmlir_type_subclassto define classes in the_mlir.irmodule.Tests are updated to use the
.get_static_typeid()function instead of the.static_typeidproperty that was specific to builtin types due to thePyConcreteTypemechanism. The change should be NFC otherwise.