diff --git a/llvm/include/llvm/ProfileData/SampleProf.h b/llvm/include/llvm/ProfileData/SampleProf.h index 51d590be124f1..10c4f0f4bfe12 100644 --- a/llvm/include/llvm/ProfileData/SampleProf.h +++ b/llvm/include/llvm/ProfileData/SampleProf.h @@ -902,6 +902,13 @@ class FunctionSamples { return Ret->second.getCallTargets(); } + /// Returns the call target count of a specific function \p CalleeName at a + /// given location \p Callsite. Returns nullptr if not found. A \p Remapper + /// can be optionally provided to look up a name equivalent to \p CalleeName. + const uint64_t * + findCallTargetAt(const LineLocation &Callsite, StringRef CalleeName, + SampleProfileReaderItaniumRemapper *Remapper) const; + /// Return the function samples at the given callsite location. FunctionSamplesMap &functionSamplesAt(const LineLocation &Loc) { return CallsiteSamples[mapIRLocToProfileLoc(Loc)]; diff --git a/llvm/include/llvm/Transforms/Utils/SampleProfileLoaderBaseImpl.h b/llvm/include/llvm/Transforms/Utils/SampleProfileLoaderBaseImpl.h index 7c725a3c1216c..844531d8c2db9 100644 --- a/llvm/include/llvm/Transforms/Utils/SampleProfileLoaderBaseImpl.h +++ b/llvm/include/llvm/Transforms/Utils/SampleProfileLoaderBaseImpl.h @@ -408,6 +408,19 @@ SampleProfileLoaderBaseImpl::getInstWeightImpl(const InstructionT &Inst) { Discriminator = DIL->getBaseDiscriminator(); ErrorOr R = FS->findSamplesAt(LineOffset, Discriminator); + if constexpr (std::is_base_of_v) { + // If Inst is a direct function call and matches a sample, we should check + // if the sample contains call target count of the matching function, and + // use that count value instead of sample count, because sample count may + // contain superfluous numbers from other non-matching call targets as a + // result of merging profiles. + if (const CallInst *Call = dyn_cast(&Inst)) + if (const Function *Callee = Call->getCalledFunction()) + if (const uint64_t *CallTargetCount = + FS->findCallTargetAt(LineLocation(LineOffset, Discriminator), + Callee->getName(), Reader->getRemapper())) + R.get() = *CallTargetCount; + } if (R) { bool FirstMark = CoverageTracker.markSamplesUsed(FS, LineOffset, Discriminator, R.get()); diff --git a/llvm/lib/ProfileData/SampleProf.cpp b/llvm/lib/ProfileData/SampleProf.cpp index 59fa71899ed47..8d3b52075641a 100644 --- a/llvm/lib/ProfileData/SampleProf.cpp +++ b/llvm/lib/ProfileData/SampleProf.cpp @@ -275,6 +275,25 @@ void FunctionSamples::findAllNames(DenseSet &NameSet) const { } } +const uint64_t *FunctionSamples::findCallTargetAt(const LineLocation &Callsite, + StringRef CalleeName, SampleProfileReaderItaniumRemapper *Remapper) const { + const auto &FindRes = BodySamples.find(mapIRLocToProfileLoc(Callsite)); + if (FindRes == BodySamples.end()) + return nullptr; + const auto &CallTargets = FindRes->second.getCallTargets(); + const auto &Ret = CallTargets.find(getRepInFormat(CalleeName)); + if (Ret != CallTargets.end()) + return &Ret->second; + if (Remapper && !UseMD5) { + if (auto RemappedName = Remapper->lookUpNameInProfile(CalleeName)) { + const auto &Ret = CallTargets.find(getRepInFormat(*RemappedName)); + if (Ret != CallTargets.end()) + return &Ret->second; + } + } + return nullptr; +} + const FunctionSamples *FunctionSamples::findFunctionSamplesAt( const LineLocation &Loc, StringRef CalleeName, SampleProfileReaderItaniumRemapper *Remapper) const { diff --git a/llvm/test/Transforms/SampleProfile/Inputs/direct-call-accurate-count.prof b/llvm/test/Transforms/SampleProfile/Inputs/direct-call-accurate-count.prof new file mode 100644 index 0000000000000..f04a1b9b89f18 --- /dev/null +++ b/llvm/test/Transforms/SampleProfile/Inputs/direct-call-accurate-count.prof @@ -0,0 +1,6 @@ +test1:10000:1000 + 2: 456 callee:123 callee2:15 +test2:20000:2000 + 3: 50 callee:30 +test3:30000:3000 + 2: 101 callee_mismatch:45 diff --git a/llvm/test/Transforms/SampleProfile/direct-call-accurate-count.ll b/llvm/test/Transforms/SampleProfile/direct-call-accurate-count.ll new file mode 100644 index 0000000000000..fd08e4ee3dc29 --- /dev/null +++ b/llvm/test/Transforms/SampleProfile/direct-call-accurate-count.ll @@ -0,0 +1,61 @@ +; RUN: opt -S %s -passes=sample-profile -sample-profile-file=%S/Inputs/direct-call-accurate-count.prof -salvage-stale-profile | FileCheck %s +; RUN: llvm-profdata merge --sample --extbinary --use-md5 -output=%t %S/Inputs/direct-call-accurate-count.prof +; RUN: opt -S %s -passes=sample-profile -sample-profile-file=%t -salvage-stale-profile | FileCheck %s + +declare void @callee() #0 + +; CHECK-LABEL: @test1 +define dso_local void @test1(i1 %0) #1 !dbg !3 { +; Add a branch here to prevent the head sample from being unconditionally +; propagated to the entire block overriding the line sample count. + br i1 %0, label %if.then, label %if.end +if.then: + call void @callee(), !dbg !4 +; CHECK: call void @callee(), !dbg !{{[0-9]+}}, !prof ![[BRANCH_WEIGHT1:[0-9]+]] + br label %if.end +if.end: + ret void +} + +; With stale profile +; CHECK-LABEL: @test2 +define dso_local void @test2(i1 %0) #1 !dbg !5 { + br i1 %0, label %if.then, label %if.end +if.then: + call void @callee(), !dbg !6 +; CHECK: call void @callee(), !dbg !{{[0-9]+}}, !prof ![[BRANCH_WEIGHT2:[0-9]+]] + br label %if.end +if.end: + ret void +} + +; Call target is not matched in profile, use sample count. +; CHECK-LABEL: @test3 +define dso_local void @test3(i1 %0) #1 !dbg !7 { + br i1 %0, label %if.then, label %if.end +if.then: + call void @callee(), !dbg !8 +; CHECK: call void @callee(), !dbg !{{[0-9]+}}, !prof ![[BRANCH_WEIGHT3:[0-9]+]] + br label %if.end +if.end: + ret void +} + +attributes #0 = { "use-sample-profile" } +attributes #1 = { "use-sample-profile" } + +!llvm.dbg.cu = !{!0} +!llvm.module.flags = !{!2} + +!0 = distinct !DICompileUnit(language: DW_LANG_C_plus_plus_14, file: !1) +!1 = !DIFile(filename: "test.cpp", directory: "/") +!2 = !{i32 2, !"Debug Info Version", i32 3} +!3 = distinct !DISubprogram(name: "test1", scope: !1, file: !1, line: 1, unit: !0) +!4 = !DILocation(line: 3, column: 4, scope: !3) +!5 = distinct !DISubprogram(name: "test2", scope: !1, file: !1, line: 11, unit: !0) +!6 = !DILocation(line: 15, column: 4, scope: !5) +!7 = distinct !DISubprogram(name: "test3", scope: !1, file: !1, line: 21, unit: !0) +!8 = !DILocation(line: 23, column: 4, scope: !7) +; CHECK-DAG: ![[BRANCH_WEIGHT1]] = !{!"branch_weights", i32 123} +; CHECK-DAG: ![[BRANCH_WEIGHT2]] = !{!"branch_weights", i32 30} +; CHECK-DAG: ![[BRANCH_WEIGHT3]] = !{!"branch_weights", i32 101}