Skip to content

Commit

Permalink
[mlir] Add RewriterBase::replaceAllUsesWith for Blocks.
Browse files Browse the repository at this point in the history
When changing IR in a RewriterPattern, all changes must go through the
rewriter. There are several convenience functions in RewriterBase that
help with high-level modifications, such as replaceAllUsesWith for
Values, but there is currently none to do the same task for Blocks.

Reviewed By: mehdi_amini, ingomueller-net

Differential Revision: https://reviews.llvm.org/D142525
  • Loading branch information
ingomueller-net committed Feb 15, 2023
1 parent 72429a4 commit 4bba8bd
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 12 deletions.
11 changes: 10 additions & 1 deletion mlir/include/mlir/IR/PatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,16 @@ class RewriterBase : public OpBuilder, public OpBuilder::Listener {
/// Find uses of `from` and replace them with `to`. It also marks every
/// modified uses and notifies the rewriter that an in-place operation
/// modification is about to happen.
void replaceAllUsesWith(Value from, Value to);
void replaceAllUsesWith(Value from, Value to) {
return replaceAllUsesWith(from.getImpl(), to);
}
template <typename OperandType, typename ValueT>
void replaceAllUsesWith(IRObjectWithUseList<OperandType> *from, ValueT &&to) {
for (OperandType &operand : llvm::make_early_inc_range(from->getUses())) {
Operation *op = operand.getOwner();
updateRootInPlace(op, [&]() { operand.set(to); });
}
}

/// Find uses of `from` and replace them with `to` if the `functor` returns
/// true. It also marks every modified uses and notifies the rewriter that an
Expand Down
8 changes: 0 additions & 8 deletions mlir/lib/IR/PatternMatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,14 +309,6 @@ void RewriterBase::mergeBlocks(Block *source, Block *dest,
source->erase();
}

/// Find uses of `from` and replace it with `to`
void RewriterBase::replaceAllUsesWith(Value from, Value to) {
for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
Operation *op = operand.getOwner();
updateRootInPlace(op, [&]() { operand.set(to); });
}
}

/// Find uses of `from` and replace them with `to` if the `functor` returns
/// true. It also marks every modified uses and notifies the rewriter that an
/// in-place operation modification is about to happen.
Expand Down
25 changes: 25 additions & 0 deletions mlir/test/Transforms/test-strict-pattern-driver.mlir
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
// RUN: mlir-opt \
// RUN: -test-strict-pattern-driver="strictness=AnyOp" \
// RUN: --split-input-file %s | FileCheck %s --check-prefix=CHECK-AN

// RUN: mlir-opt \
// RUN: -test-strict-pattern-driver="strictness=ExistingAndNewOps" \
// RUN: --split-input-file %s | FileCheck %s --check-prefix=CHECK-EN
Expand Down Expand Up @@ -58,3 +62,24 @@ func.func @test_replace_with_erase_op() {
"test.replace_with_new_op"() {create_erase_op} : () -> ()
return
}

// -----

// CHECK-AN-LABEL: func @test_trigger_rewrite_through_block
// CHECK-AN: "test.change_block_op"()[^[[BB0:.*]], ^[[BB0]]]
// CHECK-AN: return
// CHECK-AN: ^[[BB1:[^:]*]]:
// CHECK-AN: "test.implicit_change_op"()[^[[BB1]]]
func.func @test_trigger_rewrite_through_block() {
return
^bb1:
// Uses bb1. ChangeBlockOp replaces that and all other usages of bb1 with bb2.
"test.change_block_op"() [^bb1, ^bb2] : () -> ()
^bb2:
return
^bb3:
// Also uses bb1. ChangeBlockOp replaces that usage with bb2. This triggers
// this op being put on the worklist, which triggers ImplicitChangeOp, which,
// in turn, replaces the successor with bb3.
"test.implicit_change_op"() [^bb1] : () -> ()
}
63 changes: 60 additions & 3 deletions mlir/test/lib/Dialect/Test/TestPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,11 +256,19 @@ struct TestStrictPatternDriver
void runOnOperation() override {
MLIRContext *ctx = &getContext();
mlir::RewritePatternSet patterns(ctx);
patterns.add<InsertSameOp, ReplaceWithNewOp, EraseOp>(ctx);
patterns.add<
// clang-format off
InsertSameOp,
ReplaceWithNewOp,
EraseOp,
ChangeBlockOp,
ImplicitChangeOp
// clang-format on
>(ctx);
SmallVector<Operation *> ops;
getOperation()->walk([&](Operation *op) {
StringRef opName = op->getName().getStringRef();
if (opName == "test.insert_same_op" ||
if (opName == "test.insert_same_op" || opName == "test.change_block_op" ||
opName == "test.replace_with_new_op" || opName == "test.erase_op") {
ops.push_back(op);
}
Expand Down Expand Up @@ -342,7 +350,7 @@ struct TestStrictPatternDriver
}
};

// Remove an operation may introduce the re-visiting of its opreands.
// Remove an operation may introduce the re-visiting of its operands.
class EraseOp : public RewritePattern {
public:
EraseOp(MLIRContext *context)
Expand All @@ -353,6 +361,55 @@ struct TestStrictPatternDriver
return success();
}
};

// The following two patterns test RewriterBase::replaceAllUsesWith.
//
// That function replaces all usages of a Block (or a Value) with another one
// *and tracks these changes in the rewriter.* The GreedyPatternRewriteDriver
// with GreedyRewriteStrictness::AnyOp uses that tracking to construct its
// worklist: when an op is modified, it is added to the worklist. The two
// patterns below make the tracking observable: ChangeBlockOp replaces all
// usages of a block and that pattern is applied because the corresponding ops
// are put on the initial worklist (see above). ImplicitChangeOp does an
// unrelated change but ops of the corresponding type are *not* on the initial
// worklist, so the effect of the second pattern is only visible if the
// tracking and subsequent adding to the worklist actually works.

// Replace all usages of the first successor with the second successor.
class ChangeBlockOp : public RewritePattern {
public:
ChangeBlockOp(MLIRContext *context)
: RewritePattern("test.change_block_op", /*benefit=*/1, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
if (op->getNumSuccessors() < 2)
return failure();
Block *firstSuccessor = op->getSuccessor(0);
Block *secondSuccessor = op->getSuccessor(1);
if (firstSuccessor == secondSuccessor)
return failure();
// This is the function being tested:
rewriter.replaceAllUsesWith(firstSuccessor, secondSuccessor);
// Using the following line instead would make the test fail:
// firstSuccessor->replaceAllUsesWith(secondSuccessor);
return success();
}
};

// Changes the successor to the parent block.
class ImplicitChangeOp : public RewritePattern {
public:
ImplicitChangeOp(MLIRContext *context)
: RewritePattern("test.implicit_change_op", /*benefit=*/1, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
if (op->getNumSuccessors() < 1 || op->getSuccessor(0) == op->getBlock())
return failure();
rewriter.updateRootInPlace(
op, [&]() { op->setSuccessor(op->getBlock(), 0); });
return success();
}
};
};

} // namespace
Expand Down

0 comments on commit 4bba8bd

Please sign in to comment.