Skip to content

Conversation

PragmaTwice
Copy link
Member

This is a follow-up PR for #162699.

Currently, in the function where we define rewrite patterns, the op we receive is of type ir.Operation rather than a specific OpView type (such as arith.AddIOp). This means we can’t conveniently access certain parts of the operation — for example, we need to use op.operands[0] instead of op.lhs. The following example code illustrates this situation.

def to_muli(op, rewriter):
  # op is typed ir.Operation instead of arith.AddIOp
  pass

patterns.add(arith.AddIOp, to_muli)

In this PR, we convert the operation to its corresponding OpView subclass before invoking the rewrite pattern callback, making it much easier to write patterns.

@llvmbot
Copy link
Member

llvmbot commented Oct 12, 2025

@llvm/pr-subscribers-mlir

Author: Twice (PragmaTwice)

Changes

This is a follow-up PR for #162699.

Currently, in the function where we define rewrite patterns, the op we receive is of type ir.Operation rather than a specific OpView type (such as arith.AddIOp). This means we can’t conveniently access certain parts of the operation — for example, we need to use op.operands[0] instead of op.lhs. The following example code illustrates this situation.

def to_muli(op, rewriter):
  # op is typed ir.Operation instead of arith.AddIOp
  pass

patterns.add(arith.AddIOp, to_muli)

In this PR, we convert the operation to its corresponding OpView subclass before invoking the rewrite pattern callback, making it much easier to write patterns.


Full diff: https://github.com/llvm/llvm-project/pull/163080.diff

2 Files Affected:

  • (modified) mlir/lib/Bindings/Python/Rewrite.cpp (+6-1)
  • (modified) mlir/test/python/rewrite.py (+3-3)
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 47685567d5355..5ddb3fbbb1317 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -197,7 +197,12 @@ class PyRewritePatternSet {
                                    MlirPatternRewriter rewriter,
                                    void *userData) -> MlirLogicalResult {
       nb::handle f(static_cast<PyObject *>(userData));
-      nb::object res = f(op, PyPatternRewriter(rewriter));
+
+      PyMlirContextRef ctx =
+          PyMlirContext::forContext(mlirOperationGetContext(op));
+      nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();
+
+      nb::object res = f(opView, PyPatternRewriter(rewriter));
       return logicalResultFromObject(res);
     };
     MlirRewritePattern pattern = mlirOpRewritePattenCreate(
diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py
index acf7db23db914..523e6a40f7470 100644
--- a/mlir/test/python/rewrite.py
+++ b/mlir/test/python/rewrite.py
@@ -17,15 +17,15 @@ def run(f):
 def testRewritePattern():
     def to_muli(op, rewriter):
         with rewriter.ip:
-            new_op = arith.muli(op.operands[0], op.operands[1], loc=op.location)
+            new_op = arith.muli(op.lhs, op.rhs, loc=op.location)
         rewriter.replace_op(op, new_op.owner)
 
     def constant_1_to_2(op, rewriter):
-        c = op.attributes["value"].value
+        c = IntegerAttr(op.value).value
         if c != 1:
             return True  # failed to match
         with rewriter.ip:
-            new_op = arith.constant(op.result.type, 2, loc=op.location)
+            new_op = arith.constant(op.type, 2, loc=op.location)
         rewriter.replace_op(op, [new_op])
 
     with Context():

Comment on lines +200 to +202

PyMlirContextRef ctx =
PyMlirContext::forContext(mlirOperationGetContext(op));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should change all the signatures in here to be in terms of Py* so that we don't have to do this for* stuff (each one allocates a new Python object...)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(but we don't need to do that in this PR)

Copy link
Member Author

@PragmaTwice PragmaTwice Oct 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I tried but here is a C callback (The function signature is defined in C API) so it seems hard to use Py* types directly? 🤔

Copy link
Contributor

@makslevental makslevental left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM module one nit!

PragmaTwice and others added 2 commits October 13, 2025 09:22
Co-authored-by: Maksim Levental <maksim.levental@gmail.com>
@llvmbot llvmbot added the mlir:python MLIR Python bindings label Oct 13, 2025
@PragmaTwice
Copy link
Member Author

Thank you all! I'll merge this PR soon.

@PragmaTwice PragmaTwice merged commit 06e2c78 into llvm:main Oct 13, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:python MLIR Python bindings mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants