Skip to content

Commit

Permalink
[mlir] Fix worklist bug in MultiOpPatternRewriteDriver
Browse files Browse the repository at this point in the history
When `strict = true`, only pre-existing and newly-created ops are rewritten and/or folded. Such ops are stored in `strictModeFilteredOps`.

Newly-created ops were previously added to `strictModeFilteredOps` after calling `addToWorklist` (via `GreedyPatternRewriteDriver::notifyOperationInserted`). Therefore, newly-created ops were never added to the worklist.

Also fix a test case that should have gone into an infinite loop (`test.replace_with_new_op` was replaced with itself, which should have caused the op to be rewritten over and over), but did not due to this bug.

Differential Revision: https://reviews.llvm.org/D141141
  • Loading branch information
matthias-springer committed Jan 10, 2023
1 parent 089a544 commit 0e47355
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 14 deletions.
2 changes: 1 addition & 1 deletion mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -558,9 +558,9 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {

private:
void notifyOperationInserted(Operation *op) override {
GreedyPatternRewriteDriver::notifyOperationInserted(op);
if (strictMode)
strictModeFilteredOps.insert(op);
GreedyPatternRewriteDriver::notifyOperationInserted(op);
}

void notifyOperationRemoved(Operation *op) override {
Expand Down
26 changes: 21 additions & 5 deletions mlir/test/Transforms/test-strict-pattern-driver.mlir
Original file line number Diff line number Diff line change
@@ -1,23 +1,39 @@
// RUN: mlir-opt -allow-unregistered-dialect -test-strict-pattern-driver %s | FileCheck %s

// CHECK-LABEL: @test_erase
// CHECK-LABEL: func @test_erase
// CHECK: test.arg0
// CHECK: test.arg1
// CHECK-NOT: test.erase_op
func.func @test_erase() {
%0 = "test.arg0"() : () -> (i32)
%1 = "test.arg1"() : () -> (i32)
%erase = "test.erase_op"(%0, %1) : (i32, i32) -> (i32)
return
}

// CHECK-LABEL: @test_insert_same_op
// CHECK-LABEL: func @test_insert_same_op
// CHECK: "test.insert_same_op"() {skip = true}
// CHECK: "test.insert_same_op"() {skip = true}
func.func @test_insert_same_op() {
%0 = "test.insert_same_op"() : () -> (i32)
return
}

// CHECK-LABEL: @test_replace_with_same_op
func.func @test_replace_with_same_op() {
%0 = "test.replace_with_same_op"() : () -> (i32)
// CHECK-LABEL: func @test_replace_with_new_op
// CHECK: %[[n:.*]] = "test.new_op"
// CHECK: "test.dummy_user"(%[[n]])
// CHECK: "test.dummy_user"(%[[n]])
func.func @test_replace_with_new_op() {
%0 = "test.replace_with_new_op"() : () -> (i32)
%1 = "test.dummy_user"(%0) : (i32) -> (i32)
%2 = "test.dummy_user"(%0) : (i32) -> (i32)
return
}

// CHECK-LABEL: func @test_replace_with_erase_op
// CHECK-NOT: test.replace_with_new_op
// CHECK-NOT: test.erase_op
func.func @test_replace_with_erase_op() {
"test.replace_with_new_op"() {create_erase_op} : () -> ()
return
}
25 changes: 17 additions & 8 deletions mlir/test/lib/Dialect/Test/TestPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,12 +220,12 @@ struct TestStrictPatternDriver

void runOnOperation() override {
mlir::RewritePatternSet patterns(&getContext());
patterns.add<InsertSameOp, ReplaceWithSameOp, EraseOp>(&getContext());
patterns.add<InsertSameOp, ReplaceWithNewOp, EraseOp>(&getContext());
SmallVector<Operation *> ops;
getOperation()->walk([&](Operation *op) {
StringRef opName = op->getName().getStringRef();
if (opName == "test.insert_same_op" ||
opName == "test.replace_with_same_op" || opName == "test.erase_op") {
opName == "test.replace_with_new_op" || opName == "test.erase_op") {
ops.push_back(op);
}
});
Expand Down Expand Up @@ -260,16 +260,25 @@ struct TestStrictPatternDriver
};

// Replace an operation may introduce the re-visiting of its users.
class ReplaceWithSameOp : public RewritePattern {
class ReplaceWithNewOp : public RewritePattern {
public:
ReplaceWithSameOp(MLIRContext *context)
: RewritePattern("test.replace_with_same_op", /*benefit=*/1, context) {}
ReplaceWithNewOp(MLIRContext *context)
: RewritePattern("test.replace_with_new_op", /*benefit=*/1, context) {}

LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
Operation *newOp =
rewriter.create(op->getLoc(), op->getName().getIdentifier(),
op->getOperands(), op->getResultTypes());
Operation *newOp;
if (op->hasAttr("create_erase_op")) {
newOp = rewriter.create(
op->getLoc(),
OperationName("test.erase_op", op->getContext()).getIdentifier(),
ValueRange(), TypeRange());
} else {
newOp = rewriter.create(
op->getLoc(),
OperationName("test.new_op", op->getContext()).getIdentifier(),
op->getOperands(), op->getResultTypes());
}
rewriter.replaceOp(op, newOp->getResults());
return success();
}
Expand Down

0 comments on commit 0e47355

Please sign in to comment.