Skip to content

Conversation

PragmaTwice
Copy link
Member

@PragmaTwice PragmaTwice commented Oct 9, 2025

This PR adds support for defining custom RewritePattern implementations directly in the Python bindings.

Previously, users could define similar patterns using the PDL dialect’s bindings. However, for more complex patterns, this often required writing multiple Python callbacks as PDL native constraints or rewrite functions, which made the overall logic less intuitive—though it could be more performant than a pure Python implementation (especially for simple patterns).

With this change, we introduce an additional, straightforward way to define patterns purely in Python, complementing the existing PDL-based approach.

Example

def to_muli(op, rewriter):
    with rewriter.ip:
        new_op = arith.muli(op.operands[0], op.operands[1], loc=op.location)
    rewriter.replace_op(op, new_op.owner)

with Context():
    patterns = RewritePatternSet()
    patterns.add(arith.AddIOp, to_muli) # a pattern that rewrites arith.addi to arith.muli
    frozen = patterns.freeze()

    module = ...
    apply_patterns_and_fold_greedily(module, frozen)

Copy link

github-actions bot commented Oct 9, 2025

✅ With the latest revision this PR passed the Python code formatter.

@PragmaTwice PragmaTwice marked this pull request as ready for review October 10, 2025 03:42
@llvmbot llvmbot added the mlir label Oct 10, 2025
@PragmaTwice
Copy link
Member Author

I think it is ready for review now : ) Feel free to comment!

@llvmbot
Copy link
Member

llvmbot commented Oct 10, 2025

@llvm/pr-subscribers-mlir

Author: Twice (PragmaTwice)

Changes

This PR adds support for defining custom RewritePattern implementations directly in the Python bindings.

Previously, users could define similar patterns using the PDL dialect’s bindings. However, for more complex patterns, this often required writing multiple Python callbacks as PDL native constraints or rewrite functions, which made the overall logic less intuitive—though it could be more performant than a pure Python implementation (especially for simple patterns).

With this change, we introduce an additional, straightforward way to define patterns purely in Python, complementing the existing PDL-based approach.

Example

def to_muli(op, rewriter, pattern):
    with rewriter.ip:
        new_op = arith.muli(op.operands[0], op.operands[1], loc=op.location)
    rewriter.replace_op(op, new_op.owner)

with Context():
    patterns = RewritePatternSet()
    patterns.add(arith.AddIOp, to_muli) # a pattern that rewrites arith.addi to arith.muli
    frozen = patterns.freeze()

    module = ...
    apply_patterns_and_fold_greedily(module, frozen)

Full diff: https://github.com/llvm/llvm-project/pull/162699.diff

4 Files Affected:

  • (modified) mlir/include/mlir-c/Rewrite.h (+55-2)
  • (modified) mlir/lib/Bindings/Python/Rewrite.cpp (+116-4)
  • (modified) mlir/lib/CAPI/Transforms/Rewrite.cpp (+89-8)
  • (added) mlir/test/python/rewrite.py (+77)
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 5dd285ee076c4..66a9a5de1669d 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -38,6 +38,7 @@ DEFINE_C_API_STRUCT(MlirFrozenRewritePatternSet, void);
 DEFINE_C_API_STRUCT(MlirGreedyRewriteDriverConfig, void);
 DEFINE_C_API_STRUCT(MlirRewritePatternSet, void);
 DEFINE_C_API_STRUCT(MlirPatternRewriter, void);
+DEFINE_C_API_STRUCT(MlirRewritePattern, const void);
 
 //===----------------------------------------------------------------------===//
 /// RewriterBase API inherited from OpBuilder
@@ -302,11 +303,15 @@ MLIR_CAPI_EXPORTED void mlirIRRewriterDestroy(MlirRewriterBase rewriter);
 /// FrozenRewritePatternSet API
 //===----------------------------------------------------------------------===//
 
+/// Freeze the given MlirRewritePatternSet to a MlirFrozenRewritePatternSet.
+/// Note that the ownership of the input set is transferred into the frozen set
+/// after this call.
 MLIR_CAPI_EXPORTED MlirFrozenRewritePatternSet
-mlirFreezeRewritePattern(MlirRewritePatternSet op);
+mlirFreezeRewritePattern(MlirRewritePatternSet set);
 
+/// Destroy the given MlirFrozenRewritePatternSet.
 MLIR_CAPI_EXPORTED void
-mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op);
+mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet set);
 
 MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedilyWithOp(
     MlirOperation op, MlirFrozenRewritePatternSet patterns,
@@ -324,6 +329,54 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
 MLIR_CAPI_EXPORTED MlirRewriterBase
 mlirPatternRewriterAsBase(MlirPatternRewriter rewriter);
 
+//===----------------------------------------------------------------------===//
+/// RewritePattern API
+//===----------------------------------------------------------------------===//
+
+/// PatternBenefit represents the benefit of a pattern match.
+typedef unsigned short MlirPatternBenefit;
+
+/// Callbacks to construct a rewrite pattern.
+typedef struct {
+  /// Optional constructor for the user data.
+  /// Set to nullptr to disable it.
+  void (*construct)(void *userData);
+  /// Optional destructor for the user data.
+  /// Set to nullptr to disable it.
+  void (*destruct)(void *userData);
+  /// The callback function to match against code rooted at the specified
+  /// operation, and perform the rewrite if the match is successful,
+  /// corresponding to RewritePattern::matchAndRewrite.
+  MlirLogicalResult (*matchAndRewrite)(MlirRewritePattern pattern,
+                                       MlirOperation op,
+                                       MlirPatternRewriter rewriter,
+                                       void *userData);
+} MlirRewritePatternCallbacks;
+
+/// Create a rewrite pattern that matches the operation
+/// with the given rootName, corresponding to mlir::OpRewritePattern.
+MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePattenCreate(
+    MlirStringRef rootName, MlirPatternBenefit benefit, MlirContext context,
+    MlirRewritePatternCallbacks callbacks, void *userData,
+    size_t nGeneratedNames, MlirStringRef *generatedNames);
+
+//===----------------------------------------------------------------------===//
+/// RewritePatternSet API
+//===----------------------------------------------------------------------===//
+
+/// Create an empty MlirRewritePatternSet.
+MLIR_CAPI_EXPORTED MlirRewritePatternSet
+mlirRewritePatternSetCreate(MlirContext context);
+
+/// Destruct the given MlirRewritePatternSet.
+MLIR_CAPI_EXPORTED void mlirRewritePatternSetDestroy(MlirRewritePatternSet set);
+
+/// Add the given MlirRewritePattern into a MlirRewritePatternSet.
+/// Note that the ownership of the pattern is transferred to the set after this
+/// call.
+MLIR_CAPI_EXPORTED void mlirRewritePatternSetAdd(MlirRewritePatternSet set,
+                                                 MlirRewritePattern pattern);
+
 //===----------------------------------------------------------------------===//
 /// PDLPatternModule API
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 9e3d9703c82e8..07559457f2f2f 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -45,6 +45,16 @@ class PyPatternRewriter {
     return PyInsertionPoint(PyOperation::forOperation(ctx, op));
   }
 
+  void replaceOp(MlirOperation op, MlirOperation newOp) {
+    mlirRewriterBaseReplaceOpWithOperation(base, op, newOp);
+  }
+
+  void replaceOp(MlirOperation op, const std::vector<MlirValue> &values) {
+    mlirRewriterBaseReplaceOpWithValues(base, op, values.size(), values.data());
+  }
+
+  void eraseOp(MlirOperation op) { mlirRewriterBaseEraseOp(base, op); }
+
 private:
   MlirRewriterBase base;
   PyMlirContextRef ctx;
@@ -165,13 +175,115 @@ class PyFrozenRewritePatternSet {
   MlirFrozenRewritePatternSet set;
 };
 
+class PyRewritePatternSet {
+public:
+  PyRewritePatternSet(MlirContext ctx)
+      : set(mlirRewritePatternSetCreate(ctx)), ctx(ctx) {}
+  ~PyRewritePatternSet() {
+    if (set.ptr)
+      mlirRewritePatternSetDestroy(set);
+  }
+
+  void add(MlirStringRef rootName, MlirPatternBenefit benefit,
+           const nb::callable &matchAndRewrite) {
+    MlirRewritePatternCallbacks callbacks;
+    callbacks.construct = [](void *userData) {
+      nb::handle(static_cast<PyObject *>(userData)).inc_ref();
+    };
+    callbacks.destruct = [](void *userData) {
+      nb::handle(static_cast<PyObject *>(userData)).dec_ref();
+    };
+    callbacks.matchAndRewrite = [](MlirRewritePattern pattern, MlirOperation op,
+                                   MlirPatternRewriter rewriter,
+                                   void *userData) -> MlirLogicalResult {
+      nb::handle f(static_cast<PyObject *>(userData));
+      nb::object res = f(op, PyPatternRewriter(rewriter), pattern);
+      return logicalResultFromObject(res);
+    };
+    MlirRewritePattern pattern = mlirOpRewritePattenCreate(
+        rootName, benefit, ctx, callbacks, matchAndRewrite.ptr(),
+        /* nGeneratedNames */ 0,
+        /* generatedNames */ nullptr);
+    mlirRewritePatternSetAdd(set, pattern);
+  }
+
+  PyFrozenRewritePatternSet freeze() {
+    MlirRewritePatternSet s = set;
+    set.ptr = nullptr;
+    return mlirFreezeRewritePattern(s);
+  }
+
+private:
+  MlirRewritePatternSet set;
+  MlirContext ctx;
+};
+
 } // namespace
 
 /// Create the `mlir.rewrite` here.
 void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
-  nb::class_<PyPatternRewriter>(m, "PatternRewriter")
-      .def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint,
-                   "The current insertion point of the PatternRewriter.");
+  //----------------------------------------------------------------------------
+  // Mapping of the PatternRewriter
+  //----------------------------------------------------------------------------
+  nb::
+      class_<PyPatternRewriter>(m, "PatternRewriter")
+          .def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint,
+                       "The current insertion point of the PatternRewriter.")
+          .def(
+              "replace_op",
+              [](PyPatternRewriter &self, MlirOperation op,
+                 MlirOperation newOp) { self.replaceOp(op, newOp); },
+              "Replace an operation with a new operation.",
+              // clang-format off
+              nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation")
+                ", new_op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None")
+              // clang-format on
+              )
+          .def(
+              "replace_op",
+              [](PyPatternRewriter &self, MlirOperation op,
+                 const std::vector<MlirValue> &values) {
+                self.replaceOp(op, values);
+              },
+              "Replace an operation with a list of values.",
+              // clang-format off
+              nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation")
+                ", values: list[" MAKE_MLIR_PYTHON_QUALNAME("ir.Value") "]) -> None")
+              // clang-format on
+              )
+          .def("erase_op", &PyPatternRewriter::eraseOp, "Erase an operation.",
+               // clang-format off
+                nb::sig("def erase_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None")
+               // clang-format on
+          );
+
+  //----------------------------------------------------------------------------
+  // Mapping of the RewritePatternSet
+  //----------------------------------------------------------------------------
+  nb::class_<MlirRewritePattern>(m, "RewritePattern");
+  nb::class_<PyRewritePatternSet>(m, "RewritePatternSet")
+      .def(
+          "__init__",
+          [](PyRewritePatternSet &self, DefaultingPyMlirContext context) {
+            new (&self) PyRewritePatternSet(context.get()->get());
+          },
+          "context"_a = nb::none())
+      .def(
+          "add",
+          [](PyRewritePatternSet &self, nb::handle root, const nb::callable &fn,
+             unsigned benefit) {
+            std::string opName =
+                nb::cast<std::string>(root.attr("OPERATION_NAME"));
+            self.add(mlirStringRefCreate(opName.data(), opName.size()), benefit,
+                     fn);
+          },
+          "root"_a, "fn"_a, "benefit"_a = 1,
+          "Add a new rewrite pattern on the given root operation with the "
+          "callable as the matching and rewriting function and the given "
+          "benefit.")
+      .def("freeze", &PyRewritePatternSet::freeze,
+           "Freeze the pattern set into a frozen one.");
+
   //----------------------------------------------------------------------------
   // Mapping of the PDLResultList and PDLModule
   //----------------------------------------------------------------------------
@@ -237,7 +349,7 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
       .def(
           "freeze",
           [](PyPDLPatternModule &self) {
-            return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
+            return PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
                 mlirRewritePatternSetFromPDLPatternModule(self.get())));
           },
           nb::keep_alive<0, 1>())
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index c15a73b991f5d..d7c8e53f2bba6 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -270,9 +270,9 @@ void mlirIRRewriterDestroy(MlirRewriterBase rewriter) {
 /// RewritePatternSet and FrozenRewritePatternSet API
 //===----------------------------------------------------------------------===//
 
-static inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) {
+static inline mlir::RewritePatternSet *unwrap(MlirRewritePatternSet module) {
   assert(module.ptr && "unexpected null module");
-  return *(static_cast<mlir::RewritePatternSet *>(module.ptr));
+  return static_cast<mlir::RewritePatternSet *>(module.ptr);
 }
 
 static inline MlirRewritePatternSet wrap(mlir::RewritePatternSet *module) {
@@ -290,15 +290,16 @@ wrap(mlir::FrozenRewritePatternSet *module) {
   return {module};
 }
 
-MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet op) {
-  auto *m = new mlir::FrozenRewritePatternSet(std::move(unwrap(op)));
-  op.ptr = nullptr;
+MlirFrozenRewritePatternSet
+mlirFreezeRewritePattern(MlirRewritePatternSet set) {
+  auto *m = new mlir::FrozenRewritePatternSet(std::move(*unwrap(set)));
+  set.ptr = nullptr;
   return wrap(m);
 }
 
-void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op) {
-  delete unwrap(op);
-  op.ptr = nullptr;
+void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet set) {
+  delete unwrap(set);
+  set.ptr = nullptr;
 }
 
 MlirLogicalResult
@@ -332,6 +333,86 @@ MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter) {
   return wrap(static_cast<mlir::RewriterBase *>(unwrap(rewriter)));
 }
 
+//===----------------------------------------------------------------------===//
+/// RewritePattern API
+//===----------------------------------------------------------------------===//
+
+inline const mlir::RewritePattern *unwrap(MlirRewritePattern pattern) {
+  assert(pattern.ptr && "unexpected null pattern");
+  return static_cast<const mlir::RewritePattern *>(pattern.ptr);
+}
+
+inline MlirRewritePattern wrap(const mlir::RewritePattern *pattern) {
+  return {pattern};
+}
+
+namespace mlir {
+
+class ExternalRewritePattern : public mlir::RewritePattern {
+public:
+  ExternalRewritePattern(MlirRewritePatternCallbacks callbacks, void *userData,
+                         StringRef rootName, PatternBenefit benefit,
+                         MLIRContext *context,
+                         ArrayRef<StringRef> generatedNames)
+      : RewritePattern(rootName, benefit, context, generatedNames),
+        callbacks(callbacks), userData(userData) {
+    if (callbacks.construct)
+      callbacks.construct(userData);
+  }
+
+  ~ExternalRewritePattern() {
+    if (callbacks.destruct)
+      callbacks.destruct(userData);
+  }
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    return unwrap(callbacks.matchAndRewrite(
+        wrap(static_cast<const mlir::RewritePattern *>(this)), wrap(op),
+        wrap(&rewriter), userData));
+  }
+
+private:
+  MlirRewritePatternCallbacks callbacks;
+  void *userData;
+};
+
+} // namespace mlir
+
+MlirRewritePattern mlirOpRewritePattenCreate(
+    MlirStringRef rootName, MlirPatternBenefit benefit, MlirContext context,
+    MlirRewritePatternCallbacks callbacks, void *userData,
+    size_t nGeneratedNames, MlirStringRef *generatedNames) {
+  std::vector<mlir::StringRef> generatedNamesVec;
+  generatedNamesVec.reserve(nGeneratedNames);
+  for (size_t i = 0; i < nGeneratedNames; ++i) {
+    generatedNamesVec.push_back(unwrap(generatedNames[i]));
+  }
+  return wrap(new mlir::ExternalRewritePattern(
+      callbacks, userData, unwrap(rootName), PatternBenefit(benefit),
+      unwrap(context), generatedNamesVec));
+}
+
+//===----------------------------------------------------------------------===//
+/// RewritePatternSet API
+//===----------------------------------------------------------------------===//
+
+MlirRewritePatternSet mlirRewritePatternSetCreate(MlirContext context) {
+  return wrap(new mlir::RewritePatternSet(unwrap(context)));
+}
+
+void mlirRewritePatternSetDestroy(MlirRewritePatternSet set) {
+  delete unwrap(set);
+}
+
+void mlirRewritePatternSetAdd(MlirRewritePatternSet set,
+                              MlirRewritePattern pattern) {
+  std::unique_ptr<mlir::RewritePattern> patternPtr(
+      const_cast<mlir::RewritePattern *>(unwrap(pattern)));
+  pattern.ptr = nullptr;
+  unwrap(set)->add(std::move(patternPtr));
+}
+
 //===----------------------------------------------------------------------===//
 /// PDLPatternModule API
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py
new file mode 100644
index 0000000000000..cbc3a4043f96c
--- /dev/null
+++ b/mlir/test/python/rewrite.py
@@ -0,0 +1,77 @@
+# RUN: %PYTHON %s 2>&1 | FileCheck %s
+
+import gc, sys
+from mlir.ir import *
+from mlir.passmanager import *
+from mlir.dialects.builtin import ModuleOp
+from mlir.dialects import arith
+from mlir.rewrite import *
+
+
+def log(*args):
+    print(*args, file=sys.stderr)
+    sys.stderr.flush()
+
+
+def run(f):
+    log("\nTEST:", f.__name__)
+    f()
+    gc.collect()
+    assert Context._get_live_count() == 0
+
+
+# CHECK-LABEL: TEST: testRewritePattern
+@run
+def testRewritePattern():
+    def to_muli(op, rewriter, pattern):
+        with rewriter.ip:
+            new_op = arith.muli(op.operands[0], op.operands[1], loc=op.location)
+        rewriter.replace_op(op, new_op.owner)
+
+    def constant_1_to_2(op, rewriter, pattern):
+        c = op.attributes["value"].value
+        if c != 1:
+            return True # failed to match
+        with rewriter.ip:
+            new_op = arith.constant(op.result.type, 2, loc=op.location)
+        rewriter.replace_op(op, [new_op])
+
+    with Context():
+        patterns = RewritePatternSet()
+        patterns.add(arith.AddIOp, to_muli)
+        patterns.add(arith.ConstantOp, constant_1_to_2)
+        frozen = patterns.freeze()
+
+        module = ModuleOp.parse(
+            r"""
+            module {
+              func.func @add(%a: i64, %b: i64) -> i64 {
+                %sum = arith.addi %a, %b : i64
+                return %sum : i64
+              }
+            }
+            """
+        )
+
+        apply_patterns_and_fold_greedily(module, frozen)
+        # CHECK: %0 = arith.muli %arg0, %arg1 : i64
+        # CHECK: return %0 : i64
+        print(module)
+
+        module = ModuleOp.parse(
+            r"""
+            module {
+              func.func @const() -> (i64, i64) {
+                %0 = arith.constant 1 : i64
+                %1 = arith.constant 3 : i64
+                return %0, %1 : i64, i64
+              }
+            }
+            """
+        )
+
+        apply_patterns_and_fold_greedily(module, frozen)
+        # CHECK: %c2_i64 = arith.constant 2 : i64
+        # CHECK: %c3_i64 = arith.constant 3 : i64
+        # CHECK: return %c2_i64, %c3_i64 : i64, i64
+        print(module)

Co-authored-by: Maksim Levental <maksim.levental@gmail.com>
Copy link

github-actions bot commented Oct 10, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

PragmaTwice and others added 2 commits October 11, 2025 10:01
Co-authored-by: Maksim Levental <maksim.levental@gmail.com>
Copy link
Contributor

@makslevental makslevental left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great.

@PragmaTwice
Copy link
Member Author

Thank you! I'll merge it soon.

@PragmaTwice PragmaTwice merged commit 7aec3f2 into llvm:main Oct 11, 2025
10 checks passed
PragmaTwice added a commit that referenced this pull request Oct 11, 2025
…162974)

This is a follow-up PR of #162699.

In this PR we clean CAPI and Python bindings of MLIR rewrite part by:
- remove all manually-defined `wrap`/`unwrap` functions;
- remove useless nanobind-defined Python class `RewritePattern`.
PragmaTwice added a commit that referenced this pull request Oct 13, 2025
… patterns (#163080)

This is a follow-up PR for #162699.

Currently, in the function where we define rewrite patterns, the `op` we
receive is of type `ir.Operation` rather than a specific `OpView` type
(such as `arith.AddIOp`). This means we can’t conveniently access
certain parts of the operation — for example, we need to use
`op.operands[0]` instead of `op.lhs`. The following example code
illustrates this situation.

```python
def to_muli(op, rewriter):
  # op is typed ir.Operation instead of arith.AddIOp
  pass

patterns.add(arith.AddIOp, to_muli)
```

In this PR, we convert the operation to its corresponding `OpView`
subclass before invoking the rewrite pattern callback, making it much
easier to write patterns.

---------

Co-authored-by: Maksim Levental <maksim.levental@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants