diff --git a/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h index cf1fd6e2d48ca..be7e027b95f64 100644 --- a/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h @@ -102,6 +102,9 @@ struct RunLivenessAnalysis { const Liveness *getLiveness(Value val); + /// Return the configuration of the solver used for this analysis. + const DataFlowConfig &getSolverConfig() const { return solver.getConfig(); } + private: /// Stores the result of the liveness analysis that was run. DataFlowSolver solver; diff --git a/mlir/include/mlir/IR/Visitors.h b/mlir/include/mlir/IR/Visitors.h index 893f66ae33deb..632db0a7b5cd4 100644 --- a/mlir/include/mlir/IR/Visitors.h +++ b/mlir/include/mlir/IR/Visitors.h @@ -39,6 +39,20 @@ struct ForwardIterator { } }; +/// This iterator enumerates the elements in "backward" order. +struct BackwardIterator { + template + static auto makeIterable(T &range) { + if constexpr (std::is_same()) { + /// Make operations iterable: return the list of regions. + return range.getRegions(); + } else { + /// Regions and block are already iterable. + return llvm::reverse(range); + } + } +}; + /// A utility class to encode the current walk stage for "generic" walkers. /// When walking an operation, we can either choose a Pre/Post order walker /// which invokes the callback on an operation before/after all its attached diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp index e0c65b0e09774..ba2c56e5d9661 100644 --- a/mlir/lib/Transforms/RemoveDeadValues.cpp +++ b/mlir/lib/Transforms/RemoveDeadValues.cpp @@ -33,6 +33,7 @@ #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" #include "mlir/Analysis/DataFlow/LivenessAnalysis.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Dialect.h" @@ -118,8 +119,13 @@ struct RDVFinalCleanupList { /// Return true iff at least one value in `values` is live, given the liveness /// information in `la`. static bool hasLive(ValueRange values, const DenseSet &nonLiveSet, - RunLivenessAnalysis &la) { + const DenseSet &liveSet, RunLivenessAnalysis &la) { for (Value value : values) { + if (liveSet.contains(value)) { + LDBG() << "Value " << value << " is marked live by CallOp"; + return true; + } + if (nonLiveSet.contains(value)) { LDBG() << "Value " << value << " is already marked non-live (dead)"; continue; @@ -144,6 +150,7 @@ static bool hasLive(ValueRange values, const DenseSet &nonLiveSet, /// Return a BitVector of size `values.size()` where its i-th bit is 1 iff the /// i-th value in `values` is live, given the liveness information in `la`. static BitVector markLives(ValueRange values, const DenseSet &nonLiveSet, + const DenseSet &liveSet, RunLivenessAnalysis &la) { BitVector lives(values.size(), true); @@ -154,7 +161,9 @@ static BitVector markLives(ValueRange values, const DenseSet &nonLiveSet, << " is already marked non-live (dead) at index " << index; continue; } - + if (liveSet.contains(value)) { + continue; + } const Liveness *liveness = la.getLiveness(value); // It is important to note that when `liveness` is null, we can't tell if // `value` is live or not. So, the safe option is to consider it live. Also, @@ -259,8 +268,9 @@ static SmallVector operandsToOpOperands(OperandRange operands) { /// - Return-like static void processSimpleOp(Operation *op, RunLivenessAnalysis &la, DenseSet &nonLiveSet, - RDVFinalCleanupList &cl) { - if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, la)) { + DenseSet &liveSet, RDVFinalCleanupList &cl) { + if (!isMemoryEffectFree(op) || + hasLive(op->getResults(), nonLiveSet, liveSet, la)) { LDBG() << "Simple op is not memory effect free or has live results, " "preserving it: " << OpWithFlags(op, OpPrintingFlags().skipRegions()); @@ -288,7 +298,7 @@ static void processSimpleOp(Operation *op, RunLivenessAnalysis &la, /// (6) Marking all its results as non-live values. static void processFuncOp(FunctionOpInterface funcOp, Operation *module, RunLivenessAnalysis &la, DenseSet &nonLiveSet, - RDVFinalCleanupList &cl) { + DenseSet &liveSet, RDVFinalCleanupList &cl) { LDBG() << "Processing function op: " << OpWithFlags(funcOp, OpPrintingFlags().skipRegions()); if (funcOp.isPublic() || funcOp.isExternal()) { @@ -299,7 +309,7 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module, // Get the list of unnecessary (non-live) arguments in `nonLiveArgs`. SmallVector arguments(funcOp.getArguments()); - BitVector nonLiveArgs = markLives(arguments, nonLiveSet, la); + BitVector nonLiveArgs = markLives(arguments, nonLiveSet, liveSet, la); nonLiveArgs = nonLiveArgs.flip(); // Do (1). @@ -352,7 +362,8 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module, for (SymbolTable::SymbolUse use : uses) { Operation *callOp = use.getUser(); assert(isa(callOp) && "expected a call-like user"); - BitVector liveCallRets = markLives(callOp->getResults(), nonLiveSet, la); + BitVector liveCallRets = + markLives(callOp->getResults(), nonLiveSet, liveSet, la); nonLiveRets &= liveCallRets.flip(); } @@ -379,6 +390,127 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module, } } +// Create a cheaper value with the same type of oldVal in front of CallOp. +static Value createDummyArgument(CallOpInterface callOp, Value oldVal) { + OpBuilder builder(callOp.getOperation()); + Type type = oldVal.getType(); + + // Create zero constant for any supported type + if (TypedAttr zeroAttr = builder.getZeroAttr(type)) { + return builder.create(oldVal.getLoc(), type, zeroAttr); + } + return {}; +} + +// When you mark a call operand as live, also mark its definition chain, recursively. +// We handle RegionBranchOpInterface here. I think we should handle BranchOpInterface as well. +void propagateBackward(Value val, DenseSet &liveSet) { + if (liveSet.contains(val)) return; + liveSet.insert(val); + + if (auto defOp = val.getDefiningOp()) { + // Mark operands of live results as live + for (Value operand : defOp->getOperands()) { + propagateBackward(operand, liveSet); + } + + // Handle RegionBranchOpInterface specially + if (auto regionBranchOp = dyn_cast(defOp)) { + // If this is a result of a RegionBranchOpInterface, we need to trace back + // through the control flow to find the sources that contribute to this result + + OpResult result = cast(val); + unsigned resultIndex = result.getResultNumber(); + + // Find all possible sources that can contribute to this result + // by examining all regions and their terminators + for (Region ®ion : regionBranchOp->getRegions()) { + if (region.empty()) continue; + + // Get the successors from this region + SmallVector successors; + regionBranchOp.getSuccessorRegions(RegionBranchPoint(®ion), successors); + + // Check if any successor can produce this result + for (const RegionSuccessor &successor : successors) { + if (successor.isParent()) { + // This region can return to the parent operation + ValueRange successorInputs = successor.getSuccessorInputs(); + if (resultIndex < successorInputs.size()) { + // Find the terminator that contributes to this result + Operation *terminator = region.back().getTerminator(); + if (auto regionBranchTerm = + dyn_cast(terminator)) { + OperandRange terminatorOperands = + regionBranchTerm.getSuccessorOperands(RegionBranchPoint::parent()); + if (resultIndex < terminatorOperands.size()) { + // This terminator operand contributes to our result + propagateBackward(terminatorOperands[resultIndex], liveSet); + } + } + } + } + } + + // Also mark region arguments as live if they might contribute to this result + // Find which operand of the parent operation corresponds to region arguments + Block &entryBlock = region.front(); + for (BlockArgument arg : entryBlock.getArguments()) { + // Get entry successor operands - these are the operands that flow + // from the parent operation to this region + SmallVector entrySuccessors; + regionBranchOp.getSuccessorRegions(RegionBranchPoint::parent(), entrySuccessors); + + for (const RegionSuccessor &entrySuccessor : entrySuccessors) { + if (entrySuccessor.getSuccessor() == ®ion) { + // Get the operands that are forwarded to this region + OperandRange entryOperands = + regionBranchOp.getEntrySuccessorOperands(RegionBranchPoint::parent()); + unsigned argIndex = arg.getArgNumber(); + if (argIndex < entryOperands.size()) { + propagateBackward(entryOperands[argIndex], liveSet); + } + break; + } + } + } + } + } + } +} +static void processCallOp(CallOpInterface callOp, Operation *module, + RunLivenessAnalysis &la, DenseSet &nonLiveSet, + DenseSet &liveSet) { + if (!la.getSolverConfig().isInterprocedural()) + return; + + Operation *callableOp = callOp.resolveCallable(); + auto funcOp = dyn_cast(callableOp); + if (!funcOp || !funcOp.isPublic()) { + return; + } + + LDBG() << "processCallOp to a public function: " << funcOp.getName(); + // Get the list of unnecessary (non-live) arguments in `nonLiveArgs`. + SmallVector arguments(funcOp.getArguments()); + BitVector nonLiveArgs = markLives(arguments, nonLiveSet, liveSet, la); + nonLiveArgs = nonLiveArgs.flip(); + + if (nonLiveArgs.count() > 0) { + LDBG() << funcOp.getName() << " contains NonLive arguments"; + // The number of operands in the call op may not match the number of + // arguments in the func op. + SmallVector callOpOperands = + operandsToOpOperands(callOp.getArgOperands()); + + for (int index : nonLiveArgs.set_bits()) { + OpOperand *operand = callOpOperands[index]; + LDBG() << "mark operand " << index << " live " << operand->get(); + propagateBackward(operand->get(), liveSet); + } + } +} + /// Process a region branch operation `regionBranchOp` using the liveness /// information in `la`. The processing involves two scenarios: /// @@ -411,12 +543,14 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module, static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, RunLivenessAnalysis &la, DenseSet &nonLiveSet, + DenseSet &liveSet, RDVFinalCleanupList &cl) { LDBG() << "Processing region branch op: " << OpWithFlags(regionBranchOp, OpPrintingFlags().skipRegions()); // Mark live results of `regionBranchOp` in `liveResults`. auto markLiveResults = [&](BitVector &liveResults) { - liveResults = markLives(regionBranchOp->getResults(), nonLiveSet, la); + liveResults = + markLives(regionBranchOp->getResults(), nonLiveSet, liveSet, la); }; // Mark live arguments in the regions of `regionBranchOp` in `liveArgs`. @@ -425,7 +559,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, if (region.empty()) continue; SmallVector arguments(region.front().getArguments()); - BitVector regionLiveArgs = markLives(arguments, nonLiveSet, la); + BitVector regionLiveArgs = markLives(arguments, nonLiveSet, liveSet, la); liveArgs[®ion] = regionLiveArgs; } }; @@ -619,7 +753,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, // attributed to something else. // Do (1') and (2'). if (isMemoryEffectFree(regionBranchOp.getOperation()) && - !hasLive(regionBranchOp->getResults(), nonLiveSet, la)) { + !hasLive(regionBranchOp->getResults(), nonLiveSet, liveSet, la)) { cl.operations.push_back(regionBranchOp.getOperation()); return; } @@ -698,7 +832,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la, DenseSet &nonLiveSet, - RDVFinalCleanupList &cl) { + DenseSet &liveSet, RDVFinalCleanupList &cl) { LDBG() << "Processing branch op: " << *branchOp; unsigned numSuccessors = branchOp->getNumSuccessors(); @@ -716,7 +850,7 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la, // Do (2) BitVector successorNonLive = - markLives(operandValues, nonLiveSet, la).flip(); + markLives(operandValues, nonLiveSet, liveSet, la).flip(); collectNonLiveValues(nonLiveSet, successorBlock->getArguments(), successorNonLive); @@ -876,26 +1010,29 @@ void RemoveDeadValues::runOnOperation() { // Tracks values eligible for erasure - complements liveness analysis to // identify "droppable" values. DenseSet deadVals; + // mark outgoing arguments to a public function LIVE. We also propagate + // liveness backward. + DenseSet liveVals; // Maintains a list of Ops, values, branches, etc., slated for cleanup at the // end of this pass. RDVFinalCleanupList finalCleanupList; - module->walk([&](Operation *op) { + module->walk([&](Operation *op) { if (auto funcOp = dyn_cast(op)) { - processFuncOp(funcOp, module, la, deadVals, finalCleanupList); + processFuncOp(funcOp, module, la, deadVals, liveVals, finalCleanupList); } else if (auto regionBranchOp = dyn_cast(op)) { - processRegionBranchOp(regionBranchOp, la, deadVals, finalCleanupList); + processRegionBranchOp(regionBranchOp, la, deadVals, liveVals, + finalCleanupList); } else if (auto branchOp = dyn_cast(op)) { - processBranchOp(branchOp, la, deadVals, finalCleanupList); + processBranchOp(branchOp, la, deadVals, liveVals, finalCleanupList); } else if (op->hasTrait<::mlir::OpTrait::IsTerminator>()) { // Nothing to do here because this is a terminator op and it should be // honored with respect to its parent } else if (isa(op)) { - // Nothing to do because this op is associated with a function op and gets - // cleaned when the latter is cleaned. + processCallOp(cast(op), module, la, deadVals, liveVals); } else { - processSimpleOp(op, la, deadVals, finalCleanupList); + processSimpleOp(op, la, deadVals, liveVals, finalCleanupList); } }); diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir index 56449469dc29f..a60efa45fe943 100644 --- a/mlir/test/Transforms/remove-dead-values.mlir +++ b/mlir/test/Transforms/remove-dead-values.mlir @@ -569,6 +569,43 @@ module @return_void_with_unused_argument { call @fn_return_void_with_unused_argument(%arg0, %unused) : (i32, memref<4xi32>) -> () return %unused : memref<4xi32> } + + // the function signature is immutable because it is public. + func.func public @immutable_fn_with_unused_argument(%arg0: i32, %arg1: memref<4xf32>) -> () { + return + } + + // CHECK-LABEL: func.func @main2 + // CHECK: %[[ONE:.*]] = arith.constant 1 : i32 + // CHECK: %[[UNUSED:.*]] = arith.addi %[[ONE]], %[[ONE]] : i32 + // CHECK: %[[MEM:.*]] = memref.alloc() : memref<4xf32> + // CHECK: call @immutable_fn_with_unused_argument(%[[UNUSED]], %[[MEM]]) : (i32, memref<4xf32>) -> () + func.func @main2() -> () { + %one = arith.constant 1 : i32 + %scalar = arith.addi %one, %one: i32 + %mem = memref.alloc() : memref<4xf32> + + call @immutable_fn_with_unused_argument(%scalar, %mem) : (i32, memref<4xf32>) -> () + return + } + + // CHECK-LABEL: func.func @main3 + // CHECK: %[[UNUSED:.*]] = scf.if %arg0 -> (i32) + // CHECK: %[[MEM:.*]] = memref.alloc() : memref<4xf32> + // CHECK: call @immutable_fn_with_unused_argument(%[[UNUSED]], %[[MEM]]) : (i32, memref<4xf32>) -> () + func.func @main3(%arg0: i1) { + %0 = scf.if %arg0 -> (i32) { + %c1_i32 = arith.constant 1 : i32 + scf.yield %c1_i32 : i32 + } else { + %c0_i32 = arith.constant 0 : i32 + scf.yield %c0_i32 : i32 + } + %mem = memref.alloc() : memref<4xf32> + + call @immutable_fn_with_unused_argument(%0, %mem) : (i32, memref<4xf32>) -> () + return + } } // -----