diff --git a/mlir/include/mlir/IR/Visitors.h b/mlir/include/mlir/IR/Visitors.h index 893f66ae33deb..5766d262796d6 100644 --- a/mlir/include/mlir/IR/Visitors.h +++ b/mlir/include/mlir/IR/Visitors.h @@ -39,6 +39,20 @@ struct ForwardIterator { } }; +/// This iterator enumerates the elements in "backward" order. +struct BackwardIterator { + template + static auto makeIterable(T &range) { + if constexpr (std::is_same()) { + /// Make operations iterable: return the list of regions. + return llvm::reverse(range.getRegions()); + } else { + /// Regions and block are already iterable. + return llvm::reverse(range); + } + } +}; + /// A utility class to encode the current walk stage for "generic" walkers. /// When walking an operation, we can either choose a Pre/Post order walker /// which invokes the callback on an operation before/after all its attached diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp index 0e84b6dd17f29..d3e088dce8bf0 100644 --- a/mlir/lib/Transforms/RemoveDeadValues.cpp +++ b/mlir/lib/Transforms/RemoveDeadValues.cpp @@ -116,8 +116,15 @@ struct RDVFinalCleanupList { /// Return true iff at least one value in `values` is live, given the liveness /// information in `la`. static bool hasLive(ValueRange values, const DenseSet &nonLiveSet, + const DenseSet &liveSet, + RunLivenessAnalysis &la) { for (Value value : values) { + if (liveSet.contains(value)) { + LDBG() << "Value " << value << " is marked live by CallOp"; + return true; + } + if (nonLiveSet.contains(value)) { LDBG() << "Value " << value << " is already marked non-live (dead)"; continue; @@ -257,8 +264,9 @@ static SmallVector operandsToOpOperands(OperandRange operands) { /// - Return-like static void processSimpleOp(Operation *op, RunLivenessAnalysis &la, DenseSet &nonLiveSet, - RDVFinalCleanupList &cl) { - if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, la)) { + DenseSet &liveSet, RDVFinalCleanupList &cl) { + if (!isMemoryEffectFree(op) || + hasLive(op->getResults(), nonLiveSet, liveSet, la)) { LDBG() << "Simple op is not memory effect free or has live results, " "preserving it: " << OpWithFlags(op, OpPrintingFlags().skipRegions()); @@ -376,6 +384,31 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module, } } +static void processCallOp(CallOpInterface callOp, Operation *module, + RunLivenessAnalysis &la, DenseSet &liveSet) { + auto callable = callOp.getCallableForCallee(); + + if (auto symbolRef = callable.dyn_cast()) { + Operation *calleeOp = SymbolTable::lookupSymbolIn(module, symbolRef); + + if (auto funcOp = + llvm::dyn_cast_or_null(calleeOp)) { + // Ensure the outgoing arguments of PUBLIC functions are live + // because processFuncOp can not process them. + // + // Liveness treats the external function as a blackbox. + if (funcOp.isPublic()) { + for (Value arg : callOp.getArgOperands()) { + const Liveness *liveness = la.getLiveness(arg); + if (liveness && !liveness->isLive) { + liveSet.insert(arg); + } + } + } + } + } +} + /// Process a region branch operation `regionBranchOp` using the liveness /// information in `la`. The processing involves two scenarios: /// @@ -408,6 +441,7 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module, static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, RunLivenessAnalysis &la, DenseSet &nonLiveSet, + DenseSet &liveSet, RDVFinalCleanupList &cl) { LDBG() << "Processing region branch op: " << OpWithFlags(regionBranchOp, OpPrintingFlags().skipRegions()); @@ -616,7 +650,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, // attributed to something else. // Do (1') and (2'). if (isMemoryEffectFree(regionBranchOp.getOperation()) && - !hasLive(regionBranchOp->getResults(), nonLiveSet, la)) { + !hasLive(regionBranchOp->getResults(), nonLiveSet, liveSet, la)) { cl.operations.push_back(regionBranchOp.getOperation()); return; } @@ -834,16 +868,19 @@ void RemoveDeadValues::runOnOperation() { // Tracks values eligible for erasure - complements liveness analysis to // identify "droppable" values. DenseSet deadVals; + // mark outgoing arguments to a public function LIVE. + DenseSet liveVals; // Maintains a list of Ops, values, branches, etc., slated for cleanup at the // end of this pass. RDVFinalCleanupList finalCleanupList; - module->walk([&](Operation *op) { + module->walk([&](Operation *op) { if (auto funcOp = dyn_cast(op)) { processFuncOp(funcOp, module, la, deadVals, finalCleanupList); } else if (auto regionBranchOp = dyn_cast(op)) { - processRegionBranchOp(regionBranchOp, la, deadVals, finalCleanupList); + processRegionBranchOp(regionBranchOp, la, deadVals, liveVals, + finalCleanupList); } else if (auto branchOp = dyn_cast(op)) { processBranchOp(branchOp, la, deadVals, finalCleanupList); } else if (op->hasTrait<::mlir::OpTrait::IsTerminator>()) { @@ -852,8 +889,13 @@ void RemoveDeadValues::runOnOperation() { } else if (isa(op)) { // Nothing to do because this op is associated with a function op and gets // cleaned when the latter is cleaned. + // + // The only exception is public callee. By default, Liveness analysis is + // inter-procedural. Unused arguments of a public function nonLive and are + // propagated to the caller. processCallOp puts them to liveVals. + processCallOp(cast(op), module, la, liveVals); } else { - processSimpleOp(op, la, deadVals, finalCleanupList); + processSimpleOp(op, la, deadVals, liveVals, finalCleanupList); } }); diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir index fa2c145bd3701..1580009c74d4d 100644 --- a/mlir/test/Transforms/remove-dead-values.mlir +++ b/mlir/test/Transforms/remove-dead-values.mlir @@ -569,6 +569,24 @@ module @return_void_with_unused_argument { call @fn_return_void_with_unused_argument(%arg0, %unused) : (i32, memref<4xi32>) -> () return %unused : memref<4xi32> } + + // the function is immutable because it is public. + func.func public @immutable_fn_return_void_with_unused_argument(%arg0: i32, %unused: i32) -> () { + %sum = arith.addi %arg0, %arg0 : i32 + %c0 = arith.constant 0 : index + %buf = memref.alloc() : memref<1xi32> + memref.store %sum, %buf[%c0] : memref<1xi32> + return + } + // CHECK-LABEL: func.func @main2 + // CHECK-SAME: (%[[ARG0_MAIN:.*]]: i32) + // CHECK: %[[UNUSED:.*]] = arith.constant 0 : i32 + // CHECK: call @immutable_fn_return_void_with_unused_argument(%[[ARG0_MAIN]], %[[UNUSED]]) : (i32, i32) -> () + func.func @main2(%arg0: i32) -> () { + %zero = arith.constant 0 : i32 + call @immutable_fn_return_void_with_unused_argument(%arg0, %zero) : (i32, i32) -> () + return + } } // -----