diff --git a/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp b/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp index 9f0bd37451820..5e0034fc01885 100644 --- a/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp +++ b/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp @@ -62,6 +62,7 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/BranchProbabilityInfo.h" #include "llvm/Analysis/CodeMetrics.h" #include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/LoopInfo.h" @@ -141,18 +142,21 @@ class SelectInstToUnfold { explicit operator bool() const { return SI && SIUse; } }; -void unfold(DomTreeUpdater *DTU, LoopInfo *LI, SelectInstToUnfold SIToUnfold, +void unfold(DomTreeUpdater *DTU, LoopInfo *LI, BlockFrequencyInfo *BFI, + BranchProbabilityInfo *BPI, SelectInstToUnfold SIToUnfold, std::vector *NewSIsToUnfold, std::vector *NewBBs); class DFAJumpThreading { public: DFAJumpThreading(AssumptionCache *AC, DominatorTree *DT, LoopInfo *LI, + BlockFrequencyInfo *BFI, BranchProbabilityInfo *BPI, TargetTransformInfo *TTI, OptimizationRemarkEmitter *ORE) - : AC(AC), DT(DT), LI(LI), TTI(TTI), ORE(ORE) {} + : AC(AC), DT(DT), LI(LI), BFI(BFI), BPI(BPI), TTI(TTI), ORE(ORE) {} bool run(Function &F); bool LoopInfoBroken; + bool BFIBPIBroken; private: void @@ -167,7 +171,7 @@ class DFAJumpThreading { std::vector NewSIsToUnfold; std::vector NewBBs; - unfold(&DTU, LI, SIToUnfold, &NewSIsToUnfold, &NewBBs); + unfold(&DTU, LI, BFI, BPI, SIToUnfold, &NewSIsToUnfold, &NewBBs); // Put newly discovered select instructions into the work list. llvm::append_range(Stack, NewSIsToUnfold); @@ -177,6 +181,8 @@ class DFAJumpThreading { AssumptionCache *AC; DominatorTree *DT; LoopInfo *LI; + BlockFrequencyInfo *BFI; + BranchProbabilityInfo *BPI; TargetTransformInfo *TTI; OptimizationRemarkEmitter *ORE; }; @@ -192,7 +198,8 @@ namespace { /// created basic blocks into \p NewBBs. /// /// TODO: merge it with CodeGenPrepare::optimizeSelectInst() if possible. -void unfold(DomTreeUpdater *DTU, LoopInfo *LI, SelectInstToUnfold SIToUnfold, +void unfold(DomTreeUpdater *DTU, LoopInfo *LI, BlockFrequencyInfo *BFI, + BranchProbabilityInfo *BPI, SelectInstToUnfold SIToUnfold, std::vector *NewSIsToUnfold, std::vector *NewBBs) { SelectInst *SI = SIToUnfold.getInst(); @@ -200,9 +207,23 @@ void unfold(DomTreeUpdater *DTU, LoopInfo *LI, SelectInstToUnfold SIToUnfold, assert(SI->hasOneUse()); // The select may come indirectly, instead of from where it is defined. BasicBlock *StartBlock = SIUse->getIncomingBlock(*SI->use_begin()); - BranchInst *StartBlockTerm = - dyn_cast(StartBlock->getTerminator()); - assert(StartBlockTerm); + BranchInst *StartBlockTerm = cast(StartBlock->getTerminator()); + + uint64_t TrueWeight = 1; + uint64_t FalseWeight = 1; + // Copy probabilities from 'SI' to the created conditional branch. + SmallVector SIProbs; + if (extractBranchWeights(*SI, TrueWeight, FalseWeight) && + (TrueWeight + FalseWeight) != 0) { + SIProbs.emplace_back(BranchProbability::getBranchProbability( + TrueWeight, TrueWeight + FalseWeight)); + SIProbs.emplace_back(BranchProbability::getBranchProbability( + FalseWeight, TrueWeight + FalseWeight)); + } + if ((TrueWeight + FalseWeight) == 0) + TrueWeight = FalseWeight = 1; + auto FalseProb = BranchProbability::getBranchProbability( + FalseWeight, TrueWeight + FalseWeight); if (StartBlockTerm->isUnconditional()) { BasicBlock *EndBlock = StartBlock->getUniqueSuccessor(); @@ -263,6 +284,13 @@ void unfold(DomTreeUpdater *DTU, LoopInfo *LI, SelectInstToUnfold SIToUnfold, BranchInst::Create(EndBlock, NewBlock, SI->getCondition(), StartBlock); DTU->applyUpdates({{DominatorTree::Insert, StartBlock, EndBlock}, {DominatorTree::Insert, StartBlock, NewBlock}}); + + // Update BPI if exists. + if (BPI && !SIProbs.empty()) + BPI->setEdgeProbability(StartBlock, SIProbs); + // Update the block frequency of NewBlock. + if (BFI) + BFI->setBlockFreq(NewBlock, BFI->getBlockFreq(StartBlock) * FalseProb); } else { BasicBlock *EndBlock = SIUse->getParent(); BasicBlock *NewBlockT = BasicBlock::Create( @@ -336,6 +364,17 @@ void unfold(DomTreeUpdater *DTU, LoopInfo *LI, SelectInstToUnfold SIToUnfold, StartBlockTerm->setSuccessor(SuccNum, NewBlockT); DTU->applyUpdates({{DominatorTree::Delete, StartBlock, EndBlock}, {DominatorTree::Insert, StartBlock, NewBlockT}}); + // Update BPI if exists. + if (BPI && !SIProbs.empty()) + BPI->setEdgeProbability(NewBlockT, SIProbs); + // Update the block frequency of both NewBlockT and NewBlockF. + if (BFI) { + assert(BPI && "BPI should be valid if BFI exists"); + auto NewBlockTFreq = BFI->getBlockFreq(StartBlock) * + BPI->getEdgeProbability(StartBlock, SuccNum); + BFI->setBlockFreq(NewBlockT, NewBlockTFreq); + BFI->setBlockFreq(NewBlockF, NewBlockTFreq * FalseProb); + } } // Preserve loop info @@ -994,6 +1033,7 @@ struct TransformDFA { SmallPtrSet BlocksToClean; BlocksToClean.insert_range(successors(SwitchBlock)); + // TODO: Preserve BFI/BPI during creating exit paths. { DomTreeUpdater DTU(*DT, DomTreeUpdater::UpdateStrategy::Lazy); for (const ThreadingPath &TPath : SwitchPaths->getThreadingPaths()) { @@ -1378,7 +1418,7 @@ bool DFAJumpThreading::run(Function &F) { SmallVector ThreadableLoops; bool MadeChanges = false; - LoopInfoBroken = false; + LoopInfoBroken = BFIBPIBroken = false; for (BasicBlock &BB : F) { auto *SI = dyn_cast(BB.getTerminator()); @@ -1431,7 +1471,7 @@ bool DFAJumpThreading::run(Function &F) { for (AllSwitchPaths SwitchPaths : ThreadableLoops) { TransformDFA Transform(&SwitchPaths, DT, AC, TTI, ORE, EphValues); if (Transform.run()) - MadeChanges = LoopInfoBroken = true; + MadeChanges = LoopInfoBroken = BFIBPIBroken = true; } #ifdef EXPENSIVE_CHECKS @@ -1450,9 +1490,11 @@ PreservedAnalyses DFAJumpThreadingPass::run(Function &F, AssumptionCache &AC = AM.getResult(F); DominatorTree &DT = AM.getResult(F); LoopInfo &LI = AM.getResult(F); + BlockFrequencyInfo *BFI = AM.getCachedResult(F); + BranchProbabilityInfo *BPI = AM.getCachedResult(F); TargetTransformInfo &TTI = AM.getResult(F); OptimizationRemarkEmitter ORE(&F); - DFAJumpThreading ThreadImpl(&AC, &DT, &LI, &TTI, &ORE); + DFAJumpThreading ThreadImpl(&AC, &DT, &LI, BFI, BPI, &TTI, &ORE); if (!ThreadImpl.run(F)) return PreservedAnalyses::all(); @@ -1460,5 +1502,9 @@ PreservedAnalyses DFAJumpThreadingPass::run(Function &F, PA.preserve(); if (!ThreadImpl.LoopInfoBroken) PA.preserve(); + if (!ThreadImpl.BFIBPIBroken) { + PA.preserve(); + PA.preserve(); + } return PA; }