-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR][Python] Expose the insertion point of pattern rewriter #161001
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
✅ With the latest revision this PR passed the Python code formatter. |
mlir/lib/Bindings/Python/Rewrite.cpp
Outdated
nb::class_<MlirPatternRewriter>(m, "PatternRewriter") | ||
.def("ip", [](MlirPatternRewriter rewriter) { | ||
MlirRewriterBase base = mlirPatternRewriterAsBase(rewriter); | ||
MlirBlock block = mlirRewriterBaseGetInsertionBlock(base); | ||
MlirOperation op = mlirRewriterBaseGetOperationAfterInsertion(base); | ||
MlirOperation owner = mlirBlockGetParentOperation(block); | ||
auto ctx = PyMlirContext::forContext(mlirRewriterBaseGetContext(base)) | ||
->getRef(); | ||
if (mlirOperationIsNull(op)) { | ||
auto parent = PyOperation::forOperation(ctx, owner); | ||
return PyInsertionPoint(PyBlock(parent, block)); | ||
} | ||
|
||
return PyInsertionPoint(*PyOperation::forOperation(ctx, op).get()); | ||
}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if there is a good way to cast Mlir*
CAPI types into Py*
C++ classes. It seems that here we don't need to care too much about lifetime of blocks/operations (as long as the insertion point does not escape from the scope of the rewrite callback). 🤔
I'll try to define something like class PyPatternRewriter
and see if that makes the code cleaner.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in 68264af.
@llvm/pr-subscribers-mlir Author: Twice (PragmaTwice) ChangesIn #160520, we discussed the current limitations of PDL rewriting in Python (see this comment). At the moment, we cannot create new operations in PDL native (python) rewrite functions because the This PR introduces bindings to retrieve the insertion point of the Full diff: https://github.com/llvm/llvm-project/pull/161001.diff 6 Files Affected:
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 77be1f480eacf..5dd285ee076c4 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -101,6 +101,12 @@ mlirRewriterBaseGetInsertionBlock(MlirRewriterBase rewriter);
MLIR_CAPI_EXPORTED MlirBlock
mlirRewriterBaseGetBlock(MlirRewriterBase rewriter);
+/// Returns the operation right after the current insertion point
+/// of the rewriter. A null MlirOperation will be returned
+// if the current insertion point is at the end of the block.
+MLIR_CAPI_EXPORTED MlirOperation
+mlirRewriterBaseGetOperationAfterInsertion(MlirRewriterBase rewriter);
+
//===----------------------------------------------------------------------===//
/// Block and operation creation/insertion/cloning
//===----------------------------------------------------------------------===//
@@ -310,6 +316,14 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
MlirModule op, MlirFrozenRewritePatternSet patterns,
MlirGreedyRewriteDriverConfig);
+//===----------------------------------------------------------------------===//
+/// PatternRewriter API
+//===----------------------------------------------------------------------===//
+
+/// Cast the PatternRewriter to a RewriterBase
+MLIR_CAPI_EXPORTED MlirRewriterBase
+mlirPatternRewriterAsBase(MlirPatternRewriter rewriter);
+
//===----------------------------------------------------------------------===//
/// PDLPatternModule API
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 32b2b0c648cff..7b1710656243a 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2046,6 +2046,9 @@ PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
: refOperation(beforeOperationBase.getOperation().getRef()),
block((*refOperation)->getBlock()) {}
+PyInsertionPoint::PyInsertionPoint(PyOperationRef beforeOperationRef)
+ : refOperation(beforeOperationRef), block((*refOperation)->getBlock()) {}
+
void PyInsertionPoint::insert(PyOperationBase &operationBase) {
PyOperation &operation = operationBase.getOperation();
if (operation.isAttached())
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index edbd73eade906..e706be3b4d32a 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -841,6 +841,8 @@ class PyInsertionPoint {
PyInsertionPoint(const PyBlock &block);
/// Creates an insertion point positioned before a reference operation.
PyInsertionPoint(PyOperationBase &beforeOperationBase);
+ /// Creates an insertion point positioned before a reference operation.
+ PyInsertionPoint(PyOperationRef beforeOperationRef);
/// Shortcut to create an insertion point at the beginning of the block.
static PyInsertionPoint atBlockBegin(PyBlock &block);
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 836f44fd7d4be..10b539a7b3c07 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -26,6 +26,31 @@ using namespace mlir::python;
namespace {
+class PyPatternRewriter {
+public:
+ PyPatternRewriter(MlirPatternRewriter rewriter)
+ : rewriter(rewriter), base(mlirPatternRewriterAsBase(rewriter)),
+ ctx(PyMlirContext::forContext(mlirRewriterBaseGetContext(base))) {}
+
+ PyInsertionPoint getInsertionPoint() const {
+ MlirBlock block = mlirRewriterBaseGetInsertionBlock(base);
+ MlirOperation op = mlirRewriterBaseGetOperationAfterInsertion(base);
+
+ if (mlirOperationIsNull(op)) {
+ MlirOperation owner = mlirBlockGetParentOperation(block);
+ auto parent = PyOperation::forOperation(ctx, owner);
+ return PyInsertionPoint(PyBlock(parent, block));
+ }
+
+ return PyInsertionPoint(PyOperation::forOperation(ctx, op));
+ }
+
+private:
+ MlirPatternRewriter rewriter [[maybe_unused]];
+ MlirRewriterBase base;
+ PyMlirContextRef ctx;
+};
+
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
static nb::object objectFromPDLValue(MlirPDLValue value) {
if (MlirValue v = mlirPDLValueAsValue(value); !mlirValueIsNull(v))
@@ -84,7 +109,8 @@ class PyPDLPatternModule {
void *userData) -> MlirLogicalResult {
nb::handle f = nb::handle(static_cast<PyObject *>(userData));
return logicalResultFromObject(
- f(rewriter, results, objectsFromPDLValues(nValues, values)));
+ f(PyPatternRewriter(rewriter), results,
+ objectsFromPDLValues(nValues, values)));
},
fn.ptr());
}
@@ -98,7 +124,8 @@ class PyPDLPatternModule {
void *userData) -> MlirLogicalResult {
nb::handle f = nb::handle(static_cast<PyObject *>(userData));
return logicalResultFromObject(
- f(rewriter, results, objectsFromPDLValues(nValues, values)));
+ f(PyPatternRewriter(rewriter), results,
+ objectsFromPDLValues(nValues, values)));
},
fn.ptr());
}
@@ -143,7 +170,8 @@ class PyFrozenRewritePatternSet {
/// Create the `mlir.rewrite` here.
void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
- nb::class_<MlirPatternRewriter>(m, "PatternRewriter");
+ nb::class_<PyPatternRewriter>(m, "PyPatternRewriter")
+ .def("ip", &PyPatternRewriter::getInsertionPoint);
//----------------------------------------------------------------------------
// Mapping of the PDLResultList and PDLModule
//----------------------------------------------------------------------------
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index 8ee6308cadf83..b149d35f0d88b 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -70,6 +70,18 @@ MlirBlock mlirRewriterBaseGetBlock(MlirRewriterBase rewriter) {
return wrap(unwrap(rewriter)->getBlock());
}
+MlirOperation
+mlirRewriterBaseGetOperationAfterInsertion(MlirRewriterBase rewriter) {
+ mlir::RewriterBase *base = unwrap(rewriter);
+ mlir::Block *block = base->getInsertionBlock();
+ auto it = base->getInsertionPoint();
+ if (it == block->end()) {
+ return {nullptr};
+ }
+
+ return wrap(std::addressof(*it));
+}
+
//===----------------------------------------------------------------------===//
/// Block and operation creation/insertion/cloning
//===----------------------------------------------------------------------===//
@@ -316,6 +328,10 @@ inline MlirPatternRewriter wrap(mlir::PatternRewriter *rewriter) {
return {rewriter};
}
+MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter) {
+ return wrap(static_cast<mlir::RewriterBase *>(unwrap(rewriter)));
+}
+
//===----------------------------------------------------------------------===//
/// PDLPatternModule API
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/python/integration/dialects/pdl.py b/mlir/test/python/integration/dialects/pdl.py
index c8e6197e03842..752d213673a70 100644
--- a/mlir/test/python/integration/dialects/pdl.py
+++ b/mlir/test/python/integration/dialects/pdl.py
@@ -16,6 +16,7 @@ def construct_and_print_in_module(f):
print(module)
return f
+
def get_pdl_patterns():
# Create a rewrite from add to mul. This will match
# - operation name is arith.addi
@@ -121,8 +122,10 @@ def load_myint_dialect():
# This PDL pattern is to fold constant additions,
-# i.e. add(constant0, constant1) -> constant2
-# where constant2 = constant0 + constant1.
+# including two patterns:
+# 1. add(constant0, constant1) -> constant2
+# where constant2 = constant0 + constant1;
+# 2. add(x, 0) or add(0, x) -> x.
def get_pdl_pattern_fold():
m = Module.create()
i32 = IntegerType.get_signless(32)
@@ -237,3 +240,87 @@ def test_pdl_register_function_constraint(module_):
apply_patterns_and_fold_greedily(module_, frozen)
return module_
+
+
+# This pattern is to expand constant to additions
+# unless the constant is no more than 1,
+# e.g. 3 -> 1 + 2 -> 1 + (1 + 1).
+def get_pdl_pattern_expand():
+ m = Module.create()
+ i32 = IntegerType.get_signless(32)
+ with InsertionPoint(m.body):
+
+ @pdl.pattern(benefit=1, sym_name="myint_constant_expand")
+ def pat():
+ t = pdl.TypeOp(i32)
+ cst = pdl.AttributeOp()
+ pdl.apply_native_constraint([], "is_one", [cst])
+ op0 = pdl.OperationOp(
+ name="myint.constant", attributes={"value": cst}, types=[t]
+ )
+
+ @pdl.rewrite()
+ def rew():
+ expanded = pdl.apply_native_rewrite(
+ [pdl.OperationType.get()], "expand", [cst]
+ )
+ pdl.ReplaceOp(op0, with_op=expanded)
+
+ def is_one(rewriter, results, values):
+ cst = values[0].value
+ return cst <= 1
+
+ def expand(rewriter, results, values):
+ cst = values[0].value
+ c1 = cst // 2
+ c2 = cst - c1
+ with rewriter.ip():
+ op1 = Operation.create(
+ "myint.constant",
+ results=[i32],
+ attributes={"value": IntegerAttr.get(i32, c1)},
+ )
+ op2 = Operation.create(
+ "myint.constant",
+ results=[i32],
+ attributes={"value": IntegerAttr.get(i32, c2)},
+ )
+ res = Operation.create(
+ "myint.add", results=[i32], operands=[op1.result, op2.result]
+ )
+ results.append(res)
+
+ pdl_module = PDLModule(m)
+ pdl_module.register_constraint_function("is_one", is_one)
+ pdl_module.register_rewrite_function("expand", expand)
+ return pdl_module.freeze()
+
+
+# CHECK-LABEL: TEST: test_pdl_register_function_expand
+# CHECK: %0 = "myint.constant"() {value = 1 : i32} : () -> i32
+# CHECK: %1 = "myint.constant"() {value = 1 : i32} : () -> i32
+# CHECK: %2 = "myint.add"(%0, %1) : (i32, i32) -> i32
+# CHECK: %3 = "myint.constant"() {value = 1 : i32} : () -> i32
+# CHECK: %4 = "myint.constant"() {value = 1 : i32} : () -> i32
+# CHECK: %5 = "myint.constant"() {value = 1 : i32} : () -> i32
+# CHECK: %6 = "myint.add"(%4, %5) : (i32, i32) -> i32
+# CHECK: %7 = "myint.add"(%3, %6) : (i32, i32) -> i32
+# CHECK: %8 = "myint.add"(%2, %7) : (i32, i32) -> i32
+# CHECK: return %8 : i32
+@construct_and_print_in_module
+def test_pdl_register_function_expand(module_):
+ load_myint_dialect()
+
+ module_ = Module.parse(
+ """
+ func.func @f() -> i32 {
+ %0 = "myint.constant"() { value = 5 }: () -> (i32)
+ return %0 : i32
+ }
+ """
+ )
+
+ frozen = get_pdl_pattern_expand()
+ apply_patterns_and_fold_greedily(module_, frozen)
+
+ return module_
|
def expand(rewriter, results, values): | ||
cst = values[0].value | ||
c1 = cst // 2 | ||
c2 = cst - c1 | ||
with rewriter.ip(): | ||
op1 = Operation.create( | ||
"myint.constant", | ||
results=[i32], | ||
attributes={"value": IntegerAttr.get(i32, c1)}, | ||
) | ||
op2 = Operation.create( | ||
"myint.constant", | ||
results=[i32], | ||
attributes={"value": IntegerAttr.get(i32, c2)}, | ||
) | ||
res = Operation.create( | ||
"myint.add", results=[i32], operands=[op1.result, op2.result] | ||
) | ||
results.append(res) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function as an example of retrieving and using the insertion point of the rewriter.
Co-authored-by: Maksim Levental <maksim.levental@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
In #160520, we discussed the current limitations of PDL rewriting in Python (see this comment). At the moment, we cannot create new operations in PDL native (python) rewrite functions because the
PatternRewriter
APIs are not exposed.This PR introduces bindings to retrieve the insertion point of the
PatternRewriter
, enabling users to create new operations within Python rewrite functions. With this capability, more complex rewrites e.g. with branching and loops that involve op creations become possible.