Skip to content

Commit

Permalink
[ORC] Add a public unsafe-operations helper for SymbolStringPtr.
Browse files Browse the repository at this point in the history
SymbolStringPoolEntryUnsafe provides unsafe access to SymbolStringPtr objects,
allowing clients to manually retain and release pool entries, or consume or
create SymbolStringPtr instances without affecting an entry's ref-count. This
can be useful when writing C APIs that need to handle SymbolStringPtrs.

As part of this patch the LLVM-C API implementation is updated to use the new
utility, rather than the old, private OrcV2CAPIHelper utility.
  • Loading branch information
lhames committed Nov 27, 2023
1 parent 7138fab commit 56c72c7
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 64 deletions.
44 changes: 43 additions & 1 deletion llvm/include/llvm/ExecutionEngine/Orc/SymbolStringPool.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class NonOwningSymbolStringPtr;
class SymbolStringPool {
friend class SymbolStringPoolTest;
friend class SymbolStringPtrBase;
friend class SymbolStringPoolEntryUnsafe;

// Implemented in DebugUtils.h.
friend raw_ostream &operator<<(raw_ostream &OS, const SymbolStringPool &SSP);
Expand Down Expand Up @@ -134,8 +135,8 @@ class SymbolStringPtrBase {

/// Pointer to a pooled string representing a symbol name.
class SymbolStringPtr : public SymbolStringPtrBase {
friend class OrcV2CAPIHelper;
friend class SymbolStringPool;
friend class SymbolStringPoolEntryUnsafe;
friend struct DenseMapInfo<SymbolStringPtr>;

public:
Expand Down Expand Up @@ -189,6 +190,47 @@ class SymbolStringPtr : public SymbolStringPtrBase {
}
};

/// Provides unsafe access to ownership operations on SymbolStringPtr.
/// This class can be used to manage SymbolStringPtr instances from C.
class SymbolStringPoolEntryUnsafe {
public:
using PoolEntry = SymbolStringPool::PoolMapEntry;

SymbolStringPoolEntryUnsafe(PoolEntry *E) : E(E) {}

/// Create an unsafe pool entry ref without changing the ref-count.
static SymbolStringPoolEntryUnsafe from(const SymbolStringPtr &S) {
return S.S;
}

/// Consumes the given SymbolStringPtr without releasing the pool entry.
static SymbolStringPoolEntryUnsafe take(SymbolStringPtr &&S) {
PoolEntry *E = nullptr;
std::swap(E, S.S);
return E;
}

PoolEntry *rawPtr() { return E; }

/// Creates a SymbolStringPtr for this entry, with the SymbolStringPtr
/// retaining the entry as usual.
SymbolStringPtr copyToSymbolStringPtr() { return SymbolStringPtr(E); }

/// Creates a SymbolStringPtr for this entry *without* performing a retain
/// operation during construction.
SymbolStringPtr moveToSymbolStringPtr() {
SymbolStringPtr S;
std::swap(S.S, E);
return S;
}

void retain() { ++E->getValue(); }
void release() { --E->getValue(); }

private:
PoolEntry *E = nullptr;
};

/// Non-owning SymbolStringPool entry pointer. Instances are comparable with
/// SymbolStringPtr instances and guaranteed to have the same hash, but do not
/// affect the ref-count of the pooled string (and are therefore cheaper to
Expand Down
94 changes: 31 additions & 63 deletions llvm/lib/ExecutionEngine/Orc/OrcV2CBindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,42 +27,6 @@ class InProgressLookupState;

class OrcV2CAPIHelper {
public:
using PoolEntry = SymbolStringPtr::PoolEntry;
using PoolEntryPtr = SymbolStringPtr::PoolEntryPtr;

// Move from SymbolStringPtr to PoolEntryPtr (no change in ref count).
static PoolEntryPtr moveFromSymbolStringPtr(SymbolStringPtr S) {
PoolEntryPtr Result = nullptr;
std::swap(Result, S.S);
return Result;
}

// Move from a PoolEntryPtr to a SymbolStringPtr (no change in ref count).
static SymbolStringPtr moveToSymbolStringPtr(PoolEntryPtr P) {
SymbolStringPtr S;
S.S = P;
return S;
}

// Copy a pool entry to a SymbolStringPtr (increments ref count).
static SymbolStringPtr copyToSymbolStringPtr(PoolEntryPtr P) {
return SymbolStringPtr(P);
}

static PoolEntryPtr getRawPoolEntryPtr(const SymbolStringPtr &S) {
return S.S;
}

static void retainPoolEntry(PoolEntryPtr P) {
SymbolStringPtr S(P);
S.S = nullptr;
}

static void releasePoolEntry(PoolEntryPtr P) {
SymbolStringPtr S;
S.S = P;
}

static InProgressLookupState *extractLookupState(LookupState &LS) {
return LS.IPLS.release();
}
Expand All @@ -75,10 +39,16 @@ class OrcV2CAPIHelper {
} // namespace orc
} // namespace llvm

inline LLVMOrcSymbolStringPoolEntryRef wrap(SymbolStringPoolEntryUnsafe E) {
return reinterpret_cast<LLVMOrcSymbolStringPoolEntryRef>(E.rawPtr());
}

inline SymbolStringPoolEntryUnsafe unwrap(LLVMOrcSymbolStringPoolEntryRef E) {
return reinterpret_cast<SymbolStringPoolEntryUnsafe::PoolEntry *>(E);
}

DEFINE_SIMPLE_CONVERSION_FUNCTIONS(ExecutionSession, LLVMOrcExecutionSessionRef)
DEFINE_SIMPLE_CONVERSION_FUNCTIONS(SymbolStringPool, LLVMOrcSymbolStringPoolRef)
DEFINE_SIMPLE_CONVERSION_FUNCTIONS(OrcV2CAPIHelper::PoolEntry,
LLVMOrcSymbolStringPoolEntryRef)
DEFINE_SIMPLE_CONVERSION_FUNCTIONS(MaterializationUnit,
LLVMOrcMaterializationUnitRef)
DEFINE_SIMPLE_CONVERSION_FUNCTIONS(MaterializationResponsibility,
Expand Down Expand Up @@ -136,7 +106,7 @@ class OrcCAPIMaterializationUnit : public llvm::orc::MaterializationUnit {

private:
void discard(const JITDylib &JD, const SymbolStringPtr &Name) override {
Discard(Ctx, wrap(&JD), wrap(OrcV2CAPIHelper::getRawPoolEntryPtr(Name)));
Discard(Ctx, wrap(&JD), wrap(SymbolStringPoolEntryUnsafe::from(Name)));
}

std::string Name;
Expand Down Expand Up @@ -184,7 +154,7 @@ static SymbolMap toSymbolMap(LLVMOrcCSymbolMapPairs Syms, size_t NumPairs) {
SymbolMap SM;
for (size_t I = 0; I != NumPairs; ++I) {
JITSymbolFlags Flags = toJITSymbolFlags(Syms[I].Sym.Flags);
SM[OrcV2CAPIHelper::moveToSymbolStringPtr(unwrap(Syms[I].Name))] = {
SM[unwrap(Syms[I].Name).moveToSymbolStringPtr()] = {
ExecutorAddr(Syms[I].Sym.Address), Flags};
}
return SM;
Expand All @@ -199,7 +169,7 @@ toSymbolDependenceMap(LLVMOrcCDependenceMapPairs Pairs, size_t NumPairs) {

for (size_t J = 0; J != Pairs[I].Names.Length; ++J) {
auto Sym = Pairs[I].Names.Symbols[J];
Names.insert(OrcV2CAPIHelper::moveToSymbolStringPtr(unwrap(Sym)));
Names.insert(unwrap(Sym).moveToSymbolStringPtr());
}
SDM[JD] = Names;
}
Expand Down Expand Up @@ -309,7 +279,7 @@ class CAPIDefinitionGenerator final : public DefinitionGenerator {
CLookupSet.reserve(LookupSet.size());
for (auto &KV : LookupSet) {
LLVMOrcSymbolStringPoolEntryRef Name =
::wrap(OrcV2CAPIHelper::getRawPoolEntryPtr(KV.first));
::wrap(SymbolStringPoolEntryUnsafe::from(KV.first));
LLVMOrcSymbolLookupFlags SLF = fromSymbolLookupFlags(KV.second);
CLookupSet.push_back({Name, SLF});
}
Expand Down Expand Up @@ -353,8 +323,7 @@ void LLVMOrcSymbolStringPoolClearDeadEntries(LLVMOrcSymbolStringPoolRef SSP) {

LLVMOrcSymbolStringPoolEntryRef
LLVMOrcExecutionSessionIntern(LLVMOrcExecutionSessionRef ES, const char *Name) {
return wrap(
OrcV2CAPIHelper::moveFromSymbolStringPtr(unwrap(ES)->intern(Name)));
return wrap(SymbolStringPoolEntryUnsafe::take(unwrap(ES)->intern(Name)));
}

void LLVMOrcExecutionSessionLookup(
Expand All @@ -374,7 +343,7 @@ void LLVMOrcExecutionSessionLookup(

SymbolLookupSet SLS;
for (size_t I = 0; I != SymbolsSize; ++I)
SLS.add(OrcV2CAPIHelper::moveToSymbolStringPtr(unwrap(Symbols[I].Name)),
SLS.add(unwrap(Symbols[I].Name).moveToSymbolStringPtr(),
toSymbolLookupFlags(Symbols[I].LookupFlags));

unwrap(ES)->lookup(
Expand All @@ -384,7 +353,7 @@ void LLVMOrcExecutionSessionLookup(
SmallVector<LLVMOrcCSymbolMapPair> CResult;
for (auto &KV : *Result)
CResult.push_back(LLVMOrcCSymbolMapPair{
wrap(OrcV2CAPIHelper::getRawPoolEntryPtr(KV.first)),
wrap(SymbolStringPoolEntryUnsafe::from(KV.first)),
fromExecutorSymbolDef(KV.second)});
HandleResult(LLVMErrorSuccess, CResult.data(), CResult.size(), Ctx);
} else
Expand All @@ -394,15 +363,15 @@ void LLVMOrcExecutionSessionLookup(
}

void LLVMOrcRetainSymbolStringPoolEntry(LLVMOrcSymbolStringPoolEntryRef S) {
OrcV2CAPIHelper::retainPoolEntry(unwrap(S));
unwrap(S).retain();
}

void LLVMOrcReleaseSymbolStringPoolEntry(LLVMOrcSymbolStringPoolEntryRef S) {
OrcV2CAPIHelper::releasePoolEntry(unwrap(S));
unwrap(S).release();
}

const char *LLVMOrcSymbolStringPoolEntryStr(LLVMOrcSymbolStringPoolEntryRef S) {
return unwrap(S)->getKey().data();
return unwrap(S).rawPtr()->getKey().data();
}

LLVMOrcResourceTrackerRef
Expand Down Expand Up @@ -452,10 +421,10 @@ LLVMOrcMaterializationUnitRef LLVMOrcCreateCustomMaterializationUnit(
LLVMOrcMaterializationUnitDestroyFunction Destroy) {
SymbolFlagsMap SFM;
for (size_t I = 0; I != NumSyms; ++I)
SFM[OrcV2CAPIHelper::moveToSymbolStringPtr(unwrap(Syms[I].Name))] =
SFM[unwrap(Syms[I].Name).moveToSymbolStringPtr()] =
toJITSymbolFlags(Syms[I].Flags);

auto IS = OrcV2CAPIHelper::moveToSymbolStringPtr(unwrap(InitSym));
auto IS = unwrap(InitSym).moveToSymbolStringPtr();

return wrap(new OrcCAPIMaterializationUnit(
Name, std::move(SFM), std::move(IS), Ctx, Materialize, Discard, Destroy));
Expand All @@ -476,9 +445,8 @@ LLVMOrcMaterializationUnitRef LLVMOrcLazyReexports(
for (size_t I = 0; I != NumPairs; ++I) {
auto pair = CallableAliases[I];
JITSymbolFlags Flags = toJITSymbolFlags(pair.Entry.Flags);
SymbolStringPtr Name =
OrcV2CAPIHelper::moveToSymbolStringPtr(unwrap(pair.Entry.Name));
SAM[OrcV2CAPIHelper::moveToSymbolStringPtr(unwrap(pair.Name))] =
SymbolStringPtr Name = unwrap(pair.Entry.Name).moveToSymbolStringPtr();
SAM[unwrap(pair.Name).moveToSymbolStringPtr()] =
SymbolAliasMapEntry(Name, Flags);
}

Expand Down Expand Up @@ -511,7 +479,7 @@ LLVMOrcCSymbolFlagsMapPairs LLVMOrcMaterializationResponsibilityGetSymbols(
safe_malloc(Symbols.size() * sizeof(LLVMOrcCSymbolFlagsMapPair)));
size_t I = 0;
for (auto const &pair : Symbols) {
auto Name = wrap(OrcV2CAPIHelper::getRawPoolEntryPtr(pair.first));
auto Name = wrap(SymbolStringPoolEntryUnsafe::from(pair.first));
auto Flags = pair.second;
Result[I] = {Name, fromJITSymbolFlags(Flags)};
I++;
Expand All @@ -528,7 +496,7 @@ LLVMOrcSymbolStringPoolEntryRef
LLVMOrcMaterializationResponsibilityGetInitializerSymbol(
LLVMOrcMaterializationResponsibilityRef MR) {
auto Sym = unwrap(MR)->getInitializerSymbol();
return wrap(OrcV2CAPIHelper::getRawPoolEntryPtr(Sym));
return wrap(SymbolStringPoolEntryUnsafe::from(Sym));
}

LLVMOrcSymbolStringPoolEntryRef *
Expand All @@ -541,7 +509,7 @@ LLVMOrcMaterializationResponsibilityGetRequestedSymbols(
Symbols.size() * sizeof(LLVMOrcSymbolStringPoolEntryRef)));
size_t I = 0;
for (auto &Name : Symbols) {
Result[I] = wrap(OrcV2CAPIHelper::getRawPoolEntryPtr(Name));
Result[I] = wrap(SymbolStringPoolEntryUnsafe::from(Name));
I++;
}
*NumSymbols = Symbols.size();
Expand Down Expand Up @@ -569,7 +537,7 @@ LLVMErrorRef LLVMOrcMaterializationResponsibilityDefineMaterializing(
LLVMOrcCSymbolFlagsMapPairs Syms, size_t NumSyms) {
SymbolFlagsMap SFM;
for (size_t I = 0; I != NumSyms; ++I)
SFM[OrcV2CAPIHelper::moveToSymbolStringPtr(unwrap(Syms[I].Name))] =
SFM[unwrap(Syms[I].Name).moveToSymbolStringPtr()] =
toJITSymbolFlags(Syms[I].Flags);

return wrap(unwrap(MR)->defineMaterializing(std::move(SFM)));
Expand All @@ -588,7 +556,7 @@ LLVMErrorRef LLVMOrcMaterializationResponsibilityDelegate(
LLVMOrcMaterializationResponsibilityRef *Result) {
SymbolNameSet Syms;
for (size_t I = 0; I != NumSymbols; I++) {
Syms.insert(OrcV2CAPIHelper::moveToSymbolStringPtr(unwrap(Symbols[I])));
Syms.insert(unwrap(Symbols[I]).moveToSymbolStringPtr());
}
auto OtherMR = unwrap(MR)->delegate(Syms);

Expand All @@ -605,7 +573,7 @@ void LLVMOrcMaterializationResponsibilityAddDependencies(
LLVMOrcCDependenceMapPairs Dependencies, size_t NumPairs) {

SymbolDependenceMap SDM = toSymbolDependenceMap(Dependencies, NumPairs);
auto Sym = OrcV2CAPIHelper::moveToSymbolStringPtr(unwrap(Name));
auto Sym = unwrap(Name).moveToSymbolStringPtr();
unwrap(MR)->addDependencies(Sym, SDM);
}

Expand Down Expand Up @@ -698,7 +666,7 @@ LLVMErrorRef LLVMOrcCreateDynamicLibrarySearchGeneratorForProcess(
DynamicLibrarySearchGenerator::SymbolPredicate Pred;
if (Filter)
Pred = [=](const SymbolStringPtr &Name) -> bool {
return Filter(FilterCtx, wrap(OrcV2CAPIHelper::getRawPoolEntryPtr(Name)));
return Filter(FilterCtx, wrap(SymbolStringPoolEntryUnsafe::from(Name)));
};

auto ProcessSymsGenerator =
Expand All @@ -724,7 +692,7 @@ LLVMErrorRef LLVMOrcCreateDynamicLibrarySearchGeneratorForPath(
DynamicLibrarySearchGenerator::SymbolPredicate Pred;
if (Filter)
Pred = [=](const SymbolStringPtr &Name) -> bool {
return Filter(FilterCtx, wrap(OrcV2CAPIHelper::getRawPoolEntryPtr(Name)));
return Filter(FilterCtx, wrap(SymbolStringPoolEntryUnsafe::from(Name)));
};

auto LibrarySymsGenerator =
Expand Down Expand Up @@ -992,7 +960,7 @@ char LLVMOrcLLJITGetGlobalPrefix(LLVMOrcLLJITRef J) {

LLVMOrcSymbolStringPoolEntryRef
LLVMOrcLLJITMangleAndIntern(LLVMOrcLLJITRef J, const char *UnmangledName) {
return wrap(OrcV2CAPIHelper::moveFromSymbolStringPtr(
return wrap(SymbolStringPoolEntryUnsafe::take(
unwrap(J)->mangleAndIntern(UnmangledName)));
}

Expand Down
38 changes: 38 additions & 0 deletions llvm/unittests/ExecutionEngine/Orc/SymbolStringPoolTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,4 +142,42 @@ TEST_F(SymbolStringPoolTest, NonOwningPointerRefCounts) {
<< "Copy-assignment of NonOwningSymbolStringPtr changed ref-count";
}
}

TEST_F(SymbolStringPoolTest, SymbolStringPoolEntryUnsafe) {

auto A = SP.intern("a");
EXPECT_EQ(getRefCount(A), 1U);

{
// Try creating an unsafe pool entry ref from the given SymbolStringPtr.
// This should not affect the ref-count.
auto AUnsafe = SymbolStringPoolEntryUnsafe::from(A);
EXPECT_EQ(getRefCount(A), 1U);

// Create a new SymbolStringPtr from the unsafe ref. This should increment
// the ref-count.
auto ACopy = AUnsafe.copyToSymbolStringPtr();
EXPECT_EQ(getRefCount(A), 2U);
}

{
// Create a copy of the original string. Move it into an unsafe ref, and
// then move it back. None of these operations should affect the ref-count.
auto ACopy = A;
EXPECT_EQ(getRefCount(A), 2U);
auto AUnsafe = SymbolStringPoolEntryUnsafe::take(std::move(ACopy));
EXPECT_EQ(getRefCount(A), 2U);
ACopy = AUnsafe.moveToSymbolStringPtr();
EXPECT_EQ(getRefCount(A), 2U);
}

// Test manual retain / release.
auto AUnsafe = SymbolStringPoolEntryUnsafe::from(A);
EXPECT_EQ(getRefCount(A), 1U);
AUnsafe.retain();
EXPECT_EQ(getRefCount(A), 2U);
AUnsafe.release();
EXPECT_EQ(getRefCount(A), 1U);
}

} // namespace

3 comments on commit 56c72c7

@kstoimenov
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect this broke https://lab.llvm.org/buildbot/#/builders/168/builds/17115/steps/10/logs/stdio. Could you please revert or fix?

@lhames
Copy link
Contributor Author

@lhames lhames commented on 56c72c7 Nov 28, 2023 via email

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lhames
Copy link
Contributor Author

@lhames lhames commented on 56c72c7 Nov 28, 2023 via email

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.