Skip to content

Conversation

@makslevental
Copy link
Contributor

@makslevental makslevental commented Jan 1, 2026

This PR ports all in-tree dialect extensions to use the PyConcreteType, PyConcreteAttribute CRTPs instead of mlir_pure_subclass. After this PR we can soft deprecate mlir_pure_subclass. Also API signatures are updated to use Py* instead of Mlir* so that type "inference" and hints are improved.

depends on #174118

@makslevental makslevental changed the base branch from main to users/makslevental/irtypesirattributesv2 January 1, 2026 19:14
@makslevental makslevental force-pushed the users/makslevental/irtypesirattributesv2 branch from 8de0bcb to a7ab82c Compare January 2, 2026 20:04
@makslevental makslevental force-pushed the users/makslevental/portdialectextensionsv2 branch from 23e3e31 to a428025 Compare January 2, 2026 20:10
@makslevental makslevental marked this pull request as ready for review January 2, 2026 20:23
@llvmbot
Copy link
Member

llvmbot commented Jan 2, 2026

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-backend-amdgpu

Author: Maksim Levental (makslevental)

Changes

This PR ports all in-tree dialect extensions to use the PyConcreteType, PyConcreteAttribute CRTPs instead of mlir_pure_subclass. After this PR we can soft deprecate mlir_pure_subclass.

depends on #174118


Patch is 111.46 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/174156.diff

11 Files Affected:

  • (modified) mlir/lib/Bindings/Python/DialectAMDGPU.cpp (+74-36)
  • (modified) mlir/lib/Bindings/Python/DialectGPU.cpp (+87-65)
  • (modified) mlir/lib/Bindings/Python/DialectLLVM.cpp (+164-133)
  • (modified) mlir/lib/Bindings/Python/DialectNVGPU.cpp (+31-18)
  • (modified) mlir/lib/Bindings/Python/DialectPDL.cpp (+145-83)
  • (modified) mlir/lib/Bindings/Python/DialectQuant.cpp (+454-355)
  • (modified) mlir/lib/Bindings/Python/DialectSMT.cpp (+63-26)
  • (modified) mlir/lib/Bindings/Python/DialectSparseTensor.cpp (+125-109)
  • (modified) mlir/lib/Bindings/Python/DialectTransform.cpp (+150-98)
  • (modified) mlir/python/mlir/dialects/transform/extras/init.py (+6-5)
  • (modified) mlir/test/python/dialects/pdl_types.py (+107-104)
diff --git a/mlir/lib/Bindings/Python/DialectAMDGPU.cpp b/mlir/lib/Bindings/Python/DialectAMDGPU.cpp
index 26ffc0e427e41..26115c3635b7b 100644
--- a/mlir/lib/Bindings/Python/DialectAMDGPU.cpp
+++ b/mlir/lib/Bindings/Python/DialectAMDGPU.cpp
@@ -8,58 +8,96 @@
 
 #include "mlir-c/Dialect/AMDGPU.h"
 #include "mlir-c/IR.h"
+#include "mlir/Bindings/Python/IRCore.h"
 #include "mlir/Bindings/Python/Nanobind.h"
 #include "mlir/Bindings/Python/NanobindAdaptors.h"
 #include "nanobind/nanobind.h"
 
 namespace nb = nanobind;
 using namespace llvm;
-using namespace mlir;
-using namespace mlir::python;
 using namespace mlir::python::nanobind_adaptors;
 
-static void populateDialectAMDGPUSubmodule(const nb::module_ &m) {
-  auto amdgpuTDMBaseType =
-      mlir_type_subclass(m, "TDMBaseType", mlirTypeIsAAMDGPUTDMBaseType,
-                         mlirAMDGPUTDMBaseTypeGetTypeID);
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+namespace amdgpu {
+struct TDMBaseType : PyConcreteType<TDMBaseType> {
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAAMDGPUTDMBaseType;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirAMDGPUTDMBaseTypeGetTypeID;
+  static constexpr const char *pyClassName = "TDMBaseType";
+  using PyConcreteType::PyConcreteType;
 
-  amdgpuTDMBaseType.def_classmethod(
-      "get",
-      [](const nb::object &cls, MlirType elementType, MlirContext ctx) {
-        return cls(mlirAMDGPUTDMBaseTypeGet(ctx, elementType));
-      },
-      "Gets an instance of TDMBaseType in the same context", nb::arg("cls"),
-      nb::arg("element_type"), nb::arg("ctx") = nb::none());
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](const PyType &elementType, DefaultingPyMlirContext context) {
+          return TDMBaseType(
+              context->getRef(),
+              mlirAMDGPUTDMBaseTypeGet(context.get()->get(), elementType));
+        },
+        "Gets an instance of TDMBaseType in the same context",
+        nb::arg("element_type"), nb::arg("context").none() = nb::none());
+  }
+};
 
-  auto amdgpuTDMDescriptorType = mlir_type_subclass(
-      m, "TDMDescriptorType", mlirTypeIsAAMDGPUTDMDescriptorType,
-      mlirAMDGPUTDMDescriptorTypeGetTypeID);
+struct TDMDescriptorType : PyConcreteType<TDMDescriptorType> {
+  static constexpr IsAFunctionTy isaFunction =
+      mlirTypeIsAAMDGPUTDMDescriptorType;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirAMDGPUTDMDescriptorTypeGetTypeID;
+  static constexpr const char *pyClassName = "TDMDescriptorType";
+  using PyConcreteType::PyConcreteType;
 
-  amdgpuTDMDescriptorType.def_classmethod(
-      "get",
-      [](const nb::object &cls, MlirContext ctx) {
-        return cls(mlirAMDGPUTDMDescriptorTypeGet(ctx));
-      },
-      "Gets an instance of TDMDescriptorType in the same context",
-      nb::arg("cls"), nb::arg("ctx") = nb::none());
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](DefaultingPyMlirContext context) {
+          return TDMDescriptorType(
+              context->getRef(),
+              mlirAMDGPUTDMDescriptorTypeGet(context.get()->get()));
+        },
+        "Gets an instance of TDMDescriptorType in the same context",
+        nb::arg("context").none() = nb::none());
+  }
+};
 
-  auto amdgpuTDMGatherBaseType = mlir_type_subclass(
-      m, "TDMGatherBaseType", mlirTypeIsAAMDGPUTDMGatherBaseType,
-      mlirAMDGPUTDMGatherBaseTypeGetTypeID);
+struct TDMGatherBaseType : PyConcreteType<TDMGatherBaseType> {
+  static constexpr IsAFunctionTy isaFunction =
+      mlirTypeIsAAMDGPUTDMGatherBaseType;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirAMDGPUTDMGatherBaseTypeGetTypeID;
+  static constexpr const char *pyClassName = "TDMGatherBaseType";
+  using PyConcreteType::PyConcreteType;
 
-  amdgpuTDMGatherBaseType.def_classmethod(
-      "get",
-      [](const nb::object &cls, MlirType elementType, MlirType indexType,
-         MlirContext ctx) {
-        return cls(mlirAMDGPUTDMGatherBaseTypeGet(ctx, elementType, indexType));
-      },
-      "Gets an instance of TDMGatherBaseType in the same context",
-      nb::arg("cls"), nb::arg("element_type"), nb::arg("index_type"),
-      nb::arg("ctx") = nb::none());
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](const PyType &elementType, const PyType &indexType,
+           DefaultingPyMlirContext context) {
+          return TDMGatherBaseType(
+              context->getRef(),
+              mlirAMDGPUTDMGatherBaseTypeGet(context.get()->get(), elementType,
+                                             indexType));
+        },
+        "Gets an instance of TDMGatherBaseType in the same context",
+        nb::arg("element_type"), nb::arg("index_type"),
+        nb::arg("context").none() = nb::none());
+  }
 };
 
+static void populateDialectAMDGPUSubmodule(nb::module_ &m) {
+  TDMBaseType::bind(m);
+  TDMDescriptorType::bind(m);
+  TDMGatherBaseType::bind(m);
+}
+} // namespace amdgpu
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
+
 NB_MODULE(_mlirDialectsAMDGPU, m) {
   m.doc() = "MLIR AMDGPU dialect.";
 
-  populateDialectAMDGPUSubmodule(m);
+  mlir::python::mlir::amdgpu::populateDialectAMDGPUSubmodule(m);
 }
diff --git a/mlir/lib/Bindings/Python/DialectGPU.cpp b/mlir/lib/Bindings/Python/DialectGPU.cpp
index 2568d535edb5a..ea3748cc88b85 100644
--- a/mlir/lib/Bindings/Python/DialectGPU.cpp
+++ b/mlir/lib/Bindings/Python/DialectGPU.cpp
@@ -9,83 +9,105 @@
 #include "mlir-c/Dialect/GPU.h"
 #include "mlir-c/IR.h"
 #include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/IRCore.h"
 #include "mlir/Bindings/Python/Nanobind.h"
 #include "mlir/Bindings/Python/NanobindAdaptors.h"
 
 namespace nb = nanobind;
 using namespace nanobind::literals;
-
-using namespace mlir;
-using namespace mlir::python;
 using namespace mlir::python::nanobind_adaptors;
 
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+namespace gpu {
 // -----------------------------------------------------------------------------
-// Module initialization.
+// AsyncTokenType
 // -----------------------------------------------------------------------------
 
-NB_MODULE(_mlirDialectsGPU, m) {
-  m.doc() = "MLIR GPU Dialect";
-  //===-------------------------------------------------------------------===//
-  // AsyncTokenType
-  //===-------------------------------------------------------------------===//
+struct AsyncTokenType : PyConcreteType<AsyncTokenType> {
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAGPUAsyncTokenType;
+  static constexpr const char *pyClassName = "AsyncTokenType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](DefaultingPyMlirContext context) {
+          return AsyncTokenType(context->getRef(),
+                                mlirGPUAsyncTokenTypeGet(context.get()->get()));
+        },
+        "Gets an instance of AsyncTokenType in the same context",
+        nb::arg("context").none() = nb::none());
+  }
+};
+
+//===-------------------------------------------------------------------===//
+// ObjectAttr
+//===-------------------------------------------------------------------===//
+
+struct ObjectAttr : PyConcreteAttribute<ObjectAttr> {
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAGPUObjectAttr;
+  static constexpr const char *pyClassName = "ObjectAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
 
-  auto mlirGPUAsyncTokenType =
-      mlir_type_subclass(m, "AsyncTokenType", mlirTypeIsAGPUAsyncTokenType);
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](MlirAttribute target, uint32_t format, const nb::bytes &object,
+           std::optional<MlirAttribute> mlirObjectProps,
+           std::optional<MlirAttribute> mlirKernelsAttr,
+           DefaultingPyMlirContext context) {
+          MlirStringRef objectStrRef = mlirStringRefCreate(
+              static_cast<char *>(const_cast<void *>(object.data())),
+              object.size());
+          return ObjectAttr(
+              context->getRef(),
+              mlirGPUObjectAttrGetWithKernels(
+                  mlirAttributeGetContext(target), target, format, objectStrRef,
+                  mlirObjectProps.has_value() ? *mlirObjectProps
+                                              : MlirAttribute{nullptr},
+                  mlirKernelsAttr.has_value() ? *mlirKernelsAttr
+                                              : MlirAttribute{nullptr}));
+        },
+        "target"_a, "format"_a, "object"_a, "properties"_a = nb::none(),
+        "kernels"_a = nb::none(), "context"_a = nb::none(),
+        "Gets a gpu.object from parameters.");
 
-  mlirGPUAsyncTokenType.def_classmethod(
-      "get",
-      [](const nb::object &cls, MlirContext ctx) {
-        return cls(mlirGPUAsyncTokenTypeGet(ctx));
-      },
-      "Gets an instance of AsyncTokenType in the same context", nb::arg("cls"),
-      nb::arg("ctx") = nb::none());
+    c.def_prop_ro("target", [](MlirAttribute self) {
+      return mlirGPUObjectAttrGetTarget(self);
+    });
+    c.def_prop_ro("format", [](MlirAttribute self) {
+      return mlirGPUObjectAttrGetFormat(self);
+    });
+    c.def_prop_ro("object", [](MlirAttribute self) {
+      MlirStringRef stringRef = mlirGPUObjectAttrGetObject(self);
+      return nb::bytes(stringRef.data, stringRef.length);
+    });
+    c.def_prop_ro("properties", [](MlirAttribute self) -> nb::object {
+      if (mlirGPUObjectAttrHasProperties(self))
+        return nb::cast(mlirGPUObjectAttrGetProperties(self));
+      return nb::none();
+    });
+    c.def_prop_ro("kernels", [](MlirAttribute self) -> nb::object {
+      if (mlirGPUObjectAttrHasKernels(self))
+        return nb::cast(mlirGPUObjectAttrGetKernels(self));
+      return nb::none();
+    });
+  }
+};
+} // namespace gpu
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
 
-  //===-------------------------------------------------------------------===//
-  // ObjectAttr
-  //===-------------------------------------------------------------------===//
+// -----------------------------------------------------------------------------
+// Module initialization.
+// -----------------------------------------------------------------------------
+
+NB_MODULE(_mlirDialectsGPU, m) {
+  m.doc() = "MLIR GPU Dialect";
 
-  mlir_attribute_subclass(m, "ObjectAttr", mlirAttributeIsAGPUObjectAttr)
-      .def_classmethod(
-          "get",
-          [](const nb::object &cls, MlirAttribute target, uint32_t format,
-             const nb::bytes &object,
-             std::optional<MlirAttribute> mlirObjectProps,
-             std::optional<MlirAttribute> mlirKernelsAttr) {
-            MlirStringRef objectStrRef = mlirStringRefCreate(
-                static_cast<char *>(const_cast<void *>(object.data())),
-                object.size());
-            return cls(mlirGPUObjectAttrGetWithKernels(
-                mlirAttributeGetContext(target), target, format, objectStrRef,
-                mlirObjectProps.has_value() ? *mlirObjectProps
-                                            : MlirAttribute{nullptr},
-                mlirKernelsAttr.has_value() ? *mlirKernelsAttr
-                                            : MlirAttribute{nullptr}));
-          },
-          "cls"_a, "target"_a, "format"_a, "object"_a,
-          "properties"_a = nb::none(), "kernels"_a = nb::none(),
-          "Gets a gpu.object from parameters.")
-      .def_property_readonly(
-          "target",
-          [](MlirAttribute self) { return mlirGPUObjectAttrGetTarget(self); })
-      .def_property_readonly(
-          "format",
-          [](MlirAttribute self) { return mlirGPUObjectAttrGetFormat(self); })
-      .def_property_readonly(
-          "object",
-          [](MlirAttribute self) {
-            MlirStringRef stringRef = mlirGPUObjectAttrGetObject(self);
-            return nb::bytes(stringRef.data, stringRef.length);
-          })
-      .def_property_readonly("properties",
-                             [](MlirAttribute self) -> nb::object {
-                               if (mlirGPUObjectAttrHasProperties(self))
-                                 return nb::cast(
-                                     mlirGPUObjectAttrGetProperties(self));
-                               return nb::none();
-                             })
-      .def_property_readonly("kernels", [](MlirAttribute self) -> nb::object {
-        if (mlirGPUObjectAttrHasKernels(self))
-          return nb::cast(mlirGPUObjectAttrGetKernels(self));
-        return nb::none();
-      });
+  mlir::python::mlir::gpu::AsyncTokenType::bind(m);
+  mlir::python::mlir::gpu::ObjectAttr::bind(m);
 }
diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp
index 05681cecf82b3..d4eb078c0f55c 100644
--- a/mlir/lib/Bindings/Python/DialectLLVM.cpp
+++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp
@@ -13,149 +13,176 @@
 #include "mlir-c/Support.h"
 #include "mlir-c/Target/LLVMIR.h"
 #include "mlir/Bindings/Python/Diagnostics.h"
+#include "mlir/Bindings/Python/IRCore.h"
 #include "mlir/Bindings/Python/Nanobind.h"
 #include "mlir/Bindings/Python/NanobindAdaptors.h"
 
 namespace nb = nanobind;
 
 using namespace nanobind::literals;
-
 using namespace llvm;
 using namespace mlir;
-using namespace mlir::python;
 using namespace mlir::python::nanobind_adaptors;
 
-static void populateDialectLLVMSubmodule(nanobind::module_ &m) {
-
-  //===--------------------------------------------------------------------===//
-  // StructType
-  //===--------------------------------------------------------------------===//
-
-  auto llvmStructType = mlir_type_subclass(
-      m, "StructType", mlirTypeIsALLVMStructType, mlirLLVMStructTypeGetTypeID);
-
-  llvmStructType
-      .def_classmethod(
-          "get_literal",
-          [](const nb::object &cls, const std::vector<MlirType> &elements,
-             bool packed, MlirLocation loc) {
-            CollectDiagnosticsToStringScope scope(mlirLocationGetContext(loc));
-
-            MlirType type = mlirLLVMStructTypeLiteralGetChecked(
-                loc, elements.size(), elements.data(), packed);
-            if (mlirTypeIsNull(type)) {
-              throw nb::value_error(scope.takeMessage().c_str());
-            }
-            return cls(type);
-          },
-          "cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
-          "loc"_a = nb::none())
-      .def_classmethod(
-          "get_literal_unchecked",
-          [](const nb::object &cls, const std::vector<MlirType> &elements,
-             bool packed, MlirContext context) {
-            CollectDiagnosticsToStringScope scope(context);
-
-            MlirType type = mlirLLVMStructTypeLiteralGet(
-                context, elements.size(), elements.data(), packed);
-            if (mlirTypeIsNull(type)) {
-              throw nb::value_error(scope.takeMessage().c_str());
-            }
-            return cls(type);
-          },
-          "cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
-          "context"_a = nb::none());
-
-  llvmStructType.def_classmethod(
-      "get_identified",
-      [](const nb::object &cls, const std::string &name, MlirContext context) {
-        return cls(mlirLLVMStructTypeIdentifiedGet(
-            context, mlirStringRefCreate(name.data(), name.size())));
-      },
-      "cls"_a, "name"_a, nb::kw_only(), "context"_a = nb::none());
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+namespace llvm {
+//===--------------------------------------------------------------------===//
+// StructType
+//===--------------------------------------------------------------------===//
+
+struct StructType : PyConcreteType<StructType> {
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsALLVMStructType;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirLLVMStructTypeGetTypeID;
+  static constexpr const char *pyClassName = "StructType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get_literal",
+        [](const std::vector<MlirType> &elements, bool packed, MlirLocation loc,
+           DefaultingPyMlirContext context) {
+          python::CollectDiagnosticsToStringScope scope(
+              mlirLocationGetContext(loc));
+
+          MlirType type = mlirLLVMStructTypeLiteralGetChecked(
+              loc, elements.size(), elements.data(), packed);
+          if (mlirTypeIsNull(type)) {
+            throw nb::value_error(scope.takeMessage().c_str());
+          }
+          return StructType(context->getRef(), type);
+        },
+        "elements"_a, nb::kw_only(), "packed"_a = false, "loc"_a = nb::none(),
+        "context"_a = nb::none());
+
+    c.def_static(
+        "get_literal_unchecked",
+        [](const std::vector<MlirType> &elements, bool packed,
+           DefaultingPyMlirContext context) {
+          python::CollectDiagnosticsToStringScope scope(context.get()->get());
+
+          MlirType type = mlirLLVMStructTypeLiteralGet(
+              context.get()->get(), elements.size(), elements.data(), packed);
+          if (mlirTypeIsNull(type)) {
+            throw nb::value_error(scope.takeMessage().c_str());
+          }
+          return StructType(context->getRef(), type);
+        },
+        "elements"_a, nb::kw_only(), "packed"_a = false,
+        "context"_a = nb::none());
+
+    c.def_static(
+        "get_identified",
+        [](const std::string &name, DefaultingPyMlirContext context) {
+          return StructType(context->getRef(),
+                            mlirLLVMStructTypeIdentifiedGet(
+                                context.get()->get(),
+                                mlirStringRefCreate(name.data(), name.size())));
+        },
+        "name"_a, nb::kw_only(), "context"_a = nb::none());
+
+    c.def_static(
+        "get_opaque",
+        [](const std::string &name, DefaultingPyMlirContext context) {
+          return StructType(context->getRef(),
+                            mlirLLVMStructTypeOpaqueGet(
+                                context.get()->get(),
+                                mlirStringRefCreate(name.data(), name.size())));
+        },
+        "name"_a, "context"_a = nb::none());
+
+    c.def(
+        "set_body",
+        [](MlirType self, const std::vector<MlirType> &elements, bool packed) {
+          MlirLogicalResult result = mlirLLVMStructTypeSetBody(
+              self, elements.size(), elements.data(), packed);
+          if (!mlirLogicalResultIsSuccess(result)) {
+            throw nb::value_error(
+                "Struct body already set to different content.");
+          }
+        },
+        "elements"_a, nb::kw_only(), "packed"_a = false);
+
+    c.def_static(
+        "new_identified",
+        [](const std::string &name, const std::vector<MlirType> &elements,
+           bool packed, DefaultingPyMlirContext context) {
+          return StructType(context->getRef(),
+                            mlirLLVMStructTypeIdentifiedNewGet(
+                                context.get()->get(),
+                                mlirStringRefCreate(name.data(), name.length()),
+                                elements.size(), elements.data(), packed));
+        },
+        "name"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
+        "context"_a = nb::none());
+
+    c.def_prop_ro("name", [](PyType type) -> std::optional<std::string> {
+      if (mlirLLVMStructTypeIsLiteral(type))
+        return std::nullopt;
+
+      MlirStringRef stringRef = mlirLLVMStructTypeGetIdentifier(type);
+      return StringRef(stringRef.data, stringRef.length).str();
+    });
+
+    c.def_prop_ro("body", [](PyType type) -> nb::object {
+      // Don't crash in absence of a body.
+      if (mlirLLVMStructTypeIsOpaque(type))
+        return nb::none();
+
+      nb::list body;
+      for (intptr_t i = 0, e = mlirLLVMStructTypeGetNumElementTypes(type);
+           i < e; ++i) {
+        body.append(mlirLLVMStructTypeGetElementType(type, i));
+      }
+      return body;
+    });
+
+    c.def_prop_ro("packed",
+                  [](PyType type) { return mlirLLVMStructTypeIsPacked(type); })...
[truncated]

@makslevental makslevental force-pushed the users/makslevental/portdialectextensionsv2 branch 2 times, most recently from 7fd173c to b06bef3 Compare January 2, 2026 20:40
@makslevental makslevental force-pushed the users/makslevental/irtypesirattributesv2 branch from a7ab82c to 8bc6e79 Compare January 2, 2026 21:02
@makslevental makslevental force-pushed the users/makslevental/portdialectextensionsv2 branch from b06bef3 to 0140394 Compare January 2, 2026 21:02
@makslevental makslevental changed the title [mlir][Python] port dialect extensions to use core PyConcreteType, PyConcreteAttribute [mlir][Python] port in-tree dialect extensions to use core PyConcreteType, PyConcreteAttribute Jan 2, 2026
@makslevental makslevental removed the request for review from aaronmondal January 2, 2026 22:00
@makslevental makslevental changed the title [mlir][Python] port in-tree dialect extensions to use core PyConcreteType, PyConcreteAttribute [mlir][Python] port in-tree dialect extensions to use core MLIRPythonSupport Jan 2, 2026
@makslevental makslevental changed the title [mlir][Python] port in-tree dialect extensions to use core MLIRPythonSupport [mlir][Python] port in-tree dialect extensions to use MLIRPythonSupport Jan 2, 2026
@makslevental makslevental force-pushed the users/makslevental/portdialectextensionsv2 branch 3 times, most recently from ba7f935 to 9481f0e Compare January 4, 2026 03:07
@makslevental makslevental force-pushed the users/makslevental/irtypesirattributesv2 branch from 8bc6e79 to 0bef604 Compare January 4, 2026 03:28
@makslevental makslevental force-pushed the users/makslevental/portdialectextensionsv2 branch from 9481f0e to b456039 Compare January 4, 2026 03:28
@makslevental makslevental force-pushed the users/makslevental/irtypesirattributesv2 branch from 0bef604 to 31105c7 Compare January 4, 2026 04:03
@makslevental makslevental force-pushed the users/makslevental/portdialectextensionsv2 branch 2 times, most recently from dd70894 to 85ffbd0 Compare January 4, 2026 04:04
@makslevental makslevental force-pushed the users/makslevental/irtypesirattributesv2 branch from 31105c7 to 0a50db5 Compare January 5, 2026 17:11
@makslevental makslevental force-pushed the users/makslevental/portdialectextensionsv2 branch from 85ffbd0 to 8f8fa37 Compare January 5, 2026 17:12
makslevental added a commit that referenced this pull request Jan 5, 2026
…74118)

This PR continues the work of
#171775 by moving more useful
types/attributes into MLIRPythonSupport.

You can now do 

```c++
struct PyTestIntegerRankedTensorType
    : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteType<
          PyTestIntegerRankedTensorType,
          mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyRankedTensorType>
struct PyTestTensorValue
    : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteValue<
          PyTestTensorValue>
```
instead of `mlir_type_subclass` and `mlir_value_subclass`;
**specifically manual registration of the "value caster" via indirection
through the Python interpreter is no longer necessary** . You can also
now freely use all such types at the nanobind API level (e.g., overload
based on `FP*`):

```c++
using mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN;
standaloneM.def("print_fp_type", [](PyF16Type &) { nb::print("this is a fp16 type"); });
standaloneM.def("print_fp_type", [](PyF32Type &) { nb::print("this is a fp32 type"); });
standaloneM.def("print_fp_type", [](PyF64Type &) { nb::print("this is a fp64 type"); });
```

Note, here we only port `PythonTestModuleNanobind` but there is a
follow-up PR that ports **all** in-tree dialect extensions
#174156 to use these. After
that one we can soft deprecate `mlir_pure_subclass`.

Note, depends on #171775
Base automatically changed from users/makslevental/irtypesirattributesv2 to main January 5, 2026 17:35
@makslevental makslevental force-pushed the users/makslevental/portdialectextensionsv2 branch from 8f8fa37 to bb97a54 Compare January 5, 2026 17:37
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Jan 5, 2026
…Support (#174118)

This PR continues the work of
llvm/llvm-project#171775 by moving more useful
types/attributes into MLIRPythonSupport.

You can now do

```c++
struct PyTestIntegerRankedTensorType
    : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteType<
          PyTestIntegerRankedTensorType,
          mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyRankedTensorType>
struct PyTestTensorValue
    : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteValue<
          PyTestTensorValue>
```
instead of `mlir_type_subclass` and `mlir_value_subclass`;
**specifically manual registration of the "value caster" via indirection
through the Python interpreter is no longer necessary** . You can also
now freely use all such types at the nanobind API level (e.g., overload
based on `FP*`):

```c++
using mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN;
standaloneM.def("print_fp_type", [](PyF16Type &) { nb::print("this is a fp16 type"); });
standaloneM.def("print_fp_type", [](PyF32Type &) { nb::print("this is a fp32 type"); });
standaloneM.def("print_fp_type", [](PyF64Type &) { nb::print("this is a fp64 type"); });
```

Note, here we only port `PythonTestModuleNanobind` but there is a
follow-up PR that ports **all** in-tree dialect extensions
llvm/llvm-project#174156 to use these. After
that one we can soft deprecate `mlir_pure_subclass`.

Note, depends on llvm/llvm-project#171775
@makslevental makslevental merged commit ee3338d into main Jan 5, 2026
10 checks passed
@makslevental makslevental deleted the users/makslevental/portdialectextensionsv2 branch January 5, 2026 18:23
Copy link
Contributor

@boomanaiden154 boomanaiden154 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copying my comments here from the commit to keep discussion centralized and because I can't find my comments 3/4 of the time...

"cast is not valid.",
nb::arg("candidate"));
c.def_static(
"cast_to_storage_type",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're seeing new errors with the changes here. Before we could pass in mlir._mlir_libs._mlir.ir.Type, but now we cannot. Is this an intentional change?

Copy link
Contributor Author

@makslevental makslevental Jan 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah this is indeed a mistake - the .def are gettors on QuantizedType but the def_static aren't necessarily. Let me fix this and add some tests.

[](QuantizedType &type) {
MlirType castResult = mlirQuantizedTypeCastToStorageType(type);
if (!mlirTypeIsNull(castResult))
return PyType(type.getContext(), castResult);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return type here is also different. We have a test that takes the return value of this and then tries to access a .width property that now no longer exists.

Copy link
Contributor Author

@makslevental makslevental Jan 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this one on the other hand isn't quite a mistake as a result of y'all depending on an impl detail:

Type QuantizedType::castToStorageType(Type quantizedType) {
if (llvm::isa<QuantizedType>(quantizedType)) {
// i.e. quant<"uniform[i8:f32]{1.0}"> -> i8
return llvm::cast<QuantizedType>(quantizedType).getStorageType();
}
if (llvm::isa<ShapedType>(quantizedType)) {
// i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
ShapedType sType = llvm::cast<ShapedType>(quantizedType);
if (!llvm::isa<QuantizedType>(sType.getElementType())) {
return nullptr;
}
Type storageType =
llvm::cast<QuantizedType>(sType.getElementType()).getStorageType();
if (llvm::isa<RankedTensorType>(quantizedType)) {
return RankedTensorType::get(sType.getShape(), storageType);
}
if (llvm::isa<UnrankedTensorType>(quantizedType)) {
return UnrankedTensorType::get(storageType);
}
if (llvm::isa<VectorType>(quantizedType)) {
return VectorType::get(sType.getShape(), storageType);
}
}
return nullptr;

as you can see the return is Type but some of these returns would be "downcasted" (in the bindings themselves) to a type that would have .width field. What I should've done here is PyType.maybeDowncast(). Let me do that and it should fix y'all.

makslevental added a commit that referenced this pull request Jan 5, 2026
…174156 (#174489)

#174156 made all gettors return `Py*` but skipped downcasting where
possible. So restore it by calling `.maybeDowncast`.
mahesh-attarde pushed a commit to mahesh-attarde/llvm-project that referenced this pull request Jan 6, 2026
…vm#174118)

This PR continues the work of
llvm#171775 by moving more useful
types/attributes into MLIRPythonSupport.

You can now do 

```c++
struct PyTestIntegerRankedTensorType
    : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteType<
          PyTestIntegerRankedTensorType,
          mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyRankedTensorType>
struct PyTestTensorValue
    : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteValue<
          PyTestTensorValue>
```
instead of `mlir_type_subclass` and `mlir_value_subclass`;
**specifically manual registration of the "value caster" via indirection
through the Python interpreter is no longer necessary** . You can also
now freely use all such types at the nanobind API level (e.g., overload
based on `FP*`):

```c++
using mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN;
standaloneM.def("print_fp_type", [](PyF16Type &) { nb::print("this is a fp16 type"); });
standaloneM.def("print_fp_type", [](PyF32Type &) { nb::print("this is a fp32 type"); });
standaloneM.def("print_fp_type", [](PyF64Type &) { nb::print("this is a fp64 type"); });
```

Note, here we only port `PythonTestModuleNanobind` but there is a
follow-up PR that ports **all** in-tree dialect extensions
llvm#174156 to use these. After
that one we can soft deprecate `mlir_pure_subclass`.

Note, depends on llvm#171775
mahesh-attarde pushed a commit to mahesh-attarde/llvm-project that referenced this pull request Jan 6, 2026
…rt (llvm#174156)

This PR ports all in-tree dialect extensions to use the
`PyConcreteType`, `PyConcreteAttribute` CRTPs instead of
`mlir_pure_subclass`. After this PR we can soft deprecate
`mlir_pure_subclass`. Also API signatures are updated to use `Py*`
instead of `Mlir*` so that type "inference" and hints are improved.
mahesh-attarde pushed a commit to mahesh-attarde/llvm-project that referenced this pull request Jan 6, 2026
…lvm#174156 (llvm#174489)

llvm#174156 made all gettors return `Py*` but skipped downcasting where
possible. So restore it by calling `.maybeDowncast`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants