-
Notifications
You must be signed in to change notification settings - Fork 10.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add setBranchWeigths convenience function. NFC #72446
Conversation
✅ With the latest revision this PR passed the C/C++ code formatter. |
Add `setBranchWeights` convenience function to ProfDataUtils.h and use it where appropriate.
a32bc7c
to
5113b4f
Compare
Adding a global |
@llvm/pr-subscribers-llvm-transforms @llvm/pr-subscribers-pgo Author: Matthias Braun (MatzeB) ChangesAdd Full diff: https://github.com/llvm/llvm-project/pull/72446.diff 11 Files Affected:
diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h
index b61199372de0de8..255fa2ff1c79065 100644
--- a/llvm/include/llvm/IR/ProfDataUtils.h
+++ b/llvm/include/llvm/IR/ProfDataUtils.h
@@ -104,5 +104,9 @@ bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalWeights);
/// metadata was found.
bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalWeights);
+/// Create a new `branch_weights` metadata node and add or overwrite
+/// a `prof` metadata reference to instruction `I`.
+void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights);
+
} // namespace llvm
#endif
diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp
index 77b3c1cb95d686c..29536b0b090cd76 100644
--- a/llvm/lib/IR/ProfDataUtils.cpp
+++ b/llvm/lib/IR/ProfDataUtils.cpp
@@ -17,6 +17,7 @@
#include "llvm/IR/Function.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Metadata.h"
#include "llvm/Support/BranchProbability.h"
#include "llvm/Support/CommandLine.h"
@@ -183,4 +184,10 @@ bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalVal) {
return extractProfTotalWeight(I.getMetadata(LLVMContext::MD_prof), TotalVal);
}
+void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights) {
+ MDBuilder MDB(I.getContext());
+ MDNode *BranchWeights = MDB.createBranchWeights(Weights);
+ I.setMetadata(LLVMContext::MD_prof, BranchWeights);
+}
+
} // namespace llvm
diff --git a/llvm/lib/Transforms/IPO/SampleProfile.cpp b/llvm/lib/Transforms/IPO/SampleProfile.cpp
index 063f7b42022ff83..6c6f0a0eca72a7a 100644
--- a/llvm/lib/Transforms/IPO/SampleProfile.cpp
+++ b/llvm/lib/Transforms/IPO/SampleProfile.cpp
@@ -56,6 +56,7 @@
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PassManager.h"
+#include "llvm/IR/ProfDataUtils.h"
#include "llvm/IR/PseudoProbe.h"
#include "llvm/IR/ValueSymbolTable.h"
#include "llvm/ProfileData/InstrProf.h"
@@ -1710,9 +1711,7 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
else if (OverwriteExistingWeights)
I.setMetadata(LLVMContext::MD_prof, nullptr);
} else if (!isa<IntrinsicInst>(&I)) {
- I.setMetadata(LLVMContext::MD_prof,
- MDB.createBranchWeights(
- {static_cast<uint32_t>(BlockWeights[BB])}));
+ setBranchWeights(I, {static_cast<uint32_t>(BlockWeights[BB])});
}
}
} else if (OverwriteExistingWeights || ProfileSampleBlockAccurate) {
@@ -1720,10 +1719,11 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
// clear it for cold code.
for (auto &I : *BB) {
if (isa<CallInst>(I) || isa<InvokeInst>(I)) {
- if (cast<CallBase>(I).isIndirectCall())
+ if (cast<CallBase>(I).isIndirectCall()) {
I.setMetadata(LLVMContext::MD_prof, nullptr);
- else
- I.setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(0));
+ } else {
+ setBranchWeights(I, {uint32_t(0)});
+ }
}
}
}
@@ -1803,7 +1803,7 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
if (MaxWeight > 0 &&
(!TI->extractProfTotalWeight(TempWeight) || OverwriteExistingWeights)) {
LLVM_DEBUG(dbgs() << "SUCCESS. Found non-zero weights.\n");
- TI->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights));
+ setBranchWeights(*TI, Weights);
ORE->emit([&]() {
return OptimizationRemark(DEBUG_TYPE, "PopularDest", MaxDestInst)
<< "most popular destination for conditional branches at "
diff --git a/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp b/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp
index 97cf583510f9395..0a3d8d6000cf47d 100644
--- a/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp
+++ b/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp
@@ -1878,8 +1878,7 @@ void CHR::fixupBranchesAndSelects(CHRScope *Scope,
static_cast<uint32_t>(CHRBranchBias.scale(1000)),
static_cast<uint32_t>(CHRBranchBias.getCompl().scale(1000)),
};
- MDBuilder MDB(F.getContext());
- MergedBR->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights));
+ setBranchWeights(*MergedBR, Weights);
CHR_DEBUG(dbgs() << "CHR branch bias " << Weights[0] << ":" << Weights[1]
<< "\n");
}
diff --git a/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp b/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp
index 5c9799235017a8a..7344fea17517191 100644
--- a/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp
+++ b/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp
@@ -26,6 +26,7 @@
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/PassManager.h"
+#include "llvm/IR/ProfDataUtils.h"
#include "llvm/IR/Value.h"
#include "llvm/ProfileData/InstrProf.h"
#include "llvm/Support/Casting.h"
@@ -256,10 +257,7 @@ CallBase &llvm::pgo::promoteIndirectCall(CallBase &CB, Function *DirectCallee,
promoteCallWithIfThenElse(CB, DirectCallee, BranchWeights);
if (AttachProfToDirectCall) {
- MDBuilder MDB(NewInst.getContext());
- NewInst.setMetadata(
- LLVMContext::MD_prof,
- MDB.createBranchWeights({static_cast<uint32_t>(Count)}));
+ setBranchWeights(NewInst, {static_cast<uint32_t>(Count)});
}
using namespace ore;
diff --git a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
index aea0f2f7cae7786..4a5a0b25bebbaf1 100644
--- a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
+++ b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
@@ -1437,12 +1437,11 @@ void PGOUseFunc::populateCoverage(IndexedInstrProfReader *PGOReader) {
// If A is uncovered, set weight=1.
// This setup will allow BFI to give nonzero profile counts to only covered
// blocks.
- SmallVector<unsigned, 4> Weights;
+ SmallVector<uint32_t, 4> Weights;
for (auto *Succ : successors(&BB))
Weights.push_back((Coverage[Succ] || !Coverage[&BB]) ? 1 : 0);
if (Weights.size() >= 2)
- BB.getTerminator()->setMetadata(LLVMContext::MD_prof,
- MDB.createBranchWeights(Weights));
+ llvm::setBranchWeights(*BB.getTerminator(), Weights);
}
unsigned NumCorruptCoverage = 0;
@@ -2205,7 +2204,6 @@ static std::string getSimpleNodeName(const BasicBlock *Node) {
void llvm::setProfMetadata(Module *M, Instruction *TI,
ArrayRef<uint64_t> EdgeCounts, uint64_t MaxCount) {
- MDBuilder MDB(M->getContext());
assert(MaxCount > 0 && "Bad max count");
uint64_t Scale = calculateCountScale(MaxCount);
SmallVector<unsigned, 4> Weights;
@@ -2219,7 +2217,7 @@ void llvm::setProfMetadata(Module *M, Instruction *TI,
misexpect::checkExpectAnnotations(*TI, Weights, /*IsFrontend=*/false);
- TI->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights));
+ setBranchWeights(*TI, Weights);
if (EmitBranchProbability) {
std::string BrCondStr = getBranchCondString(TI);
if (BrCondStr.empty())
diff --git a/llvm/lib/Transforms/Scalar/JumpThreading.cpp b/llvm/lib/Transforms/Scalar/JumpThreading.cpp
index 7a8128c5b6c0901..2d899f100f8154d 100644
--- a/llvm/lib/Transforms/Scalar/JumpThreading.cpp
+++ b/llvm/lib/Transforms/Scalar/JumpThreading.cpp
@@ -228,17 +228,15 @@ static void updatePredecessorProfileMetadata(PHINode *PN, BasicBlock *BB) {
if (BP >= BranchProbability(50, 100))
continue;
- SmallVector<uint32_t, 2> Weights;
+ uint32_t Weights[2];
if (PredBr->getSuccessor(0) == PredOutEdge.second) {
- Weights.push_back(BP.getNumerator());
- Weights.push_back(BP.getCompl().getNumerator());
+ Weights[0] = BP.getNumerator();
+ Weights[1] = BP.getCompl().getNumerator();
} else {
- Weights.push_back(BP.getCompl().getNumerator());
- Weights.push_back(BP.getNumerator());
+ Weights[0] = BP.getCompl().getNumerator();
+ Weights[1] = BP.getNumerator();
}
- PredBr->setMetadata(LLVMContext::MD_prof,
- MDBuilder(PredBr->getParent()->getContext())
- .createBranchWeights(Weights));
+ setBranchWeights(*PredBr, Weights);
}
}
@@ -2574,9 +2572,7 @@ void JumpThreadingPass::updateBlockFreqAndEdgeWeight(BasicBlock *PredBB,
Weights.push_back(Prob.getNumerator());
auto TI = BB->getTerminator();
- TI->setMetadata(
- LLVMContext::MD_prof,
- MDBuilder(TI->getParent()->getContext()).createBranchWeights(Weights));
+ setBranchWeights(*TI, Weights);
}
}
diff --git a/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp b/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp
index ac87ee736c0d169..6f87e4d91d2c794 100644
--- a/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp
@@ -20,6 +20,7 @@
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/MDBuilder.h"
+#include "llvm/IR/ProfDataUtils.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Transforms/Utils/MisExpect.h"
@@ -101,10 +102,7 @@ static bool handleSwitchExpect(SwitchInst &SI) {
misexpect::checkExpectAnnotations(SI, Weights, /*IsFrontend=*/true);
SI.setCondition(ArgValue);
-
- SI.setMetadata(LLVMContext::MD_prof,
- MDBuilder(CI->getContext()).createBranchWeights(Weights));
-
+ setBranchWeights(SI, Weights);
return true;
}
diff --git a/llvm/lib/Transforms/Utils/Local.cpp b/llvm/lib/Transforms/Utils/Local.cpp
index aacf66bfe38eb91..c8c8d02cdcbb2b3 100644
--- a/llvm/lib/Transforms/Utils/Local.cpp
+++ b/llvm/lib/Transforms/Utils/Local.cpp
@@ -227,9 +227,7 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions,
// Remove weight for this case.
std::swap(Weights[Idx + 1], Weights.back());
Weights.pop_back();
- SI->setMetadata(LLVMContext::MD_prof,
- MDBuilder(BB->getContext()).
- createBranchWeights(Weights));
+ setBranchWeights(*SI, Weights);
}
// Remove this entry.
BasicBlock *ParentBB = SI->getParent();
diff --git a/llvm/lib/Transforms/Utils/LoopPeel.cpp b/llvm/lib/Transforms/Utils/LoopPeel.cpp
index 2881444206b0bb5..7566f70661baf48 100644
--- a/llvm/lib/Transforms/Utils/LoopPeel.cpp
+++ b/llvm/lib/Transforms/Utils/LoopPeel.cpp
@@ -631,9 +631,7 @@ struct WeightInfo {
/// To avoid dealing with division rounding we can just multiple both part
/// of weights to E and use weight as (F - I * E, E).
static void updateBranchWeights(Instruction *Term, WeightInfo &Info) {
- MDBuilder MDB(Term->getContext());
- Term->setMetadata(LLVMContext::MD_prof,
- MDB.createBranchWeights(Info.Weights));
+ setBranchWeights(*Term, Info.Weights);
for (auto [Idx, SubWeight] : enumerate(Info.SubWeights))
if (SubWeight != 0)
// Don't set the probability of taking the edge from latch to loop header
@@ -690,14 +688,6 @@ static void initBranchWeights(DenseMap<Instruction *, WeightInfo> &WeightInfos,
}
}
-/// Update the weights of original exiting block after peeling off all
-/// iterations.
-static void fixupBranchWeights(Instruction *Term, const WeightInfo &Info) {
- MDBuilder MDB(Term->getContext());
- Term->setMetadata(LLVMContext::MD_prof,
- MDB.createBranchWeights(Info.Weights));
-}
-
/// Clones the body of the loop L, putting it between \p InsertTop and \p
/// InsertBot.
/// \param IterNumber The serial number of the iteration currently being
@@ -1033,8 +1023,9 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI,
PHI->setIncomingValueForBlock(NewPreHeader, NewVal);
}
- for (const auto &[Term, Info] : Weights)
- fixupBranchWeights(Term, Info);
+ for (const auto &[Term, Info] : Weights) {
+ setBranchWeights(*Term, Info.Weights);
+ }
// Update Metadata for count of peeled off iterations.
unsigned AlreadyPeeled = 0;
diff --git a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
index 012aa5dbb9ca004..ae155ac082d8111 100644
--- a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
@@ -352,16 +352,17 @@ static void updateBranchWeights(BranchInst &PreHeaderBI, BranchInst &LoopBI,
LoopBackWeight = 0;
}
- MDBuilder MDB(LoopBI.getContext());
- MDNode *LoopWeightMD =
- MDB.createBranchWeights(SuccsSwapped ? LoopBackWeight : ExitWeight1,
- SuccsSwapped ? ExitWeight1 : LoopBackWeight);
- LoopBI.setMetadata(LLVMContext::MD_prof, LoopWeightMD);
+ const uint32_t LoopBIWeights[] = {
+ SuccsSwapped ? LoopBackWeight : ExitWeight1,
+ SuccsSwapped ? ExitWeight1 : LoopBackWeight,
+ };
+ setBranchWeights(LoopBI, LoopBIWeights);
if (HasConditionalPreHeader) {
- MDNode *PreHeaderWeightMD =
- MDB.createBranchWeights(SuccsSwapped ? EnterWeight : ExitWeight0,
- SuccsSwapped ? ExitWeight0 : EnterWeight);
- PreHeaderBI.setMetadata(LLVMContext::MD_prof, PreHeaderWeightMD);
+ const uint32_t PreHeaderBIWeights[] = {
+ SuccsSwapped ? EnterWeight : ExitWeight0,
+ SuccsSwapped ? ExitWeight0 : EnterWeight,
+ };
+ setBranchWeights(PreHeaderBI, PreHeaderBIWeights);
}
}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: should the commit title say [nfc]?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. I was surprised that Instruction
didn't implement this already, but I think this is a good change. Thanks for the nice improvement.
Add `setBranchWeights` convenience function to ProfDataUtils.h and use it where appropriate.
Add `setBranchWeights` convenience function to ProfDataUtils.h and use it where appropriate.
Add
setBranchWeights
convenience function to ProfDataUtils.h and useit where appropriate.