Skip to content

Conversation

PragmaTwice
Copy link
Member

@PragmaTwice PragmaTwice commented Sep 20, 2025

In the MLIR Python bindings, we can currently use PDL to define simple patterns and then execute them with the greedy rewrite driver. However, when dealing with more complex patterns—such as constant folding for integer addition—we find that we need apply_native_rewrite to actually perform arithmetic (i.e., compute the sum of two constants). For example, consider the following PDL pseudocode:

pdl.pattern : benefit(1) {
  %a0 = pdl.attribute
  %a1 = pdl.attribute
  %c0 = pdl.operation "arith.constant" {value = %a0}
  %c1 = pdl.operation "arith.constant" {value = %a1}

  %op = pdl.operation "arith.addi"(%c0, %c1)

  %sum = pdl.apply_native_rewrite "addIntegers"(%a0, %a1)
  %new_cst = pdl.operation "arith.constant" {value = %sum}

  pdl.replace %op with %new_cst
}

Here, addIntegers cannot be expressed in PDL alone—it requires a native rewrite function. This PR introduces a mechanism to support exactly that, allowing complex rewrite patterns to be expressed in Python and enabling many passes to be implemented directly in Python as well.

As a test case, we defined two new operations (myint.constant and myint.add) in Python and implemented a constant-folding rewrite pattern for them. The core code looks like this:

m = Module.create()
with InsertionPoint(m.body):

    @pdl.pattern(benefit=1, sym_name="myint_add_fold")
    def pat():
        ...
        op0 = pdl.OperationOp(name="myint.add", args=[v0, v1], types=[t])

        @pdl.rewrite()
        def rew():
            sum = pdl.apply_native_rewrite(
                [pdl.AttributeType.get()], "add_fold", [a0, a1]
            )
            newOp = pdl.OperationOp(
                name="myint.constant", attributes={"value": sum}, types=[t]
            )
            pdl.ReplaceOp(op0, with_op=newOp)

def add_fold(rewriter, results, values):
    a0, a1 = values
    results.push_back(IntegerAttr.get(i32, a0.value + a1.value))

pdl_module = PDLModule(m)
pdl_module.register_rewrite_function("add_fold", add_fold)

The idea is previously discussed in Discord #mlir-python channel with @makslevental.

Copy link

github-actions bot commented Sep 20, 2025

✅ With the latest revision this PR passed the Python code formatter.

@PragmaTwice PragmaTwice marked this pull request as ready for review September 20, 2025 15:08
@llvmbot llvmbot added the mlir label Sep 20, 2025
@llvmbot
Copy link
Member

llvmbot commented Sep 20, 2025

@llvm/pr-subscribers-mlir

Author: Twice (PragmaTwice)

Changes

In the MLIR Python bindings, we can currently use PDL to define simple patterns and then execute them with the greedy rewrite driver. However, when dealing with more complex patterns—such as constant folding for integer addition—we find that we need apply_native_rewrite to actually perform arithmetic (i.e., compute the sum of two constants). For example, consider the following PDL pseudocode:

pdl.pattern : benefit(1) {
  %a0 = pdl.attribute
  %a1 = pdl.attribute
  %c0 = pdl.operation "arith.constant" {value = %a0}
  %c1 = pdl.operation "arith.constant" {value = %a1}

  %op = pdl.operation "arith.addi"(%c0, %c1)

  %sum = pdl.apply_native_rewrite "addIntegers"(%a0, %a1)
  %new_cst = pdl.operation "arith.constant" {value = %sum}

  pdl.replace %op with %new_cst
}

Here, addIntegers cannot be expressed in PDL alone—it requires a native rewrite function. This PR introduces a mechanism to support exactly that, allowing complex rewrite patterns to be expressed in Python and enabling many passes to be implemented directly in Python as well.

As a test case, we defined two new operations (myint.constant and myint.add) in Python and implemented a constant-folding rewrite pattern for them. The core code looks like this:

m = Module.create()
with InsertionPoint(m.body):

    @<!-- -->pdl.pattern(benefit=1, sym_name="myint_add_fold")
    def pat():
        ...
        op0 = pdl.OperationOp(name="myint.add", args=[v0, v1], types=[t])

        @<!-- -->pdl.rewrite()
        def rew():
            sum = pdl.apply_native_rewrite(
                [pdl.AttributeType.get()], "add_fold", [a0, a1]
            )
            newOp = pdl.OperationOp(
                name="myint.constant", attributes={"value": sum}, types=[t]
            )
            pdl.ReplaceOp(op0, with_op=newOp)

def add_fold(rewriter, results, values):
    a0, a1 = values
    results.push_back(IntegerAttr.get(i32, a0.value + a1.value))
    return True

pdl_module = PDLModule(m)
pdl_module.register_rewrite_function("add_fold", add_fold)

Full diff: https://github.com/llvm/llvm-project/pull/159926.diff

4 Files Affected:

  • (modified) mlir/include/mlir-c/Rewrite.h (+32)
  • (modified) mlir/lib/Bindings/Python/Rewrite.cpp (+74-5)
  • (modified) mlir/lib/CAPI/Transforms/Rewrite.cpp (+99)
  • (modified) mlir/test/python/integration/dialects/pdl.py (+96)
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 374d2fb78de88..c20558fc8f9d9 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -37,6 +37,7 @@ DEFINE_C_API_STRUCT(MlirRewriterBase, void);
 DEFINE_C_API_STRUCT(MlirFrozenRewritePatternSet, void);
 DEFINE_C_API_STRUCT(MlirGreedyRewriteDriverConfig, void);
 DEFINE_C_API_STRUCT(MlirRewritePatternSet, void);
+DEFINE_C_API_STRUCT(MlirPatternRewriter, void);
 
 //===----------------------------------------------------------------------===//
 /// RewriterBase API inherited from OpBuilder
@@ -315,6 +316,8 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
 
 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
 DEFINE_C_API_STRUCT(MlirPDLPatternModule, void);
+DEFINE_C_API_STRUCT(MlirPDLValue, const void);
+DEFINE_C_API_STRUCT(MlirPDLResultList, void);
 
 MLIR_CAPI_EXPORTED MlirPDLPatternModule
 mlirPDLPatternModuleFromModule(MlirModule op);
@@ -323,6 +326,35 @@ MLIR_CAPI_EXPORTED void mlirPDLPatternModuleDestroy(MlirPDLPatternModule op);
 
 MLIR_CAPI_EXPORTED MlirRewritePatternSet
 mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op);
+
+MLIR_CAPI_EXPORTED bool mlirPDLValueIsValue(MlirPDLValue value);
+MLIR_CAPI_EXPORTED MlirValue mlirPDLValueAsValue(MlirPDLValue value);
+MLIR_CAPI_EXPORTED bool mlirPDLValueIsType(MlirPDLValue value);
+MLIR_CAPI_EXPORTED MlirType mlirPDLValueAsType(MlirPDLValue value);
+MLIR_CAPI_EXPORTED bool mlirPDLValueIsOperation(MlirPDLValue value);
+MLIR_CAPI_EXPORTED MlirOperation mlirPDLValueAsOperation(MlirPDLValue value);
+MLIR_CAPI_EXPORTED bool mlirPDLValueIsAttribute(MlirPDLValue value);
+MLIR_CAPI_EXPORTED MlirAttribute mlirPDLValueAsAttribute(MlirPDLValue value);
+
+MLIR_CAPI_EXPORTED void
+mlirPDLResultListPushBackValue(MlirPDLResultList results, MlirValue value);
+MLIR_CAPI_EXPORTED void mlirPDLResultListPushBackType(MlirPDLResultList results,
+                                                      MlirType value);
+MLIR_CAPI_EXPORTED void
+mlirPDLResultListPushBackOperation(MlirPDLResultList results,
+                                   MlirOperation value);
+MLIR_CAPI_EXPORTED void
+mlirPDLResultListPushBackAttribute(MlirPDLResultList results,
+                                   MlirAttribute value);
+
+typedef MlirLogicalResult (*MlirPDLRewriteFunction)(
+    MlirPatternRewriter rewriter, MlirPDLResultList results, size_t nValues,
+    MlirPDLValue *values, void *userData);
+
+MLIR_CAPI_EXPORTED void mlirPDLPatternModuleRegisterRewriteFunction(
+    MlirPDLPatternModule module, MlirStringRef name,
+    MlirPDLRewriteFunction rewriteFn, void *userData);
+
 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
 
 #undef DEFINE_C_API_STRUCT
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 5b7de50f02e6a..eceb5895fd901 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -9,10 +9,15 @@
 #include "Rewrite.h"
 
 #include "IRModule.h"
+#include "mlir-c/IR.h"
 #include "mlir-c/Rewrite.h"
+#include "mlir-c/Support.h"
+// clang-format off
 #include "mlir/Bindings/Python/Nanobind.h"
 #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
+// clang-format on
 #include "mlir/Config/mlir-config.h"
+#include "nanobind/nanobind.h"
 
 namespace nb = nanobind;
 using namespace mlir;
@@ -36,6 +41,43 @@ class PyPDLPatternModule {
   }
   MlirPDLPatternModule get() { return module; }
 
+  static nb::object fromPDLValue(MlirPDLValue value) {
+    if (mlirPDLValueIsValue(value)) {
+      return nb::cast(mlirPDLValueAsValue(value));
+    }
+    if (mlirPDLValueIsOperation(value)) {
+      return nb::cast(mlirPDLValueAsOperation(value));
+    }
+    if (mlirPDLValueIsAttribute(value)) {
+      return nb::cast(mlirPDLValueAsAttribute(value));
+    }
+    if (mlirPDLValueIsType(value)) {
+      return nb::cast(mlirPDLValueAsType(value));
+    }
+
+    throw std::runtime_error("unsupported PDL value type");
+  }
+
+  void registerRewriteFunction(const std::string &name,
+                               const nb::callable &fn) {
+    mlirPDLPatternModuleRegisterRewriteFunction(
+        get(), mlirStringRefCreate(name.data(), name.size()),
+        [](MlirPatternRewriter rewriter, MlirPDLResultList results,
+           size_t nValues, MlirPDLValue *values,
+           void *userData) -> MlirLogicalResult {
+          auto f = nb::handle(static_cast<PyObject *>(userData));
+          std::vector<nb::object> args;
+          args.reserve(nValues);
+          for (size_t i = 0; i < nValues; ++i) {
+            args.push_back(fromPDLValue(values[i]));
+          }
+          return nb::cast<bool>(f(rewriter, results, args))
+                     ? mlirLogicalResultSuccess()
+                     : mlirLogicalResultFailure();
+        },
+        fn.ptr());
+  }
+
 private:
   MlirPDLPatternModule module;
 };
@@ -76,10 +118,27 @@ class PyFrozenRewritePatternSet {
 
 /// Create the `mlir.rewrite` here.
 void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
+  nb::class_<MlirPatternRewriter>(m, "PatternRewriter");
   //----------------------------------------------------------------------------
-  // Mapping of the top-level PassManager
+  // Mapping of the PDLResultList and PDLModule
   //----------------------------------------------------------------------------
 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
+  nb::class_<MlirPDLResultList>(m, "PDLResultList")
+      .def("push_back",
+           [](MlirPDLResultList results, const PyValue &value) {
+             mlirPDLResultListPushBackValue(results, value);
+           })
+      .def("push_back",
+           [](MlirPDLResultList results, const PyOperation &op) {
+             mlirPDLResultListPushBackOperation(results, op);
+           })
+      .def("push_back",
+           [](MlirPDLResultList results, const PyType &type) {
+             mlirPDLResultListPushBackType(results, type);
+           })
+      .def("push_back", [](MlirPDLResultList results, const PyAttribute &attr) {
+        mlirPDLResultListPushBackAttribute(results, attr);
+      });
   nb::class_<PyPDLPatternModule>(m, "PDLModule")
       .def(
           "__init__",
@@ -88,10 +147,20 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
                 PyPDLPatternModule(mlirPDLPatternModuleFromModule(module));
           },
           "module"_a, "Create a PDL module from the given module.")
-      .def("freeze", [](PyPDLPatternModule &self) {
-        return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
-            mlirRewritePatternSetFromPDLPatternModule(self.get())));
-      });
+      .def(
+          "freeze",
+          [](PyPDLPatternModule &self) {
+            return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
+                mlirRewritePatternSetFromPDLPatternModule(self.get())));
+          },
+          nb::keep_alive<0, 1>())
+      .def(
+          "register_rewrite_function",
+          [](PyPDLPatternModule &self, const std::string &name,
+             const nb::callable &fn) {
+            self.registerRewriteFunction(name, fn);
+          },
+          nb::keep_alive<1, 3>());
 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
   nb::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet")
       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR,
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index 6f85357a14a18..0033abde986ea 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -13,6 +13,8 @@
 #include "mlir/CAPI/Rewrite.h"
 #include "mlir/CAPI/Support.h"
 #include "mlir/CAPI/Wrap.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/PDLPatternMatch.h.inc"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -301,6 +303,19 @@ mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op,
   return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
 }
 
+//===----------------------------------------------------------------------===//
+/// PatternRewriter API
+//===----------------------------------------------------------------------===//
+
+inline mlir::PatternRewriter *unwrap(MlirPatternRewriter rewriter) {
+  assert(rewriter.ptr && "unexpected null rewriter");
+  return static_cast<mlir::PatternRewriter *>(rewriter.ptr);
+}
+
+inline MlirPatternRewriter wrap(mlir::PatternRewriter *rewriter) {
+  return {rewriter};
+}
+
 //===----------------------------------------------------------------------===//
 /// PDLPatternModule API
 //===----------------------------------------------------------------------===//
@@ -331,4 +346,88 @@ mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op) {
   op.ptr = nullptr;
   return wrap(m);
 }
+
+inline const mlir::PDLValue *unwrap(MlirPDLValue value) {
+  assert(value.ptr && "unexpected null PDL value");
+  return static_cast<const mlir::PDLValue *>(value.ptr);
+}
+
+inline MlirPDLValue wrap(const mlir::PDLValue *value) { return {value}; }
+
+inline mlir::PDLResultList *unwrap(MlirPDLResultList results) {
+  assert(results.ptr && "unexpected null PDL results");
+  return static_cast<mlir::PDLResultList *>(results.ptr);
+}
+
+inline MlirPDLResultList wrap(mlir::PDLResultList *results) {
+  return {results};
+}
+
+bool mlirPDLValueIsValue(MlirPDLValue value) {
+  return unwrap(value)->isa<mlir::Value>();
+}
+
+MlirValue mlirPDLValueAsValue(MlirPDLValue value) {
+  return wrap(unwrap(value)->cast<mlir::Value>());
+}
+
+bool mlirPDLValueIsType(MlirPDLValue value) {
+  return unwrap(value)->isa<mlir::Type>();
+}
+
+MlirType mlirPDLValueAsType(MlirPDLValue value) {
+  return wrap(unwrap(value)->cast<mlir::Type>());
+}
+
+bool mlirPDLValueIsOperation(MlirPDLValue value) {
+  return unwrap(value)->isa<mlir::Operation *>();
+}
+
+MlirOperation mlirPDLValueAsOperation(MlirPDLValue value) {
+  return wrap(unwrap(value)->cast<mlir::Operation *>());
+}
+
+bool mlirPDLValueIsAttribute(MlirPDLValue value) {
+  return unwrap(value)->isa<mlir::Attribute>();
+}
+
+MlirAttribute mlirPDLValueAsAttribute(MlirPDLValue value) {
+  return wrap(unwrap(value)->cast<mlir::Attribute>());
+}
+
+void mlirPDLResultListPushBackValue(MlirPDLResultList results,
+                                    MlirValue value) {
+  unwrap(results)->push_back(unwrap(value));
+}
+
+void mlirPDLResultListPushBackType(MlirPDLResultList results, MlirType value) {
+  unwrap(results)->push_back(unwrap(value));
+}
+
+void mlirPDLResultListPushBackOperation(MlirPDLResultList results,
+                                        MlirOperation value) {
+  unwrap(results)->push_back(unwrap(value));
+}
+
+void mlirPDLResultListPushBackAttribute(MlirPDLResultList results,
+                                        MlirAttribute value) {
+  unwrap(results)->push_back(unwrap(value));
+}
+
+void mlirPDLPatternModuleRegisterRewriteFunction(
+    MlirPDLPatternModule module, MlirStringRef name,
+    MlirPDLRewriteFunction rewriteFn, void *userData) {
+  unwrap(module)->registerRewriteFunction(
+      unwrap(name),
+      [userData, rewriteFn](PatternRewriter &rewriter, PDLResultList &results,
+                            ArrayRef<PDLValue> values) -> LogicalResult {
+        std::vector<MlirPDLValue> mlirValues;
+        for (auto &value : values) {
+          mlirValues.push_back(wrap(&value));
+        }
+        return unwrap(rewriteFn(wrap(&rewriter), wrap(&results),
+                                mlirValues.size(), mlirValues.data(),
+                                userData));
+      });
+}
 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
diff --git a/mlir/test/python/integration/dialects/pdl.py b/mlir/test/python/integration/dialects/pdl.py
index dd6c74ce622c8..c78f2d4f9a0dc 100644
--- a/mlir/test/python/integration/dialects/pdl.py
+++ b/mlir/test/python/integration/dialects/pdl.py
@@ -86,3 +86,99 @@ def add_func(a, b):
     frozen = get_pdl_patterns()
     apply_patterns_and_fold_greedily(module_.operation, frozen)
     return module_
+
+
+# If we use arith.constant and arith.addi here,
+# these C++-defined folding/canonicalization will be applied
+# implicitly in the greedy pattern rewrite driver to
+# make our Python-defined folding useless,
+# so here we define a new dialect to workaround this.
+def load_myint_dialect():
+    from mlir.dialects import irdl
+
+    m = Module.create()
+    with InsertionPoint(m.body):
+        myint = irdl.dialect("myint")
+        with InsertionPoint(myint.body):
+            constant = irdl.operation_("constant")
+            with InsertionPoint(constant.body):
+                iattr = irdl.base(base_name="#builtin.integer")
+                i32 = irdl.is_(TypeAttr.get(IntegerType.get_signless(32)))
+                irdl.attributes_([iattr], ["value"])
+                irdl.results_([i32], ["cst"], [irdl.Variadicity.single])
+            add = irdl.operation_("add")
+            with InsertionPoint(add.body):
+                i32 = irdl.is_(TypeAttr.get(IntegerType.get_signless(32)))
+                irdl.operands_(
+                    [i32, i32],
+                    ["lhs", "rhs"],
+                    [irdl.Variadicity.single, irdl.Variadicity.single],
+                )
+                irdl.results_([i32], ["res"], [irdl.Variadicity.single])
+
+    m.operation.verify()
+    irdl.load_dialects(m)
+
+
+# This PDL pattern is to fold constant additions,
+# i.e. add(constant0, constant1) -> constant2
+# where constant2 = constant0 + constant1.
+def get_pdl_pattern_fold():
+    m = Module.create()
+    i32 = IntegerType.get_signless(32)
+    with InsertionPoint(m.body):
+
+        @pdl.pattern(benefit=1, sym_name="myint_add_fold")
+        def pat():
+            t = pdl.TypeOp(i32)
+            a0 = pdl.AttributeOp()
+            a1 = pdl.AttributeOp()
+            c0 = pdl.OperationOp(
+                name="myint.constant", attributes={"value": a0}, types=[t]
+            )
+            c1 = pdl.OperationOp(
+                name="myint.constant", attributes={"value": a1}, types=[t]
+            )
+            v0 = pdl.ResultOp(c0, 0)
+            v1 = pdl.ResultOp(c1, 0)
+            op0 = pdl.OperationOp(name="myint.add", args=[v0, v1], types=[t])
+
+            @pdl.rewrite()
+            def rew():
+                sum = pdl.apply_native_rewrite(
+                    [pdl.AttributeType.get()], "add_fold", [a0, a1]
+                )
+                newOp = pdl.OperationOp(
+                    name="myint.constant", attributes={"value": sum}, types=[t]
+                )
+                pdl.ReplaceOp(op0, with_op=newOp)
+
+    def add_fold(rewriter, results, values):
+        a0, a1 = values
+        results.push_back(IntegerAttr.get(i32, a0.value + a1.value))
+        return True
+
+    pdl_module = PDLModule(m)
+    pdl_module.register_rewrite_function("add_fold", add_fold)
+    return pdl_module.freeze()
+
+
+# CHECK-LABEL: TEST: test_pdl_register_function
+# CHECK: "myint.constant"() {value = 8 : i32} : () -> i32
+@construct_and_print_in_module
+def test_pdl_register_function(module_):
+    load_myint_dialect()
+
+    module_ = Module.parse(
+        """
+        %c0 = "myint.constant"() { value = 2 }: () -> (i32)
+        %c1 = "myint.constant"() { value = 3 }: () -> (i32)
+        %x = "myint.add"(%c0, %c1): (i32, i32) -> (i32)
+        "myint.add"(%x, %c1): (i32, i32) -> (i32)
+        """
+    )
+
+    frozen = get_pdl_pattern_fold()
+    apply_patterns_and_fold_greedily(module_, frozen)
+
+    return module_

@PragmaTwice PragmaTwice changed the title [MLIR][Python] Add bindings for PDL native function registering [MLIR][Python] Add bindings for PDL native rewrite function registering Sep 20, 2025
@makslevental
Copy link
Contributor

Man great work - this is way simpler than I thought since we're not actually using any PatternRewriter APIs.

PragmaTwice and others added 4 commits September 21, 2025 11:27
Co-authored-by: Maksim Levental <maksim.levental@gmail.com>
Co-authored-by: Maksim Levental <maksim.levental@gmail.com>
template <typename T,
typename ResultT = std::conditional_t<
std::is_convertible<T, bool>::value, T, std::optional<T>>>
std::is_constructible_v<bool, T>, T, std::optional<T>>>
Copy link
Member Author

@PragmaTwice PragmaTwice Sep 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some classes like mlir::Value, mlir::Type .. have an explicit operator bool(), and is_convertible will ignore explicit conversions. So here we replace it with is_constructible_v to make it work for these types, so that mlir::Value can be used instead of weird std::optional<mlir::Value> (since mlir::Value is nullable itself).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I don't understand this one - also is this needed for this PR?

Copy link
Member Author

@PragmaTwice PragmaTwice Sep 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup. Due to this review suggestion #159926 (comment), now we use dyn_cast instead of isa and cast of PDLValue. And this change is to address issues in PDLValue::dyn_cast. (previously we did't need this change : )

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can change the API back to the origin form if we want to avoid touching that header. Maybe @jpienaar has idea for this : )

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nah it's fine I just didn't see that comment about dyn_cast

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one random q: why is this file named .h.inc? @jpienaar?

Copy link
Contributor

@makslevental makslevental left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - @ftynse any other asks?

@PragmaTwice
Copy link
Member Author

Thank you all and I'll merge it soon. 🥳

@PragmaTwice PragmaTwice merged commit b5daf76 into llvm:main Sep 24, 2025
9 checks passed
rupprecht added a commit that referenced this pull request Sep 24, 2025
```
external/llvm-project/mlir/lib/CAPI/Transforms/Rewrite.cpp:17:10: error: use of private header from outside its module: 'mlir/IR/PDLPatternMatch.h.inc' [-Wprivate-header]
   17 | #include "mlir/IR/PDLPatternMatch.h.inc"
      |          ^
```
PragmaTwice added a commit that referenced this pull request Sep 25, 2025
…160520)

This is a follow-up to #159926.

That PR (#159926) exposed native rewrite function registration in PDL
through the C API and Python, enabling use with
`pdl.apply_native_rewrite`.

In this PR, we add support for native constraint functions in PDL via
`pdl.apply_native_constraint`, further completing the PDL API.
YixingZhang007 pushed a commit to YixingZhang007/llvm-project that referenced this pull request Sep 27, 2025
…ng (llvm#159926)

In the MLIR Python bindings, we can currently use PDL to define simple
patterns and then execute them with the greedy rewrite driver. However,
when dealing with more complex patterns—such as constant folding for
integer addition—we find that we need `apply_native_rewrite` to actually
perform arithmetic (i.e., compute the sum of two constants). For
example, consider the following PDL pseudocode:

```mlir
pdl.pattern : benefit(1) {
  %a0 = pdl.attribute
  %a1 = pdl.attribute
  %c0 = pdl.operation "arith.constant" {value = %a0}
  %c1 = pdl.operation "arith.constant" {value = %a1}

  %op = pdl.operation "arith.addi"(%c0, %c1)

  %sum = pdl.apply_native_rewrite "addIntegers"(%a0, %a1)
  %new_cst = pdl.operation "arith.constant" {value = %sum}

  pdl.replace %op with %new_cst
}
```

Here, `addIntegers` cannot be expressed in PDL alone—it requires a
*native rewrite function*. This PR introduces a mechanism to support
exactly that, allowing complex rewrite patterns to be expressed in
Python and enabling many passes to be implemented directly in Python as
well.

As a test case, we defined two new operations (`myint.constant` and
`myint.add`) in Python and implemented a constant-folding rewrite
pattern for them. The core code looks like this:

```python
m = Module.create()
with InsertionPoint(m.body):

    @pdl.pattern(benefit=1, sym_name="myint_add_fold")
    def pat():
        ...
        op0 = pdl.OperationOp(name="myint.add", args=[v0, v1], types=[t])

        @pdl.rewrite()
        def rew():
            sum = pdl.apply_native_rewrite(
                [pdl.AttributeType.get()], "add_fold", [a0, a1]
            )
            newOp = pdl.OperationOp(
                name="myint.constant", attributes={"value": sum}, types=[t]
            )
            pdl.ReplaceOp(op0, with_op=newOp)

def add_fold(rewriter, results, values):
    a0, a1 = values
    results.push_back(IntegerAttr.get(i32, a0.value + a1.value))

pdl_module = PDLModule(m)
pdl_module.register_rewrite_function("add_fold", add_fold)
```

The idea is previously discussed in Discord #mlir-python channel with
@makslevental.

---------

Co-authored-by: Maksim Levental <maksim.levental@gmail.com>
YixingZhang007 pushed a commit to YixingZhang007/llvm-project that referenced this pull request Sep 27, 2025
…lvm#160432)

```
external/llvm-project/mlir/lib/CAPI/Transforms/Rewrite.cpp:17:10: error: use of private header from outside its module: 'mlir/IR/PDLPatternMatch.h.inc' [-Wprivate-header]
   17 | #include "mlir/IR/PDLPatternMatch.h.inc"
      |          ^
```
YixingZhang007 pushed a commit to YixingZhang007/llvm-project that referenced this pull request Sep 27, 2025
…lvm#160520)

This is a follow-up to llvm#159926.

That PR (llvm#159926) exposed native rewrite function registration in PDL
through the C API and Python, enabling use with
`pdl.apply_native_rewrite`.

In this PR, we add support for native constraint functions in PDL via
`pdl.apply_native_constraint`, further completing the PDL API.
zyx-billy added a commit that referenced this pull request Sep 29, 2025
Adds argument names to the method stubs for PDLResultList (from
#159926).
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Sep 29, 2025
Adds argument names to the method stubs for PDLResultList (from
llvm/llvm-project#159926).
bump-llvm bot pushed a commit to makslevental/python_bindings_fork that referenced this pull request Sep 30, 2025
Adds argument names to the method stubs for PDLResultList (from
llvm/llvm-project#159926).
mahesh-attarde pushed a commit to mahesh-attarde/llvm-project that referenced this pull request Oct 3, 2025
…lvm#160520)

This is a follow-up to llvm#159926.

That PR (llvm#159926) exposed native rewrite function registration in PDL
through the C API and Python, enabling use with
`pdl.apply_native_rewrite`.

In this PR, we add support for native constraint functions in PDL via
`pdl.apply_native_constraint`, further completing the PDL API.
mahesh-attarde pushed a commit to mahesh-attarde/llvm-project that referenced this pull request Oct 3, 2025
Adds argument names to the method stubs for PDLResultList (from
llvm#159926).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants