Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions llvm/include/llvm/IR/ProfDataUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,5 +104,9 @@ bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalWeights);
/// metadata was found.
bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalWeights);

/// Create a new `branch_weights` metadata node and add or overwrite
/// a `prof` metadata reference to instruction `I`.
void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights);

} // namespace llvm
#endif
7 changes: 7 additions & 0 deletions llvm/lib/IR/ProfDataUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "llvm/IR/Function.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Metadata.h"
#include "llvm/Support/BranchProbability.h"
#include "llvm/Support/CommandLine.h"
Expand Down Expand Up @@ -183,4 +184,10 @@ bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalVal) {
return extractProfTotalWeight(I.getMetadata(LLVMContext::MD_prof), TotalVal);
}

void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights) {
MDBuilder MDB(I.getContext());
MDNode *BranchWeights = MDB.createBranchWeights(Weights);
I.setMetadata(LLVMContext::MD_prof, BranchWeights);
}

} // namespace llvm
14 changes: 7 additions & 7 deletions llvm/lib/Transforms/IPO/SampleProfile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/ProfDataUtils.h"
#include "llvm/IR/PseudoProbe.h"
#include "llvm/IR/ValueSymbolTable.h"
#include "llvm/ProfileData/InstrProf.h"
Expand Down Expand Up @@ -1710,20 +1711,19 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
else if (OverwriteExistingWeights)
I.setMetadata(LLVMContext::MD_prof, nullptr);
} else if (!isa<IntrinsicInst>(&I)) {
I.setMetadata(LLVMContext::MD_prof,
MDB.createBranchWeights(
{static_cast<uint32_t>(BlockWeights[BB])}));
setBranchWeights(I, {static_cast<uint32_t>(BlockWeights[BB])});
}
}
} else if (OverwriteExistingWeights || ProfileSampleBlockAccurate) {
// Set profile metadata (possibly annotated by LTO prelink) to zero or
// clear it for cold code.
for (auto &I : *BB) {
if (isa<CallInst>(I) || isa<InvokeInst>(I)) {
if (cast<CallBase>(I).isIndirectCall())
if (cast<CallBase>(I).isIndirectCall()) {
I.setMetadata(LLVMContext::MD_prof, nullptr);
else
I.setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(0));
} else {
setBranchWeights(I, {uint32_t(0)});
}
}
}
}
Expand Down Expand Up @@ -1803,7 +1803,7 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
if (MaxWeight > 0 &&
(!TI->extractProfTotalWeight(TempWeight) || OverwriteExistingWeights)) {
LLVM_DEBUG(dbgs() << "SUCCESS. Found non-zero weights.\n");
TI->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights));
setBranchWeights(*TI, Weights);
ORE->emit([&]() {
return OptimizationRemark(DEBUG_TYPE, "PopularDest", MaxDestInst)
<< "most popular destination for conditional branches at "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1878,8 +1878,7 @@ void CHR::fixupBranchesAndSelects(CHRScope *Scope,
static_cast<uint32_t>(CHRBranchBias.scale(1000)),
static_cast<uint32_t>(CHRBranchBias.getCompl().scale(1000)),
};
MDBuilder MDB(F.getContext());
MergedBR->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights));
setBranchWeights(*MergedBR, Weights);
CHR_DEBUG(dbgs() << "CHR branch bias " << Weights[0] << ":" << Weights[1]
<< "\n");
}
Expand Down
6 changes: 2 additions & 4 deletions llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/ProfDataUtils.h"
#include "llvm/IR/Value.h"
#include "llvm/ProfileData/InstrProf.h"
#include "llvm/Support/Casting.h"
Expand Down Expand Up @@ -256,10 +257,7 @@ CallBase &llvm::pgo::promoteIndirectCall(CallBase &CB, Function *DirectCallee,
promoteCallWithIfThenElse(CB, DirectCallee, BranchWeights);

if (AttachProfToDirectCall) {
MDBuilder MDB(NewInst.getContext());
NewInst.setMetadata(
LLVMContext::MD_prof,
MDB.createBranchWeights({static_cast<uint32_t>(Count)}));
setBranchWeights(NewInst, {static_cast<uint32_t>(Count)});
}

using namespace ore;
Expand Down
8 changes: 3 additions & 5 deletions llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1437,12 +1437,11 @@ void PGOUseFunc::populateCoverage(IndexedInstrProfReader *PGOReader) {
// If A is uncovered, set weight=1.
// This setup will allow BFI to give nonzero profile counts to only covered
// blocks.
SmallVector<unsigned, 4> Weights;
SmallVector<uint32_t, 4> Weights;
for (auto *Succ : successors(&BB))
Weights.push_back((Coverage[Succ] || !Coverage[&BB]) ? 1 : 0);
if (Weights.size() >= 2)
BB.getTerminator()->setMetadata(LLVMContext::MD_prof,
MDB.createBranchWeights(Weights));
llvm::setBranchWeights(*BB.getTerminator(), Weights);
}

unsigned NumCorruptCoverage = 0;
Expand Down Expand Up @@ -2205,7 +2204,6 @@ static std::string getSimpleNodeName(const BasicBlock *Node) {

void llvm::setProfMetadata(Module *M, Instruction *TI,
ArrayRef<uint64_t> EdgeCounts, uint64_t MaxCount) {
MDBuilder MDB(M->getContext());
assert(MaxCount > 0 && "Bad max count");
uint64_t Scale = calculateCountScale(MaxCount);
SmallVector<unsigned, 4> Weights;
Expand All @@ -2219,7 +2217,7 @@ void llvm::setProfMetadata(Module *M, Instruction *TI,

misexpect::checkExpectAnnotations(*TI, Weights, /*IsFrontend=*/false);

TI->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights));
setBranchWeights(*TI, Weights);
if (EmitBranchProbability) {
std::string BrCondStr = getBranchCondString(TI);
if (BrCondStr.empty())
Expand Down
18 changes: 7 additions & 11 deletions llvm/lib/Transforms/Scalar/JumpThreading.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,17 +228,15 @@ static void updatePredecessorProfileMetadata(PHINode *PN, BasicBlock *BB) {
if (BP >= BranchProbability(50, 100))
continue;

SmallVector<uint32_t, 2> Weights;
uint32_t Weights[2];
if (PredBr->getSuccessor(0) == PredOutEdge.second) {
Weights.push_back(BP.getNumerator());
Weights.push_back(BP.getCompl().getNumerator());
Weights[0] = BP.getNumerator();
Weights[1] = BP.getCompl().getNumerator();
} else {
Weights.push_back(BP.getCompl().getNumerator());
Weights.push_back(BP.getNumerator());
Weights[0] = BP.getCompl().getNumerator();
Weights[1] = BP.getNumerator();
}
PredBr->setMetadata(LLVMContext::MD_prof,
MDBuilder(PredBr->getParent()->getContext())
.createBranchWeights(Weights));
setBranchWeights(*PredBr, Weights);
}
}

Expand Down Expand Up @@ -2574,9 +2572,7 @@ void JumpThreadingPass::updateBlockFreqAndEdgeWeight(BasicBlock *PredBB,
Weights.push_back(Prob.getNumerator());

auto TI = BB->getTerminator();
TI->setMetadata(
LLVMContext::MD_prof,
MDBuilder(TI->getParent()->getContext()).createBranchWeights(Weights));
setBranchWeights(*TI, Weights);
}
}

Expand Down
6 changes: 2 additions & 4 deletions llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/ProfDataUtils.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Transforms/Utils/MisExpect.h"

Expand Down Expand Up @@ -101,10 +102,7 @@ static bool handleSwitchExpect(SwitchInst &SI) {
misexpect::checkExpectAnnotations(SI, Weights, /*IsFrontend=*/true);

SI.setCondition(ArgValue);

SI.setMetadata(LLVMContext::MD_prof,
MDBuilder(CI->getContext()).createBranchWeights(Weights));

setBranchWeights(SI, Weights);
return true;
}

Expand Down
4 changes: 1 addition & 3 deletions llvm/lib/Transforms/Utils/Local.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,9 +227,7 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions,
// Remove weight for this case.
std::swap(Weights[Idx + 1], Weights.back());
Weights.pop_back();
SI->setMetadata(LLVMContext::MD_prof,
MDBuilder(BB->getContext()).
createBranchWeights(Weights));
setBranchWeights(*SI, Weights);
}
// Remove this entry.
BasicBlock *ParentBB = SI->getParent();
Expand Down
17 changes: 4 additions & 13 deletions llvm/lib/Transforms/Utils/LoopPeel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -631,9 +631,7 @@ struct WeightInfo {
/// To avoid dealing with division rounding we can just multiple both part
/// of weights to E and use weight as (F - I * E, E).
static void updateBranchWeights(Instruction *Term, WeightInfo &Info) {
MDBuilder MDB(Term->getContext());
Term->setMetadata(LLVMContext::MD_prof,
MDB.createBranchWeights(Info.Weights));
setBranchWeights(*Term, Info.Weights);
for (auto [Idx, SubWeight] : enumerate(Info.SubWeights))
if (SubWeight != 0)
// Don't set the probability of taking the edge from latch to loop header
Expand Down Expand Up @@ -690,14 +688,6 @@ static void initBranchWeights(DenseMap<Instruction *, WeightInfo> &WeightInfos,
}
}

/// Update the weights of original exiting block after peeling off all
/// iterations.
static void fixupBranchWeights(Instruction *Term, const WeightInfo &Info) {
MDBuilder MDB(Term->getContext());
Term->setMetadata(LLVMContext::MD_prof,
MDB.createBranchWeights(Info.Weights));
}

/// Clones the body of the loop L, putting it between \p InsertTop and \p
/// InsertBot.
/// \param IterNumber The serial number of the iteration currently being
Expand Down Expand Up @@ -1033,8 +1023,9 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI,
PHI->setIncomingValueForBlock(NewPreHeader, NewVal);
}

for (const auto &[Term, Info] : Weights)
fixupBranchWeights(Term, Info);
for (const auto &[Term, Info] : Weights) {
setBranchWeights(*Term, Info.Weights);
}

// Update Metadata for count of peeled off iterations.
unsigned AlreadyPeeled = 0;
Expand Down
19 changes: 10 additions & 9 deletions llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -352,16 +352,17 @@ static void updateBranchWeights(BranchInst &PreHeaderBI, BranchInst &LoopBI,
LoopBackWeight = 0;
}

MDBuilder MDB(LoopBI.getContext());
MDNode *LoopWeightMD =
MDB.createBranchWeights(SuccsSwapped ? LoopBackWeight : ExitWeight1,
SuccsSwapped ? ExitWeight1 : LoopBackWeight);
LoopBI.setMetadata(LLVMContext::MD_prof, LoopWeightMD);
const uint32_t LoopBIWeights[] = {
SuccsSwapped ? LoopBackWeight : ExitWeight1,
SuccsSwapped ? ExitWeight1 : LoopBackWeight,
};
setBranchWeights(LoopBI, LoopBIWeights);
if (HasConditionalPreHeader) {
MDNode *PreHeaderWeightMD =
MDB.createBranchWeights(SuccsSwapped ? EnterWeight : ExitWeight0,
SuccsSwapped ? ExitWeight0 : EnterWeight);
PreHeaderBI.setMetadata(LLVMContext::MD_prof, PreHeaderWeightMD);
const uint32_t PreHeaderBIWeights[] = {
SuccsSwapped ? EnterWeight : ExitWeight0,
SuccsSwapped ? ExitWeight0 : EnterWeight,
};
setBranchWeights(PreHeaderBI, PreHeaderBIWeights);
}
}

Expand Down
Loading