diff --git a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h index c7c405e1423cb..3b2914cdd4c98 100644 --- a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h @@ -17,6 +17,7 @@ #include "mlir/Analysis/DataFlowFramework.h" #include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" #include "llvm/ADT/SmallPtrSet.h" #include @@ -200,6 +201,13 @@ class DeadCodeAnalysis : public DataFlowAnalysis { /// which are live from the current block. void visitBranchOperation(BranchOpInterface branch); + /// Visit region branch edges from `predecessorOp` to a list of successors. + /// For each edge, mark the successor program point as executable, and record + /// the predecessor information in its `PredecessorState`. + void visitRegionBranchEdges(RegionBranchOpInterface regionBranchOp, + Operation *predecessorOp, + const SmallVector &successors); + /// Visit the given region branch operation, which defines regions, and /// compute any necessary lattice state. This also resolves the lattice state /// of both the operation results and any nested regions. diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp index 131c49c44171b..de1ed39ed4fdb 100644 --- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp @@ -444,30 +444,21 @@ void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) { /// Get the constant values of the operands of an operation. If any of the /// constant value lattices are uninitialized, return std::nullopt to indicate /// the analysis should bail out. -static std::optional> getOperandValuesImpl( - Operation *op, - function_ref *(Value)> getLattice) { +std::optional> +DeadCodeAnalysis::getOperandValues(Operation *op) { SmallVector operands; operands.reserve(op->getNumOperands()); for (Value operand : op->getOperands()) { - const Lattice *cv = getLattice(operand); + Lattice *cv = getOrCreate>(operand); + cv->useDefSubscribe(this); // If any of the operands' values are uninitialized, bail out. if (cv->getValue().isUninitialized()) - return {}; + return std::nullopt; operands.push_back(cv->getValue().getConstantValue()); } return operands; } -std::optional> -DeadCodeAnalysis::getOperandValues(Operation *op) { - return getOperandValuesImpl(op, [&](Value value) { - auto *lattice = getOrCreate>(value); - lattice->useDefSubscribe(this); - return lattice; - }); -} - void DeadCodeAnalysis::visitBranchOperation(BranchOpInterface branch) { LDBG() << "visitBranchOperation: " << OpWithFlags(branch.getOperation(), OpPrintingFlags().skipRegions()); @@ -498,23 +489,8 @@ void DeadCodeAnalysis::visitRegionBranchOperation( SmallVector successors; branch.getEntrySuccessorRegions(*operands, successors); - for (const RegionSuccessor &successor : successors) { - // The successor can be either an entry block or the parent operation. - ProgramPoint *point = - successor.getSuccessor() - ? getProgramPointBefore(&successor.getSuccessor()->front()) - : getProgramPointAfter(branch); - // Mark the entry block as executable. - auto *state = getOrCreate(point); - propagateIfChanged(state, state->setToLive()); - LDBG() << "Marked region successor live: " << point; - // Add the parent op as a predecessor. - auto *predecessors = getOrCreate(point); - propagateIfChanged( - predecessors, - predecessors->join(branch, successor.getSuccessorInputs())); - LDBG() << "Added region branch as predecessor for successor: " << point; - } + + visitRegionBranchEdges(branch, branch.getOperation(), successors); } void DeadCodeAnalysis::visitRegionTerminator(Operation *op, @@ -530,26 +506,30 @@ void DeadCodeAnalysis::visitRegionTerminator(Operation *op, else branch.getSuccessorRegions(op->getParentRegion(), successors); - // Mark successor region entry blocks as executable and add this op to the - // list of predecessors. + visitRegionBranchEdges(branch, op, successors); +} + +void DeadCodeAnalysis::visitRegionBranchEdges( + RegionBranchOpInterface regionBranchOp, Operation *predecessorOp, + const SmallVector &successors) { for (const RegionSuccessor &successor : successors) { - PredecessorState *predecessors; - if (Region *region = successor.getSuccessor()) { - auto *state = - getOrCreate(getProgramPointBefore(®ion->front())); - propagateIfChanged(state, state->setToLive()); - LDBG() << "Marked region entry block live for region: " << region; - predecessors = getOrCreate( - getProgramPointBefore(®ion->front())); - } else { - // Add this terminator as a predecessor to the parent op. - predecessors = - getOrCreate(getProgramPointAfter(branch)); - } - propagateIfChanged(predecessors, - predecessors->join(op, successor.getSuccessorInputs())); - LDBG() << "Added region terminator as predecessor for successor: " - << (successor.getSuccessor() ? "region entry" : "parent op"); + // The successor can be either an entry block or the parent operation. + ProgramPoint *point = + successor.getSuccessor() + ? getProgramPointBefore(&successor.getSuccessor()->front()) + : getProgramPointAfter(regionBranchOp); + + // Mark the entry block as executable. + auto *state = getOrCreate(point); + propagateIfChanged(state, state->setToLive()); + LDBG() << "Marked region successor live: " << point; + + // Add the parent op as a predecessor. + auto *predecessors = getOrCreate(point); + propagateIfChanged( + predecessors, + predecessors->join(predecessorOp, successor.getSuccessorInputs())); + LDBG() << "Added region branch as predecessor for successor: " << point; } }