diff --git a/llvm/include/llvm/IR/Instruction.h b/llvm/include/llvm/IR/Instruction.h index 15b0bdf557fb1e..1044a634408ca3 100644 --- a/llvm/include/llvm/IR/Instruction.h +++ b/llvm/include/llvm/IR/Instruction.h @@ -335,11 +335,6 @@ class Instruction : public User, /// Sets the AA metadata on this instruction from the AAMDNodes structure. void setAAMetadata(const AAMDNodes &N); - /// Retrieve the raw weight values of a conditional branch or select. - /// Returns true on success with profile weights filled in. - /// Returns false if no metadata or invalid metadata was found. - bool extractProfMetadata(uint64_t &TrueVal, uint64_t &FalseVal) const; - /// Retrieve total raw weight values of a branch. /// Returns true on success with profile total weights filled in. /// Returns false if no metadata was found. diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h index b6c53a3dc1a54f..0051c41536b55f 100644 --- a/llvm/include/llvm/IR/ProfDataUtils.h +++ b/llvm/include/llvm/IR/ProfDataUtils.h @@ -16,15 +16,15 @@ bool isBranchWeightMD(const MDNode *ProfileData); /// Checks if an instructions has Branch Weight Metadata /// /// \param I The instruction to check -/// \return True if I has an MD_prof node containing Branch Weights. False +/// \returns True if I has an MD_prof node containing Branch Weights. False /// otherwise. bool hasBranchWeightMD(const Instruction &I); /// Extract branch weights from MD_prof metadata /// /// \param ProfileData A pointer to an MDNode. -/// \param Weights An output vector to fill with branch weights -/// \return True if weights were extracted, False otherwise. When false Weights +/// \param [out] Weights An output vector to fill with branch weights +/// \returns True if weights were extracted, False otherwise. When false Weights /// will be cleared. bool extractBranchWeights(const MDNode *ProfileData, SmallVectorImpl &Weights); @@ -32,24 +32,28 @@ bool extractBranchWeights(const MDNode *ProfileData, /// Extract branch weights attatched to an Instruction /// /// \param I The Instruction to extract weights from. -/// \param Weights An output vector to fill with branch weights -/// \return True if weights were extracted, False otherwise. When false Weights +/// \param [out] Weights An output vector to fill with branch weights +/// \returns True if weights were extracted, False otherwise. When false Weights /// will be cleared. bool extractBranchWeights(const Instruction &I, SmallVectorImpl &Weights); -/// Retrieve the raw weight values of a conditional branch or select. -/// Returns true on success with profile weights filled in. -/// Returns false if no metadata or invalid metadata was found. +/// Extract branch weights from a conditional branch or select Instruction. +/// +/// \param I The instruction to extract branch weights from. +/// \param [out] TrueVal will contain the branch weight for the True branch +/// \param [out] FalseVal will contain the branch weight for the False branch +/// \returns True on success with profile weights filled in. False if no +/// metadata or invalid metadata was found. bool extractBranchWeights(const Instruction &I, uint64_t &TrueVal, uint64_t &FalseVal); /// Retrieve the total of all weights from MD_prof data. /// /// \param ProfileData The profile data to extract the total weight from -/// \param TotalWeights input variable to fill with total weights -/// \return true on success with profile total weights filled in. -/// \return false if no metadata was found. +/// \param [out] TotalWeights input variable to fill with total weights +/// \returns True on success with profile total weights filled in. False if no +/// metadata was found. bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalWeights); } // namespace llvm diff --git a/llvm/lib/Analysis/BranchProbabilityInfo.cpp b/llvm/lib/Analysis/BranchProbabilityInfo.cpp index f45728768fcdb0..8918fb967594e4 100644 --- a/llvm/lib/Analysis/BranchProbabilityInfo.cpp +++ b/llvm/lib/Analysis/BranchProbabilityInfo.cpp @@ -31,6 +31,7 @@ #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/PassManager.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" #include "llvm/InitializePasses.h" @@ -401,24 +402,18 @@ bool BranchProbabilityInfo::calcMetadataWeights(const BasicBlock *BB) { SmallVector Weights; SmallVector UnreachableIdxs; SmallVector ReachableIdxs; - Weights.reserve(TI->getNumSuccessors()); - for (unsigned I = 1, E = WeightsNode->getNumOperands(); I != E; ++I) { - ConstantInt *Weight = - mdconst::dyn_extract(WeightsNode->getOperand(I)); - if (!Weight) - return false; - assert(Weight->getValue().getActiveBits() <= 32 && - "Too many bits for uint32_t"); - Weights.push_back(Weight->getZExtValue()); - WeightSum += Weights.back(); + + extractBranchWeights(*TI, Weights); + for (unsigned I = 0, E = Weights.size(); I != E; ++I) { + WeightSum += Weights[I]; const LoopBlock SrcLoopBB = getLoopBlock(BB); - const LoopBlock DstLoopBB = getLoopBlock(TI->getSuccessor(I - 1)); + const LoopBlock DstLoopBB = getLoopBlock(TI->getSuccessor(I)); auto EstimatedWeight = getEstimatedEdgeWeight({SrcLoopBB, DstLoopBB}); if (EstimatedWeight && *EstimatedWeight <= static_cast(BlockExecWeight::UNREACHABLE)) - UnreachableIdxs.push_back(I - 1); + UnreachableIdxs.push_back(I); else - ReachableIdxs.push_back(I - 1); + ReachableIdxs.push_back(I); } assert(Weights.size() == TI->getNumSuccessors() && "Checked above"); diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp index d5d08f04d4cd8c..b100fbe2b33c40 100644 --- a/llvm/lib/CodeGen/CodeGenPrepare.cpp +++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp @@ -65,6 +65,7 @@ #include "llvm/IR/Module.h" #include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/IR/Statepoint.h" #include "llvm/IR/Type.h" #include "llvm/IR/Use.h" @@ -6620,7 +6621,7 @@ static bool isFormingBranchFromSelectProfitable(const TargetTransformInfo *TTI, // If metadata tells us that the select condition is obviously predictable, // then we want to replace the select with a branch. uint64_t TrueWeight, FalseWeight; - if (SI->extractProfMetadata(TrueWeight, FalseWeight)) { + if (extractBranchWeights(*SI, TrueWeight, FalseWeight)) { uint64_t Max = std::max(TrueWeight, FalseWeight); uint64_t Sum = TrueWeight + FalseWeight; if (Sum != 0) { @@ -8366,7 +8367,7 @@ bool CodeGenPrepare::splitBranchCondition(Function &F, bool &ModifiedDT) { // Another choice is to assume TrueProb for BB1 equals to TrueProb for // TmpBB, but the math is more complicated. uint64_t TrueWeight, FalseWeight; - if (Br1->extractProfMetadata(TrueWeight, FalseWeight)) { + if (extractBranchWeights(*Br1, TrueWeight, FalseWeight)) { uint64_t NewTrueWeight = TrueWeight; uint64_t NewFalseWeight = TrueWeight + 2 * FalseWeight; scaleWeights(NewTrueWeight, NewFalseWeight); @@ -8399,7 +8400,7 @@ bool CodeGenPrepare::splitBranchCondition(Function &F, bool &ModifiedDT) { // assumes that // FalseProb for BB1 == TrueProb for BB1 * FalseProb for TmpBB. uint64_t TrueWeight, FalseWeight; - if (Br1->extractProfMetadata(TrueWeight, FalseWeight)) { + if (extractBranchWeights(*Br1, TrueWeight, FalseWeight)) { uint64_t NewTrueWeight = 2 * TrueWeight + FalseWeight; uint64_t NewFalseWeight = FalseWeight; scaleWeights(NewTrueWeight, NewFalseWeight); diff --git a/llvm/lib/CodeGen/SelectOptimize.cpp b/llvm/lib/CodeGen/SelectOptimize.cpp index 011f55efce1d5f..a7a698744591ca 100644 --- a/llvm/lib/CodeGen/SelectOptimize.cpp +++ b/llvm/lib/CodeGen/SelectOptimize.cpp @@ -29,6 +29,7 @@ #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instruction.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Support/ScaledNumber.h" @@ -655,7 +656,7 @@ bool SelectOptimize::hasExpensiveColdOperand( const SmallVector &ASI) { bool ColdOperand = false; uint64_t TrueWeight, FalseWeight, TotalWeight; - if (ASI.front()->extractProfMetadata(TrueWeight, FalseWeight)) { + if (extractBranchWeights(*ASI.front(), TrueWeight, FalseWeight)) { uint64_t MinWeight = std::min(TrueWeight, FalseWeight); TotalWeight = TrueWeight + FalseWeight; // Is there a path with frequency extractProfMetadata(TrueWeight, FalseWeight)) { + if (extractBranchWeights(*SI, TrueWeight, FalseWeight)) { uint64_t Max = std::max(TrueWeight, FalseWeight); uint64_t Sum = TrueWeight + FalseWeight; if (Sum != 0) { @@ -959,7 +960,7 @@ SelectOptimize::getPredictedPathCost(Scaled64 TrueCost, Scaled64 FalseCost, const SelectInst *SI) { Scaled64 PredPathCost; uint64_t TrueWeight, FalseWeight; - if (SI->extractProfMetadata(TrueWeight, FalseWeight)) { + if (extractBranchWeights(*SI, TrueWeight, FalseWeight)) { uint64_t SumWeight = TrueWeight + FalseWeight; if (SumWeight != 0) { PredPathCost = TrueCost * Scaled64::get(TrueWeight) + diff --git a/llvm/lib/IR/Metadata.cpp b/llvm/lib/IR/Metadata.cpp index 2a1a514922fdc9..a2027439240fe5 100644 --- a/llvm/lib/IR/Metadata.cpp +++ b/llvm/lib/IR/Metadata.cpp @@ -40,6 +40,7 @@ #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/IR/TrackingMDRef.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" @@ -1493,31 +1494,6 @@ void Instruction::getAllMetadataImpl( Value::getAllMetadata(Result); } -bool Instruction::extractProfMetadata(uint64_t &TrueVal, - uint64_t &FalseVal) const { - assert( - (getOpcode() == Instruction::Br || getOpcode() == Instruction::Select) && - "Looking for branch weights on something besides branch or select"); - - auto *ProfileData = getMetadata(LLVMContext::MD_prof); - if (!ProfileData || ProfileData->getNumOperands() != 3) - return false; - - auto *ProfDataName = dyn_cast(ProfileData->getOperand(0)); - if (!ProfDataName || !ProfDataName->getString().equals("branch_weights")) - return false; - - auto *CITrue = mdconst::dyn_extract(ProfileData->getOperand(1)); - auto *CIFalse = mdconst::dyn_extract(ProfileData->getOperand(2)); - if (!CITrue || !CIFalse) - return false; - - TrueVal = CITrue->getValue().getZExtValue(); - FalseVal = CIFalse->getValue().getZExtValue(); - - return true; -} - bool Instruction::extractProfTotalWeight(uint64_t &TotalVal) const { assert( (getOpcode() == Instruction::Br || getOpcode() == Instruction::Select || @@ -1526,32 +1502,7 @@ bool Instruction::extractProfTotalWeight(uint64_t &TotalVal) const { getOpcode() == Instruction::Switch) && "Looking for branch weights on something besides branch"); - TotalVal = 0; - auto *ProfileData = getMetadata(LLVMContext::MD_prof); - if (!ProfileData) - return false; - - auto *ProfDataName = dyn_cast(ProfileData->getOperand(0)); - if (!ProfDataName) - return false; - - if (ProfDataName->getString().equals("branch_weights")) { - TotalVal = 0; - for (unsigned i = 1; i < ProfileData->getNumOperands(); i++) { - auto *V = mdconst::dyn_extract(ProfileData->getOperand(i)); - if (!V) - return false; - TotalVal += V->getValue().getZExtValue(); - } - return true; - } else if (ProfDataName->getString().equals("VP") && - ProfileData->getNumOperands() > 3) { - TotalVal = mdconst::dyn_extract(ProfileData->getOperand(2)) - ->getValue() - .getZExtValue(); - return true; - } - return false; + return ::extractProfTotalWeight(getMetadata(LLVMContext::MD_prof), TotalVal); } void GlobalObject::copyMetadata(const GlobalObject *Other, unsigned Offset) { diff --git a/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp b/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp index cf728933c08d25..c8945879d0bc35 100644 --- a/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp +++ b/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp @@ -15,6 +15,7 @@ #include "llvm/CodeGen/TargetLowering.h" #include "llvm/CodeGen/TargetSchedule.h" #include "llvm/IR/IntrinsicsPowerPC.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/KnownBits.h" @@ -757,7 +758,7 @@ bool PPCTTIImpl::isHardwareLoopProfitable(Loop *L, ScalarEvolution &SE, if (BranchInst *BI = dyn_cast(TI)) { uint64_t TrueWeight = 0, FalseWeight = 0; if (!BI->isConditional() || - !BI->extractProfMetadata(TrueWeight, FalseWeight)) + !extractBranchWeights(*BI, TrueWeight, FalseWeight)) continue; // If the exit path is more frequent than the loop path, diff --git a/llvm/lib/Transforms/IPO/PartialInlining.cpp b/llvm/lib/Transforms/IPO/PartialInlining.cpp index 54c72bdbb203f9..ec2e7fb215b2f3 100644 --- a/llvm/lib/Transforms/IPO/PartialInlining.cpp +++ b/llvm/lib/Transforms/IPO/PartialInlining.cpp @@ -40,6 +40,7 @@ #include "llvm/IR/Intrinsics.h" #include "llvm/IR/Module.h" #include "llvm/IR/Operator.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/IR/User.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" @@ -717,7 +718,7 @@ static bool hasProfileData(const Function &F, const FunctionOutliningInfo &OI) { if (!BR || BR->isUnconditional()) continue; uint64_t T, F; - if (BR->extractProfMetadata(T, F)) + if (extractBranchWeights(*BR, T, F)) return true; } return false; diff --git a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp index c4512d0222cde6..90cc61c6b291ba 100644 --- a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp +++ b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp @@ -91,6 +91,7 @@ #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" #include "llvm/IR/PassManager.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/IR/ProfileSummary.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" @@ -2067,7 +2068,7 @@ template <> struct DOTGraphTraits : DefaultDOTGraphTraits { // Display scaled counts for SELECT instruction: OS << "SELECT : { T = "; uint64_t TC, FC; - bool HasProf = I.extractProfMetadata(TC, FC); + bool HasProf = extractBranchWeights(I, TC, FC); if (!HasProf) OS << "Unknown, F = Unknown }\\l"; else diff --git a/llvm/lib/Transforms/Scalar/JumpThreading.cpp b/llvm/lib/Transforms/Scalar/JumpThreading.cpp index b31eab50c5ecca..0113e7bf570dfc 100644 --- a/llvm/lib/Transforms/Scalar/JumpThreading.cpp +++ b/llvm/lib/Transforms/Scalar/JumpThreading.cpp @@ -54,6 +54,7 @@ #include "llvm/IR/Module.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/IR/Type.h" #include "llvm/IR/Use.h" #include "llvm/IR/Value.h" @@ -216,7 +217,7 @@ static void updatePredecessorProfileMetadata(PHINode *PN, BasicBlock *BB) { return; uint64_t TrueWeight, FalseWeight; - if (!CondBr->extractProfMetadata(TrueWeight, FalseWeight)) + if (!extractBranchWeights(*CondBr, TrueWeight, FalseWeight)) return; if (TrueWeight + FalseWeight == 0) @@ -279,7 +280,7 @@ static void updatePredecessorProfileMetadata(PHINode *PN, BasicBlock *BB) { // With PGO, this can be used to refine even existing profile data with // context information. This needs to be done after more performance // testing. - if (PredBr->extractProfMetadata(PredTrueWeight, PredFalseWeight)) + if (extractBranchWeights(*PredBr, PredTrueWeight, PredFalseWeight)) continue; // We can not infer anything useful when BP >= 50%, because BP is the diff --git a/llvm/lib/Transforms/Utils/LoopPeel.cpp b/llvm/lib/Transforms/Utils/LoopPeel.cpp index f093fea19c4d49..9a7f9df7a330f0 100644 --- a/llvm/lib/Transforms/Utils/LoopPeel.cpp +++ b/llvm/lib/Transforms/Utils/LoopPeel.cpp @@ -29,6 +29,7 @@ #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -532,7 +533,7 @@ static void initBranchWeights(BasicBlock *Header, BranchInst *LatchBR, uint64_t &ExitWeight, uint64_t &FallThroughWeight) { uint64_t TrueWeight, FalseWeight; - if (!LatchBR->extractProfMetadata(TrueWeight, FalseWeight)) + if (!extractBranchWeights(*LatchBR, TrueWeight, FalseWeight)) return; unsigned HeaderIdx = LatchBR->getSuccessor(0) == Header ? 0 : 1; ExitWeight = HeaderIdx ? TrueWeight : FalseWeight; diff --git a/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp b/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp index 023a0afd329b5a..1c44ccb589bccf 100644 --- a/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp +++ b/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp @@ -30,6 +30,7 @@ #include "llvm/IR/Dominators.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" @@ -471,7 +472,7 @@ static void updateLatchBranchWeightsForRemainderLoop(Loop *OrigLoop, uint64_t TrueWeight, FalseWeight; BranchInst *LatchBR = cast(OrigLoop->getLoopLatch()->getTerminator()); - if (!LatchBR->extractProfMetadata(TrueWeight, FalseWeight)) + if (!extractBranchWeights(*LatchBR, TrueWeight, FalseWeight)) return; uint64_t ExitWeight = LatchBR->getSuccessor(0) == OrigLoop->getHeader() ? FalseWeight diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp index 349063dd5e892e..03272e5f8a5499 100644 --- a/llvm/lib/Transforms/Utils/LoopUtils.cpp +++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp @@ -38,6 +38,7 @@ #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/IR/ValueHandle.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" @@ -789,7 +790,7 @@ getEstimatedTripCount(BranchInst *ExitingBranch, Loop *L, // know the number of times the backedge was taken, vs. the number of times // we exited the loop. uint64_t LoopWeight, ExitWeight; - if (!ExitingBranch->extractProfMetadata(LoopWeight, ExitWeight)) + if (!extractBranchWeights(*ExitingBranch, LoopWeight, ExitWeight)) return None; if (L->contains(ExitingBranch->getSuccessor(1))) diff --git a/llvm/lib/Transforms/Utils/MisExpect.cpp b/llvm/lib/Transforms/Utils/MisExpect.cpp index d85e9ddb2beec8..8f94902b331fd8 100644 --- a/llvm/lib/Transforms/Utils/MisExpect.cpp +++ b/llvm/lib/Transforms/Utils/MisExpect.cpp @@ -35,6 +35,7 @@ #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/Support/BranchProbability.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -118,34 +119,6 @@ void emitMisexpectDiagnostic(Instruction *I, LLVMContext &Ctx, namespace llvm { namespace misexpect { -// Helper function to extract branch weights into a vector -Optional> extractWeights(Instruction *I, - LLVMContext &Ctx) { - assert(I && "MisExpect::extractWeights given invalid pointer"); - - auto *ProfileData = I->getMetadata(LLVMContext::MD_prof); - if (!ProfileData) - return None; - - unsigned NOps = ProfileData->getNumOperands(); - if (NOps < 3) - return None; - - auto *ProfDataName = dyn_cast(ProfileData->getOperand(0)); - if (!ProfDataName || !ProfDataName->getString().equals("branch_weights")) - return None; - - SmallVector Weights(NOps - 1); - for (unsigned Idx = 1; Idx < NOps; Idx++) { - ConstantInt *Value = - mdconst::dyn_extract(ProfileData->getOperand(Idx)); - uint32_t V = Value->getZExtValue(); - Weights[Idx - 1] = V; - } - - return Weights; -} - // TODO: when clang allows c++17, use std::clamp instead uint32_t clamp(uint64_t value, uint32_t low, uint32_t hi) { if (value > hi) @@ -218,19 +191,17 @@ void verifyMisExpect(Instruction &I, ArrayRef RealWeights, void checkBackendInstrumentation(Instruction &I, const ArrayRef RealWeights) { - auto ExpectedWeightsOpt = extractWeights(&I, I.getContext()); - if (!ExpectedWeightsOpt) + SmallVector ExpectedWeights; + if (!extractBranchWeights(I, ExpectedWeights)) return; - auto ExpectedWeights = ExpectedWeightsOpt.value(); verifyMisExpect(I, RealWeights, ExpectedWeights); } void checkFrontendInstrumentation(Instruction &I, const ArrayRef ExpectedWeights) { - auto RealWeightsOpt = extractWeights(&I, I.getContext()); - if (!RealWeightsOpt) + SmallVector RealWeights; + if (!extractBranchWeights(I, RealWeights)) return; - auto RealWeights = RealWeightsOpt.value(); verifyMisExpect(I, RealWeights, ExpectedWeights); } diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp index 1806081678a867..bba831521ff598 100644 --- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -57,6 +57,7 @@ #include "llvm/IR/NoFolder.h" #include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/IR/Type.h" #include "llvm/IR/Use.h" #include "llvm/IR/User.h" @@ -1050,15 +1051,6 @@ static int ConstantIntSortPredicate(ConstantInt *const *P1, return LHS->getValue().ult(RHS->getValue()) ? 1 : -1; } -static inline bool HasBranchWeights(const Instruction *I) { - MDNode *ProfMD = I->getMetadata(LLVMContext::MD_prof); - if (ProfMD && ProfMD->getOperand(0)) - if (MDString *MDS = dyn_cast(ProfMD->getOperand(0))) - return MDS->getString().equals("branch_weights"); - - return false; -} - /// Get Weights of a given terminator, the default weight is at the front /// of the vector. If TI is a conditional eq, we need to swap the branch-weight /// metadata. @@ -1177,8 +1169,8 @@ bool SimplifyCFGOpt::PerformValueComparisonIntoPredecessorFolding( // Update the branch weight metadata along the way SmallVector Weights; - bool PredHasWeights = HasBranchWeights(PTI); - bool SuccHasWeights = HasBranchWeights(TI); + bool PredHasWeights = hasBranchWeightMD(*PTI); + bool SuccHasWeights = hasBranchWeightMD(*TI); if (PredHasWeights) { GetBranchWeights(PTI, Weights); @@ -2752,7 +2744,8 @@ bool SimplifyCFGOpt::SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB, // the `then` block, then avoid speculating it. if (!BI->getMetadata(LLVMContext::MD_unpredictable)) { uint64_t TWeight, FWeight; - if (BI->extractProfMetadata(TWeight, FWeight) && (TWeight + FWeight) != 0) { + if (extractBranchWeights(*BI, TWeight, FWeight) && + (TWeight + FWeight) != 0) { uint64_t EndWeight = Invert ? TWeight : FWeight; BranchProbability BIEndProb = BranchProbability::getBranchProbability(EndWeight, TWeight + FWeight); @@ -3174,7 +3167,7 @@ static bool FoldTwoEntryPHINode(PHINode *PN, const TargetTransformInfo &TTI, // from the block that we know is predictably not entered. if (!DomBI->getMetadata(LLVMContext::MD_unpredictable)) { uint64_t TWeight, FWeight; - if (DomBI->extractProfMetadata(TWeight, FWeight) && + if (extractBranchWeights(*DomBI, TWeight, FWeight) && (TWeight + FWeight) != 0) { BranchProbability BITrueProb = BranchProbability::getBranchProbability(TWeight, TWeight + FWeight); @@ -3354,9 +3347,9 @@ static bool extractPredSuccWeights(BranchInst *PBI, BranchInst *BI, uint64_t &SuccTrueWeight, uint64_t &SuccFalseWeight) { bool PredHasWeights = - PBI->extractProfMetadata(PredTrueWeight, PredFalseWeight); + extractBranchWeights(*PBI, PredTrueWeight, PredFalseWeight); bool SuccHasWeights = - BI->extractProfMetadata(SuccTrueWeight, SuccFalseWeight); + extractBranchWeights(*BI, SuccTrueWeight, SuccFalseWeight); if (PredHasWeights || SuccHasWeights) { if (!PredHasWeights) PredTrueWeight = PredFalseWeight = 1; @@ -3384,7 +3377,7 @@ shouldFoldCondBranchesToCommonDestination(BranchInst *BI, BranchInst *PBI, uint64_t PTWeight, PFWeight; BranchProbability PBITrueProb, Likely; if (TTI && !PBI->getMetadata(LLVMContext::MD_unpredictable) && - PBI->extractProfMetadata(PTWeight, PFWeight) && + extractBranchWeights(*PBI, PTWeight, PFWeight) && (PTWeight + PFWeight) != 0) { PBITrueProb = BranchProbability::getBranchProbability(PTWeight, PTWeight + PFWeight); @@ -4349,7 +4342,7 @@ bool SimplifyCFGOpt::SimplifySwitchOnSelect(SwitchInst *SI, // Get weight for TrueBB and FalseBB. uint32_t TrueWeight = 0, FalseWeight = 0; SmallVector Weights; - bool HasWeights = HasBranchWeights(SI); + bool HasWeights = hasBranchWeightMD(*SI); if (HasWeights) { GetBranchWeights(SI, Weights); if (Weights.size() == 1 + SI->getNumCases()) { @@ -5209,7 +5202,7 @@ bool SimplifyCFGOpt::TurnSwitchRangeIntoICmp(SwitchInst *SI, BranchInst *NewBI = Builder.CreateCondBr(Cmp, ContiguousDest, OtherDest); // Update weight for the newly-created conditional branch. - if (HasBranchWeights(SI)) { + if (hasBranchWeightMD(*SI)) { SmallVector Weights; GetBranchWeights(SI, Weights); if (Weights.size() == 1 + SI->getNumCases()) {