-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][Transforms] Add dead code elimination pass #106258
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-core Author: Matthias Springer (matthias-springer) ChangesIn the absence of a dedicated DCE pass, MLIR users sometimes resort to the canonicalizer pass to remove dead IR. The canonicalizer pass is quite expensive to run. This PR adds a lightweight dead code elimination pass that removes dead operation and dead blocks. The pass performs 3 walks over the input IR.
Full diff: https://github.com/llvm/llvm-project/pull/106258.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 8e4a43c3f24586..d03b879f405af7 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -33,6 +33,7 @@ class GreedyRewriteConfig;
#define GEN_PASS_DECL_CANONICALIZER
#define GEN_PASS_DECL_CONTROLFLOWSINK
#define GEN_PASS_DECL_CSEPASS
+#define GEN_PASS_DECL_DEADCODEELIMINATION
#define GEN_PASS_DECL_INLINER
#define GEN_PASS_DECL_LOOPINVARIANTCODEMOTION
#define GEN_PASS_DECL_MEM2REG
@@ -111,6 +112,9 @@ std::unique_ptr<Pass>
createInlinerPass(llvm::StringMap<OpPassManager> opPipelines,
std::function<void(OpPassManager &)> defaultPipelineBuilder);
+/// Creates an optimization pass to remove dead operations and blocks.
+std::unique_ptr<Pass> createDeadCodeEliminationPass();
+
/// Creates an optimization pass to remove dead values.
std::unique_ptr<Pass> createRemoveDeadValuesPass();
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 000d9f697618e6..b10b1be74d0fcb 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -93,6 +93,20 @@ def CSE : Pass<"cse"> {
];
}
+def DeadCodeElimination : Pass<"dce"> {
+ let summary = "Remove dead operations and blocks";
+ let description = [{
+ This pass eliminates dead operations and blocks.
+
+ Operations are eliminated if they have no users and no side effects. Blocks
+ are eliminated if they are not reachable.
+
+ Note: Graph regions are currently not supported and skipped by this pass.
+ }];
+
+ let constructor = "mlir::createDeadCodeEliminationPass()";
+}
+
def RemoveDeadValues : Pass<"remove-dead-values"> {
let summary = "Remove dead values";
let description = [{
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 90c0298fb5e46a..4b90774f972ced 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -5,6 +5,7 @@ add_mlir_library(MLIRTransforms
CompositePass.cpp
ControlFlowSink.cpp
CSE.cpp
+ DeadCodeElimination.cpp
GenerateRuntimeVerification.cpp
InlinerPass.cpp
LocationSnapshot.cpp
diff --git a/mlir/lib/Transforms/DeadCodeElimination.cpp b/mlir/lib/Transforms/DeadCodeElimination.cpp
new file mode 100644
index 00000000000000..33a12b84daa46f
--- /dev/null
+++ b/mlir/lib/Transforms/DeadCodeElimination.cpp
@@ -0,0 +1,75 @@
+//===- DeadCodeElimination.cpp - Dead Code Elimination --------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Iterators.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/Passes.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_DEADCODEELIMINATION
+#include "mlir/Transforms/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+struct DeadCodeElimination
+ : public impl::DeadCodeEliminationBase<DeadCodeElimination> {
+ void runOnOperation() override;
+};
+} // namespace
+
+void DeadCodeElimination::runOnOperation() {
+ Operation *topLevel = getOperation();
+
+ // Visit operations in reverse dominance order. This visits all users before
+ // their definitions. (Also takes into account unstructured control flow
+ // between blocks.)
+ topLevel->walk<WalkOrder::PostOrder,
+ ReverseDominanceIterator</*NoGraphRegions=*/false>>(
+ [&](Operation *op) {
+ // Do not remove the top-level op.
+ if (op == topLevel)
+ return WalkResult::advance();
+
+ // Do not remove ops from regions that may be graph regions.
+ if (mayBeGraphRegion(*op->getParentRegion()))
+ return WalkResult::advance();
+
+ // Remove dead ops.
+ if (isOpTriviallyDead(op)) {
+ op->erase();
+ return WalkResult::skip();
+ }
+
+ return WalkResult::advance();
+ });
+
+ // ReverseDominanceIterator does not visit unreachable blocks. Erase those in
+ // a second walk. First collect all reachable blocks.
+ // TODO: Extend walker API to provide a callback for both ops and blocks, so
+ // that reachable blocks can be collected in the same walk.
+ DenseSet<Block *> reachableBlocks;
+ topLevel->walk<WalkOrder::PostOrder,
+ ForwardDominanceIterator</*NoGraphRegions=*/false>>(
+ [&](Block *block) { reachableBlocks.insert(block); });
+ // Erase all blocks that were not visited. These are unreachable and thus
+ // dead.
+ topLevel->walk<WalkOrder::PostOrder>([&](Block *block) {
+ if (!reachableBlocks.contains(block)) {
+ block->dropAllDefinedValueUses();
+ block->erase();
+ }
+ });
+}
+
+std::unique_ptr<Pass> mlir::createDeadCodeEliminationPass() {
+ return std::make_unique<DeadCodeElimination>();
+}
diff --git a/mlir/test/Transforms/dead-code-elimination.mlir b/mlir/test/Transforms/dead-code-elimination.mlir
new file mode 100644
index 00000000000000..67130bb3366d94
--- /dev/null
+++ b/mlir/test/Transforms/dead-code-elimination.mlir
@@ -0,0 +1,130 @@
+// RUN: mlir-opt -dead-code-elimination -split-input-file %s
+
+// CHECK-LABEL: func @simple_test(
+// CHECK-SAME: %[[arg0:.*]]: i16)
+// CHECK-NEXT: %[[c5:.*]] = arith.constant 5 : i16
+// CHECK-NEXT: %[[add:.*]] = arith.addi %[[c5]], %[[arg0]]
+// CHECK-NEXT: return %[[add]]
+func.func @simple_test(%arg0: i16) -> i16 {
+ %0 = arith.constant 5 : i16
+ %1 = arith.addi %0, %arg0 : i16
+ %2 = arith.addi %1, %1 : i16
+ %3 = arith.addi %2, %1 : i16
+ return %1 : i16
+}
+
+// -----
+
+// CHECK-LABEL: func @eliminate_from_region
+// CHECK-NEXT: scf.for {{.*}} {
+// CHECK-NEXT: arith.constant
+// CHECK-NEXT: "test.print"
+// CHECK-NEXT: }
+// CHECK-NEXT: return
+func.func @eliminate_from_region(%lb: index, %ub: index, %step: index) {
+ scf.for %iv = %lb to %ub step %step {
+ %0 = arith.constant 5 : i16
+ %1 = arith.constant 10 : i16
+ "test.print"(%0) : (i16) -> ()
+ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @eliminate_op_with_region
+// CHECK-NEXT: return
+func.func @eliminate_op_with_region(%lb: index, %ub: index, %step: index) {
+ %c0 = arith.constant 0 : i16
+ %0 = scf.for %iv = %lb to %ub step %step iter_args(%iter = %c0) -> i16 {
+ %0 = arith.constant 5 : i16
+ %added = arith.addi %iter, %0 : i16
+ scf.yield %added : i16
+ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @unstructured_control_flow(
+// CHECK-SAME: %[[arg0:.*]]: i16)
+// CHECK-NEXT: %[[c5:.*]] = arith.constant 5 : i16
+// CHECK-NEXT: cf.br ^[[bb2:.*]]
+// CHECK-NEXT: ^[[bb1:.*]]: // pred
+// CHECK-NEXT: cf.br ^[[bb3:.*]]
+// CHECK-NEXT: ^[[bb2]]:
+// CHECK-NEXT: %[[add:.*]] = arith.addi %[[c5]], %[[arg0]]
+// CHECK-NEXT: cf.br ^[[bb1]]
+// CHECK-NEXT: ^[[bb3]]:
+// CHECK-NEXT: return %[[add]]
+func.func @unstructured_control_flow(%arg0: i16) -> i16 {
+ %0 = arith.constant 5 : i16
+ cf.br ^bb2
+^bb1:
+ %3 = arith.addi %1, %1 : i16
+ %4 = arith.addi %3, %2 : i16
+ cf.br ^bb3
+^bb2:
+ %1 = arith.addi %0, %arg0 : i16
+ %2 = arith.subi %0, %arg0 : i16
+ cf.br ^bb1
+^bb3:
+ return %1 : i16
+}
+
+// -----
+
+// CHECK-LABEL: func @remove_dead_block()
+// CHECK-NEXT: cf.br ^[[bb2:.*]]
+// CHECK-NEXT: ^[[bb2]]:
+// CHECK-NEXT: return
+func.func @remove_dead_block() {
+ cf.br ^bb2
+^bb1:
+ %0 = arith.constant 0 : i16
+ cf.br ^bb2
+^bb2:
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @potentially_side_effecting_op()
+// CHECK-NEXT: "test.print"
+// CHECK-NEXT: return
+func.func @potentially_side_effecting_op() {
+ "test.print"() : () -> ()
+ return
+}
+
+// -----
+
+// Note: Graph regions are not supported and skipped.
+
+// CHECK-LABEL: test.graph_region {
+// CHECK-NEXT: arith.addi
+// CHECK-NEXT: arith.constant 5 : i16
+// CHECK-NEXT: "test.baz"
+// CHECK-NEXT: }
+test.graph_region {
+ %1 = arith.addi %0, %0 : i16
+ %0 = arith.constant 5 : i16
+ "test.baz"() : () -> i32
+}
+
+// -----
+
+// CHECK-LABEL: dead_blocks()
+// CHECK-NEXT: cf.br ^[[bb3:.*]]
+// CHECK-NEXT: ^[[bb3]]:
+// CHECK-NEXT: return
+func.func @dead_blocks() {
+ cf.br ^bb3
+^bb1:
+ "test.print"() : () -> ()
+ cf.br ^bb2
+^bb2:
+ cf.br ^bb1
+^bb3:
+ return
+}
|
void DeadCodeElimination::runOnOperation() { | ||
Operation *topLevel = getOperation(); | ||
|
||
// Visit operations in reverse dominance order. This visits all users before |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not call into eraseUnreachableBlocks/runRegionDCE? They should perform the same set of actions, and also handle things like dead arguments.
Operations are eliminated if they have no users and no side effects. Blocks | ||
are eliminated if they are not reachable. | ||
|
||
Note: Graph regions are currently not supported and skipped by this pass. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Graph regions are trivial to support though I believe, can we just add them? Completeness is nice :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think graph regions require a worklist based approach. The purpose of the commit was to provide a very fast DCE pass, so ideally I'd not even maintain a worklist if there are only SSA dominance regions. (A simple walk would suffice then.) Let me try to replace the walk with a custom IR traversal...
}); | ||
} | ||
|
||
std::unique_ptr<Pass> mlir::createDeadCodeEliminationPass() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be auto-generated if you remove the constructor field from the .td file.
What is the difference between the pass you propose and the existing |
f84ba23
to
b6d5e8f
Compare
I have to take a closer look at that pass, but it does not support unstructured control flow and seems quite heavyweight. I just wanted to erase dead ops, and for that you don't need any analysis and/or worklist. Just a single walk over the IR. (Removing dead blocks requires a second walk.) I have to check if it's possible to improve |
In general pro a lightweight pass here. I'd have expected folks to be able to use a greedy pattern rewriter without patterns as a "cheap" DCE too [but it does more than just DCE even without patterns as is being discussed in the other PR ;-)]. Do you perhaps have a nontrivial benchmark that you could compare times here? |
I'd personally love to have this pass for the purpose of being able to have an |
In the absence of a dedicated DCE pass, MLIR users sometimes resort to the canonicalizer pass to remove dead IR. The canonicalizer pass is quite expensive to run. This PR adds a lightweight dead code elimination pass that removes dead operation and dead blocks.
The pass performs 3 walks over the input IR.