Skip to content
14 changes: 14 additions & 0 deletions mlir/include/mlir-c/Rewrite.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ mlirRewriterBaseGetInsertionBlock(MlirRewriterBase rewriter);
MLIR_CAPI_EXPORTED MlirBlock
mlirRewriterBaseGetBlock(MlirRewriterBase rewriter);

/// Returns the operation right after the current insertion point
/// of the rewriter. A null MlirOperation will be returned
// if the current insertion point is at the end of the block.
MLIR_CAPI_EXPORTED MlirOperation
mlirRewriterBaseGetOperationAfterInsertion(MlirRewriterBase rewriter);

//===----------------------------------------------------------------------===//
/// Block and operation creation/insertion/cloning
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -310,6 +316,14 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
MlirModule op, MlirFrozenRewritePatternSet patterns,
MlirGreedyRewriteDriverConfig);

//===----------------------------------------------------------------------===//
/// PatternRewriter API
//===----------------------------------------------------------------------===//

/// Cast the PatternRewriter to a RewriterBase
MLIR_CAPI_EXPORTED MlirRewriterBase
mlirPatternRewriterAsBase(MlirPatternRewriter rewriter);

//===----------------------------------------------------------------------===//
/// PDLPatternModule API
//===----------------------------------------------------------------------===//
Expand Down
3 changes: 3 additions & 0 deletions mlir/lib/Bindings/Python/IRCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2046,6 +2046,9 @@ PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
: refOperation(beforeOperationBase.getOperation().getRef()),
block((*refOperation)->getBlock()) {}

PyInsertionPoint::PyInsertionPoint(PyOperationRef beforeOperationRef)
: refOperation(beforeOperationRef), block((*refOperation)->getBlock()) {}

void PyInsertionPoint::insert(PyOperationBase &operationBase) {
PyOperation &operation = operationBase.getOperation();
if (operation.isAttached())
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Bindings/Python/IRModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,8 @@ class PyInsertionPoint {
PyInsertionPoint(const PyBlock &block);
/// Creates an insertion point positioned before a reference operation.
PyInsertionPoint(PyOperationBase &beforeOperationBase);
/// Creates an insertion point positioned before a reference operation.
PyInsertionPoint(PyOperationRef beforeOperationRef);

/// Shortcut to create an insertion point at the beginning of the block.
static PyInsertionPoint atBlockBegin(PyBlock &block);
Expand Down
34 changes: 31 additions & 3 deletions mlir/lib/Bindings/Python/Rewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,30 @@ using namespace mlir::python;

namespace {

class PyPatternRewriter {
public:
PyPatternRewriter(MlirPatternRewriter rewriter)
: base(mlirPatternRewriterAsBase(rewriter)),
ctx(PyMlirContext::forContext(mlirRewriterBaseGetContext(base))) {}

PyInsertionPoint getInsertionPoint() const {
MlirBlock block = mlirRewriterBaseGetInsertionBlock(base);
MlirOperation op = mlirRewriterBaseGetOperationAfterInsertion(base);

if (mlirOperationIsNull(op)) {
MlirOperation owner = mlirBlockGetParentOperation(block);
auto parent = PyOperation::forOperation(ctx, owner);
return PyInsertionPoint(PyBlock(parent, block));
}

return PyInsertionPoint(PyOperation::forOperation(ctx, op));
}

private:
MlirRewriterBase base;
PyMlirContextRef ctx;
};

#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
static nb::object objectFromPDLValue(MlirPDLValue value) {
if (MlirValue v = mlirPDLValueAsValue(value); !mlirValueIsNull(v))
Expand Down Expand Up @@ -84,7 +108,8 @@ class PyPDLPatternModule {
void *userData) -> MlirLogicalResult {
nb::handle f = nb::handle(static_cast<PyObject *>(userData));
return logicalResultFromObject(
f(rewriter, results, objectsFromPDLValues(nValues, values)));
f(PyPatternRewriter(rewriter), results,
objectsFromPDLValues(nValues, values)));
},
fn.ptr());
}
Expand All @@ -98,7 +123,8 @@ class PyPDLPatternModule {
void *userData) -> MlirLogicalResult {
nb::handle f = nb::handle(static_cast<PyObject *>(userData));
return logicalResultFromObject(
f(rewriter, results, objectsFromPDLValues(nValues, values)));
f(PyPatternRewriter(rewriter), results,
objectsFromPDLValues(nValues, values)));
},
fn.ptr());
}
Expand Down Expand Up @@ -143,7 +169,9 @@ class PyFrozenRewritePatternSet {

/// Create the `mlir.rewrite` here.
void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
nb::class_<MlirPatternRewriter>(m, "PatternRewriter");
nb::class_<PyPatternRewriter>(m, "PatternRewriter")
.def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint,
"The current insertion point of the PatternRewriter.");
//----------------------------------------------------------------------------
// Mapping of the PDLResultList and PDLModule
//----------------------------------------------------------------------------
Expand Down
15 changes: 15 additions & 0 deletions mlir/lib/CAPI/Transforms/Rewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,17 @@ MlirBlock mlirRewriterBaseGetBlock(MlirRewriterBase rewriter) {
return wrap(unwrap(rewriter)->getBlock());
}

MlirOperation
mlirRewriterBaseGetOperationAfterInsertion(MlirRewriterBase rewriter) {
mlir::RewriterBase *base = unwrap(rewriter);
mlir::Block *block = base->getInsertionBlock();
mlir::Block::iterator it = base->getInsertionPoint();
if (it == block->end())
return {nullptr};

return wrap(std::addressof(*it));
}

//===----------------------------------------------------------------------===//
/// Block and operation creation/insertion/cloning
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -316,6 +327,10 @@ inline MlirPatternRewriter wrap(mlir::PatternRewriter *rewriter) {
return {rewriter};
}

MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter) {
return wrap(static_cast<mlir::RewriterBase *>(unwrap(rewriter)));
}

//===----------------------------------------------------------------------===//
/// PDLPatternModule API
//===----------------------------------------------------------------------===//
Expand Down
91 changes: 89 additions & 2 deletions mlir/test/python/integration/dialects/pdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def construct_and_print_in_module(f):
print(module)
return f


def get_pdl_patterns():
# Create a rewrite from add to mul. This will match
# - operation name is arith.addi
Expand Down Expand Up @@ -121,8 +122,10 @@ def load_myint_dialect():


# This PDL pattern is to fold constant additions,
# i.e. add(constant0, constant1) -> constant2
# where constant2 = constant0 + constant1.
# including two patterns:
# 1. add(constant0, constant1) -> constant2
# where constant2 = constant0 + constant1;
# 2. add(x, 0) or add(0, x) -> x.
def get_pdl_pattern_fold():
m = Module.create()
i32 = IntegerType.get_signless(32)
Expand Down Expand Up @@ -237,3 +240,87 @@ def test_pdl_register_function_constraint(module_):
apply_patterns_and_fold_greedily(module_, frozen)

return module_


# This pattern is to expand constant to additions
# unless the constant is no more than 1,
# e.g. 3 -> 1 + 2 -> 1 + (1 + 1).
def get_pdl_pattern_expand():
m = Module.create()
i32 = IntegerType.get_signless(32)
with InsertionPoint(m.body):

@pdl.pattern(benefit=1, sym_name="myint_constant_expand")
def pat():
t = pdl.TypeOp(i32)
cst = pdl.AttributeOp()
pdl.apply_native_constraint([], "is_one", [cst])
op0 = pdl.OperationOp(
name="myint.constant", attributes={"value": cst}, types=[t]
)

@pdl.rewrite()
def rew():
expanded = pdl.apply_native_rewrite(
[pdl.OperationType.get()], "expand", [cst]
)
pdl.ReplaceOp(op0, with_op=expanded)

def is_one(rewriter, results, values):
cst = values[0].value
return cst <= 1

def expand(rewriter, results, values):
cst = values[0].value
c1 = cst // 2
c2 = cst - c1
with rewriter.ip:
op1 = Operation.create(
"myint.constant",
results=[i32],
attributes={"value": IntegerAttr.get(i32, c1)},
)
op2 = Operation.create(
"myint.constant",
results=[i32],
attributes={"value": IntegerAttr.get(i32, c2)},
)
res = Operation.create(
"myint.add", results=[i32], operands=[op1.result, op2.result]
)
results.append(res)

pdl_module = PDLModule(m)
pdl_module.register_constraint_function("is_one", is_one)
pdl_module.register_rewrite_function("expand", expand)
return pdl_module.freeze()


# CHECK-LABEL: TEST: test_pdl_register_function_expand
# CHECK: %0 = "myint.constant"() {value = 1 : i32} : () -> i32
# CHECK: %1 = "myint.constant"() {value = 1 : i32} : () -> i32
# CHECK: %2 = "myint.add"(%0, %1) : (i32, i32) -> i32
# CHECK: %3 = "myint.constant"() {value = 1 : i32} : () -> i32
# CHECK: %4 = "myint.constant"() {value = 1 : i32} : () -> i32
# CHECK: %5 = "myint.constant"() {value = 1 : i32} : () -> i32
# CHECK: %6 = "myint.add"(%4, %5) : (i32, i32) -> i32
# CHECK: %7 = "myint.add"(%3, %6) : (i32, i32) -> i32
# CHECK: %8 = "myint.add"(%2, %7) : (i32, i32) -> i32
# CHECK: return %8 : i32
@construct_and_print_in_module
def test_pdl_register_function_expand(module_):
load_myint_dialect()

module_ = Module.parse(
"""
func.func @f() -> i32 {
%0 = "myint.constant"() { value = 5 }: () -> (i32)
return %0 : i32
}
"""
)

frozen = get_pdl_pattern_expand()
apply_patterns_and_fold_greedily(module_, frozen)

return module_