-
Notifications
You must be signed in to change notification settings - Fork 15.7k
[mlir][Python] port in-tree dialect extensions to use MLIRPythonSupport #174156
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
8de0bcb to
a7ab82c
Compare
23e3e31 to
a428025
Compare
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-backend-amdgpu Author: Maksim Levental (makslevental) ChangesThis PR ports all in-tree dialect extensions to use the depends on #174118 Patch is 111.46 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/174156.diff 11 Files Affected:
diff --git a/mlir/lib/Bindings/Python/DialectAMDGPU.cpp b/mlir/lib/Bindings/Python/DialectAMDGPU.cpp
index 26ffc0e427e41..26115c3635b7b 100644
--- a/mlir/lib/Bindings/Python/DialectAMDGPU.cpp
+++ b/mlir/lib/Bindings/Python/DialectAMDGPU.cpp
@@ -8,58 +8,96 @@
#include "mlir-c/Dialect/AMDGPU.h"
#include "mlir-c/IR.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "nanobind/nanobind.h"
namespace nb = nanobind;
using namespace llvm;
-using namespace mlir;
-using namespace mlir::python;
using namespace mlir::python::nanobind_adaptors;
-static void populateDialectAMDGPUSubmodule(const nb::module_ &m) {
- auto amdgpuTDMBaseType =
- mlir_type_subclass(m, "TDMBaseType", mlirTypeIsAAMDGPUTDMBaseType,
- mlirAMDGPUTDMBaseTypeGetTypeID);
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+namespace amdgpu {
+struct TDMBaseType : PyConcreteType<TDMBaseType> {
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAAMDGPUTDMBaseType;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirAMDGPUTDMBaseTypeGetTypeID;
+ static constexpr const char *pyClassName = "TDMBaseType";
+ using PyConcreteType::PyConcreteType;
- amdgpuTDMBaseType.def_classmethod(
- "get",
- [](const nb::object &cls, MlirType elementType, MlirContext ctx) {
- return cls(mlirAMDGPUTDMBaseTypeGet(ctx, elementType));
- },
- "Gets an instance of TDMBaseType in the same context", nb::arg("cls"),
- nb::arg("element_type"), nb::arg("ctx") = nb::none());
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](const PyType &elementType, DefaultingPyMlirContext context) {
+ return TDMBaseType(
+ context->getRef(),
+ mlirAMDGPUTDMBaseTypeGet(context.get()->get(), elementType));
+ },
+ "Gets an instance of TDMBaseType in the same context",
+ nb::arg("element_type"), nb::arg("context").none() = nb::none());
+ }
+};
- auto amdgpuTDMDescriptorType = mlir_type_subclass(
- m, "TDMDescriptorType", mlirTypeIsAAMDGPUTDMDescriptorType,
- mlirAMDGPUTDMDescriptorTypeGetTypeID);
+struct TDMDescriptorType : PyConcreteType<TDMDescriptorType> {
+ static constexpr IsAFunctionTy isaFunction =
+ mlirTypeIsAAMDGPUTDMDescriptorType;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirAMDGPUTDMDescriptorTypeGetTypeID;
+ static constexpr const char *pyClassName = "TDMDescriptorType";
+ using PyConcreteType::PyConcreteType;
- amdgpuTDMDescriptorType.def_classmethod(
- "get",
- [](const nb::object &cls, MlirContext ctx) {
- return cls(mlirAMDGPUTDMDescriptorTypeGet(ctx));
- },
- "Gets an instance of TDMDescriptorType in the same context",
- nb::arg("cls"), nb::arg("ctx") = nb::none());
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ return TDMDescriptorType(
+ context->getRef(),
+ mlirAMDGPUTDMDescriptorTypeGet(context.get()->get()));
+ },
+ "Gets an instance of TDMDescriptorType in the same context",
+ nb::arg("context").none() = nb::none());
+ }
+};
- auto amdgpuTDMGatherBaseType = mlir_type_subclass(
- m, "TDMGatherBaseType", mlirTypeIsAAMDGPUTDMGatherBaseType,
- mlirAMDGPUTDMGatherBaseTypeGetTypeID);
+struct TDMGatherBaseType : PyConcreteType<TDMGatherBaseType> {
+ static constexpr IsAFunctionTy isaFunction =
+ mlirTypeIsAAMDGPUTDMGatherBaseType;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirAMDGPUTDMGatherBaseTypeGetTypeID;
+ static constexpr const char *pyClassName = "TDMGatherBaseType";
+ using PyConcreteType::PyConcreteType;
- amdgpuTDMGatherBaseType.def_classmethod(
- "get",
- [](const nb::object &cls, MlirType elementType, MlirType indexType,
- MlirContext ctx) {
- return cls(mlirAMDGPUTDMGatherBaseTypeGet(ctx, elementType, indexType));
- },
- "Gets an instance of TDMGatherBaseType in the same context",
- nb::arg("cls"), nb::arg("element_type"), nb::arg("index_type"),
- nb::arg("ctx") = nb::none());
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](const PyType &elementType, const PyType &indexType,
+ DefaultingPyMlirContext context) {
+ return TDMGatherBaseType(
+ context->getRef(),
+ mlirAMDGPUTDMGatherBaseTypeGet(context.get()->get(), elementType,
+ indexType));
+ },
+ "Gets an instance of TDMGatherBaseType in the same context",
+ nb::arg("element_type"), nb::arg("index_type"),
+ nb::arg("context").none() = nb::none());
+ }
};
+static void populateDialectAMDGPUSubmodule(nb::module_ &m) {
+ TDMBaseType::bind(m);
+ TDMDescriptorType::bind(m);
+ TDMGatherBaseType::bind(m);
+}
+} // namespace amdgpu
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
+
NB_MODULE(_mlirDialectsAMDGPU, m) {
m.doc() = "MLIR AMDGPU dialect.";
- populateDialectAMDGPUSubmodule(m);
+ mlir::python::mlir::amdgpu::populateDialectAMDGPUSubmodule(m);
}
diff --git a/mlir/lib/Bindings/Python/DialectGPU.cpp b/mlir/lib/Bindings/Python/DialectGPU.cpp
index 2568d535edb5a..ea3748cc88b85 100644
--- a/mlir/lib/Bindings/Python/DialectGPU.cpp
+++ b/mlir/lib/Bindings/Python/DialectGPU.cpp
@@ -9,83 +9,105 @@
#include "mlir-c/Dialect/GPU.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
using namespace nanobind::literals;
-
-using namespace mlir;
-using namespace mlir::python;
using namespace mlir::python::nanobind_adaptors;
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+namespace gpu {
// -----------------------------------------------------------------------------
-// Module initialization.
+// AsyncTokenType
// -----------------------------------------------------------------------------
-NB_MODULE(_mlirDialectsGPU, m) {
- m.doc() = "MLIR GPU Dialect";
- //===-------------------------------------------------------------------===//
- // AsyncTokenType
- //===-------------------------------------------------------------------===//
+struct AsyncTokenType : PyConcreteType<AsyncTokenType> {
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAGPUAsyncTokenType;
+ static constexpr const char *pyClassName = "AsyncTokenType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ return AsyncTokenType(context->getRef(),
+ mlirGPUAsyncTokenTypeGet(context.get()->get()));
+ },
+ "Gets an instance of AsyncTokenType in the same context",
+ nb::arg("context").none() = nb::none());
+ }
+};
+
+//===-------------------------------------------------------------------===//
+// ObjectAttr
+//===-------------------------------------------------------------------===//
+
+struct ObjectAttr : PyConcreteAttribute<ObjectAttr> {
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAGPUObjectAttr;
+ static constexpr const char *pyClassName = "ObjectAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
- auto mlirGPUAsyncTokenType =
- mlir_type_subclass(m, "AsyncTokenType", mlirTypeIsAGPUAsyncTokenType);
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](MlirAttribute target, uint32_t format, const nb::bytes &object,
+ std::optional<MlirAttribute> mlirObjectProps,
+ std::optional<MlirAttribute> mlirKernelsAttr,
+ DefaultingPyMlirContext context) {
+ MlirStringRef objectStrRef = mlirStringRefCreate(
+ static_cast<char *>(const_cast<void *>(object.data())),
+ object.size());
+ return ObjectAttr(
+ context->getRef(),
+ mlirGPUObjectAttrGetWithKernels(
+ mlirAttributeGetContext(target), target, format, objectStrRef,
+ mlirObjectProps.has_value() ? *mlirObjectProps
+ : MlirAttribute{nullptr},
+ mlirKernelsAttr.has_value() ? *mlirKernelsAttr
+ : MlirAttribute{nullptr}));
+ },
+ "target"_a, "format"_a, "object"_a, "properties"_a = nb::none(),
+ "kernels"_a = nb::none(), "context"_a = nb::none(),
+ "Gets a gpu.object from parameters.");
- mlirGPUAsyncTokenType.def_classmethod(
- "get",
- [](const nb::object &cls, MlirContext ctx) {
- return cls(mlirGPUAsyncTokenTypeGet(ctx));
- },
- "Gets an instance of AsyncTokenType in the same context", nb::arg("cls"),
- nb::arg("ctx") = nb::none());
+ c.def_prop_ro("target", [](MlirAttribute self) {
+ return mlirGPUObjectAttrGetTarget(self);
+ });
+ c.def_prop_ro("format", [](MlirAttribute self) {
+ return mlirGPUObjectAttrGetFormat(self);
+ });
+ c.def_prop_ro("object", [](MlirAttribute self) {
+ MlirStringRef stringRef = mlirGPUObjectAttrGetObject(self);
+ return nb::bytes(stringRef.data, stringRef.length);
+ });
+ c.def_prop_ro("properties", [](MlirAttribute self) -> nb::object {
+ if (mlirGPUObjectAttrHasProperties(self))
+ return nb::cast(mlirGPUObjectAttrGetProperties(self));
+ return nb::none();
+ });
+ c.def_prop_ro("kernels", [](MlirAttribute self) -> nb::object {
+ if (mlirGPUObjectAttrHasKernels(self))
+ return nb::cast(mlirGPUObjectAttrGetKernels(self));
+ return nb::none();
+ });
+ }
+};
+} // namespace gpu
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
- //===-------------------------------------------------------------------===//
- // ObjectAttr
- //===-------------------------------------------------------------------===//
+// -----------------------------------------------------------------------------
+// Module initialization.
+// -----------------------------------------------------------------------------
+
+NB_MODULE(_mlirDialectsGPU, m) {
+ m.doc() = "MLIR GPU Dialect";
- mlir_attribute_subclass(m, "ObjectAttr", mlirAttributeIsAGPUObjectAttr)
- .def_classmethod(
- "get",
- [](const nb::object &cls, MlirAttribute target, uint32_t format,
- const nb::bytes &object,
- std::optional<MlirAttribute> mlirObjectProps,
- std::optional<MlirAttribute> mlirKernelsAttr) {
- MlirStringRef objectStrRef = mlirStringRefCreate(
- static_cast<char *>(const_cast<void *>(object.data())),
- object.size());
- return cls(mlirGPUObjectAttrGetWithKernels(
- mlirAttributeGetContext(target), target, format, objectStrRef,
- mlirObjectProps.has_value() ? *mlirObjectProps
- : MlirAttribute{nullptr},
- mlirKernelsAttr.has_value() ? *mlirKernelsAttr
- : MlirAttribute{nullptr}));
- },
- "cls"_a, "target"_a, "format"_a, "object"_a,
- "properties"_a = nb::none(), "kernels"_a = nb::none(),
- "Gets a gpu.object from parameters.")
- .def_property_readonly(
- "target",
- [](MlirAttribute self) { return mlirGPUObjectAttrGetTarget(self); })
- .def_property_readonly(
- "format",
- [](MlirAttribute self) { return mlirGPUObjectAttrGetFormat(self); })
- .def_property_readonly(
- "object",
- [](MlirAttribute self) {
- MlirStringRef stringRef = mlirGPUObjectAttrGetObject(self);
- return nb::bytes(stringRef.data, stringRef.length);
- })
- .def_property_readonly("properties",
- [](MlirAttribute self) -> nb::object {
- if (mlirGPUObjectAttrHasProperties(self))
- return nb::cast(
- mlirGPUObjectAttrGetProperties(self));
- return nb::none();
- })
- .def_property_readonly("kernels", [](MlirAttribute self) -> nb::object {
- if (mlirGPUObjectAttrHasKernels(self))
- return nb::cast(mlirGPUObjectAttrGetKernels(self));
- return nb::none();
- });
+ mlir::python::mlir::gpu::AsyncTokenType::bind(m);
+ mlir::python::mlir::gpu::ObjectAttr::bind(m);
}
diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp
index 05681cecf82b3..d4eb078c0f55c 100644
--- a/mlir/lib/Bindings/Python/DialectLLVM.cpp
+++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp
@@ -13,149 +13,176 @@
#include "mlir-c/Support.h"
#include "mlir-c/Target/LLVMIR.h"
#include "mlir/Bindings/Python/Diagnostics.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
using namespace nanobind::literals;
-
using namespace llvm;
using namespace mlir;
-using namespace mlir::python;
using namespace mlir::python::nanobind_adaptors;
-static void populateDialectLLVMSubmodule(nanobind::module_ &m) {
-
- //===--------------------------------------------------------------------===//
- // StructType
- //===--------------------------------------------------------------------===//
-
- auto llvmStructType = mlir_type_subclass(
- m, "StructType", mlirTypeIsALLVMStructType, mlirLLVMStructTypeGetTypeID);
-
- llvmStructType
- .def_classmethod(
- "get_literal",
- [](const nb::object &cls, const std::vector<MlirType> &elements,
- bool packed, MlirLocation loc) {
- CollectDiagnosticsToStringScope scope(mlirLocationGetContext(loc));
-
- MlirType type = mlirLLVMStructTypeLiteralGetChecked(
- loc, elements.size(), elements.data(), packed);
- if (mlirTypeIsNull(type)) {
- throw nb::value_error(scope.takeMessage().c_str());
- }
- return cls(type);
- },
- "cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
- "loc"_a = nb::none())
- .def_classmethod(
- "get_literal_unchecked",
- [](const nb::object &cls, const std::vector<MlirType> &elements,
- bool packed, MlirContext context) {
- CollectDiagnosticsToStringScope scope(context);
-
- MlirType type = mlirLLVMStructTypeLiteralGet(
- context, elements.size(), elements.data(), packed);
- if (mlirTypeIsNull(type)) {
- throw nb::value_error(scope.takeMessage().c_str());
- }
- return cls(type);
- },
- "cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
- "context"_a = nb::none());
-
- llvmStructType.def_classmethod(
- "get_identified",
- [](const nb::object &cls, const std::string &name, MlirContext context) {
- return cls(mlirLLVMStructTypeIdentifiedGet(
- context, mlirStringRefCreate(name.data(), name.size())));
- },
- "cls"_a, "name"_a, nb::kw_only(), "context"_a = nb::none());
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+namespace llvm {
+//===--------------------------------------------------------------------===//
+// StructType
+//===--------------------------------------------------------------------===//
+
+struct StructType : PyConcreteType<StructType> {
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsALLVMStructType;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirLLVMStructTypeGetTypeID;
+ static constexpr const char *pyClassName = "StructType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get_literal",
+ [](const std::vector<MlirType> &elements, bool packed, MlirLocation loc,
+ DefaultingPyMlirContext context) {
+ python::CollectDiagnosticsToStringScope scope(
+ mlirLocationGetContext(loc));
+
+ MlirType type = mlirLLVMStructTypeLiteralGetChecked(
+ loc, elements.size(), elements.data(), packed);
+ if (mlirTypeIsNull(type)) {
+ throw nb::value_error(scope.takeMessage().c_str());
+ }
+ return StructType(context->getRef(), type);
+ },
+ "elements"_a, nb::kw_only(), "packed"_a = false, "loc"_a = nb::none(),
+ "context"_a = nb::none());
+
+ c.def_static(
+ "get_literal_unchecked",
+ [](const std::vector<MlirType> &elements, bool packed,
+ DefaultingPyMlirContext context) {
+ python::CollectDiagnosticsToStringScope scope(context.get()->get());
+
+ MlirType type = mlirLLVMStructTypeLiteralGet(
+ context.get()->get(), elements.size(), elements.data(), packed);
+ if (mlirTypeIsNull(type)) {
+ throw nb::value_error(scope.takeMessage().c_str());
+ }
+ return StructType(context->getRef(), type);
+ },
+ "elements"_a, nb::kw_only(), "packed"_a = false,
+ "context"_a = nb::none());
+
+ c.def_static(
+ "get_identified",
+ [](const std::string &name, DefaultingPyMlirContext context) {
+ return StructType(context->getRef(),
+ mlirLLVMStructTypeIdentifiedGet(
+ context.get()->get(),
+ mlirStringRefCreate(name.data(), name.size())));
+ },
+ "name"_a, nb::kw_only(), "context"_a = nb::none());
+
+ c.def_static(
+ "get_opaque",
+ [](const std::string &name, DefaultingPyMlirContext context) {
+ return StructType(context->getRef(),
+ mlirLLVMStructTypeOpaqueGet(
+ context.get()->get(),
+ mlirStringRefCreate(name.data(), name.size())));
+ },
+ "name"_a, "context"_a = nb::none());
+
+ c.def(
+ "set_body",
+ [](MlirType self, const std::vector<MlirType> &elements, bool packed) {
+ MlirLogicalResult result = mlirLLVMStructTypeSetBody(
+ self, elements.size(), elements.data(), packed);
+ if (!mlirLogicalResultIsSuccess(result)) {
+ throw nb::value_error(
+ "Struct body already set to different content.");
+ }
+ },
+ "elements"_a, nb::kw_only(), "packed"_a = false);
+
+ c.def_static(
+ "new_identified",
+ [](const std::string &name, const std::vector<MlirType> &elements,
+ bool packed, DefaultingPyMlirContext context) {
+ return StructType(context->getRef(),
+ mlirLLVMStructTypeIdentifiedNewGet(
+ context.get()->get(),
+ mlirStringRefCreate(name.data(), name.length()),
+ elements.size(), elements.data(), packed));
+ },
+ "name"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
+ "context"_a = nb::none());
+
+ c.def_prop_ro("name", [](PyType type) -> std::optional<std::string> {
+ if (mlirLLVMStructTypeIsLiteral(type))
+ return std::nullopt;
+
+ MlirStringRef stringRef = mlirLLVMStructTypeGetIdentifier(type);
+ return StringRef(stringRef.data, stringRef.length).str();
+ });
+
+ c.def_prop_ro("body", [](PyType type) -> nb::object {
+ // Don't crash in absence of a body.
+ if (mlirLLVMStructTypeIsOpaque(type))
+ return nb::none();
+
+ nb::list body;
+ for (intptr_t i = 0, e = mlirLLVMStructTypeGetNumElementTypes(type);
+ i < e; ++i) {
+ body.append(mlirLLVMStructTypeGetElementType(type, i));
+ }
+ return body;
+ });
+
+ c.def_prop_ro("packed",
+ [](PyType type) { return mlirLLVMStructTypeIsPacked(type); })...
[truncated]
|
7fd173c to
b06bef3
Compare
a7ab82c to
8bc6e79
Compare
b06bef3 to
0140394
Compare
ba7f935 to
9481f0e
Compare
8bc6e79 to
0bef604
Compare
9481f0e to
b456039
Compare
0bef604 to
31105c7
Compare
dd70894 to
85ffbd0
Compare
31105c7 to
0a50db5
Compare
85ffbd0 to
8f8fa37
Compare
…74118) This PR continues the work of #171775 by moving more useful types/attributes into MLIRPythonSupport. You can now do ```c++ struct PyTestIntegerRankedTensorType : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteType< PyTestIntegerRankedTensorType, mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyRankedTensorType> struct PyTestTensorValue : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteValue< PyTestTensorValue> ``` instead of `mlir_type_subclass` and `mlir_value_subclass`; **specifically manual registration of the "value caster" via indirection through the Python interpreter is no longer necessary** . You can also now freely use all such types at the nanobind API level (e.g., overload based on `FP*`): ```c++ using mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN; standaloneM.def("print_fp_type", [](PyF16Type &) { nb::print("this is a fp16 type"); }); standaloneM.def("print_fp_type", [](PyF32Type &) { nb::print("this is a fp32 type"); }); standaloneM.def("print_fp_type", [](PyF64Type &) { nb::print("this is a fp64 type"); }); ``` Note, here we only port `PythonTestModuleNanobind` but there is a follow-up PR that ports **all** in-tree dialect extensions #174156 to use these. After that one we can soft deprecate `mlir_pure_subclass`. Note, depends on #171775
8f8fa37 to
bb97a54
Compare
…Support (#174118) This PR continues the work of llvm/llvm-project#171775 by moving more useful types/attributes into MLIRPythonSupport. You can now do ```c++ struct PyTestIntegerRankedTensorType : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteType< PyTestIntegerRankedTensorType, mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyRankedTensorType> struct PyTestTensorValue : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteValue< PyTestTensorValue> ``` instead of `mlir_type_subclass` and `mlir_value_subclass`; **specifically manual registration of the "value caster" via indirection through the Python interpreter is no longer necessary** . You can also now freely use all such types at the nanobind API level (e.g., overload based on `FP*`): ```c++ using mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN; standaloneM.def("print_fp_type", [](PyF16Type &) { nb::print("this is a fp16 type"); }); standaloneM.def("print_fp_type", [](PyF32Type &) { nb::print("this is a fp32 type"); }); standaloneM.def("print_fp_type", [](PyF64Type &) { nb::print("this is a fp64 type"); }); ``` Note, here we only port `PythonTestModuleNanobind` but there is a follow-up PR that ports **all** in-tree dialect extensions llvm/llvm-project#174156 to use these. After that one we can soft deprecate `mlir_pure_subclass`. Note, depends on llvm/llvm-project#171775
boomanaiden154
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copying my comments here from the commit to keep discussion centralized and because I can't find my comments 3/4 of the time...
| "cast is not valid.", | ||
| nb::arg("candidate")); | ||
| c.def_static( | ||
| "cast_to_storage_type", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're seeing new errors with the changes here. Before we could pass in mlir._mlir_libs._mlir.ir.Type, but now we cannot. Is this an intentional change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah this is indeed a mistake - the .def are gettors on QuantizedType but the def_static aren't necessarily. Let me fix this and add some tests.
| [](QuantizedType &type) { | ||
| MlirType castResult = mlirQuantizedTypeCastToStorageType(type); | ||
| if (!mlirTypeIsNull(castResult)) | ||
| return PyType(type.getContext(), castResult); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The return type here is also different. We have a test that takes the return value of this and then tries to access a .width property that now no longer exists.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this one on the other hand isn't quite a mistake as a result of y'all depending on an impl detail:
llvm-project/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
Lines 152 to 176 in 69aa6a0
| Type QuantizedType::castToStorageType(Type quantizedType) { | |
| if (llvm::isa<QuantizedType>(quantizedType)) { | |
| // i.e. quant<"uniform[i8:f32]{1.0}"> -> i8 | |
| return llvm::cast<QuantizedType>(quantizedType).getStorageType(); | |
| } | |
| if (llvm::isa<ShapedType>(quantizedType)) { | |
| // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> | |
| ShapedType sType = llvm::cast<ShapedType>(quantizedType); | |
| if (!llvm::isa<QuantizedType>(sType.getElementType())) { | |
| return nullptr; | |
| } | |
| Type storageType = | |
| llvm::cast<QuantizedType>(sType.getElementType()).getStorageType(); | |
| if (llvm::isa<RankedTensorType>(quantizedType)) { | |
| return RankedTensorType::get(sType.getShape(), storageType); | |
| } | |
| if (llvm::isa<UnrankedTensorType>(quantizedType)) { | |
| return UnrankedTensorType::get(storageType); | |
| } | |
| if (llvm::isa<VectorType>(quantizedType)) { | |
| return VectorType::get(sType.getShape(), storageType); | |
| } | |
| } | |
| return nullptr; |
as you can see the return is Type but some of these returns would be "downcasted" (in the bindings themselves) to a type that would have .width field. What I should've done here is PyType.maybeDowncast(). Let me do that and it should fix y'all.
…vm#174118) This PR continues the work of llvm#171775 by moving more useful types/attributes into MLIRPythonSupport. You can now do ```c++ struct PyTestIntegerRankedTensorType : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteType< PyTestIntegerRankedTensorType, mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyRankedTensorType> struct PyTestTensorValue : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteValue< PyTestTensorValue> ``` instead of `mlir_type_subclass` and `mlir_value_subclass`; **specifically manual registration of the "value caster" via indirection through the Python interpreter is no longer necessary** . You can also now freely use all such types at the nanobind API level (e.g., overload based on `FP*`): ```c++ using mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN; standaloneM.def("print_fp_type", [](PyF16Type &) { nb::print("this is a fp16 type"); }); standaloneM.def("print_fp_type", [](PyF32Type &) { nb::print("this is a fp32 type"); }); standaloneM.def("print_fp_type", [](PyF64Type &) { nb::print("this is a fp64 type"); }); ``` Note, here we only port `PythonTestModuleNanobind` but there is a follow-up PR that ports **all** in-tree dialect extensions llvm#174156 to use these. After that one we can soft deprecate `mlir_pure_subclass`. Note, depends on llvm#171775
…rt (llvm#174156) This PR ports all in-tree dialect extensions to use the `PyConcreteType`, `PyConcreteAttribute` CRTPs instead of `mlir_pure_subclass`. After this PR we can soft deprecate `mlir_pure_subclass`. Also API signatures are updated to use `Py*` instead of `Mlir*` so that type "inference" and hints are improved.
…lvm#174156 (llvm#174489) llvm#174156 made all gettors return `Py*` but skipped downcasting where possible. So restore it by calling `.maybeDowncast`.
This PR ports all in-tree dialect extensions to use the
PyConcreteType,PyConcreteAttributeCRTPs instead ofmlir_pure_subclass. After this PR we can soft deprecatemlir_pure_subclass. Also API signatures are updated to usePy*instead ofMlir*so that type "inference" and hints are improved.depends on #174118