Skip to content

Conversation

PragmaTwice
Copy link
Member

@PragmaTwice PragmaTwice commented Sep 27, 2025

In #160520, we discussed the current limitations of PDL rewriting in Python (see this comment). At the moment, we cannot create new operations in PDL native (python) rewrite functions because the PatternRewriter APIs are not exposed.

This PR introduces bindings to retrieve the insertion point of the PatternRewriter, enabling users to create new operations within Python rewrite functions. With this capability, more complex rewrites e.g. with branching and loops that involve op creations become possible.

Copy link

github-actions bot commented Sep 27, 2025

✅ With the latest revision this PR passed the Python code formatter.

Comment on lines 146 to 160
nb::class_<MlirPatternRewriter>(m, "PatternRewriter")
.def("ip", [](MlirPatternRewriter rewriter) {
MlirRewriterBase base = mlirPatternRewriterAsBase(rewriter);
MlirBlock block = mlirRewriterBaseGetInsertionBlock(base);
MlirOperation op = mlirRewriterBaseGetOperationAfterInsertion(base);
MlirOperation owner = mlirBlockGetParentOperation(block);
auto ctx = PyMlirContext::forContext(mlirRewriterBaseGetContext(base))
->getRef();
if (mlirOperationIsNull(op)) {
auto parent = PyOperation::forOperation(ctx, owner);
return PyInsertionPoint(PyBlock(parent, block));
}

return PyInsertionPoint(*PyOperation::forOperation(ctx, op).get());
});
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if there is a good way to cast Mlir* CAPI types into Py* C++ classes. It seems that here we don't need to care too much about lifetime of blocks/operations (as long as the insertion point does not escape from the scope of the rewrite callback). 🤔

I'll try to define something like class PyPatternRewriter and see if that makes the code cleaner.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 68264af.

@PragmaTwice PragmaTwice marked this pull request as ready for review October 3, 2025 08:47
@llvmbot llvmbot added the mlir label Oct 3, 2025
@llvmbot
Copy link
Member

llvmbot commented Oct 3, 2025

@llvm/pr-subscribers-mlir

Author: Twice (PragmaTwice)

Changes

In #160520, we discussed the current limitations of PDL rewriting in Python (see this comment). At the moment, we cannot create new operations in PDL native (python) rewrite functions because the PatternRewriter APIs are not exposed.

This PR introduces bindings to retrieve the insertion point of the PatternRewriter, enabling users to create new operations within Python rewrite functions. With this capability, more complex rewrites e.g. with branching and loops that involve op creations become possible.


Full diff: https://github.com/llvm/llvm-project/pull/161001.diff

6 Files Affected:

  • (modified) mlir/include/mlir-c/Rewrite.h (+14)
  • (modified) mlir/lib/Bindings/Python/IRCore.cpp (+3)
  • (modified) mlir/lib/Bindings/Python/IRModule.h (+2)
  • (modified) mlir/lib/Bindings/Python/Rewrite.cpp (+31-3)
  • (modified) mlir/lib/CAPI/Transforms/Rewrite.cpp (+16)
  • (modified) mlir/test/python/integration/dialects/pdl.py (+89-2)
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 77be1f480eacf..5dd285ee076c4 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -101,6 +101,12 @@ mlirRewriterBaseGetInsertionBlock(MlirRewriterBase rewriter);
 MLIR_CAPI_EXPORTED MlirBlock
 mlirRewriterBaseGetBlock(MlirRewriterBase rewriter);
 
+/// Returns the operation right after the current insertion point
+/// of the rewriter. A null MlirOperation will be returned
+// if the current insertion point is at the end of the block.
+MLIR_CAPI_EXPORTED MlirOperation
+mlirRewriterBaseGetOperationAfterInsertion(MlirRewriterBase rewriter);
+
 //===----------------------------------------------------------------------===//
 /// Block and operation creation/insertion/cloning
 //===----------------------------------------------------------------------===//
@@ -310,6 +316,14 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
     MlirModule op, MlirFrozenRewritePatternSet patterns,
     MlirGreedyRewriteDriverConfig);
 
+//===----------------------------------------------------------------------===//
+/// PatternRewriter API
+//===----------------------------------------------------------------------===//
+
+/// Cast the PatternRewriter to a RewriterBase
+MLIR_CAPI_EXPORTED MlirRewriterBase
+mlirPatternRewriterAsBase(MlirPatternRewriter rewriter);
+
 //===----------------------------------------------------------------------===//
 /// PDLPatternModule API
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 32b2b0c648cff..7b1710656243a 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2046,6 +2046,9 @@ PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
     : refOperation(beforeOperationBase.getOperation().getRef()),
       block((*refOperation)->getBlock()) {}
 
+PyInsertionPoint::PyInsertionPoint(PyOperationRef beforeOperationRef)
+    : refOperation(beforeOperationRef), block((*refOperation)->getBlock()) {}
+
 void PyInsertionPoint::insert(PyOperationBase &operationBase) {
   PyOperation &operation = operationBase.getOperation();
   if (operation.isAttached())
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index edbd73eade906..e706be3b4d32a 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -841,6 +841,8 @@ class PyInsertionPoint {
   PyInsertionPoint(const PyBlock &block);
   /// Creates an insertion point positioned before a reference operation.
   PyInsertionPoint(PyOperationBase &beforeOperationBase);
+  /// Creates an insertion point positioned before a reference operation.
+  PyInsertionPoint(PyOperationRef beforeOperationRef);
 
   /// Shortcut to create an insertion point at the beginning of the block.
   static PyInsertionPoint atBlockBegin(PyBlock &block);
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 836f44fd7d4be..10b539a7b3c07 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -26,6 +26,31 @@ using namespace mlir::python;
 
 namespace {
 
+class PyPatternRewriter {
+public:
+  PyPatternRewriter(MlirPatternRewriter rewriter)
+      : rewriter(rewriter), base(mlirPatternRewriterAsBase(rewriter)),
+        ctx(PyMlirContext::forContext(mlirRewriterBaseGetContext(base))) {}
+
+  PyInsertionPoint getInsertionPoint() const {
+    MlirBlock block = mlirRewriterBaseGetInsertionBlock(base);
+    MlirOperation op = mlirRewriterBaseGetOperationAfterInsertion(base);
+
+    if (mlirOperationIsNull(op)) {
+      MlirOperation owner = mlirBlockGetParentOperation(block);
+      auto parent = PyOperation::forOperation(ctx, owner);
+      return PyInsertionPoint(PyBlock(parent, block));
+    }
+
+    return PyInsertionPoint(PyOperation::forOperation(ctx, op));
+  }
+
+private:
+  MlirPatternRewriter rewriter [[maybe_unused]];
+  MlirRewriterBase base;
+  PyMlirContextRef ctx;
+};
+
 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
 static nb::object objectFromPDLValue(MlirPDLValue value) {
   if (MlirValue v = mlirPDLValueAsValue(value); !mlirValueIsNull(v))
@@ -84,7 +109,8 @@ class PyPDLPatternModule {
            void *userData) -> MlirLogicalResult {
           nb::handle f = nb::handle(static_cast<PyObject *>(userData));
           return logicalResultFromObject(
-              f(rewriter, results, objectsFromPDLValues(nValues, values)));
+              f(PyPatternRewriter(rewriter), results,
+                objectsFromPDLValues(nValues, values)));
         },
         fn.ptr());
   }
@@ -98,7 +124,8 @@ class PyPDLPatternModule {
            void *userData) -> MlirLogicalResult {
           nb::handle f = nb::handle(static_cast<PyObject *>(userData));
           return logicalResultFromObject(
-              f(rewriter, results, objectsFromPDLValues(nValues, values)));
+              f(PyPatternRewriter(rewriter), results,
+                objectsFromPDLValues(nValues, values)));
         },
         fn.ptr());
   }
@@ -143,7 +170,8 @@ class PyFrozenRewritePatternSet {
 
 /// Create the `mlir.rewrite` here.
 void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
-  nb::class_<MlirPatternRewriter>(m, "PatternRewriter");
+  nb::class_<PyPatternRewriter>(m, "PyPatternRewriter")
+      .def("ip", &PyPatternRewriter::getInsertionPoint);
   //----------------------------------------------------------------------------
   // Mapping of the PDLResultList and PDLModule
   //----------------------------------------------------------------------------
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index 8ee6308cadf83..b149d35f0d88b 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -70,6 +70,18 @@ MlirBlock mlirRewriterBaseGetBlock(MlirRewriterBase rewriter) {
   return wrap(unwrap(rewriter)->getBlock());
 }
 
+MlirOperation
+mlirRewriterBaseGetOperationAfterInsertion(MlirRewriterBase rewriter) {
+  mlir::RewriterBase *base = unwrap(rewriter);
+  mlir::Block *block = base->getInsertionBlock();
+  auto it = base->getInsertionPoint();
+  if (it == block->end()) {
+    return {nullptr};
+  }
+
+  return wrap(std::addressof(*it));
+}
+
 //===----------------------------------------------------------------------===//
 /// Block and operation creation/insertion/cloning
 //===----------------------------------------------------------------------===//
@@ -316,6 +328,10 @@ inline MlirPatternRewriter wrap(mlir::PatternRewriter *rewriter) {
   return {rewriter};
 }
 
+MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter) {
+  return wrap(static_cast<mlir::RewriterBase *>(unwrap(rewriter)));
+}
+
 //===----------------------------------------------------------------------===//
 /// PDLPatternModule API
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/python/integration/dialects/pdl.py b/mlir/test/python/integration/dialects/pdl.py
index c8e6197e03842..752d213673a70 100644
--- a/mlir/test/python/integration/dialects/pdl.py
+++ b/mlir/test/python/integration/dialects/pdl.py
@@ -16,6 +16,7 @@ def construct_and_print_in_module(f):
             print(module)
     return f
 
+
 def get_pdl_patterns():
     # Create a rewrite from add to mul. This will match
     # - operation name is arith.addi
@@ -121,8 +122,10 @@ def load_myint_dialect():
 
 
 # This PDL pattern is to fold constant additions,
-# i.e. add(constant0, constant1) -> constant2
-# where constant2 = constant0 + constant1.
+# including two patterns:
+# 1. add(constant0, constant1) -> constant2
+#    where constant2 = constant0 + constant1;
+# 2. add(x, 0) or add(0, x) -> x.
 def get_pdl_pattern_fold():
     m = Module.create()
     i32 = IntegerType.get_signless(32)
@@ -237,3 +240,87 @@ def test_pdl_register_function_constraint(module_):
     apply_patterns_and_fold_greedily(module_, frozen)
 
     return module_
+
+
+# This pattern is to expand constant to additions
+# unless the constant is no more than 1,
+# e.g. 3 -> 1 + 2 -> 1 + (1 + 1).
+def get_pdl_pattern_expand():
+    m = Module.create()
+    i32 = IntegerType.get_signless(32)
+    with InsertionPoint(m.body):
+
+        @pdl.pattern(benefit=1, sym_name="myint_constant_expand")
+        def pat():
+            t = pdl.TypeOp(i32)
+            cst = pdl.AttributeOp()
+            pdl.apply_native_constraint([], "is_one", [cst])
+            op0 = pdl.OperationOp(
+                name="myint.constant", attributes={"value": cst}, types=[t]
+            )
+
+            @pdl.rewrite()
+            def rew():
+                expanded = pdl.apply_native_rewrite(
+                    [pdl.OperationType.get()], "expand", [cst]
+                )
+                pdl.ReplaceOp(op0, with_op=expanded)
+
+    def is_one(rewriter, results, values):
+        cst = values[0].value
+        return cst <= 1
+
+    def expand(rewriter, results, values):
+        cst = values[0].value
+        c1 = cst // 2
+        c2 = cst - c1
+        with rewriter.ip():
+            op1 = Operation.create(
+                "myint.constant",
+                results=[i32],
+                attributes={"value": IntegerAttr.get(i32, c1)},
+            )
+            op2 = Operation.create(
+                "myint.constant",
+                results=[i32],
+                attributes={"value": IntegerAttr.get(i32, c2)},
+            )
+            res = Operation.create(
+                "myint.add", results=[i32], operands=[op1.result, op2.result]
+            )
+        results.append(res)
+
+    pdl_module = PDLModule(m)
+    pdl_module.register_constraint_function("is_one", is_one)
+    pdl_module.register_rewrite_function("expand", expand)
+    return pdl_module.freeze()
+
+
+# CHECK-LABEL: TEST: test_pdl_register_function_expand
+# CHECK: %0 = "myint.constant"() {value = 1 : i32} : () -> i32
+# CHECK: %1 = "myint.constant"() {value = 1 : i32} : () -> i32
+# CHECK: %2 = "myint.add"(%0, %1) : (i32, i32) -> i32
+# CHECK: %3 = "myint.constant"() {value = 1 : i32} : () -> i32
+# CHECK: %4 = "myint.constant"() {value = 1 : i32} : () -> i32
+# CHECK: %5 = "myint.constant"() {value = 1 : i32} : () -> i32
+# CHECK: %6 = "myint.add"(%4, %5) : (i32, i32) -> i32
+# CHECK: %7 = "myint.add"(%3, %6) : (i32, i32) -> i32
+# CHECK: %8 = "myint.add"(%2, %7) : (i32, i32) -> i32
+# CHECK: return %8 : i32
+@construct_and_print_in_module
+def test_pdl_register_function_expand(module_):
+    load_myint_dialect()
+
+    module_ = Module.parse(
+        """
+        func.func @f() -> i32 {
+          %0 = "myint.constant"() { value = 5 }: () -> (i32)
+          return %0 : i32
+        }
+        """
+    )
+
+    frozen = get_pdl_pattern_expand()
+    apply_patterns_and_fold_greedily(module_, frozen)
+
+    return module_

Comment on lines 273 to 291
def expand(rewriter, results, values):
cst = values[0].value
c1 = cst // 2
c2 = cst - c1
with rewriter.ip():
op1 = Operation.create(
"myint.constant",
results=[i32],
attributes={"value": IntegerAttr.get(i32, c1)},
)
op2 = Operation.create(
"myint.constant",
results=[i32],
attributes={"value": IntegerAttr.get(i32, c2)},
)
res = Operation.create(
"myint.add", results=[i32], operands=[op1.result, op2.result]
)
results.append(res)
Copy link
Member Author

@PragmaTwice PragmaTwice Oct 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function as an example of retrieving and using the insertion point of the rewriter.

PragmaTwice and others added 2 commits October 4, 2025 12:26
Co-authored-by: Maksim Levental <maksim.levental@gmail.com>
Copy link
Contributor

@makslevental makslevental left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@PragmaTwice PragmaTwice merged commit 8181c3d into llvm:main Oct 5, 2025
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants