Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[llvm][IR] Extend BranchWeightMetadata to track provenance of weights #86609

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -221,5 +221,5 @@ void tu2(int &i) {
}
}

// CHECK: [[BW_LIKELY]] = !{!"branch_weights", i32 2000, i32 1}
// CHECK: [[BW_UNLIKELY]] = !{!"branch_weights", i32 1, i32 2000}
// CHECK: [[BW_LIKELY]] = !{!"branch_weights", !"expected", i32 2000, i32 1}
// CHECK: [[BW_UNLIKELY]] = !{!"branch_weights", !"expected", i32 1, i32 2000}
7 changes: 7 additions & 0 deletions llvm/docs/BranchWeightMetadata.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,14 @@ Supported Instructions

Metadata is only assigned to the conditional branches. There are two extra
operands for the true and the false branch.
We optionally track if the metadata was added by ``__builtin_expect`` or
``__builtin_expect_with_probability`` with an optional field ``!"expected"``.

.. code-block:: none

!0 = !{
!"branch_weights",
[ !"expected", ]
i32 <TRUE_BRANCH_WEIGHT>,
i32 <FALSE_BRANCH_WEIGHT>
}
Expand All @@ -47,6 +50,7 @@ is always case #0).

!0 = !{
!"branch_weights",
[ !"expected", ]
i32 <DEFAULT_BRANCH_WEIGHT>
[ , i32 <CASE_BRANCH_WEIGHT> ... ]
}
Expand All @@ -60,6 +64,7 @@ Branch weights are assigned to every destination.

!0 = !{
!"branch_weights",
[ !"expected", ]
i32 <LABEL_BRANCH_WEIGHT>
[ , i32 <LABEL_BRANCH_WEIGHT> ... ]
}
Expand All @@ -75,6 +80,7 @@ block and entry counts which may not be accurate with sampling.

!0 = !{
!"branch_weights",
[ !"expected", ]
i32 <CALL_BRANCH_WEIGHT>
}

Expand All @@ -95,6 +101,7 @@ is used.

!0 = !{
!"branch_weights",
[ !"expected", ]
i32 <INVOKE_NORMAL_WEIGHT>
[ , i32 <INVOKE_UNWIND_WEIGHT> ]
}
Expand Down
11 changes: 9 additions & 2 deletions llvm/include/llvm/IR/MDBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,11 @@ class MDBuilder {
//===------------------------------------------------------------------===//

/// Return metadata containing two branch weights.
MDNode *createBranchWeights(uint32_t TrueWeight, uint32_t FalseWeight);
/// @param TrueWeight the weight of the true branch
/// @param FalseWeight the weight of the false branch
/// @param Do these weights come from __builtin_expect*
MDNode *createBranchWeights(uint32_t TrueWeight, uint32_t FalseWeight,
bool IsExpected = false);

/// Return metadata containing two branch weights, with significant bias
/// towards `true` destination.
Expand All @@ -70,7 +74,10 @@ class MDBuilder {
MDNode *createUnlikelyBranchWeights();

/// Return metadata containing a number of branch weights.
MDNode *createBranchWeights(ArrayRef<uint32_t> Weights);
/// @param Weights the weights of all the branches
/// @param Do these weights come from __builtin_expect*
MDNode *createBranchWeights(ArrayRef<uint32_t> Weights,
bool IsExpected = false);

/// Return metadata specifying that a branch or switch is unpredictable.
MDNode *createUnpredictable();
Expand Down
17 changes: 16 additions & 1 deletion llvm/include/llvm/IR/ProfDataUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,17 @@ MDNode *getBranchWeightMDNode(const Instruction &I);
/// Nullptr otherwise.
MDNode *getValidBranchWeightMDNode(const Instruction &I);

/// Check if Branch Weight Metadata has an "expected" field from an llvm.expect*
/// intrinsic
bool hasBranchWeightOrigin(const Instruction &I);

/// Check if Branch Weight Metadata has an "expected" field from an llvm.expect*
/// intrinsic
bool hasBranchWeightOrigin(const MDNode *ProfileData);

/// Return the offset to the first branch weight data
unsigned getBranchWeightOffset(const MDNode *ProfileData);

/// Extract branch weights from MD_prof metadata
///
/// \param ProfileData A pointer to an MDNode.
Expand Down Expand Up @@ -111,7 +122,11 @@ bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalWeights);

/// Create a new `branch_weights` metadata node and add or overwrite
/// a `prof` metadata reference to instruction `I`.
void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights);
/// \param I the Instruction to set branch weights on.
/// \param Weights an array of weights to set on instruction I.
/// \param IsExpected were these weights added from an llvm.expect* intrinsic.
void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights,
bool IsExpected);

/// Scaling the profile data attached to 'I' using the ratio of S/T.
void scaleProfData(Instruction &I, uint64_t S, uint64_t T);
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/CodeGen/CodeGenPrepare.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8866,7 +8866,8 @@ bool CodeGenPrepare::splitBranchCondition(Function &F, ModifyDT &ModifiedDT) {
scaleWeights(NewTrueWeight, NewFalseWeight);
Br1->setMetadata(LLVMContext::MD_prof,
MDBuilder(Br1->getContext())
.createBranchWeights(TrueWeight, FalseWeight));
.createBranchWeights(TrueWeight, FalseWeight,
hasBranchWeightOrigin(*Br1)));

NewTrueWeight = TrueWeight;
NewFalseWeight = 2 * FalseWeight;
Expand Down
19 changes: 15 additions & 4 deletions llvm/lib/IR/Instruction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1268,12 +1268,23 @@ Instruction *Instruction::cloneImpl() const {

void Instruction::swapProfMetadata() {
MDNode *ProfileData = getBranchWeightMDNode(*this);
if (!ProfileData || ProfileData->getNumOperands() != 3)
if (!ProfileData)
return;
unsigned FirstIdx = getBranchWeightOffset(ProfileData);
if (ProfileData->getNumOperands() != 2 + FirstIdx)
return;

// The first operand is the name. Fetch them backwards and build a new one.
Metadata *Ops[] = {ProfileData->getOperand(0), ProfileData->getOperand(2),
ProfileData->getOperand(1)};
unsigned SecondIdx = FirstIdx + 1;
SmallVector<Metadata *, 4> Ops;
// If there are more weights past the second, we can't swap them
if (ProfileData->getNumOperands() > SecondIdx + 1)
return;
for (unsigned Idx = 0; Idx < FirstIdx; ++Idx) {
Ops.push_back(ProfileData->getOperand(Idx));
}
// Switch the order of the weights
Ops.push_back(ProfileData->getOperand(SecondIdx));
Ops.push_back(ProfileData->getOperand(FirstIdx));
setMetadata(LLVMContext::MD_prof,
MDNode::get(ProfileData->getContext(), Ops));
}
Expand Down
6 changes: 5 additions & 1 deletion llvm/lib/IR/Instructions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5199,7 +5199,11 @@ void SwitchInstProfUpdateWrapper::init() {
if (!ProfileData)
return;

if (ProfileData->getNumOperands() != SI.getNumSuccessors() + 1) {
// FIXME: This check belongs in ProfDataUtils. Its almost equivalent to
// getValidBranchWeightMDNode(), but the need to use llvm_unreachable
// makes them slightly different.
if (ProfileData->getNumOperands() !=
SI.getNumSuccessors() + getBranchWeightOffset(ProfileData)) {
Comment on lines +5202 to +5206
Copy link
Contributor

@MatzeB MatzeB Apr 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems simple enough to do something about it instead of adding a FIXME? Could for example add a getNumBranchWeights(<profile_data>) API so this can become getNumBranchWeights(ProfileData) != SI.getNumSuccessors()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another good suggestion. Thank you.

llvm_unreachable("number of prof branch_weights metadata operands does "
"not correspond to number of succesors");
}
Expand Down
14 changes: 9 additions & 5 deletions llvm/lib/IR/MDBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ MDNode *MDBuilder::createFPMath(float Accuracy) {
}

MDNode *MDBuilder::createBranchWeights(uint32_t TrueWeight,
uint32_t FalseWeight) {
return createBranchWeights({TrueWeight, FalseWeight});
uint32_t FalseWeight, bool IsExpected) {
return createBranchWeights({TrueWeight, FalseWeight}, IsExpected);
}

MDNode *MDBuilder::createLikelyBranchWeights() {
Expand All @@ -49,15 +49,19 @@ MDNode *MDBuilder::createUnlikelyBranchWeights() {
return createBranchWeights(1, (1U << 20) - 1);
}

MDNode *MDBuilder::createBranchWeights(ArrayRef<uint32_t> Weights) {
MDNode *MDBuilder::createBranchWeights(ArrayRef<uint32_t> Weights,
bool IsExpected) {
assert(Weights.size() >= 1 && "Need at least one branch weights!");

SmallVector<Metadata *, 4> Vals(Weights.size() + 1);
unsigned int Offset = IsExpected ? 2 : 1;
SmallVector<Metadata *, 4> Vals(Weights.size() + Offset);
Vals[0] = createString("branch_weights");
if (IsExpected)
Vals[1] = createString("expected");

Type *Int32Ty = Type::getInt32Ty(Context);
for (unsigned i = 0, e = Weights.size(); i != e; ++i)
Vals[i + 1] = createConstant(ConstantInt::get(Int32Ty, Weights[i]));
Vals[i + Offset] = createConstant(ConstantInt::get(Int32Ty, Weights[i]));

return MDNode::get(Context, Vals);
}
Expand Down
8 changes: 4 additions & 4 deletions llvm/lib/IR/Metadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1196,10 +1196,10 @@ MDNode *MDNode::mergeDirectCallProfMetadata(MDNode *A, MDNode *B,
StringRef AProfName = AMDS->getString();
StringRef BProfName = BMDS->getString();
if (AProfName == "branch_weights" && BProfName == "branch_weights") {
ConstantInt *AInstrWeight =
mdconst::dyn_extract<ConstantInt>(A->getOperand(1));
ConstantInt *BInstrWeight =
mdconst::dyn_extract<ConstantInt>(B->getOperand(1));
ConstantInt *AInstrWeight = mdconst::dyn_extract<ConstantInt>(
A->getOperand(getBranchWeightOffset(A)));
ConstantInt *BInstrWeight = mdconst::dyn_extract<ConstantInt>(
B->getOperand(getBranchWeightOffset(B)));
assert(AInstrWeight && BInstrWeight && "verified by LLVM verifier");
return MDNode::get(Ctx,
{MDHelper.createString("branch_weights"),
Expand Down
40 changes: 31 additions & 9 deletions llvm/lib/IR/ProfDataUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,6 @@ namespace {
// We maintain some constants here to ensure that we access the branch weights
// correctly, and can change the behavior in the future if the layout changes

// The index at which the weights vector starts
constexpr unsigned WeightsIdx = 1;

// the minimum number of operands for MD_prof nodes with branch weights
constexpr unsigned MinBWOps = 3;

Expand Down Expand Up @@ -75,15 +72,16 @@ static void extractFromBranchWeightMD(const MDNode *ProfileData,
assert(isBranchWeightMD(ProfileData) && "wrong metadata");

unsigned NOps = ProfileData->getNumOperands();
unsigned WeightsIdx = getBranchWeightOffset(ProfileData);
assert(WeightsIdx < NOps && "Weights Index must be less than NOps.");
Weights.resize(NOps - WeightsIdx);

for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) {
ConstantInt *Weight =
mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
assert(Weight && "Malformed branch_weight in MD_prof node");
assert(Weight->getValue().getActiveBits() <= 32 &&
"Too many bits for uint32_t");
assert(Weight->getValue().getActiveBits() <= (sizeof(T) * 8) &&
"Too many bits for MD_prof branch_weight");
Weights[Idx - WeightsIdx] = Weight->getZExtValue();
}
}
Expand Down Expand Up @@ -123,6 +121,26 @@ bool hasValidBranchWeightMD(const Instruction &I) {
return getValidBranchWeightMDNode(I);
}

bool hasBranchWeightOrigin(const Instruction &I) {
auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
return hasBranchWeightOrigin(ProfileData);
}

bool hasBranchWeightOrigin(const MDNode *ProfileData) {
if (!isBranchWeightMD(ProfileData))
return false;
auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(1));
// NOTE: if we ever have more types of branch weight provenance,
// we need to check the string value is "expected". For now, we
// supply a more generic API, and avoid the spurious comparisons.
assert(ProfDataName == nullptr || ProfDataName->getString() == "expected");
return ProfDataName != nullptr;
}

unsigned getBranchWeightOffset(const MDNode *ProfileData) {
return hasBranchWeightOrigin(ProfileData) ? 2 : 1;
}

MDNode *getBranchWeightMDNode(const Instruction &I) {
auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
if (!isBranchWeightMD(ProfileData))
Expand All @@ -132,7 +150,9 @@ MDNode *getBranchWeightMDNode(const Instruction &I) {

MDNode *getValidBranchWeightMDNode(const Instruction &I) {
auto *ProfileData = getBranchWeightMDNode(I);
if (ProfileData && ProfileData->getNumOperands() == 1 + I.getNumSuccessors())
auto Offset = getBranchWeightOffset(ProfileData);
if (ProfileData &&
ProfileData->getNumOperands() == Offset + I.getNumSuccessors())
return ProfileData;
return nullptr;
}
Expand Down Expand Up @@ -191,7 +211,8 @@ bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalVal) {
return false;

if (ProfDataName->getString() == "branch_weights") {
for (unsigned Idx = 1; Idx < ProfileData->getNumOperands(); Idx++) {
unsigned Offset = getBranchWeightOffset(ProfileData);
for (unsigned Idx = Offset; Idx < ProfileData->getNumOperands(); ++Idx) {
auto *V = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
assert(V && "Malformed branch_weight in MD_prof node");
TotalVal += V->getValue().getZExtValue();
Expand All @@ -212,9 +233,10 @@ bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalVal) {
return extractProfTotalWeight(I.getMetadata(LLVMContext::MD_prof), TotalVal);
}

void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights) {
void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights,
bool IsExpected) {
MDBuilder MDB(I.getContext());
MDNode *BranchWeights = MDB.createBranchWeights(Weights);
MDNode *BranchWeights = MDB.createBranchWeights(Weights, IsExpected);
I.setMetadata(LLVMContext::MD_prof, BranchWeights);
}

Expand Down
9 changes: 6 additions & 3 deletions llvm/lib/IR/Verifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@
#include "llvm/IR/Module.h"
#include "llvm/IR/ModuleSlotTracker.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/ProfDataUtils.h"
#include "llvm/IR/Statepoint.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Use.h"
Expand Down Expand Up @@ -4808,8 +4809,10 @@ void Verifier::visitProfMetadata(Instruction &I, MDNode *MD) {

// Check consistency of !prof branch_weights metadata.
if (ProfName == "branch_weights") {
unsigned int Offset = getBranchWeightOffset(MD);
if (isa<InvokeInst>(&I)) {
Check(MD->getNumOperands() == 2 || MD->getNumOperands() == 3,
Check(MD->getNumOperands() == (1 + Offset) ||
MD->getNumOperands() == (2 + Offset),
Comment on lines +4814 to +4815
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More opportunities for a possible getNumBranchWeights(...) API...

"Wrong number of InvokeInst branch_weights operands", MD);
} else {
unsigned ExpectedNumOperands = 0;
Expand All @@ -4829,10 +4832,10 @@ void Verifier::visitProfMetadata(Instruction &I, MDNode *MD) {
CheckFailed("!prof branch_weights are not allowed for this instruction",
MD);

Check(MD->getNumOperands() == 1 + ExpectedNumOperands,
Check(MD->getNumOperands() == Offset + ExpectedNumOperands,
"Wrong number of operands", MD);
}
for (unsigned i = 1; i < MD->getNumOperands(); ++i) {
for (unsigned i = Offset; i < MD->getNumOperands(); ++i) {
auto &MDO = MD->getOperand(i);
Check(MDO, "second operand should not be null", MD);
Check(mdconst::dyn_extract<ConstantInt>(MDO),
Expand Down
7 changes: 4 additions & 3 deletions llvm/lib/Transforms/IPO/SampleProfile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1662,7 +1662,8 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
else if (OverwriteExistingWeights)
I.setMetadata(LLVMContext::MD_prof, nullptr);
} else if (!isa<IntrinsicInst>(&I)) {
setBranchWeights(I, {static_cast<uint32_t>(BlockWeights[BB])});
setBranchWeights(I, {static_cast<uint32_t>(BlockWeights[BB])},
/*IsExpected=*/false);
}
}
} else if (OverwriteExistingWeights || ProfileSampleBlockAccurate) {
Expand All @@ -1673,7 +1674,7 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
if (cast<CallBase>(I).isIndirectCall()) {
I.setMetadata(LLVMContext::MD_prof, nullptr);
} else {
setBranchWeights(I, {uint32_t(0)});
setBranchWeights(I, {uint32_t(0)}, /*IsExpected=*/false);
}
}
}
Expand Down Expand Up @@ -1756,7 +1757,7 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
if (MaxWeight > 0 &&
(!TI->extractProfTotalWeight(TempWeight) || OverwriteExistingWeights)) {
LLVM_DEBUG(dbgs() << "SUCCESS. Found non-zero weights.\n");
setBranchWeights(*TI, Weights);
setBranchWeights(*TI, Weights, /*IsExpected=*/false);
ORE->emit([&]() {
return OptimizationRemark(DEBUG_TYPE, "PopularDest", MaxDestInst)
<< "most popular destination for conditional branches at "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1878,7 +1878,7 @@ void CHR::fixupBranchesAndSelects(CHRScope *Scope,
static_cast<uint32_t>(CHRBranchBias.scale(1000)),
static_cast<uint32_t>(CHRBranchBias.getCompl().scale(1000)),
};
setBranchWeights(*MergedBR, Weights);
setBranchWeights(*MergedBR, Weights, /*IsExpected=*/false);
CHR_DEBUG(dbgs() << "CHR branch bias " << Weights[0] << ":" << Weights[1]
<< "\n");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,8 @@ CallBase &llvm::pgo::promoteIndirectCall(CallBase &CB, Function *DirectCallee,
promoteCallWithIfThenElse(CB, DirectCallee, BranchWeights);

if (AttachProfToDirectCall) {
setBranchWeights(NewInst, {static_cast<uint32_t>(Count)});
setBranchWeights(NewInst, {static_cast<uint32_t>(Count)},
/*IsExpected=*/false);
}

using namespace ore;
Expand Down
5 changes: 3 additions & 2 deletions llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1474,7 +1474,8 @@ void PGOUseFunc::populateCoverage(IndexedInstrProfReader *PGOReader) {
for (auto *Succ : successors(&BB))
Weights.push_back((Coverage[Succ] || !Coverage[&BB]) ? 1 : 0);
if (Weights.size() >= 2)
llvm::setBranchWeights(*BB.getTerminator(), Weights);
llvm::setBranchWeights(*BB.getTerminator(), Weights,
/*IsExpected=*/false);
}

unsigned NumCorruptCoverage = 0;
Expand Down Expand Up @@ -2260,7 +2261,7 @@ void llvm::setProfMetadata(Module *M, Instruction *TI,

misexpect::checkExpectAnnotations(*TI, Weights, /*IsFrontend=*/false);

setBranchWeights(*TI, Weights);
setBranchWeights(*TI, Weights, /*IsExpected=*/false);
if (EmitBranchProbability) {
std::string BrCondStr = getBranchCondString(TI);
if (BrCondStr.empty())
Expand Down
Loading
Loading