29 changes: 18 additions & 11 deletions llvm/lib/Transforms/Utils/SimplifyCFG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
#include "llvm/IR/NoFolder.h"
#include "llvm/IR/Operator.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/IR/ProfDataUtils.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Use.h"
#include "llvm/IR/User.h"
Expand Down Expand Up @@ -1051,6 +1050,15 @@ static int ConstantIntSortPredicate(ConstantInt *const *P1,
return LHS->getValue().ult(RHS->getValue()) ? 1 : -1;
}

static inline bool HasBranchWeights(const Instruction *I) {
MDNode *ProfMD = I->getMetadata(LLVMContext::MD_prof);
if (ProfMD && ProfMD->getOperand(0))
if (MDString *MDS = dyn_cast<MDString>(ProfMD->getOperand(0)))
return MDS->getString().equals("branch_weights");

return false;
}

/// Get Weights of a given terminator, the default weight is at the front
/// of the vector. If TI is a conditional eq, we need to swap the branch-weight
/// metadata.
Expand Down Expand Up @@ -1169,8 +1177,8 @@ bool SimplifyCFGOpt::PerformValueComparisonIntoPredecessorFolding(

// Update the branch weight metadata along the way
SmallVector<uint64_t, 8> Weights;
bool PredHasWeights = hasBranchWeightMD(*PTI);
bool SuccHasWeights = hasBranchWeightMD(*TI);
bool PredHasWeights = HasBranchWeights(PTI);
bool SuccHasWeights = HasBranchWeights(TI);

if (PredHasWeights) {
GetBranchWeights(PTI, Weights);
Expand Down Expand Up @@ -2744,8 +2752,7 @@ bool SimplifyCFGOpt::SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB,
// the `then` block, then avoid speculating it.
if (!BI->getMetadata(LLVMContext::MD_unpredictable)) {
uint64_t TWeight, FWeight;
if (extractBranchWeights(*BI, TWeight, FWeight) &&
(TWeight + FWeight) != 0) {
if (BI->extractProfMetadata(TWeight, FWeight) && (TWeight + FWeight) != 0) {
uint64_t EndWeight = Invert ? TWeight : FWeight;
BranchProbability BIEndProb =
BranchProbability::getBranchProbability(EndWeight, TWeight + FWeight);
Expand Down Expand Up @@ -3167,7 +3174,7 @@ static bool FoldTwoEntryPHINode(PHINode *PN, const TargetTransformInfo &TTI,
// from the block that we know is predictably not entered.
if (!DomBI->getMetadata(LLVMContext::MD_unpredictable)) {
uint64_t TWeight, FWeight;
if (extractBranchWeights(*DomBI, TWeight, FWeight) &&
if (DomBI->extractProfMetadata(TWeight, FWeight) &&
(TWeight + FWeight) != 0) {
BranchProbability BITrueProb =
BranchProbability::getBranchProbability(TWeight, TWeight + FWeight);
Expand Down Expand Up @@ -3347,9 +3354,9 @@ static bool extractPredSuccWeights(BranchInst *PBI, BranchInst *BI,
uint64_t &SuccTrueWeight,
uint64_t &SuccFalseWeight) {
bool PredHasWeights =
extractBranchWeights(*PBI, PredTrueWeight, PredFalseWeight);
PBI->extractProfMetadata(PredTrueWeight, PredFalseWeight);
bool SuccHasWeights =
extractBranchWeights(*BI, SuccTrueWeight, SuccFalseWeight);
BI->extractProfMetadata(SuccTrueWeight, SuccFalseWeight);
if (PredHasWeights || SuccHasWeights) {
if (!PredHasWeights)
PredTrueWeight = PredFalseWeight = 1;
Expand Down Expand Up @@ -3377,7 +3384,7 @@ shouldFoldCondBranchesToCommonDestination(BranchInst *BI, BranchInst *PBI,
uint64_t PTWeight, PFWeight;
BranchProbability PBITrueProb, Likely;
if (TTI && !PBI->getMetadata(LLVMContext::MD_unpredictable) &&
extractBranchWeights(*PBI, PTWeight, PFWeight) &&
PBI->extractProfMetadata(PTWeight, PFWeight) &&
(PTWeight + PFWeight) != 0) {
PBITrueProb =
BranchProbability::getBranchProbability(PTWeight, PTWeight + PFWeight);
Expand Down Expand Up @@ -4342,7 +4349,7 @@ bool SimplifyCFGOpt::SimplifySwitchOnSelect(SwitchInst *SI,
// Get weight for TrueBB and FalseBB.
uint32_t TrueWeight = 0, FalseWeight = 0;
SmallVector<uint64_t, 8> Weights;
bool HasWeights = hasBranchWeightMD(*SI);
bool HasWeights = HasBranchWeights(SI);
if (HasWeights) {
GetBranchWeights(SI, Weights);
if (Weights.size() == 1 + SI->getNumCases()) {
Expand Down Expand Up @@ -5202,7 +5209,7 @@ bool SimplifyCFGOpt::TurnSwitchRangeIntoICmp(SwitchInst *SI,
BranchInst *NewBI = Builder.CreateCondBr(Cmp, ContiguousDest, OtherDest);

// Update weight for the newly-created conditional branch.
if (hasBranchWeightMD(*SI)) {
if (HasBranchWeights(SI)) {
SmallVector<uint64_t, 8> Weights;
GetBranchWeights(SI, Weights);
if (Weights.size() == 1 + SI->getNumCases()) {
Expand Down