diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td index 28b4a01cf0ecd..55addfdb693e4 100644 --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -248,6 +248,7 @@ def RemoveDeadValues : Pass<"remove-dead-values"> { ``` }]; let constructor = "mlir::createRemoveDeadValuesPass()"; + let dependentDialects = ["ub::UBDialect"]; } def PrintIRPass : Pass<"print-ir"> { diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt index 54b67f5c7a91e..06161293e907f 100644 --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -39,4 +39,5 @@ add_mlir_library(MLIRTransforms MLIRSideEffectInterfaces MLIRSupport MLIRTransformUtils + MLIRUBDialect ) diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp index 989c614ef6617..e9ced064c9884 100644 --- a/mlir/lib/Transforms/RemoveDeadValues.cpp +++ b/mlir/lib/Transforms/RemoveDeadValues.cpp @@ -33,6 +33,7 @@ #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" #include "mlir/Analysis/DataFlow/LivenessAnalysis.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Dialect.h" @@ -260,6 +261,22 @@ static SmallVector operandsToOpOperands(OperandRange operands) { static void processSimpleOp(Operation *op, RunLivenessAnalysis &la, DenseSet &nonLiveSet, RDVFinalCleanupList &cl) { + // Operations that have dead operands can be erased regardless of their + // side effects. The liveness analysis would not have marked an SSA value as + // "dead" if it had a side-effecting user that is reachable. + bool hasDeadOperand = + markLives(op->getOperands(), nonLiveSet, la).flip().any(); + if (hasDeadOperand) { + LDBG() << "Simple op has dead operands, so the op must be dead: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); + assert(!hasLive(op->getResults(), nonLiveSet, la) && + "expected the op to have no live results"); + cl.operations.push_back(op); + collectNonLiveValues(nonLiveSet, op->getResults(), + BitVector(op->getNumResults(), true)); + return; + } + if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, la)) { LDBG() << "Simple op is not memory effect free or has live results, " "preserving it: " @@ -361,6 +378,8 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module, // block other than the entry block, because every block has a terminator. for (Block &block : funcOp.getBlocks()) { Operation *returnOp = block.getTerminator(); + if (!returnOp->hasTrait()) + continue; if (returnOp && returnOp->getNumOperands() == numReturns) cl.operands.push_back({returnOp, nonLiveRets}); } @@ -700,7 +719,11 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, } /// Steps to process a `BranchOpInterface` operation: -/// Iterate through each successor block of `branchOp`. +/// +/// When a non-forwarded operand is dead (e.g., the condition value of a +/// conditional branch op), the entire operation is dead. +/// +/// Otherwise, iterate through each successor block of `branchOp`. /// (1) For each successor block, gather all operands from all successors. /// (2) Fetch their associated liveness analysis data and collect for future /// removal. @@ -711,7 +734,22 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la, DenseSet &nonLiveSet, RDVFinalCleanupList &cl) { LDBG() << "Processing branch op: " << *branchOp; + + // Check for dead non-forwarded operands. + BitVector deadNonForwardedOperands = + markLives(branchOp->getOperands(), nonLiveSet, la).flip(); unsigned numSuccessors = branchOp->getNumSuccessors(); + for (unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) { + SuccessorOperands successorOperands = + branchOp.getSuccessorOperands(succIdx); + // Remove all non-forwarded operands from the bit vector. + for (OpOperand &opOperand : successorOperands.getMutableForwardedOperands()) + deadNonForwardedOperands[opOperand.getOperandNumber()] = false; + } + if (deadNonForwardedOperands.any()) { + cl.operations.push_back(branchOp.getOperation()); + return; + } for (unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) { Block *successorBlock = branchOp->getSuccessor(succIdx); @@ -786,9 +824,14 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) { // 3. Operations LDBG() << "Cleaning up " << list.operations.size() << " operations"; - for (auto &op : list.operations) { + for (Operation *op : list.operations) { LDBG() << "Erasing operation: " << OpWithFlags(op, OpPrintingFlags().skipRegions()); + if (op->hasTrait()) { + // When erasing a terminator, insert an unreachable op in its place. + OpBuilder b(op); + ub::UnreachableOp::create(b, op->getLoc()); + } op->dropAllUses(); op->erase(); } diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir index 4bae85dcf4f7d..71306676d48e9 100644 --- a/mlir/test/Transforms/remove-dead-values.mlir +++ b/mlir/test/Transforms/remove-dead-values.mlir @@ -118,6 +118,17 @@ func.func @main(%arg0 : i32) { // ----- +// CHECK-LABEL: func.func private @clean_func_op_remove_side_effecting_op() { +// CHECK-NEXT: return +// CHECK-NEXT: } +func.func private @clean_func_op_remove_side_effecting_op(%arg0: i32) -> (i32) { + // vector.print has a side effect but the op is dead. + vector.print %arg0 : i32 + return %arg0 : i32 +} + +// ----- + // %arg0 is not live because it is never used. %arg1 is not live because its // user `arith.addi` doesn't have any uses and the value that it is forwarded to // (%non_live_0) also doesn't have any uses. @@ -687,3 +698,19 @@ func.func @op_block_have_dead_arg(%arg0: index, %arg1: index, %arg2: i1) { // CHECK-NEXT: return return } + +// ----- + +// CHECK-LABEL: func private @remove_dead_branch_op() +// CHECK-NEXT: ub.unreachable +// CHECK-NEXT: ^{{.*}}: +// CHECK-NEXT: return +// CHECK-NEXT: ^{{.*}}: +// CHECK-NEXT: return +func.func private @remove_dead_branch_op(%c: i1, %arg0: i64, %arg1: i64) -> (i64) { + cf.cond_br %c, ^bb1, ^bb2 +^bb1: + return %arg0 : i64 +^bb2: + return %arg1 : i64 +}