Skip to content

Commit 8181c3d

Browse files
[MLIR][Python] Expose the insertion point of pattern rewriter (#161001)
In [#160520](#160520), we discussed the current limitations of PDL rewriting in Python (see [this comment](#160520 (comment))). At the moment, we cannot create new operations in PDL native (python) rewrite functions because the `PatternRewriter` APIs are not exposed. This PR introduces bindings to retrieve the insertion point of the `PatternRewriter`, enabling users to create new operations within Python rewrite functions. With this capability, more complex rewrites e.g. with branching and loops that involve op creations become possible. --------- Co-authored-by: Maksim Levental <maksim.levental@gmail.com>
1 parent 074308c commit 8181c3d

File tree

6 files changed

+154
-5
lines changed

6 files changed

+154
-5
lines changed

mlir/include/mlir-c/Rewrite.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,12 @@ mlirRewriterBaseGetInsertionBlock(MlirRewriterBase rewriter);
101101
MLIR_CAPI_EXPORTED MlirBlock
102102
mlirRewriterBaseGetBlock(MlirRewriterBase rewriter);
103103

104+
/// Returns the operation right after the current insertion point
105+
/// of the rewriter. A null MlirOperation will be returned
106+
// if the current insertion point is at the end of the block.
107+
MLIR_CAPI_EXPORTED MlirOperation
108+
mlirRewriterBaseGetOperationAfterInsertion(MlirRewriterBase rewriter);
109+
104110
//===----------------------------------------------------------------------===//
105111
/// Block and operation creation/insertion/cloning
106112
//===----------------------------------------------------------------------===//
@@ -310,6 +316,14 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
310316
MlirModule op, MlirFrozenRewritePatternSet patterns,
311317
MlirGreedyRewriteDriverConfig);
312318

319+
//===----------------------------------------------------------------------===//
320+
/// PatternRewriter API
321+
//===----------------------------------------------------------------------===//
322+
323+
/// Cast the PatternRewriter to a RewriterBase
324+
MLIR_CAPI_EXPORTED MlirRewriterBase
325+
mlirPatternRewriterAsBase(MlirPatternRewriter rewriter);
326+
313327
//===----------------------------------------------------------------------===//
314328
/// PDLPatternModule API
315329
//===----------------------------------------------------------------------===//

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2046,6 +2046,9 @@ PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
20462046
: refOperation(beforeOperationBase.getOperation().getRef()),
20472047
block((*refOperation)->getBlock()) {}
20482048

2049+
PyInsertionPoint::PyInsertionPoint(PyOperationRef beforeOperationRef)
2050+
: refOperation(beforeOperationRef), block((*refOperation)->getBlock()) {}
2051+
20492052
void PyInsertionPoint::insert(PyOperationBase &operationBase) {
20502053
PyOperation &operation = operationBase.getOperation();
20512054
if (operation.isAttached())

mlir/lib/Bindings/Python/IRModule.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -841,6 +841,8 @@ class PyInsertionPoint {
841841
PyInsertionPoint(const PyBlock &block);
842842
/// Creates an insertion point positioned before a reference operation.
843843
PyInsertionPoint(PyOperationBase &beforeOperationBase);
844+
/// Creates an insertion point positioned before a reference operation.
845+
PyInsertionPoint(PyOperationRef beforeOperationRef);
844846

845847
/// Shortcut to create an insertion point at the beginning of the block.
846848
static PyInsertionPoint atBlockBegin(PyBlock &block);

mlir/lib/Bindings/Python/Rewrite.cpp

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,30 @@ using namespace mlir::python;
2626

2727
namespace {
2828

29+
class PyPatternRewriter {
30+
public:
31+
PyPatternRewriter(MlirPatternRewriter rewriter)
32+
: base(mlirPatternRewriterAsBase(rewriter)),
33+
ctx(PyMlirContext::forContext(mlirRewriterBaseGetContext(base))) {}
34+
35+
PyInsertionPoint getInsertionPoint() const {
36+
MlirBlock block = mlirRewriterBaseGetInsertionBlock(base);
37+
MlirOperation op = mlirRewriterBaseGetOperationAfterInsertion(base);
38+
39+
if (mlirOperationIsNull(op)) {
40+
MlirOperation owner = mlirBlockGetParentOperation(block);
41+
auto parent = PyOperation::forOperation(ctx, owner);
42+
return PyInsertionPoint(PyBlock(parent, block));
43+
}
44+
45+
return PyInsertionPoint(PyOperation::forOperation(ctx, op));
46+
}
47+
48+
private:
49+
MlirRewriterBase base;
50+
PyMlirContextRef ctx;
51+
};
52+
2953
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
3054
static nb::object objectFromPDLValue(MlirPDLValue value) {
3155
if (MlirValue v = mlirPDLValueAsValue(value); !mlirValueIsNull(v))
@@ -84,7 +108,8 @@ class PyPDLPatternModule {
84108
void *userData) -> MlirLogicalResult {
85109
nb::handle f = nb::handle(static_cast<PyObject *>(userData));
86110
return logicalResultFromObject(
87-
f(rewriter, results, objectsFromPDLValues(nValues, values)));
111+
f(PyPatternRewriter(rewriter), results,
112+
objectsFromPDLValues(nValues, values)));
88113
},
89114
fn.ptr());
90115
}
@@ -98,7 +123,8 @@ class PyPDLPatternModule {
98123
void *userData) -> MlirLogicalResult {
99124
nb::handle f = nb::handle(static_cast<PyObject *>(userData));
100125
return logicalResultFromObject(
101-
f(rewriter, results, objectsFromPDLValues(nValues, values)));
126+
f(PyPatternRewriter(rewriter), results,
127+
objectsFromPDLValues(nValues, values)));
102128
},
103129
fn.ptr());
104130
}
@@ -143,7 +169,9 @@ class PyFrozenRewritePatternSet {
143169

144170
/// Create the `mlir.rewrite` here.
145171
void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
146-
nb::class_<MlirPatternRewriter>(m, "PatternRewriter");
172+
nb::class_<PyPatternRewriter>(m, "PatternRewriter")
173+
.def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint,
174+
"The current insertion point of the PatternRewriter.");
147175
//----------------------------------------------------------------------------
148176
// Mapping of the PDLResultList and PDLModule
149177
//----------------------------------------------------------------------------

mlir/lib/CAPI/Transforms/Rewrite.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,17 @@ MlirBlock mlirRewriterBaseGetBlock(MlirRewriterBase rewriter) {
7070
return wrap(unwrap(rewriter)->getBlock());
7171
}
7272

73+
MlirOperation
74+
mlirRewriterBaseGetOperationAfterInsertion(MlirRewriterBase rewriter) {
75+
mlir::RewriterBase *base = unwrap(rewriter);
76+
mlir::Block *block = base->getInsertionBlock();
77+
mlir::Block::iterator it = base->getInsertionPoint();
78+
if (it == block->end())
79+
return {nullptr};
80+
81+
return wrap(std::addressof(*it));
82+
}
83+
7384
//===----------------------------------------------------------------------===//
7485
/// Block and operation creation/insertion/cloning
7586
//===----------------------------------------------------------------------===//
@@ -317,6 +328,10 @@ inline MlirPatternRewriter wrap(mlir::PatternRewriter *rewriter) {
317328
return {rewriter};
318329
}
319330

331+
MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter) {
332+
return wrap(static_cast<mlir::RewriterBase *>(unwrap(rewriter)));
333+
}
334+
320335
//===----------------------------------------------------------------------===//
321336
/// PDLPatternModule API
322337
//===----------------------------------------------------------------------===//

mlir/test/python/integration/dialects/pdl.py

Lines changed: 89 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def construct_and_print_in_module(f):
1616
print(module)
1717
return f
1818

19+
1920
def get_pdl_patterns():
2021
# Create a rewrite from add to mul. This will match
2122
# - operation name is arith.addi
@@ -121,8 +122,10 @@ def load_myint_dialect():
121122

122123

123124
# This PDL pattern is to fold constant additions,
124-
# i.e. add(constant0, constant1) -> constant2
125-
# where constant2 = constant0 + constant1.
125+
# including two patterns:
126+
# 1. add(constant0, constant1) -> constant2
127+
# where constant2 = constant0 + constant1;
128+
# 2. add(x, 0) or add(0, x) -> x.
126129
def get_pdl_pattern_fold():
127130
m = Module.create()
128131
i32 = IntegerType.get_signless(32)
@@ -237,3 +240,87 @@ def test_pdl_register_function_constraint(module_):
237240
apply_patterns_and_fold_greedily(module_, frozen)
238241

239242
return module_
243+
244+
245+
# This pattern is to expand constant to additions
246+
# unless the constant is no more than 1,
247+
# e.g. 3 -> 1 + 2 -> 1 + (1 + 1).
248+
def get_pdl_pattern_expand():
249+
m = Module.create()
250+
i32 = IntegerType.get_signless(32)
251+
with InsertionPoint(m.body):
252+
253+
@pdl.pattern(benefit=1, sym_name="myint_constant_expand")
254+
def pat():
255+
t = pdl.TypeOp(i32)
256+
cst = pdl.AttributeOp()
257+
pdl.apply_native_constraint([], "is_one", [cst])
258+
op0 = pdl.OperationOp(
259+
name="myint.constant", attributes={"value": cst}, types=[t]
260+
)
261+
262+
@pdl.rewrite()
263+
def rew():
264+
expanded = pdl.apply_native_rewrite(
265+
[pdl.OperationType.get()], "expand", [cst]
266+
)
267+
pdl.ReplaceOp(op0, with_op=expanded)
268+
269+
def is_one(rewriter, results, values):
270+
cst = values[0].value
271+
return cst <= 1
272+
273+
def expand(rewriter, results, values):
274+
cst = values[0].value
275+
c1 = cst // 2
276+
c2 = cst - c1
277+
with rewriter.ip:
278+
op1 = Operation.create(
279+
"myint.constant",
280+
results=[i32],
281+
attributes={"value": IntegerAttr.get(i32, c1)},
282+
)
283+
op2 = Operation.create(
284+
"myint.constant",
285+
results=[i32],
286+
attributes={"value": IntegerAttr.get(i32, c2)},
287+
)
288+
res = Operation.create(
289+
"myint.add", results=[i32], operands=[op1.result, op2.result]
290+
)
291+
results.append(res)
292+
293+
pdl_module = PDLModule(m)
294+
pdl_module.register_constraint_function("is_one", is_one)
295+
pdl_module.register_rewrite_function("expand", expand)
296+
return pdl_module.freeze()
297+
298+
299+
# CHECK-LABEL: TEST: test_pdl_register_function_expand
300+
# CHECK: %0 = "myint.constant"() {value = 1 : i32} : () -> i32
301+
# CHECK: %1 = "myint.constant"() {value = 1 : i32} : () -> i32
302+
# CHECK: %2 = "myint.add"(%0, %1) : (i32, i32) -> i32
303+
# CHECK: %3 = "myint.constant"() {value = 1 : i32} : () -> i32
304+
# CHECK: %4 = "myint.constant"() {value = 1 : i32} : () -> i32
305+
# CHECK: %5 = "myint.constant"() {value = 1 : i32} : () -> i32
306+
# CHECK: %6 = "myint.add"(%4, %5) : (i32, i32) -> i32
307+
# CHECK: %7 = "myint.add"(%3, %6) : (i32, i32) -> i32
308+
# CHECK: %8 = "myint.add"(%2, %7) : (i32, i32) -> i32
309+
# CHECK: return %8 : i32
310+
@construct_and_print_in_module
311+
def test_pdl_register_function_expand(module_):
312+
load_myint_dialect()
313+
314+
module_ = Module.parse(
315+
"""
316+
func.func @f() -> i32 {
317+
%0 = "myint.constant"() { value = 5 }: () -> (i32)
318+
return %0 : i32
319+
}
320+
"""
321+
)
322+
323+
frozen = get_pdl_pattern_expand()
324+
apply_patterns_and_fold_greedily(module_, frozen)
325+
326+
return module_

0 commit comments

Comments
 (0)