diff --git a/compiler-rt/lib/ctx_profile/CtxInstrContextNode.h b/compiler-rt/lib/ctx_profile/CtxInstrContextNode.h index 36a996632b71e..625ec8cd4d32a 100644 --- a/compiler-rt/lib/ctx_profile/CtxInstrContextNode.h +++ b/compiler-rt/lib/ctx_profile/CtxInstrContextNode.h @@ -112,6 +112,13 @@ class ContextNode final { uint64_t entrycount() const { return counters()[0]; } }; + +/// Abstraction for the parameter passed to `__llvm_ctx_profile_fetch`. +class ProfileWriter { +public: + virtual void writeContextual(const ctx_profile::ContextNode &RootNode) = 0; + virtual ~ProfileWriter() = default; +}; } // namespace ctx_profile } // namespace llvm #endif diff --git a/compiler-rt/lib/ctx_profile/CtxInstrProfiling.cpp b/compiler-rt/lib/ctx_profile/CtxInstrProfiling.cpp index df30986cdfc69..32d13283c1b48 100644 --- a/compiler-rt/lib/ctx_profile/CtxInstrProfiling.cpp +++ b/compiler-rt/lib/ctx_profile/CtxInstrProfiling.cpp @@ -294,9 +294,7 @@ void __llvm_ctx_profile_start_collection() { __sanitizer::Printf("[ctxprof] Initial NumMemUnits: %zu \n", NumMemUnits); } -bool __llvm_ctx_profile_fetch(void *Data, - bool (*Writer)(void *W, const ContextNode &)) { - assert(Writer); +bool __llvm_ctx_profile_fetch(ProfileWriter &Writer) { __sanitizer::GenericScopedLock<__sanitizer::SpinMutex> Lock( &AllContextsMutex); @@ -308,8 +306,7 @@ bool __llvm_ctx_profile_fetch(void *Data, __sanitizer::Printf("[ctxprof] Contextual Profile is %s\n", "invalid"); return false; } - if (!Writer(Data, *Root->FirstNode)) - return false; + Writer.writeContextual(*Root->FirstNode); } return true; } diff --git a/compiler-rt/lib/ctx_profile/CtxInstrProfiling.h b/compiler-rt/lib/ctx_profile/CtxInstrProfiling.h index 74d346d6e0a07..8a6949d4ec288 100644 --- a/compiler-rt/lib/ctx_profile/CtxInstrProfiling.h +++ b/compiler-rt/lib/ctx_profile/CtxInstrProfiling.h @@ -169,7 +169,6 @@ void __llvm_ctx_profile_free(); /// The Writer's first parameter plays the role of closure for Writer, and is /// what the caller of __llvm_ctx_profile_fetch passes as the Data parameter. /// The second parameter is the root of a context tree. -bool __llvm_ctx_profile_fetch(void *Data, - bool (*Writer)(void *, const ContextNode &)); +bool __llvm_ctx_profile_fetch(ProfileWriter &); } #endif // CTX_PROFILE_CTXINSTRPROFILING_H_ diff --git a/compiler-rt/lib/ctx_profile/tests/CtxInstrProfilingTest.cpp b/compiler-rt/lib/ctx_profile/tests/CtxInstrProfilingTest.cpp index d9f08b1e7efe8..e040b18e2d77a 100644 --- a/compiler-rt/lib/ctx_profile/tests/CtxInstrProfilingTest.cpp +++ b/compiler-rt/lib/ctx_profile/tests/CtxInstrProfilingTest.cpp @@ -179,13 +179,15 @@ TEST_F(ContextTest, Dump) { (void)Subctx; __llvm_ctx_profile_release_context(&Root); - struct Writer { + class TestProfileWriter : public ProfileWriter { + public: ContextRoot *const Root; const size_t Entries; bool State = false; - Writer(ContextRoot *Root, size_t Entries) : Root(Root), Entries(Entries) {} + TestProfileWriter(ContextRoot *Root, size_t Entries) + : Root(Root), Entries(Entries) {} - bool write(const ContextNode &Node) { + void writeContextual(const ContextNode &Node) override { EXPECT_FALSE(Root->Taken.TryLock()); EXPECT_EQ(Node.guid(), 1U); EXPECT_EQ(Node.counters()[0], Entries); @@ -202,22 +204,17 @@ TEST_F(ContextTest, Dump) { EXPECT_EQ(SN.callsites_size(), 1U); EXPECT_EQ(SN.subContexts()[0], nullptr); State = true; - return true; } }; - Writer W(&Root, 1); + TestProfileWriter W(&Root, 1); EXPECT_FALSE(W.State); - __llvm_ctx_profile_fetch(&W, [](void *W, const ContextNode &Node) -> bool { - return reinterpret_cast(W)->write(Node); - }); + __llvm_ctx_profile_fetch(W); EXPECT_TRUE(W.State); // this resets all counters but not the internal structure. __llvm_ctx_profile_start_collection(); - Writer W2(&Root, 0); + TestProfileWriter W2(&Root, 0); EXPECT_FALSE(W2.State); - __llvm_ctx_profile_fetch(&W2, [](void *W, const ContextNode &Node) -> bool { - return reinterpret_cast(W)->write(Node); - }); + __llvm_ctx_profile_fetch(W2); EXPECT_TRUE(W2.State); } diff --git a/compiler-rt/test/ctx_profile/TestCases/generate-context.cpp b/compiler-rt/test/ctx_profile/TestCases/generate-context.cpp index 797b871860655..cb69c8826239d 100644 --- a/compiler-rt/test/ctx_profile/TestCases/generate-context.cpp +++ b/compiler-rt/test/ctx_profile/TestCases/generate-context.cpp @@ -15,9 +15,7 @@ #include using namespace llvm::ctx_profile; -extern "C" bool __llvm_ctx_profile_fetch(void *Data, - bool (*Writer)(void *, - const ContextNode &)); +extern "C" bool __llvm_ctx_profile_fetch(ProfileWriter &); // avoid name mangling extern "C" { @@ -46,22 +44,29 @@ __attribute__((noinline)) void theRoot() { // CHECK-NEXT: check even // CHECK-NEXT: check odd -void printProfile(const ContextNode &Node, const std::string &Indent, - const std::string &Increment) { - std::cout << Indent << "Guid: " << Node.guid() << std::endl; - std::cout << Indent << "Entries: " << Node.entrycount() << std::endl; - std::cout << Indent << Node.counters_size() << " counters and " - << Node.callsites_size() << " callsites" << std::endl; - std::cout << Indent << "Counter values: "; - for (uint32_t I = 0U; I < Node.counters_size(); ++I) - std::cout << Node.counters()[I] << " "; - std::cout << std::endl; - for (uint32_t I = 0U; I < Node.callsites_size(); ++I) - for (const auto *N = Node.subContexts()[I]; N; N = N->next()) { - std::cout << Indent << "At Index " << I << ":" << std::endl; - printProfile(*N, Indent + Increment, Increment); - } -} +class TestProfileWriter : public ProfileWriter { + void printProfile(const ContextNode &Node, const std::string &Indent, + const std::string &Increment) { + std::cout << Indent << "Guid: " << Node.guid() << std::endl; + std::cout << Indent << "Entries: " << Node.entrycount() << std::endl; + std::cout << Indent << Node.counters_size() << " counters and " + << Node.callsites_size() << " callsites" << std::endl; + std::cout << Indent << "Counter values: "; + for (uint32_t I = 0U; I < Node.counters_size(); ++I) + std::cout << Node.counters()[I] << " "; + std::cout << std::endl; + for (uint32_t I = 0U; I < Node.callsites_size(); ++I) + for (const auto *N = Node.subContexts()[I]; N; N = N->next()) { + std::cout << Indent << "At Index " << I << ":" << std::endl; + printProfile(*N, Indent + Increment, Increment); + } + } + +public: + void writeContextual(const ContextNode &RootNode) override { + printProfile(RootNode, "", ""); + } +}; // 8657661246551306189 is theRoot. We expect 2 callsites and 2 counters - one // for the entry basic block and one for the loop. @@ -88,11 +93,8 @@ void printProfile(const ContextNode &Node, const std::string &Indent, // CHECK-NEXT: Counter values: 2 1 bool profileWriter() { - return __llvm_ctx_profile_fetch( - nullptr, +[](void *, const ContextNode &Node) { - printProfile(Node, "", " "); - return true; - }); + TestProfileWriter W; + return __llvm_ctx_profile_fetch(W); } int main(int argc, char **argv) { diff --git a/llvm/include/llvm/ProfileData/CtxInstrContextNode.h b/llvm/include/llvm/ProfileData/CtxInstrContextNode.h index 36a996632b71e..625ec8cd4d32a 100644 --- a/llvm/include/llvm/ProfileData/CtxInstrContextNode.h +++ b/llvm/include/llvm/ProfileData/CtxInstrContextNode.h @@ -112,6 +112,13 @@ class ContextNode final { uint64_t entrycount() const { return counters()[0]; } }; + +/// Abstraction for the parameter passed to `__llvm_ctx_profile_fetch`. +class ProfileWriter { +public: + virtual void writeContextual(const ctx_profile::ContextNode &RootNode) = 0; + virtual ~ProfileWriter() = default; +}; } // namespace ctx_profile } // namespace llvm #endif