-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR][Python] Pass OpView subclasses instead of Operation in rewrite patterns #163080
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][Python] Pass OpView subclasses instead of Operation in rewrite patterns #163080
Conversation
@llvm/pr-subscribers-mlir Author: Twice (PragmaTwice) ChangesThis is a follow-up PR for #162699. Currently, in the function where we define rewrite patterns, the def to_muli(op, rewriter):
# op is typed ir.Operation instead of arith.AddIOp
pass
patterns.add(arith.AddIOp, to_muli) In this PR, we convert the operation to its corresponding Full diff: https://github.com/llvm/llvm-project/pull/163080.diff 2 Files Affected:
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 47685567d5355..5ddb3fbbb1317 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -197,7 +197,12 @@ class PyRewritePatternSet {
MlirPatternRewriter rewriter,
void *userData) -> MlirLogicalResult {
nb::handle f(static_cast<PyObject *>(userData));
- nb::object res = f(op, PyPatternRewriter(rewriter));
+
+ PyMlirContextRef ctx =
+ PyMlirContext::forContext(mlirOperationGetContext(op));
+ nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();
+
+ nb::object res = f(opView, PyPatternRewriter(rewriter));
return logicalResultFromObject(res);
};
MlirRewritePattern pattern = mlirOpRewritePattenCreate(
diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py
index acf7db23db914..523e6a40f7470 100644
--- a/mlir/test/python/rewrite.py
+++ b/mlir/test/python/rewrite.py
@@ -17,15 +17,15 @@ def run(f):
def testRewritePattern():
def to_muli(op, rewriter):
with rewriter.ip:
- new_op = arith.muli(op.operands[0], op.operands[1], loc=op.location)
+ new_op = arith.muli(op.lhs, op.rhs, loc=op.location)
rewriter.replace_op(op, new_op.owner)
def constant_1_to_2(op, rewriter):
- c = op.attributes["value"].value
+ c = IntegerAttr(op.value).value
if c != 1:
return True # failed to match
with rewriter.ip:
- new_op = arith.constant(op.result.type, 2, loc=op.location)
+ new_op = arith.constant(op.type, 2, loc=op.location)
rewriter.replace_op(op, [new_op])
with Context():
|
|
||
PyMlirContextRef ctx = | ||
PyMlirContext::forContext(mlirOperationGetContext(op)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should change all the signatures in here to be in terms of Py*
so that we don't have to do this for*
stuff (each one allocates a new Python object...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(but we don't need to do that in this PR)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I tried but here is a C callback (The function signature is defined in C API) so it seems hard to use Py*
types directly? 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM module one nit!
Co-authored-by: Maksim Levental <maksim.levental@gmail.com>
Thank you all! I'll merge this PR soon. |
This is a follow-up PR for #162699.
Currently, in the function where we define rewrite patterns, the
op
we receive is of typeir.Operation
rather than a specificOpView
type (such asarith.AddIOp
). This means we can’t conveniently access certain parts of the operation — for example, we need to useop.operands[0]
instead ofop.lhs
. The following example code illustrates this situation.In this PR, we convert the operation to its corresponding
OpView
subclass before invoking the rewrite pattern callback, making it much easier to write patterns.