From 0e4735546e6bbcfd5d11d0a6b8b68cb9ccad9b41 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Tue, 10 Jan 2023 15:30:49 +0100 Subject: [PATCH] [mlir] Fix worklist bug in MultiOpPatternRewriteDriver When `strict = true`, only pre-existing and newly-created ops are rewritten and/or folded. Such ops are stored in `strictModeFilteredOps`. Newly-created ops were previously added to `strictModeFilteredOps` after calling `addToWorklist` (via `GreedyPatternRewriteDriver::notifyOperationInserted`). Therefore, newly-created ops were never added to the worklist. Also fix a test case that should have gone into an infinite loop (`test.replace_with_new_op` was replaced with itself, which should have caused the op to be rewritten over and over), but did not due to this bug. Differential Revision: https://reviews.llvm.org/D141141 --- .../Utils/GreedyPatternRewriteDriver.cpp | 2 +- .../test-strict-pattern-driver.mlir | 26 +++++++++++++++---- mlir/test/lib/Dialect/Test/TestPatterns.cpp | 25 ++++++++++++------ 3 files changed, 39 insertions(+), 14 deletions(-) diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index 5005a08bc29bb8..cdb0b78c7a74ec 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -558,9 +558,9 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver { private: void notifyOperationInserted(Operation *op) override { - GreedyPatternRewriteDriver::notifyOperationInserted(op); if (strictMode) strictModeFilteredOps.insert(op); + GreedyPatternRewriteDriver::notifyOperationInserted(op); } void notifyOperationRemoved(Operation *op) override { diff --git a/mlir/test/Transforms/test-strict-pattern-driver.mlir b/mlir/test/Transforms/test-strict-pattern-driver.mlir index 51d296935a97b0..8c6eaf345d92de 100644 --- a/mlir/test/Transforms/test-strict-pattern-driver.mlir +++ b/mlir/test/Transforms/test-strict-pattern-driver.mlir @@ -1,6 +1,9 @@ // RUN: mlir-opt -allow-unregistered-dialect -test-strict-pattern-driver %s | FileCheck %s -// CHECK-LABEL: @test_erase +// CHECK-LABEL: func @test_erase +// CHECK: test.arg0 +// CHECK: test.arg1 +// CHECK-NOT: test.erase_op func.func @test_erase() { %0 = "test.arg0"() : () -> (i32) %1 = "test.arg1"() : () -> (i32) @@ -8,16 +11,29 @@ func.func @test_erase() { return } -// CHECK-LABEL: @test_insert_same_op +// CHECK-LABEL: func @test_insert_same_op +// CHECK: "test.insert_same_op"() {skip = true} +// CHECK: "test.insert_same_op"() {skip = true} func.func @test_insert_same_op() { %0 = "test.insert_same_op"() : () -> (i32) return } -// CHECK-LABEL: @test_replace_with_same_op -func.func @test_replace_with_same_op() { - %0 = "test.replace_with_same_op"() : () -> (i32) +// CHECK-LABEL: func @test_replace_with_new_op +// CHECK: %[[n:.*]] = "test.new_op" +// CHECK: "test.dummy_user"(%[[n]]) +// CHECK: "test.dummy_user"(%[[n]]) +func.func @test_replace_with_new_op() { + %0 = "test.replace_with_new_op"() : () -> (i32) %1 = "test.dummy_user"(%0) : (i32) -> (i32) %2 = "test.dummy_user"(%0) : (i32) -> (i32) return } + +// CHECK-LABEL: func @test_replace_with_erase_op +// CHECK-NOT: test.replace_with_new_op +// CHECK-NOT: test.erase_op +func.func @test_replace_with_erase_op() { + "test.replace_with_new_op"() {create_erase_op} : () -> () + return +} diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index 9b74e808506f19..2573f76deb691f 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -220,12 +220,12 @@ struct TestStrictPatternDriver void runOnOperation() override { mlir::RewritePatternSet patterns(&getContext()); - patterns.add(&getContext()); + patterns.add(&getContext()); SmallVector ops; getOperation()->walk([&](Operation *op) { StringRef opName = op->getName().getStringRef(); if (opName == "test.insert_same_op" || - opName == "test.replace_with_same_op" || opName == "test.erase_op") { + opName == "test.replace_with_new_op" || opName == "test.erase_op") { ops.push_back(op); } }); @@ -260,16 +260,25 @@ struct TestStrictPatternDriver }; // Replace an operation may introduce the re-visiting of its users. - class ReplaceWithSameOp : public RewritePattern { + class ReplaceWithNewOp : public RewritePattern { public: - ReplaceWithSameOp(MLIRContext *context) - : RewritePattern("test.replace_with_same_op", /*benefit=*/1, context) {} + ReplaceWithNewOp(MLIRContext *context) + : RewritePattern("test.replace_with_new_op", /*benefit=*/1, context) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - Operation *newOp = - rewriter.create(op->getLoc(), op->getName().getIdentifier(), - op->getOperands(), op->getResultTypes()); + Operation *newOp; + if (op->hasAttr("create_erase_op")) { + newOp = rewriter.create( + op->getLoc(), + OperationName("test.erase_op", op->getContext()).getIdentifier(), + ValueRange(), TypeRange()); + } else { + newOp = rewriter.create( + op->getLoc(), + OperationName("test.new_op", op->getContext()).getIdentifier(), + op->getOperands(), op->getResultTypes()); + } rewriter.replaceOp(op, newOp->getResults()); return success(); }