116 changes: 79 additions & 37 deletions llvm/lib/Transforms/IPO/SampleContextTracker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,23 +63,24 @@ ContextTrieNode::getHottestChildContext(const LineLocation &CallSite) {
return ChildNodeRet;
}

ContextTrieNode &ContextTrieNode::moveToChildContext(
const LineLocation &CallSite, ContextTrieNode &&NodeToMove,
uint32_t ContextFramesToRemove, bool DeleteNode) {
ContextTrieNode &
SampleContextTracker::moveContextSamples(ContextTrieNode &ToNodeParent,
const LineLocation &CallSite,
ContextTrieNode &&NodeToMove) {
uint64_t Hash =
FunctionSamples::getCallSiteHash(NodeToMove.getFuncName(), CallSite);
std::map<uint64_t, ContextTrieNode> &AllChildContext =
ToNodeParent.getAllChildContext();
assert(!AllChildContext.count(Hash) && "Node to remove must exist");
LineLocation OldCallSite = NodeToMove.CallSiteLoc;
ContextTrieNode &OldParentContext = *NodeToMove.getParentContext();
AllChildContext[Hash] = NodeToMove;
ContextTrieNode &NewNode = AllChildContext[Hash];
NewNode.CallSiteLoc = CallSite;
NewNode.setCallSiteLoc(CallSite);

// Walk through nodes in the moved the subtree, and update
// FunctionSamples' context as for the context promotion.
// We also need to set new parant link for all children.
std::queue<ContextTrieNode *> NodeToUpdate;
NewNode.setParentContext(this);
NewNode.setParentContext(&ToNodeParent);
NodeToUpdate.push(&NewNode);

while (!NodeToUpdate.empty()) {
Expand All @@ -88,10 +89,8 @@ ContextTrieNode &ContextTrieNode::moveToChildContext(
FunctionSamples *FSamples = Node->getFunctionSamples();

if (FSamples) {
FSamples->getContext().promoteOnPath(ContextFramesToRemove);
setContextNode(FSamples, Node);
FSamples->getContext().setState(SyntheticContext);
LLVM_DEBUG(dbgs() << " Context promoted to: "
<< FSamples->getContext().toString() << "\n");
}

for (auto &It : Node->getAllChildContext()) {
Expand All @@ -101,10 +100,6 @@ ContextTrieNode &ContextTrieNode::moveToChildContext(
}
}

// Original context no longer needed, destroy if requested.
if (DeleteNode)
OldParentContext.removeChildContext(OldCallSite, NewNode.getFuncName());

return NewNode;
}

Expand Down Expand Up @@ -148,6 +143,10 @@ void ContextTrieNode::setParentContext(ContextTrieNode *Parent) {
ParentContext = Parent;
}

void ContextTrieNode::setCallSiteLoc(const LineLocation &Loc) {
CallSiteLoc = Loc;
}

void ContextTrieNode::dumpNode() {
dbgs() << "Node: " << FuncName << "\n"
<< " Callsite: " << CallSiteLoc << "\n"
Expand Down Expand Up @@ -203,13 +202,23 @@ SampleContextTracker::SampleContextTracker(
SampleContext Context = FuncSample.first;
LLVM_DEBUG(dbgs() << "Tracking Context for function: " << Context.toString()
<< "\n");
if (!Context.isBaseContext())
FuncToCtxtProfiles[Context.getName()].insert(FSamples);
ContextTrieNode *NewNode = getOrCreateContextPath(Context, true);
assert(!NewNode->getFunctionSamples() &&
"New node can't have sample profile");
NewNode->setFunctionSamples(FSamples);
}
populateFuncToCtxtMap();
}

void SampleContextTracker::populateFuncToCtxtMap() {
for (auto *Node : *this) {
FunctionSamples *FSamples = Node->getFunctionSamples();
if (FSamples) {
FSamples->getContext().setState(RawContext);
setContextNode(FSamples, Node);
FuncToCtxtProfiles[Node->getFuncName()].push_back(FSamples);
}
}
}

FunctionSamples *
Expand All @@ -232,7 +241,7 @@ SampleContextTracker::getCalleeContextSamplesFor(const CallBase &Inst,
if (CalleeContext) {
FunctionSamples *FSamples = CalleeContext->getFunctionSamples();
LLVM_DEBUG(if (FSamples) {
dbgs() << " Callee context found: " << FSamples->getContext().toString()
dbgs() << " Callee context found: " << getContextString(CalleeContext)
<< "\n";
});
return FSamples;
Expand Down Expand Up @@ -334,7 +343,7 @@ FunctionSamples *SampleContextTracker::getBaseSamplesFor(StringRef Name,
if (Context.hasState(InlinedContext) || Context.hasState(MergedContext))
continue;

ContextTrieNode *FromNode = getContextFor(Context);
ContextTrieNode *FromNode = getContextNodeForProfile(CSamples);
if (FromNode == Node)
continue;

Expand All @@ -355,7 +364,7 @@ void SampleContextTracker::markContextSamplesInlined(
const FunctionSamples *InlinedSamples) {
assert(InlinedSamples && "Expect non-null inlined samples");
LLVM_DEBUG(dbgs() << "Marking context profile as inlined: "
<< InlinedSamples->getContext().toString() << "\n");
<< getContextString(*InlinedSamples) << "\n");
InlinedSamples->getContext().setState(InlinedContext);
}

Expand Down Expand Up @@ -407,16 +416,40 @@ ContextTrieNode &SampleContextTracker::promoteMergeContextSamplesTree(
FunctionSamples *FromSamples = NodeToPromo.getFunctionSamples();
assert(FromSamples && "Shouldn't promote a context without profile");
LLVM_DEBUG(dbgs() << " Found context tree root to promote: "
<< FromSamples->getContext().toString() << "\n");
<< getContextString(&NodeToPromo) << "\n");

assert(!FromSamples->getContext().hasState(InlinedContext) &&
"Shouldn't promote inlined context profile");
uint32_t ContextFramesToRemove =
FromSamples->getContext().getContextFrames().size() - 1;
return promoteMergeContextSamplesTree(NodeToPromo, RootContext,
ContextFramesToRemove);
return promoteMergeContextSamplesTree(NodeToPromo, RootContext);
}

#ifndef NDEBUG
std::string
SampleContextTracker::getContextString(const FunctionSamples &FSamples) const {
return getContextString(getContextNodeForProfile(&FSamples));
}

std::string
SampleContextTracker::getContextString(ContextTrieNode *Node) const {
SampleContextFrameVector Res;
if (Node == &RootContext)
return std::string();
Res.emplace_back(Node->getFuncName(), LineLocation(0, 0));

ContextTrieNode *PreNode = Node;
Node = Node->getParentContext();
while (Node && Node != &RootContext) {
Res.emplace_back(Node->getFuncName(), PreNode->getCallSiteLoc());
PreNode = Node;
Node = Node->getParentContext();
}

std::reverse(Res.begin(), Res.end());

return SampleContext::getContextString(Res);
}
#endif

void SampleContextTracker::dump() { RootContext.dumpTree(); }

StringRef SampleContextTracker::getFuncNameFor(ContextTrieNode *Node) const {
Expand Down Expand Up @@ -527,8 +560,7 @@ ContextTrieNode &SampleContextTracker::addTopLevelContextNode(StringRef FName) {
}

void SampleContextTracker::mergeContextNode(ContextTrieNode &FromNode,
ContextTrieNode &ToNode,
uint32_t ContextFramesToRemove) {
ContextTrieNode &ToNode) {
FunctionSamples *FromSamples = FromNode.getFunctionSamples();
FunctionSamples *ToSamples = ToNode.getFunctionSamples();
if (FromSamples && ToSamples) {
Expand All @@ -541,16 +573,13 @@ void SampleContextTracker::mergeContextNode(ContextTrieNode &FromNode,
} else if (FromSamples) {
// Transfer FromSamples from FromNode to ToNode
ToNode.setFunctionSamples(FromSamples);
setContextNode(FromSamples, &ToNode);
FromSamples->getContext().setState(SyntheticContext);
FromSamples->getContext().promoteOnPath(ContextFramesToRemove);
FromNode.setFunctionSamples(nullptr);
}
}

ContextTrieNode &SampleContextTracker::promoteMergeContextSamplesTree(
ContextTrieNode &FromNode, ContextTrieNode &ToNodeParent,
uint32_t ContextFramesToRemove) {
assert(ContextFramesToRemove && "Context to remove can't be empty");
ContextTrieNode &FromNode, ContextTrieNode &ToNodeParent) {

// Ignore call site location if destination is top level under root
LineLocation NewCallSiteLoc = LineLocation(0, 0);
Expand All @@ -567,22 +596,25 @@ ContextTrieNode &SampleContextTracker::promoteMergeContextSamplesTree(
if (!ToNode) {
// Do not delete node to move from its parent here because
// caller is iterating over children of that parent node.
ToNode = &ToNodeParent.moveToChildContext(
NewCallSiteLoc, std::move(FromNode), ContextFramesToRemove, false);
ToNode =
&moveContextSamples(ToNodeParent, NewCallSiteLoc, std::move(FromNode));
LLVM_DEBUG({
dbgs() << " Context promoted and merged to: " << getContextString(ToNode)
<< "\n";
});
} else {
// Destination node exists, merge samples for the context tree
mergeContextNode(FromNode, *ToNode, ContextFramesToRemove);
mergeContextNode(FromNode, *ToNode);
LLVM_DEBUG({
if (ToNode->getFunctionSamples())
dbgs() << " Context promoted and merged to: "
<< ToNode->getFunctionSamples()->getContext().toString() << "\n";
<< getContextString(ToNode) << "\n";
});

// Recursively promote and merge children
for (auto &It : FromNode.getAllChildContext()) {
ContextTrieNode &FromChildNode = It.second;
promoteMergeContextSamplesTree(FromChildNode, *ToNode,
ContextFramesToRemove);
promoteMergeContextSamplesTree(FromChildNode, *ToNode);
}

// Remove children once they're all merged
Expand All @@ -595,4 +627,14 @@ ContextTrieNode &SampleContextTracker::promoteMergeContextSamplesTree(

return *ToNode;
}

void SampleContextTracker::createContextLessProfileMap(
SampleProfileMap &ContextLessProfiles) {
for (auto *Node : *this) {
FunctionSamples *FProfile = Node->getFunctionSamples();
// Profile's context can be empty, use ContextNode's func name.
if (FProfile)
ContextLessProfiles[Node->getFuncName()].merge(*FProfile);
}
}
} // namespace llvm
3 changes: 1 addition & 2 deletions llvm/lib/Transforms/IPO/SampleProfile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1073,8 +1073,7 @@ void SampleProfileLoader::findExternalInlineCandidate(
return;
}

ContextTrieNode *Caller =
ContextTracker->getContextFor(Samples->getContext());
ContextTrieNode *Caller = ContextTracker->getContextNodeForProfile(Samples);
std::queue<ContextTrieNode *> CalleeList;
CalleeList.push(Caller);
while (!CalleeList.empty()) {
Expand Down
83 changes: 40 additions & 43 deletions llvm/tools/llvm-profgen/CSPreInliner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,13 @@ static cl::opt<bool> SamplePreInlineReplay(
cl::desc(
"Replay previous inlining and adjust context profile accordingly"));

CSPreInliner::CSPreInliner(SampleProfileMap &Profiles, ProfiledBinary &Binary,
ProfileSummary *Summary)
CSPreInliner::CSPreInliner(SampleContextTracker &Tracker,
ProfiledBinary &Binary, ProfileSummary *Summary)
: UseContextCost(UseContextCostForPreInliner),
// TODO: Pass in a guid-to-name map in order for
// ContextTracker.getFuncNameFor to work, if `Profiles` can have md5 codes
// as their profile context.
ContextTracker(Profiles, nullptr), ProfileMap(Profiles), Binary(Binary),
Summary(Summary) {
ContextTracker(Tracker), Binary(Binary), Summary(Summary) {
// Set default preinliner hot/cold call site threshold tuned with CSSPGO.
// for good performance with reasonable profile size.
if (!SampleHotCallSiteThreshold.getNumOccurrences())
Expand Down Expand Up @@ -107,7 +106,7 @@ bool CSPreInliner::getInlineCandidates(ProfiledCandidateQueue &CQueue,
// current one in the trie is relavent. So we walk the trie instead of call
// targets from function profile.
ContextTrieNode *CallerNode =
ContextTracker.getContextFor(CallerSamples->getContext());
ContextTracker.getContextNodeForProfile(CallerSamples);

bool HasNewCandidate = false;
for (auto &Child : CallerNode->getAllChildContext()) {
Expand All @@ -131,20 +130,19 @@ bool CSPreInliner::getInlineCandidates(ProfiledCandidateQueue &CQueue,
// TODO: call site and callee entry count should be mostly consistent, add
// check for that.
HasNewCandidate = true;
uint32_t CalleeSize = getFuncSize(*CalleeSamples);
uint32_t CalleeSize = getFuncSize(CalleeNode);
CQueue.emplace(CalleeSamples, std::max(CallsiteCount, CalleeEntryCount),
CalleeSize);
}

return HasNewCandidate;
}

uint32_t CSPreInliner::getFuncSize(const FunctionSamples &FSamples) {
if (UseContextCost) {
return Binary.getFuncSizeForContext(FSamples.getContext());
}
uint32_t CSPreInliner::getFuncSize(const ContextTrieNode *ContextNode) {
if (UseContextCost)
return Binary.getFuncSizeForContext(ContextNode);

return FSamples.getBodySamples().size();
return ContextNode->getFunctionSamples()->getBodySamples().size();
}

bool CSPreInliner::shouldInline(ProfiledInlineCandidate &Candidate) {
Expand Down Expand Up @@ -189,7 +187,8 @@ void CSPreInliner::processFunction(const StringRef Name) {
if (!FSamples)
return;

unsigned FuncSize = getFuncSize(*FSamples);
unsigned FuncSize =
getFuncSize(ContextTracker.getContextNodeForProfile(FSamples));
unsigned FuncFinalSize = FuncSize;
unsigned SizeLimit = FuncSize * ProfileInlineGrowthLimit;
SizeLimit = std::min(SizeLimit, (unsigned)ProfileInlineLimitMax);
Expand Down Expand Up @@ -218,11 +217,12 @@ void CSPreInliner::processFunction(const StringRef Name) {
} else {
++PreInlNumCSNotInlined;
}
LLVM_DEBUG(dbgs() << (ShouldInline ? " Inlined" : " Outlined")
<< " context profile for: "
<< Candidate.CalleeSamples->getContext().toString()
<< " (callee size: " << Candidate.SizeCost
<< ", call count:" << Candidate.CallsiteCount << ")\n");
LLVM_DEBUG(
dbgs() << (ShouldInline ? " Inlined" : " Outlined")
<< " context profile for: "
<< ContextTracker.getContextString(*Candidate.CalleeSamples)
<< " (callee size: " << Candidate.SizeCost
<< ", call count:" << Candidate.CallsiteCount << ")\n");
}

if (!CQueue.empty()) {
Expand All @@ -246,7 +246,8 @@ void CSPreInliner::processFunction(const StringRef Name) {
CQueue.pop();
bool WasInlined =
Candidate.CalleeSamples->getContext().hasAttribute(ContextWasInlined);
dbgs() << " " << Candidate.CalleeSamples->getContext().toString()
dbgs() << " "
<< ContextTracker.getContextString(*Candidate.CalleeSamples)
<< " (candidate size:" << Candidate.SizeCost
<< ", call count: " << Candidate.CallsiteCount << ", previously "
<< (WasInlined ? "inlined)\n" : "not inlined)\n");
Expand All @@ -256,19 +257,24 @@ void CSPreInliner::processFunction(const StringRef Name) {

void CSPreInliner::run() {
#ifndef NDEBUG
auto printProfileNames = [](SampleProfileMap &Profiles, bool IsInput) {
dbgs() << (IsInput ? "Input" : "Output") << " context-sensitive profiles ("
<< Profiles.size() << " total):\n";
for (auto &It : Profiles) {
const FunctionSamples &Samples = It.second;
dbgs() << " [" << Samples.getContext().toString() << "] "
<< Samples.getTotalSamples() << ":" << Samples.getHeadSamples()
<< "\n";
auto printProfileNames = [](SampleContextTracker &ContextTracker,
bool IsInput) {
uint32_t Size = 0;
for (auto *Node : ContextTracker) {
FunctionSamples *FSamples = Node->getFunctionSamples();
if (FSamples) {
Size++;
dbgs() << " [" << ContextTracker.getContextString(Node) << "] "
<< FSamples->getTotalSamples() << ":"
<< FSamples->getHeadSamples() << "\n";
}
}
dbgs() << (IsInput ? "Input" : "Output") << " context-sensitive profiles ("
<< Size << " total):\n";
};
#endif

LLVM_DEBUG(printProfileNames(ProfileMap, true));
LLVM_DEBUG(printProfileNames(ContextTracker, true));

// Execute global pre-inliner to estimate a global top-down inline
// decision and merge profiles accordingly. This helps with profile
Expand All @@ -283,24 +289,15 @@ void CSPreInliner::run() {

// Not inlined context profiles are merged into its base, so we can
// trim out such profiles from the output.
std::vector<SampleContext> ProfilesToBeRemoved;
for (auto &It : ProfileMap) {
SampleContext &Context = It.second.getContext();
if (!Context.isBaseContext() && !Context.hasState(InlinedContext)) {
assert(Context.hasState(MergedContext) &&
"Not inlined context profile should be merged already");
ProfilesToBeRemoved.push_back(It.first);
for (auto *Node : ContextTracker) {
FunctionSamples *FProfile = Node->getFunctionSamples();
if (FProfile &&
(Node->getParentContext() != &ContextTracker.getRootContext() &&
!FProfile->getContext().hasState(InlinedContext))) {
Node->setFunctionSamples(nullptr);
}
}

for (auto &ContextName : ProfilesToBeRemoved) {
ProfileMap.erase(ContextName);
}

// Make sure ProfileMap's key is consistent with FunctionSamples' name.
SampleContextTrimmer(ProfileMap).canonicalizeContextProfiles();

FunctionSamples::ProfileIsPreInlined = true;

LLVM_DEBUG(printProfileNames(ProfileMap, false));
LLVM_DEBUG(printProfileNames(ContextTracker, false));
}
7 changes: 3 additions & 4 deletions llvm/tools/llvm-profgen/CSPreInliner.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ using ProfiledCandidateQueue =
// size by only keep context that is estimated to be inlined.
class CSPreInliner {
public:
CSPreInliner(SampleProfileMap &Profiles, ProfiledBinary &Binary,
CSPreInliner(SampleContextTracker &Tracker, ProfiledBinary &Binary,
ProfileSummary *Summary);
void run();

Expand All @@ -77,10 +77,9 @@ class CSPreInliner {
std::vector<StringRef> buildTopDownOrder();
void processFunction(StringRef Name);
bool shouldInline(ProfiledInlineCandidate &Candidate);
uint32_t getFuncSize(const FunctionSamples &FSamples);
uint32_t getFuncSize(const ContextTrieNode *ContextNode);
bool UseContextCost;
SampleContextTracker ContextTracker;
SampleProfileMap &ProfileMap;
SampleContextTracker &ContextTracker;
ProfiledBinary &Binary;
ProfileSummary *Summary;
};
Expand Down
378 changes: 226 additions & 152 deletions llvm/tools/llvm-profgen/ProfileGenerator.cpp

Large diffs are not rendered by default.

64 changes: 53 additions & 11 deletions llvm/tools/llvm-profgen/ProfileGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ using ProbeCounterMap =
class ProfileGeneratorBase {

public:
ProfileGeneratorBase(ProfiledBinary *Binary) : Binary(Binary){};
ProfileGeneratorBase(ProfiledBinary *Binary,
const ContextSampleCounterMap *Counters)
: Binary(Binary), SampleCounters(Counters){};
Expand All @@ -44,7 +45,7 @@ class ProfileGeneratorBase {
create(ProfiledBinary *Binary, const ContextSampleCounterMap *Counters,
bool profileIsCS);
static std::unique_ptr<ProfileGeneratorBase>
create(ProfiledBinary *Binary, const SampleProfileMap &&ProfileMap,
create(ProfiledBinary *Binary, SampleProfileMap &ProfileMap,
bool profileIsCS);
virtual void generateProfile() = 0;
void write();
Expand Down Expand Up @@ -109,7 +110,7 @@ class ProfileGeneratorBase {

StringRef getCalleeNameForOffset(uint64_t TargetOffset);

void computeSummaryAndThreshold();
void computeSummaryAndThreshold(SampleProfileMap &ProfileMap);

void calculateAndShowDensity(const SampleProfileMap &Profiles);

Expand All @@ -120,6 +121,13 @@ class ProfileGeneratorBase {

void collectProfiledFunctions();

bool collectFunctionsFromRawProfile(
std::unordered_set<const BinaryFunction *> &ProfiledFunctions);

// Collect profiled Functions for llvm sample profile input.
virtual bool collectFunctionsFromLLVMProfile(
std::unordered_set<const BinaryFunction *> &ProfiledFunctions) = 0;

// Thresholds from profile summary to answer isHotCount/isColdCount queries.
uint64_t HotCountThreshold;

Expand Down Expand Up @@ -166,15 +174,17 @@ class ProfileGenerator : public ProfileGeneratorBase {
void postProcessProfiles();
void trimColdProfiles(const SampleProfileMap &Profiles,
uint64_t ColdCntThreshold);
bool collectFunctionsFromLLVMProfile(
std::unordered_set<const BinaryFunction *> &ProfiledFunctions) override;
};

class CSProfileGenerator : public ProfileGeneratorBase {
public:
CSProfileGenerator(ProfiledBinary *Binary,
const ContextSampleCounterMap *Counters)
: ProfileGeneratorBase(Binary, Counters){};
CSProfileGenerator(ProfiledBinary *Binary, const SampleProfileMap &&Profiles)
: ProfileGeneratorBase(Binary, std::move(Profiles)){};
CSProfileGenerator(ProfiledBinary *Binary, SampleProfileMap &Profiles)
: ProfileGeneratorBase(Binary), ContextTracker(Profiles, nullptr){};
void generateProfile() override;

// Trim the context stack at a given depth.
Expand Down Expand Up @@ -294,10 +304,15 @@ class CSProfileGenerator : public ProfileGeneratorBase {

private:
void generateLineNumBasedProfile();
// Lookup or create FunctionSamples for the context
FunctionSamples &
getFunctionProfileForContext(const SampleContextFrameVector &Context,
bool WasLeafInlined = false);

FunctionSamples *getOrCreateFunctionSamples(ContextTrieNode *ContextNode,
bool WasLeafInlined = false);

// Lookup or create ContextTrieNode for the context, FunctionSamples is
// created inside this function.
ContextTrieNode *getOrCreateContextNode(const SampleContextFrames Context,
bool WasLeafInlined = false);

// For profiled only functions, on-demand compute their inline context
// function byte size which is used by the pre-inliner.
void computeSizeForProfiledFunctions();
Expand All @@ -307,10 +322,13 @@ class CSProfileGenerator : public ProfileGeneratorBase {

void populateBodySamplesForFunction(FunctionSamples &FunctionProfile,
const RangeSample &RangeCounters);
void populateBoundarySamplesForFunction(SampleContextFrames ContextId,
FunctionSamples *CallerProfile,

void populateBoundarySamplesForFunction(ContextTrieNode *CallerNode,
const BranchSample &BranchCounters);
void populateInferredFunctionSamples();

void populateInferredFunctionSamples(ContextTrieNode &Node);

void updateFunctionSamples();

void generateProbeBasedProfile();

Expand All @@ -320,14 +338,38 @@ class CSProfileGenerator : public ProfileGeneratorBase {
// Fill in boundary samples for a call probe
void populateBoundarySamplesWithProbes(const BranchSample &BranchCounter,
SampleContextFrames ContextStack);

ContextTrieNode *
getContextNodeForLeafProbe(SampleContextFrames ContextStack,
const MCDecodedPseudoProbe *LeafProbe);

// Helper function to get FunctionSamples for the leaf probe
FunctionSamples &
getFunctionProfileForLeafProbe(SampleContextFrames ContextStack,
const MCDecodedPseudoProbe *LeafProbe);

void convertToProfileMap(ContextTrieNode &Node,
SampleContextFrameVector &Context);

void convertToProfileMap();

void computeSummaryAndThreshold();

bool collectFunctionsFromLLVMProfile(
std::unordered_set<const BinaryFunction *> &ProfiledFunctions) override;

ContextTrieNode &getRootContext() { return ContextTracker.getRootContext(); };

// The container for holding the FunctionSamples used by context trie.
std::list<FunctionSamples> FSamplesList;

// Underlying context table serves for sample profile writer.
std::unordered_set<SampleContextFrameVector, SampleContextFrameHash> Contexts;

SampleContextTracker ContextTracker;

bool IsProfileValidOnTrie = true;

public:
// Deduplicate adjacent repeated context sequences up to a given sequence
// length. -1 means no size limit.
Expand Down
16 changes: 7 additions & 9 deletions llvm/tools/llvm-profgen/ProfiledBinary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,25 +83,23 @@ void BinarySizeContextTracker::addInstructionForContext(
}

uint32_t
BinarySizeContextTracker::getFuncSizeForContext(const SampleContext &Context) {
BinarySizeContextTracker::getFuncSizeForContext(const ContextTrieNode *Node) {
ContextTrieNode *CurrNode = &RootContext;
ContextTrieNode *PrevNode = nullptr;
SampleContextFrames Frames = Context.getContextFrames();
int32_t I = Frames.size() - 1;

Optional<uint32_t> Size;

// Start from top-level context-less function, traverse down the reverse
// context trie to find the best/longest match for given context, then
// retrieve the size.

while (CurrNode && I >= 0) {
// Process from leaf function to callers (added to context).
const auto &ChildFrame = Frames[I--];
LineLocation CallSiteLoc(0, 0);
while (CurrNode && Node->getParentContext() != nullptr) {
PrevNode = CurrNode;
CurrNode =
CurrNode->getChildContext(ChildFrame.Location, ChildFrame.FuncName);
CurrNode = CurrNode->getChildContext(CallSiteLoc, Node->getFuncName());
if (CurrNode && CurrNode->getFunctionSize())
Size = CurrNode->getFunctionSize().getValue();
CallSiteLoc = Node->getCallSiteLoc();
Node = Node->getParentContext();
}

// If we traversed all nodes along the path of the context and haven't
Expand Down
6 changes: 3 additions & 3 deletions llvm/tools/llvm-profgen/ProfiledBinary.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ class BinarySizeContextTracker {
// Get function size with a specific context. When there's no exact match
// for the given context, try to retrieve the size of that function from
// closest matching context.
uint32_t getFuncSizeForContext(const SampleContext &Context);
uint32_t getFuncSizeForContext(const ContextTrieNode *Context);

// For inlinees that are full optimized away, we can establish zero size using
// their remaining probes.
Expand Down Expand Up @@ -485,8 +485,8 @@ class ProfiledBinary {
return &I->second;
}

uint32_t getFuncSizeForContext(SampleContext &Context) {
return FuncSizeTracker.getFuncSizeForContext(Context);
uint32_t getFuncSizeForContext(const ContextTrieNode *ContextNode) {
return FuncSizeTracker.getFuncSizeForContext(ContextNode);
}

// Load the symbols from debug table and populate into symbol list.
Expand Down
3 changes: 1 addition & 2 deletions llvm/tools/llvm-profgen/llvm-profgen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,7 @@ int main(int argc, const char *argv[]) {
std::move(ReaderOrErr.get());
Reader->read();
std::unique_ptr<ProfileGeneratorBase> Generator =
ProfileGeneratorBase::create(Binary.get(),
std::move(Reader->getProfiles()),
ProfileGeneratorBase::create(Binary.get(), Reader->getProfiles(),
Reader->profileIsCS());
Generator->generateProfile();
Generator->write();
Expand Down