diff --git a/llvm/include/llvm/ExecutionEngine/Orc/SymbolStringPool.h b/llvm/include/llvm/ExecutionEngine/Orc/SymbolStringPool.h index 497e29da98bd5..f47956a65f2e7 100644 --- a/llvm/include/llvm/ExecutionEngine/Orc/SymbolStringPool.h +++ b/llvm/include/llvm/ExecutionEngine/Orc/SymbolStringPool.h @@ -32,6 +32,7 @@ class NonOwningSymbolStringPtr; class SymbolStringPool { friend class SymbolStringPoolTest; friend class SymbolStringPtrBase; + friend class SymbolStringPoolEntryUnsafe; // Implemented in DebugUtils.h. friend raw_ostream &operator<<(raw_ostream &OS, const SymbolStringPool &SSP); @@ -134,8 +135,8 @@ class SymbolStringPtrBase { /// Pointer to a pooled string representing a symbol name. class SymbolStringPtr : public SymbolStringPtrBase { - friend class OrcV2CAPIHelper; friend class SymbolStringPool; + friend class SymbolStringPoolEntryUnsafe; friend struct DenseMapInfo; public: @@ -189,6 +190,47 @@ class SymbolStringPtr : public SymbolStringPtrBase { } }; +/// Provides unsafe access to ownership operations on SymbolStringPtr. +/// This class can be used to manage SymbolStringPtr instances from C. +class SymbolStringPoolEntryUnsafe { +public: + using PoolEntry = SymbolStringPool::PoolMapEntry; + + SymbolStringPoolEntryUnsafe(PoolEntry *E) : E(E) {} + + /// Create an unsafe pool entry ref without changing the ref-count. + static SymbolStringPoolEntryUnsafe from(const SymbolStringPtr &S) { + return S.S; + } + + /// Consumes the given SymbolStringPtr without releasing the pool entry. + static SymbolStringPoolEntryUnsafe take(SymbolStringPtr &&S) { + PoolEntry *E = nullptr; + std::swap(E, S.S); + return E; + } + + PoolEntry *rawPtr() { return E; } + + /// Creates a SymbolStringPtr for this entry, with the SymbolStringPtr + /// retaining the entry as usual. + SymbolStringPtr copyToSymbolStringPtr() { return SymbolStringPtr(E); } + + /// Creates a SymbolStringPtr for this entry *without* performing a retain + /// operation during construction. + SymbolStringPtr moveToSymbolStringPtr() { + SymbolStringPtr S; + std::swap(S.S, E); + return S; + } + + void retain() { ++E->getValue(); } + void release() { --E->getValue(); } + +private: + PoolEntry *E = nullptr; +}; + /// Non-owning SymbolStringPool entry pointer. Instances are comparable with /// SymbolStringPtr instances and guaranteed to have the same hash, but do not /// affect the ref-count of the pooled string (and are therefore cheaper to diff --git a/llvm/lib/ExecutionEngine/Orc/OrcV2CBindings.cpp b/llvm/lib/ExecutionEngine/Orc/OrcV2CBindings.cpp index a73aec6d98c64..72314cceedf33 100644 --- a/llvm/lib/ExecutionEngine/Orc/OrcV2CBindings.cpp +++ b/llvm/lib/ExecutionEngine/Orc/OrcV2CBindings.cpp @@ -27,42 +27,6 @@ class InProgressLookupState; class OrcV2CAPIHelper { public: - using PoolEntry = SymbolStringPtr::PoolEntry; - using PoolEntryPtr = SymbolStringPtr::PoolEntryPtr; - - // Move from SymbolStringPtr to PoolEntryPtr (no change in ref count). - static PoolEntryPtr moveFromSymbolStringPtr(SymbolStringPtr S) { - PoolEntryPtr Result = nullptr; - std::swap(Result, S.S); - return Result; - } - - // Move from a PoolEntryPtr to a SymbolStringPtr (no change in ref count). - static SymbolStringPtr moveToSymbolStringPtr(PoolEntryPtr P) { - SymbolStringPtr S; - S.S = P; - return S; - } - - // Copy a pool entry to a SymbolStringPtr (increments ref count). - static SymbolStringPtr copyToSymbolStringPtr(PoolEntryPtr P) { - return SymbolStringPtr(P); - } - - static PoolEntryPtr getRawPoolEntryPtr(const SymbolStringPtr &S) { - return S.S; - } - - static void retainPoolEntry(PoolEntryPtr P) { - SymbolStringPtr S(P); - S.S = nullptr; - } - - static void releasePoolEntry(PoolEntryPtr P) { - SymbolStringPtr S; - S.S = P; - } - static InProgressLookupState *extractLookupState(LookupState &LS) { return LS.IPLS.release(); } @@ -75,10 +39,16 @@ class OrcV2CAPIHelper { } // namespace orc } // namespace llvm +inline LLVMOrcSymbolStringPoolEntryRef wrap(SymbolStringPoolEntryUnsafe E) { + return reinterpret_cast(E.rawPtr()); +} + +inline SymbolStringPoolEntryUnsafe unwrap(LLVMOrcSymbolStringPoolEntryRef E) { + return reinterpret_cast(E); +} + DEFINE_SIMPLE_CONVERSION_FUNCTIONS(ExecutionSession, LLVMOrcExecutionSessionRef) DEFINE_SIMPLE_CONVERSION_FUNCTIONS(SymbolStringPool, LLVMOrcSymbolStringPoolRef) -DEFINE_SIMPLE_CONVERSION_FUNCTIONS(OrcV2CAPIHelper::PoolEntry, - LLVMOrcSymbolStringPoolEntryRef) DEFINE_SIMPLE_CONVERSION_FUNCTIONS(MaterializationUnit, LLVMOrcMaterializationUnitRef) DEFINE_SIMPLE_CONVERSION_FUNCTIONS(MaterializationResponsibility, @@ -136,7 +106,7 @@ class OrcCAPIMaterializationUnit : public llvm::orc::MaterializationUnit { private: void discard(const JITDylib &JD, const SymbolStringPtr &Name) override { - Discard(Ctx, wrap(&JD), wrap(OrcV2CAPIHelper::getRawPoolEntryPtr(Name))); + Discard(Ctx, wrap(&JD), wrap(SymbolStringPoolEntryUnsafe::from(Name))); } std::string Name; @@ -184,7 +154,7 @@ static SymbolMap toSymbolMap(LLVMOrcCSymbolMapPairs Syms, size_t NumPairs) { SymbolMap SM; for (size_t I = 0; I != NumPairs; ++I) { JITSymbolFlags Flags = toJITSymbolFlags(Syms[I].Sym.Flags); - SM[OrcV2CAPIHelper::moveToSymbolStringPtr(unwrap(Syms[I].Name))] = { + SM[unwrap(Syms[I].Name).moveToSymbolStringPtr()] = { ExecutorAddr(Syms[I].Sym.Address), Flags}; } return SM; @@ -199,7 +169,7 @@ toSymbolDependenceMap(LLVMOrcCDependenceMapPairs Pairs, size_t NumPairs) { for (size_t J = 0; J != Pairs[I].Names.Length; ++J) { auto Sym = Pairs[I].Names.Symbols[J]; - Names.insert(OrcV2CAPIHelper::moveToSymbolStringPtr(unwrap(Sym))); + Names.insert(unwrap(Sym).moveToSymbolStringPtr()); } SDM[JD] = Names; } @@ -309,7 +279,7 @@ class CAPIDefinitionGenerator final : public DefinitionGenerator { CLookupSet.reserve(LookupSet.size()); for (auto &KV : LookupSet) { LLVMOrcSymbolStringPoolEntryRef Name = - ::wrap(OrcV2CAPIHelper::getRawPoolEntryPtr(KV.first)); + ::wrap(SymbolStringPoolEntryUnsafe::from(KV.first)); LLVMOrcSymbolLookupFlags SLF = fromSymbolLookupFlags(KV.second); CLookupSet.push_back({Name, SLF}); } @@ -353,8 +323,7 @@ void LLVMOrcSymbolStringPoolClearDeadEntries(LLVMOrcSymbolStringPoolRef SSP) { LLVMOrcSymbolStringPoolEntryRef LLVMOrcExecutionSessionIntern(LLVMOrcExecutionSessionRef ES, const char *Name) { - return wrap( - OrcV2CAPIHelper::moveFromSymbolStringPtr(unwrap(ES)->intern(Name))); + return wrap(SymbolStringPoolEntryUnsafe::take(unwrap(ES)->intern(Name))); } void LLVMOrcExecutionSessionLookup( @@ -374,7 +343,7 @@ void LLVMOrcExecutionSessionLookup( SymbolLookupSet SLS; for (size_t I = 0; I != SymbolsSize; ++I) - SLS.add(OrcV2CAPIHelper::moveToSymbolStringPtr(unwrap(Symbols[I].Name)), + SLS.add(unwrap(Symbols[I].Name).moveToSymbolStringPtr(), toSymbolLookupFlags(Symbols[I].LookupFlags)); unwrap(ES)->lookup( @@ -384,7 +353,7 @@ void LLVMOrcExecutionSessionLookup( SmallVector CResult; for (auto &KV : *Result) CResult.push_back(LLVMOrcCSymbolMapPair{ - wrap(OrcV2CAPIHelper::getRawPoolEntryPtr(KV.first)), + wrap(SymbolStringPoolEntryUnsafe::from(KV.first)), fromExecutorSymbolDef(KV.second)}); HandleResult(LLVMErrorSuccess, CResult.data(), CResult.size(), Ctx); } else @@ -394,15 +363,15 @@ void LLVMOrcExecutionSessionLookup( } void LLVMOrcRetainSymbolStringPoolEntry(LLVMOrcSymbolStringPoolEntryRef S) { - OrcV2CAPIHelper::retainPoolEntry(unwrap(S)); + unwrap(S).retain(); } void LLVMOrcReleaseSymbolStringPoolEntry(LLVMOrcSymbolStringPoolEntryRef S) { - OrcV2CAPIHelper::releasePoolEntry(unwrap(S)); + unwrap(S).release(); } const char *LLVMOrcSymbolStringPoolEntryStr(LLVMOrcSymbolStringPoolEntryRef S) { - return unwrap(S)->getKey().data(); + return unwrap(S).rawPtr()->getKey().data(); } LLVMOrcResourceTrackerRef @@ -452,10 +421,10 @@ LLVMOrcMaterializationUnitRef LLVMOrcCreateCustomMaterializationUnit( LLVMOrcMaterializationUnitDestroyFunction Destroy) { SymbolFlagsMap SFM; for (size_t I = 0; I != NumSyms; ++I) - SFM[OrcV2CAPIHelper::moveToSymbolStringPtr(unwrap(Syms[I].Name))] = + SFM[unwrap(Syms[I].Name).moveToSymbolStringPtr()] = toJITSymbolFlags(Syms[I].Flags); - auto IS = OrcV2CAPIHelper::moveToSymbolStringPtr(unwrap(InitSym)); + auto IS = unwrap(InitSym).moveToSymbolStringPtr(); return wrap(new OrcCAPIMaterializationUnit( Name, std::move(SFM), std::move(IS), Ctx, Materialize, Discard, Destroy)); @@ -476,9 +445,8 @@ LLVMOrcMaterializationUnitRef LLVMOrcLazyReexports( for (size_t I = 0; I != NumPairs; ++I) { auto pair = CallableAliases[I]; JITSymbolFlags Flags = toJITSymbolFlags(pair.Entry.Flags); - SymbolStringPtr Name = - OrcV2CAPIHelper::moveToSymbolStringPtr(unwrap(pair.Entry.Name)); - SAM[OrcV2CAPIHelper::moveToSymbolStringPtr(unwrap(pair.Name))] = + SymbolStringPtr Name = unwrap(pair.Entry.Name).moveToSymbolStringPtr(); + SAM[unwrap(pair.Name).moveToSymbolStringPtr()] = SymbolAliasMapEntry(Name, Flags); } @@ -511,7 +479,7 @@ LLVMOrcCSymbolFlagsMapPairs LLVMOrcMaterializationResponsibilityGetSymbols( safe_malloc(Symbols.size() * sizeof(LLVMOrcCSymbolFlagsMapPair))); size_t I = 0; for (auto const &pair : Symbols) { - auto Name = wrap(OrcV2CAPIHelper::getRawPoolEntryPtr(pair.first)); + auto Name = wrap(SymbolStringPoolEntryUnsafe::from(pair.first)); auto Flags = pair.second; Result[I] = {Name, fromJITSymbolFlags(Flags)}; I++; @@ -528,7 +496,7 @@ LLVMOrcSymbolStringPoolEntryRef LLVMOrcMaterializationResponsibilityGetInitializerSymbol( LLVMOrcMaterializationResponsibilityRef MR) { auto Sym = unwrap(MR)->getInitializerSymbol(); - return wrap(OrcV2CAPIHelper::getRawPoolEntryPtr(Sym)); + return wrap(SymbolStringPoolEntryUnsafe::from(Sym)); } LLVMOrcSymbolStringPoolEntryRef * @@ -541,7 +509,7 @@ LLVMOrcMaterializationResponsibilityGetRequestedSymbols( Symbols.size() * sizeof(LLVMOrcSymbolStringPoolEntryRef))); size_t I = 0; for (auto &Name : Symbols) { - Result[I] = wrap(OrcV2CAPIHelper::getRawPoolEntryPtr(Name)); + Result[I] = wrap(SymbolStringPoolEntryUnsafe::from(Name)); I++; } *NumSymbols = Symbols.size(); @@ -569,7 +537,7 @@ LLVMErrorRef LLVMOrcMaterializationResponsibilityDefineMaterializing( LLVMOrcCSymbolFlagsMapPairs Syms, size_t NumSyms) { SymbolFlagsMap SFM; for (size_t I = 0; I != NumSyms; ++I) - SFM[OrcV2CAPIHelper::moveToSymbolStringPtr(unwrap(Syms[I].Name))] = + SFM[unwrap(Syms[I].Name).moveToSymbolStringPtr()] = toJITSymbolFlags(Syms[I].Flags); return wrap(unwrap(MR)->defineMaterializing(std::move(SFM))); @@ -588,7 +556,7 @@ LLVMErrorRef LLVMOrcMaterializationResponsibilityDelegate( LLVMOrcMaterializationResponsibilityRef *Result) { SymbolNameSet Syms; for (size_t I = 0; I != NumSymbols; I++) { - Syms.insert(OrcV2CAPIHelper::moveToSymbolStringPtr(unwrap(Symbols[I]))); + Syms.insert(unwrap(Symbols[I]).moveToSymbolStringPtr()); } auto OtherMR = unwrap(MR)->delegate(Syms); @@ -605,7 +573,7 @@ void LLVMOrcMaterializationResponsibilityAddDependencies( LLVMOrcCDependenceMapPairs Dependencies, size_t NumPairs) { SymbolDependenceMap SDM = toSymbolDependenceMap(Dependencies, NumPairs); - auto Sym = OrcV2CAPIHelper::moveToSymbolStringPtr(unwrap(Name)); + auto Sym = unwrap(Name).moveToSymbolStringPtr(); unwrap(MR)->addDependencies(Sym, SDM); } @@ -698,7 +666,7 @@ LLVMErrorRef LLVMOrcCreateDynamicLibrarySearchGeneratorForProcess( DynamicLibrarySearchGenerator::SymbolPredicate Pred; if (Filter) Pred = [=](const SymbolStringPtr &Name) -> bool { - return Filter(FilterCtx, wrap(OrcV2CAPIHelper::getRawPoolEntryPtr(Name))); + return Filter(FilterCtx, wrap(SymbolStringPoolEntryUnsafe::from(Name))); }; auto ProcessSymsGenerator = @@ -724,7 +692,7 @@ LLVMErrorRef LLVMOrcCreateDynamicLibrarySearchGeneratorForPath( DynamicLibrarySearchGenerator::SymbolPredicate Pred; if (Filter) Pred = [=](const SymbolStringPtr &Name) -> bool { - return Filter(FilterCtx, wrap(OrcV2CAPIHelper::getRawPoolEntryPtr(Name))); + return Filter(FilterCtx, wrap(SymbolStringPoolEntryUnsafe::from(Name))); }; auto LibrarySymsGenerator = @@ -992,7 +960,7 @@ char LLVMOrcLLJITGetGlobalPrefix(LLVMOrcLLJITRef J) { LLVMOrcSymbolStringPoolEntryRef LLVMOrcLLJITMangleAndIntern(LLVMOrcLLJITRef J, const char *UnmangledName) { - return wrap(OrcV2CAPIHelper::moveFromSymbolStringPtr( + return wrap(SymbolStringPoolEntryUnsafe::take( unwrap(J)->mangleAndIntern(UnmangledName))); } diff --git a/llvm/unittests/ExecutionEngine/Orc/SymbolStringPoolTest.cpp b/llvm/unittests/ExecutionEngine/Orc/SymbolStringPoolTest.cpp index fc864ab1131b2..cd1cecd3244d6 100644 --- a/llvm/unittests/ExecutionEngine/Orc/SymbolStringPoolTest.cpp +++ b/llvm/unittests/ExecutionEngine/Orc/SymbolStringPoolTest.cpp @@ -142,4 +142,42 @@ TEST_F(SymbolStringPoolTest, NonOwningPointerRefCounts) { << "Copy-assignment of NonOwningSymbolStringPtr changed ref-count"; } } + +TEST_F(SymbolStringPoolTest, SymbolStringPoolEntryUnsafe) { + + auto A = SP.intern("a"); + EXPECT_EQ(getRefCount(A), 1U); + + { + // Try creating an unsafe pool entry ref from the given SymbolStringPtr. + // This should not affect the ref-count. + auto AUnsafe = SymbolStringPoolEntryUnsafe::from(A); + EXPECT_EQ(getRefCount(A), 1U); + + // Create a new SymbolStringPtr from the unsafe ref. This should increment + // the ref-count. + auto ACopy = AUnsafe.copyToSymbolStringPtr(); + EXPECT_EQ(getRefCount(A), 2U); + } + + { + // Create a copy of the original string. Move it into an unsafe ref, and + // then move it back. None of these operations should affect the ref-count. + auto ACopy = A; + EXPECT_EQ(getRefCount(A), 2U); + auto AUnsafe = SymbolStringPoolEntryUnsafe::take(std::move(ACopy)); + EXPECT_EQ(getRefCount(A), 2U); + ACopy = AUnsafe.moveToSymbolStringPtr(); + EXPECT_EQ(getRefCount(A), 2U); + } + + // Test manual retain / release. + auto AUnsafe = SymbolStringPoolEntryUnsafe::from(A); + EXPECT_EQ(getRefCount(A), 1U); + AUnsafe.retain(); + EXPECT_EQ(getRefCount(A), 2U); + AUnsafe.release(); + EXPECT_EQ(getRefCount(A), 1U); +} + } // namespace