232 changes: 232 additions & 0 deletions llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,218 @@ class OrcRemoteTargetClient
std::vector<EHFrame> RegisteredEHFrames;
};

class RPCMMAlloc : public jitlink::JITLinkMemoryManager::Allocation {
using AllocationMap = DenseMap<unsigned, sys::MemoryBlock>;
using FinalizeContinuation =
jitlink::JITLinkMemoryManager::Allocation::FinalizeContinuation;
using ProtectionFlags = sys::Memory::ProtectionFlags;
using SegmentsRequestMap =
DenseMap<unsigned, jitlink::JITLinkMemoryManager::SegmentRequest>;

RPCMMAlloc(OrcRemoteTargetClient &Client, ResourceIdMgr::ResourceId Id)
: Client(Client), Id(Id) {}

public:
static Expected<std::unique_ptr<RPCMMAlloc>>
Create(OrcRemoteTargetClient &Client, ResourceIdMgr::ResourceId Id,
const SegmentsRequestMap &Request) {
auto *MM = new RPCMMAlloc(Client, Id);

if (Error Err = MM->allocateHostBlocks(Request))
return std::move(Err);

if (Error Err = MM->allocateTargetBlocks())
return std::move(Err);

return std::unique_ptr<RPCMMAlloc>(MM);
}

MutableArrayRef<char> getWorkingMemory(ProtectionFlags Seg) override {
assert(HostSegBlocks.count(Seg) && "No allocation for segment");
return {static_cast<char *>(HostSegBlocks[Seg].base()),
HostSegBlocks[Seg].allocatedSize()};
}

JITTargetAddress getTargetMemory(ProtectionFlags Seg) override {
assert(TargetSegBlocks.count(Seg) && "No allocation for segment");
return pointerToJITTargetAddress(TargetSegBlocks[Seg].base());
}

void finalizeAsync(FinalizeContinuation OnFinalize) override {
// Host allocations (working memory) remain ReadWrite.
OnFinalize(copyAndProtect());
}

Error deallocate() override {
// TODO: Cannot release target allocation. RPCAPI has no function
// symmetric to reserveMem(). Add RPC call like freeMem()?
return errorCodeToError(sys::Memory::releaseMappedMemory(HostAllocation));
}

private:
OrcRemoteTargetClient &Client;
ResourceIdMgr::ResourceId Id;
AllocationMap HostSegBlocks;
AllocationMap TargetSegBlocks;
JITTargetAddress TargetSegmentAddr;
sys::MemoryBlock HostAllocation;

Error allocateHostBlocks(const SegmentsRequestMap &Request) {
unsigned TargetPageSize = Client.getPageSize();

if (!isPowerOf2_64(static_cast<uint64_t>(TargetPageSize)))
return make_error<StringError>("Host page size is not a power of 2",
inconvertibleErrorCode());

auto TotalSize = calcTotalAllocSize(Request, TargetPageSize);
if (!TotalSize)
return TotalSize.takeError();

// Allocate one slab to cover all the segments.
const sys::Memory::ProtectionFlags ReadWrite =
static_cast<sys::Memory::ProtectionFlags>(sys::Memory::MF_READ |
sys::Memory::MF_WRITE);
std::error_code EC;
HostAllocation =
sys::Memory::allocateMappedMemory(*TotalSize, nullptr, ReadWrite, EC);
if (EC)
return errorCodeToError(EC);

char *SlabAddr = static_cast<char *>(HostAllocation.base());
char *SlabAddrEnd = SlabAddr + HostAllocation.allocatedSize();

// Allocate segment memory from the slab.
for (auto &KV : Request) {
const auto &Seg = KV.second;

uint64_t SegmentSize = Seg.getContentSize() + Seg.getZeroFillSize();
uint64_t AlignedSegmentSize = alignTo(SegmentSize, TargetPageSize);

// Zero out zero-fill memory.
char *ZeroFillBegin = SlabAddr + Seg.getContentSize();
memset(ZeroFillBegin, 0, Seg.getZeroFillSize());

// Record the block for this segment.
HostSegBlocks[KV.first] =
sys::MemoryBlock(SlabAddr, AlignedSegmentSize);

SlabAddr += AlignedSegmentSize;
assert(SlabAddr <= SlabAddrEnd && "Out of range");
}

return Error::success();
}

Error allocateTargetBlocks() {
// Reserve memory for all blocks on the target. We need as much space on
// the target as we allocated on the host.
TargetSegmentAddr = Client.reserveMem(Id, HostAllocation.allocatedSize(),
Client.getPageSize());
if (!TargetSegmentAddr)
return make_error<StringError>("Failed to reserve memory on the target",
inconvertibleErrorCode());

// Map memory blocks into the allocation, that match the host allocation.
JITTargetAddress TargetAllocAddr = TargetSegmentAddr;
for (const auto &KV : HostSegBlocks) {
size_t TargetAllocSize = KV.second.allocatedSize();

TargetSegBlocks[KV.first] =
sys::MemoryBlock(jitTargetAddressToPointer<void *>(TargetAllocAddr),
TargetAllocSize);

TargetAllocAddr += TargetAllocSize;
assert(TargetAllocAddr - TargetSegmentAddr <=
HostAllocation.allocatedSize() &&
"Out of range on target");
}

return Error::success();
}

Error copyAndProtect() {
unsigned Permissions = 0u;

// Copy segments one by one.
for (auto &KV : TargetSegBlocks) {
Permissions |= KV.first;

const sys::MemoryBlock &TargetBlock = KV.second;
const sys::MemoryBlock &HostBlock = HostSegBlocks.lookup(KV.first);

size_t TargetAllocSize = TargetBlock.allocatedSize();
auto TargetAllocAddr = pointerToJITTargetAddress(TargetBlock.base());
auto *HostAllocBegin = static_cast<const char *>(HostBlock.base());

bool CopyErr =
Client.writeMem(TargetAllocAddr, HostAllocBegin, TargetAllocSize);
if (CopyErr)
return createStringError(inconvertibleErrorCode(),
"Failed to copy %d segment to the target",
KV.first);
}

// Set permission flags for all segments at once.
bool ProtectErr =
Client.setProtections(Id, TargetSegmentAddr, Permissions);
if (ProtectErr)
return createStringError(inconvertibleErrorCode(),
"Failed to apply permissions for %d segment "
"on the target",
Permissions);
return Error::success();
}

static Expected<size_t>
calcTotalAllocSize(const SegmentsRequestMap &Request,
unsigned TargetPageSize) {
size_t TotalSize = 0;
for (const auto &KV : Request) {
const auto &Seg = KV.second;

if (Seg.getAlignment() > TargetPageSize)
return make_error<StringError>("Cannot request alignment higher than "
"page alignment on target",
inconvertibleErrorCode());

TotalSize = alignTo(TotalSize, TargetPageSize);
TotalSize += Seg.getContentSize();
TotalSize += Seg.getZeroFillSize();
}

return TotalSize;
}
};

class RemoteJITLinkMemoryManager : public jitlink::JITLinkMemoryManager {
public:
RemoteJITLinkMemoryManager(OrcRemoteTargetClient &Client,
ResourceIdMgr::ResourceId Id)
: Client(Client), Id(Id) {}

RemoteJITLinkMemoryManager(const RemoteJITLinkMemoryManager &) = delete;
RemoteJITLinkMemoryManager(RemoteJITLinkMemoryManager &&) = default;

RemoteJITLinkMemoryManager &
operator=(const RemoteJITLinkMemoryManager &) = delete;
RemoteJITLinkMemoryManager &
operator=(RemoteJITLinkMemoryManager &&) = delete;

~RemoteJITLinkMemoryManager() {
Client.destroyRemoteAllocator(Id);
LLVM_DEBUG(dbgs() << "Destroyed remote allocator " << Id << "\n");
}

Expected<std::unique_ptr<Allocation>>
allocate(const SegmentsRequestMap &Request) override {
return RPCMMAlloc::Create(Client, Id, Request);
}

private:
OrcRemoteTargetClient &Client;
ResourceIdMgr::ResourceId Id;
};

/// Remote indirect stubs manager.
class RemoteIndirectStubsManager : public IndirectStubsManager {
public:
Expand Down Expand Up @@ -504,6 +716,14 @@ class OrcRemoteTargetClient
return callB<exec::CallIntVoid>(Addr);
}

/// Call the int(int) function at the given address in the target and return
/// its result.
Expected<int> callIntInt(JITTargetAddress Addr, int Arg) {
LLVM_DEBUG(dbgs() << "Calling int(*)(int) " << format("0x%016" PRIx64, Addr)
<< "\n");
return callB<exec::CallIntInt>(Addr, Arg);
}

/// Call the int(int, char*[]) function at the given address in the target and
/// return its result.
Expected<int> callMain(JITTargetAddress Addr,
Expand Down Expand Up @@ -532,6 +752,18 @@ class OrcRemoteTargetClient
new RemoteRTDyldMemoryManager(*this, Id));
}

/// Create a JITLink-compatible memory manager which will allocate working
/// memory on the host and target memory on the remote target.
Expected<std::unique_ptr<RemoteJITLinkMemoryManager>>
createRemoteJITLinkMemoryManager() {
auto Id = AllocatorIds.getNext();
if (auto Err = callB<mem::CreateRemoteAllocator>(Id))
return std::move(Err);
LLVM_DEBUG(dbgs() << "Created remote allocator " << Id << "\n");
return std::unique_ptr<RemoteJITLinkMemoryManager>(
new RemoteJITLinkMemoryManager(*this, Id));
}

/// Create an RCIndirectStubsManager that will allocate stubs on the remote
/// target.
Expected<std::unique_ptr<RemoteIndirectStubsManager>>
Expand Down
8 changes: 8 additions & 0 deletions llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,14 @@ namespace exec {
static const char *getName() { return "CallIntVoid"; }
};

/// Call an 'int32_t(int32_t)'-type function on the remote, returns the called
/// function's return value.
class CallIntInt
: public rpc::Function<CallIntInt, int32_t(JITTargetAddress Addr, int)> {
public:
static const char *getName() { return "CallIntInt"; }
};

/// Call an 'int32_t(int32_t, char**)'-type function on the remote, returns the
/// called function's return value.
class CallMain
Expand Down
14 changes: 14 additions & 0 deletions llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class OrcRemoteTargetServer
EHFramesDeregister(std::move(EHFramesDeregister)) {
using ThisT = std::remove_reference_t<decltype(*this)>;
addHandler<exec::CallIntVoid>(*this, &ThisT::handleCallIntVoid);
addHandler<exec::CallIntInt>(*this, &ThisT::handleCallIntInt);
addHandler<exec::CallMain>(*this, &ThisT::handleCallMain);
addHandler<exec::CallVoidVoid>(*this, &ThisT::handleCallVoidVoid);
addHandler<mem::CreateRemoteAllocator>(*this,
Expand Down Expand Up @@ -168,6 +169,19 @@ class OrcRemoteTargetServer
return Result;
}

Expected<int32_t> handleCallIntInt(JITTargetAddress Addr, int Arg) {
using IntIntFnTy = int (*)(int);

IntIntFnTy Fn = reinterpret_cast<IntIntFnTy>(static_cast<uintptr_t>(Addr));

LLVM_DEBUG(dbgs() << " Calling " << format("0x%016x", Addr)
<< " with argument " << Arg << "\n");
int Result = Fn(Arg);
LLVM_DEBUG(dbgs() << " Result = " << Result << "\n");

return Result;
}

Expected<int32_t> handleCallMain(JITTargetAddress Addr,
std::vector<std::string> Args) {
using MainFnTy = int (*)(int, const char *[]);
Expand Down
11 changes: 4 additions & 7 deletions llvm/lib/ExecutionEngine/Orc/Layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,14 @@ IRMaterializationUnit::IRMaterializationUnit(
if (!llvm::empty(getStaticInitGVs(M))) {
size_t Counter = 0;

while (true) {
do {
std::string InitSymbolName;
raw_string_ostream(InitSymbolName)
<< "$." << M.getModuleIdentifier() << ".__inits." << Counter++;
InitSymbol = ES.intern(InitSymbolName);
if (SymbolFlags.count(InitSymbol))
continue;
SymbolFlags[InitSymbol] =
JITSymbolFlags::MaterializationSideEffectsOnly;
break;
}
} while (SymbolFlags.count(InitSymbol));

SymbolFlags[InitSymbol] = JITSymbolFlags::MaterializationSideEffectsOnly;
}
});
}
Expand Down