Skip to content

Conversation

matthias-springer
Copy link
Member

In the absence of a dedicated DCE pass, MLIR users sometimes resort to the canonicalizer pass to remove dead IR. The canonicalizer pass is quite expensive to run. This PR adds a lightweight dead code elimination pass that removes dead operation and dead blocks.

The pass performs 3 walks over the input IR.

  1. Visit all operations in reverse dominance order and remove dead ops.
  2. Visit all blocks in forward dominance order. This walk enumerates only reachable blocks. Collect all reachable blocks. (This walk could be combined with Step 1, but that would require an extension of the walker API.)
  3. Visit all blocks and erase the ones that were not visited in Step 2.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Aug 27, 2024
@llvmbot
Copy link
Member

llvmbot commented Aug 27, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

Changes

In the absence of a dedicated DCE pass, MLIR users sometimes resort to the canonicalizer pass to remove dead IR. The canonicalizer pass is quite expensive to run. This PR adds a lightweight dead code elimination pass that removes dead operation and dead blocks.

The pass performs 3 walks over the input IR.

  1. Visit all operations in reverse dominance order and remove dead ops.
  2. Visit all blocks in forward dominance order. This walk enumerates only reachable blocks. Collect all reachable blocks. (This walk could be combined with Step 1, but that would require an extension of the walker API.)
  3. Visit all blocks and erase the ones that were not visited in Step 2.

Full diff: https://github.com/llvm/llvm-project/pull/106258.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Transforms/Passes.h (+4)
  • (modified) mlir/include/mlir/Transforms/Passes.td (+14)
  • (modified) mlir/lib/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Transforms/DeadCodeElimination.cpp (+75)
  • (added) mlir/test/Transforms/dead-code-elimination.mlir (+130)
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 8e4a43c3f24586..d03b879f405af7 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -33,6 +33,7 @@ class GreedyRewriteConfig;
 #define GEN_PASS_DECL_CANONICALIZER
 #define GEN_PASS_DECL_CONTROLFLOWSINK
 #define GEN_PASS_DECL_CSEPASS
+#define GEN_PASS_DECL_DEADCODEELIMINATION
 #define GEN_PASS_DECL_INLINER
 #define GEN_PASS_DECL_LOOPINVARIANTCODEMOTION
 #define GEN_PASS_DECL_MEM2REG
@@ -111,6 +112,9 @@ std::unique_ptr<Pass>
 createInlinerPass(llvm::StringMap<OpPassManager> opPipelines,
                   std::function<void(OpPassManager &)> defaultPipelineBuilder);
 
+/// Creates an optimization pass to remove dead operations and blocks.
+std::unique_ptr<Pass> createDeadCodeEliminationPass();
+
 /// Creates an optimization pass to remove dead values.
 std::unique_ptr<Pass> createRemoveDeadValuesPass();
 
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 000d9f697618e6..b10b1be74d0fcb 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -93,6 +93,20 @@ def CSE : Pass<"cse"> {
   ];
 }
 
+def DeadCodeElimination : Pass<"dce"> {
+  let summary = "Remove dead operations and blocks";
+  let description = [{
+    This pass eliminates dead operations and blocks.
+
+    Operations are eliminated if they have no users and no side effects. Blocks
+    are eliminated if they are not reachable.
+
+    Note: Graph regions are currently not supported and skipped by this pass.
+  }];
+
+  let constructor = "mlir::createDeadCodeEliminationPass()";
+}
+
 def RemoveDeadValues : Pass<"remove-dead-values"> {
   let summary = "Remove dead values";
   let description = [{
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 90c0298fb5e46a..4b90774f972ced 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -5,6 +5,7 @@ add_mlir_library(MLIRTransforms
   CompositePass.cpp
   ControlFlowSink.cpp
   CSE.cpp
+  DeadCodeElimination.cpp
   GenerateRuntimeVerification.cpp
   InlinerPass.cpp
   LocationSnapshot.cpp
diff --git a/mlir/lib/Transforms/DeadCodeElimination.cpp b/mlir/lib/Transforms/DeadCodeElimination.cpp
new file mode 100644
index 00000000000000..33a12b84daa46f
--- /dev/null
+++ b/mlir/lib/Transforms/DeadCodeElimination.cpp
@@ -0,0 +1,75 @@
+//===- DeadCodeElimination.cpp - Dead Code Elimination --------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Iterators.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/Passes.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_DEADCODEELIMINATION
+#include "mlir/Transforms/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+struct DeadCodeElimination
+    : public impl::DeadCodeEliminationBase<DeadCodeElimination> {
+  void runOnOperation() override;
+};
+} // namespace
+
+void DeadCodeElimination::runOnOperation() {
+  Operation *topLevel = getOperation();
+
+  // Visit operations in reverse dominance order. This visits all users before
+  // their definitions. (Also takes into account unstructured control flow
+  // between blocks.)
+  topLevel->walk<WalkOrder::PostOrder,
+                 ReverseDominanceIterator</*NoGraphRegions=*/false>>(
+      [&](Operation *op) {
+        // Do not remove the top-level op.
+        if (op == topLevel)
+          return WalkResult::advance();
+
+        // Do not remove ops from regions that may be graph regions.
+        if (mayBeGraphRegion(*op->getParentRegion()))
+          return WalkResult::advance();
+
+        // Remove dead ops.
+        if (isOpTriviallyDead(op)) {
+          op->erase();
+          return WalkResult::skip();
+        }
+
+        return WalkResult::advance();
+      });
+
+  // ReverseDominanceIterator does not visit unreachable blocks. Erase those in
+  // a second walk. First collect all reachable blocks.
+  // TODO: Extend walker API to provide a callback for both ops and blocks, so
+  // that reachable blocks can be collected in the same walk.
+  DenseSet<Block *> reachableBlocks;
+  topLevel->walk<WalkOrder::PostOrder,
+                 ForwardDominanceIterator</*NoGraphRegions=*/false>>(
+      [&](Block *block) { reachableBlocks.insert(block); });
+  // Erase all blocks that were not visited. These are unreachable and thus
+  // dead.
+  topLevel->walk<WalkOrder::PostOrder>([&](Block *block) {
+    if (!reachableBlocks.contains(block)) {
+      block->dropAllDefinedValueUses();
+      block->erase();
+    }
+  });
+}
+
+std::unique_ptr<Pass> mlir::createDeadCodeEliminationPass() {
+  return std::make_unique<DeadCodeElimination>();
+}
diff --git a/mlir/test/Transforms/dead-code-elimination.mlir b/mlir/test/Transforms/dead-code-elimination.mlir
new file mode 100644
index 00000000000000..67130bb3366d94
--- /dev/null
+++ b/mlir/test/Transforms/dead-code-elimination.mlir
@@ -0,0 +1,130 @@
+// RUN: mlir-opt -dead-code-elimination -split-input-file %s
+
+// CHECK-LABEL: func @simple_test(
+//  CHECK-SAME:     %[[arg0:.*]]: i16)
+//  CHECK-NEXT:   %[[c5:.*]] = arith.constant 5 : i16
+//  CHECK-NEXT:   %[[add:.*]] = arith.addi %[[c5]], %[[arg0]]
+//  CHECK-NEXT:   return %[[add]]
+func.func @simple_test(%arg0: i16) -> i16 {
+  %0 = arith.constant 5 : i16
+  %1 = arith.addi %0, %arg0 : i16
+  %2 = arith.addi %1, %1 : i16
+  %3 = arith.addi %2, %1 : i16
+  return %1 : i16
+}
+
+// -----
+
+// CHECK-LABEL: func @eliminate_from_region
+//  CHECK-NEXT:   scf.for {{.*}} {
+//  CHECK-NEXT:     arith.constant
+//  CHECK-NEXT:     "test.print"
+//  CHECK-NEXT:   }
+//  CHECK-NEXT:   return
+func.func @eliminate_from_region(%lb: index, %ub: index, %step: index) {
+  scf.for %iv = %lb to %ub step %step {
+    %0 = arith.constant 5 : i16
+    %1 = arith.constant 10 : i16
+    "test.print"(%0) : (i16) -> ()
+  }
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @eliminate_op_with_region
+//  CHECK-NEXT:   return
+func.func @eliminate_op_with_region(%lb: index, %ub: index, %step: index) {
+  %c0 = arith.constant 0 : i16
+  %0 = scf.for %iv = %lb to %ub step %step iter_args(%iter = %c0) -> i16 {
+    %0 = arith.constant 5 : i16
+    %added = arith.addi %iter, %0 : i16
+    scf.yield %added : i16
+  }
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @unstructured_control_flow(
+//  CHECK-SAME:     %[[arg0:.*]]: i16)
+//  CHECK-NEXT:   %[[c5:.*]] = arith.constant 5 : i16
+//  CHECK-NEXT:   cf.br ^[[bb2:.*]]
+//  CHECK-NEXT: ^[[bb1:.*]]:  // pred
+//  CHECK-NEXT:   cf.br ^[[bb3:.*]]
+//  CHECK-NEXT: ^[[bb2]]:
+//  CHECK-NEXT:   %[[add:.*]] = arith.addi %[[c5]], %[[arg0]]
+//  CHECK-NEXT:   cf.br ^[[bb1]]
+//  CHECK-NEXT: ^[[bb3]]:
+//  CHECK-NEXT:   return %[[add]]
+func.func @unstructured_control_flow(%arg0: i16) -> i16 {
+  %0 = arith.constant 5 : i16
+  cf.br ^bb2
+^bb1:
+  %3 = arith.addi %1, %1 : i16
+  %4 = arith.addi %3, %2 : i16
+  cf.br ^bb3
+^bb2:
+  %1 = arith.addi %0, %arg0 : i16
+  %2 = arith.subi %0, %arg0 : i16
+  cf.br ^bb1
+^bb3:
+  return %1 : i16
+}
+
+// -----
+
+// CHECK-LABEL: func @remove_dead_block()
+//  CHECK-NEXT:   cf.br ^[[bb2:.*]]
+//  CHECK-NEXT: ^[[bb2]]:
+//  CHECK-NEXT:   return
+func.func @remove_dead_block() {
+  cf.br ^bb2
+^bb1:
+  %0 = arith.constant 0 : i16
+  cf.br ^bb2
+^bb2:
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @potentially_side_effecting_op()
+//  CHECK-NEXT:   "test.print"
+//  CHECK-NEXT:   return
+func.func @potentially_side_effecting_op() {
+  "test.print"() : () -> ()
+  return
+}
+
+// -----
+
+// Note: Graph regions are not supported and skipped.
+
+// CHECK-LABEL: test.graph_region {
+//  CHECK-NEXT:   arith.addi
+//  CHECK-NEXT:   arith.constant 5 : i16
+//  CHECK-NEXT:   "test.baz"
+//  CHECK-NEXT: }
+test.graph_region {
+  %1 = arith.addi %0, %0 : i16
+  %0 = arith.constant 5 : i16
+  "test.baz"() : () -> i32
+}
+
+// -----
+
+// CHECK-LABEL: dead_blocks()
+//  CHECK-NEXT:   cf.br ^[[bb3:.*]]
+//  CHECK-NEXT: ^[[bb3]]:
+//  CHECK-NEXT:   return
+func.func @dead_blocks() {
+  cf.br ^bb3
+^bb1:
+  "test.print"() : () -> ()
+  cf.br ^bb2
+^bb2:
+  cf.br ^bb1
+^bb3:
+  return
+}

void DeadCodeElimination::runOnOperation() {
Operation *topLevel = getOperation();

// Visit operations in reverse dominance order. This visits all users before
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not call into eraseUnreachableBlocks/runRegionDCE? They should perform the same set of actions, and also handle things like dead arguments.

Operations are eliminated if they have no users and no side effects. Blocks
are eliminated if they are not reachable.

Note: Graph regions are currently not supported and skipped by this pass.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Graph regions are trivial to support though I believe, can we just add them? Completeness is nice :)

Copy link
Member Author

@matthias-springer matthias-springer Aug 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think graph regions require a worklist based approach. The purpose of the commit was to provide a very fast DCE pass, so ideally I'd not even maintain a worklist if there are only SSA dominance regions. (A simple walk would suffice then.) Let me try to replace the walk with a custom IR traversal...

});
}

std::unique_ptr<Pass> mlir::createDeadCodeEliminationPass() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be auto-generated if you remove the constructor field from the .td file.

@AzizZayed
Copy link

What is the difference between the pass you propose and the existing --remove-dead-values pass here?

@matthias-springer matthias-springer force-pushed the users/matthias-springer/dce_pass branch from f84ba23 to b6d5e8f Compare August 28, 2024 18:35
@matthias-springer
Copy link
Member Author

matthias-springer commented Aug 28, 2024

What is the difference between the pass you propose and the existing --remove-dead-values pass here?

I have to take a closer look at that pass, but it does not support unstructured control flow and seems quite heavyweight. I just wanted to erase dead ops, and for that you don't need any analysis and/or worklist. Just a single walk over the IR. (Removing dead blocks requires a second walk.)

I have to check if it's possible to improve RemoveDeadValues.

@jpienaar
Copy link
Member

In general pro a lightweight pass here. I'd have expected folks to be able to use a greedy pattern rewriter without patterns as a "cheap" DCE too [but it does more than just DCE even without patterns as is being discussed in the other PR ;-)]. Do you perhaps have a nontrivial benchmark that you could compare times here?

@zero9178
Copy link
Member

I'd personally love to have this pass for the purpose of being able to have an O0 pipeline that still satisfies the preconditions of dialect conversion, which requires no blocks with no predecessors to be in the IR. Alternatively one could reconsider lifting this precondition, but I'd consider simple DCE like this to a good fit for an O0 pipeline either way.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants