Skip to content

Commit

Permalink
[ORC] Fix SimpleRemoteEPC data races.
Browse files Browse the repository at this point in the history
Adds a 'start' method to SimpleRemoteEPCTransport to defer transport startup
until the client has been configured. This avoids races on client members if the
first messages arrives while the client is being configured.

Also fixes races on the file descriptors in FDSimpleRemoteEPCTransport.
  • Loading branch information
lhames committed Sep 27, 2021
1 parent acd1399 commit 4b37462
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 57 deletions.
Expand Up @@ -21,6 +21,7 @@
#include "llvm/ExecutionEngine/Orc/Shared/SimplePackedSerialization.h"
#include "llvm/Support/Error.h"

#include <atomic>
#include <mutex>
#include <string>
#include <thread>
Expand Down Expand Up @@ -77,6 +78,13 @@ class SimpleRemoteEPCTransport {
public:
virtual ~SimpleRemoteEPCTransport();

/// Called during setup of the client to indicate that the client is ready
/// to receive messages.
///
/// Transport objects should not access the client until this method is
/// called.
virtual Error start() = 0;

/// Send a SimpleRemoteEPC message.
///
/// This function may be called concurrently. Subclasses should implement
Expand Down Expand Up @@ -107,14 +115,17 @@ class FDSimpleRemoteEPCTransport : public SimpleRemoteEPCTransport {

~FDSimpleRemoteEPCTransport() override;

Error start() override;

Error sendMessage(SimpleRemoteEPCOpcode OpC, uint64_t SeqNo,
ExecutorAddr TagAddr, ArrayRef<char> ArgBytes) override;

void disconnect() override;

private:
FDSimpleRemoteEPCTransport(SimpleRemoteEPCTransportClient &C, int InFD,
int OutFD);
int OutFD)
: C(C), InFD(InFD), OutFD(OutFD) {}

Error readBytes(char *Dst, size_t Size, bool *IsEOF = nullptr);
int writeBytes(const char *Src, size_t Size);
Expand All @@ -124,6 +135,7 @@ class FDSimpleRemoteEPCTransport : public SimpleRemoteEPCTransport {
SimpleRemoteEPCTransportClient &C;
std::thread ListenerThread;
int InFD, OutFD;
std::atomic<bool> Disconnected{false};
};

struct RemoteSymbolLookupSetElement {
Expand Down
18 changes: 3 additions & 15 deletions llvm/include/llvm/ExecutionEngine/Orc/SimpleRemoteEPC.h
Expand Up @@ -37,21 +37,12 @@ class SimpleRemoteEPC : public ExecutorProcessControl,
Create(TransportTCtorArgTs &&...TransportTCtorArgs) {
std::unique_ptr<SimpleRemoteEPC> SREPC(
new SimpleRemoteEPC(std::make_shared<SymbolStringPool>()));

// Prepare for setup packet.
std::promise<MSVCPExpected<SimpleRemoteEPCExecutorInfo>> EIP;
auto EIF = EIP.get_future();
SREPC->prepareToReceiveSetupMessage(EIP);
auto T = TransportT::Create(
*SREPC, std::forward<TransportTCtorArgTs>(TransportTCtorArgs)...);
if (!T)
return T.takeError();
auto EI = EIF.get();
if (!EI) {
(*T)->disconnect();
return EI.takeError();
}
if (auto Err = SREPC->setup(std::move(*T), std::move(*EI)))
SREPC->T = std::move(*T);
if (auto Err = SREPC->setup())
return joinErrors(std::move(Err), SREPC->disconnect());
return std::move(SREPC);
}
Expand Down Expand Up @@ -96,10 +87,7 @@ class SimpleRemoteEPC : public ExecutorProcessControl,

Error handleSetup(uint64_t SeqNo, ExecutorAddr TagAddr,
SimpleRemoteEPCArgBytesVector ArgBytes);
void prepareToReceiveSetupMessage(
std::promise<MSVCPExpected<SimpleRemoteEPCExecutorInfo>> &ExecInfoP);
Error setup(std::unique_ptr<SimpleRemoteEPCTransport> T,
SimpleRemoteEPCExecutorInfo EI);
Error setup();

Error handleResult(uint64_t SeqNo, ExecutorAddr TagAddr,
SimpleRemoteEPCArgBytesVector ArgBytes);
Expand Down
Expand Up @@ -104,6 +104,8 @@ class SimpleRemoteEPCServer : public SimpleRemoteEPCTransportClient {
if (!T)
return T.takeError();
Server->T = std::move(*T);
if (auto Err = Server->T->start())
return std::move(Err);

// If transport creation succeeds then start up services.
Server->Services = std::move(S.services());
Expand Down
43 changes: 18 additions & 25 deletions llvm/lib/ExecutionEngine/Orc/Shared/SimpleRemoteEPCUtils.cpp
Expand Up @@ -69,18 +69,18 @@ FDSimpleRemoteEPCTransport::Create(SimpleRemoteEPCTransportClient &C, int InFD,
#endif
}

FDSimpleRemoteEPCTransport::FDSimpleRemoteEPCTransport(
SimpleRemoteEPCTransportClient &C, int InFD, int OutFD)
: C(C), InFD(InFD), OutFD(OutFD) {
FDSimpleRemoteEPCTransport::~FDSimpleRemoteEPCTransport() {
#if LLVM_ENABLE_THREADS
ListenerThread = std::thread([this]() { listenLoop(); });
ListenerThread.join();
#endif
}

FDSimpleRemoteEPCTransport::~FDSimpleRemoteEPCTransport() {
Error FDSimpleRemoteEPCTransport::start() {
#if LLVM_ENABLE_THREADS
ListenerThread.join();
ListenerThread = std::thread([this]() { listenLoop(); });
return Error::success();
#endif
llvm_unreachable("Should not be called with LLVM_ENABLE_THREADS=Off");
}

Error FDSimpleRemoteEPCTransport::sendMessage(SimpleRemoteEPCOpcode OpC,
Expand All @@ -98,7 +98,7 @@ Error FDSimpleRemoteEPCTransport::sendMessage(SimpleRemoteEPCOpcode OpC,
TagAddr.getValue();

std::lock_guard<std::mutex> Lock(M);
if (OutFD == -1)
if (Disconnected)
return make_error<StringError>("FD-transport disconnected",
inconvertibleErrorCode());
if (int ErrNo = writeBytes(HeaderBuffer, FDMsgHeader::Size))
Expand All @@ -109,28 +109,21 @@ Error FDSimpleRemoteEPCTransport::sendMessage(SimpleRemoteEPCOpcode OpC,
}

void FDSimpleRemoteEPCTransport::disconnect() {
int CloseInFD = -1, CloseOutFD = -1;
{
std::lock_guard<std::mutex> Lock(M);
std::swap(InFD, CloseInFD);
std::swap(OutFD, CloseOutFD);
}
if (Disconnected)
return; // Return if already disconnected.

// If CloseOutFD == CloseInFD then set CloseOutFD to -1 up-front so that we
// don't double-close.
if (CloseOutFD == CloseInFD)
CloseOutFD = -1;
Disconnected = true;
bool CloseOutFD = InFD != OutFD;

// Close InFD.
if (CloseInFD != -1)
while (close(CloseInFD) == -1) {
if (errno == EBADF)
break;
}
while (close(InFD) == -1) {
if (errno == EBADF)
break;
}

// Close OutFD.
if (CloseOutFD != -1) {
while (close(CloseOutFD) == -1) {
if (CloseOutFD) {
while (close(OutFD) == -1) {
if (errno == EBADF)
break;
}
Expand Down Expand Up @@ -160,7 +153,7 @@ Error FDSimpleRemoteEPCTransport::readBytes(char *Dst, size_t Size,
continue;
else {
std::lock_guard<std::mutex> Lock(M);
if (InFD == -1 && IsEOF) { // Disconnected locally. Pretend this is EOF.
if (Disconnected && IsEOF) { // disconnect called, pretend this is EOF.
*IsEOF = true;
return Error::success();
}
Expand Down
43 changes: 27 additions & 16 deletions llvm/lib/ExecutionEngine/Orc/SimpleRemoteEPC.cpp
Expand Up @@ -238,12 +238,17 @@ Error SimpleRemoteEPC::handleSetup(uint64_t SeqNo, ExecutorAddr TagAddr,
return Error::success();
}

void SimpleRemoteEPC::prepareToReceiveSetupMessage(
std::promise<MSVCPExpected<SimpleRemoteEPCExecutorInfo>> &ExecInfoP) {
Error SimpleRemoteEPC::setup() {
using namespace SimpleRemoteEPCDefaultBootstrapSymbolNames;

std::promise<MSVCPExpected<SimpleRemoteEPCExecutorInfo>> EIP;
auto EIF = EIP.get_future();

// Prepare a handler for the setup packet.
PendingCallWrapperResults[0] =
[&](shared::WrapperFunctionResult SetupMsgBytes) {
if (const char *ErrMsg = SetupMsgBytes.getOutOfBandError()) {
ExecInfoP.set_value(
EIP.set_value(
make_error<StringError>(ErrMsg, inconvertibleErrorCode()));
return;
}
Expand All @@ -252,29 +257,35 @@ void SimpleRemoteEPC::prepareToReceiveSetupMessage(
shared::SPSInputBuffer IB(SetupMsgBytes.data(), SetupMsgBytes.size());
SimpleRemoteEPCExecutorInfo EI;
if (SPSSerialize::deserialize(IB, EI))
ExecInfoP.set_value(EI);
EIP.set_value(EI);
else
ExecInfoP.set_value(make_error<StringError>(
EIP.set_value(make_error<StringError>(
"Could not deserialize setup message", inconvertibleErrorCode()));
};
}

Error SimpleRemoteEPC::setup(std::unique_ptr<SimpleRemoteEPCTransport> T,
SimpleRemoteEPCExecutorInfo EI) {
using namespace SimpleRemoteEPCDefaultBootstrapSymbolNames;
// Start the transport.
if (auto Err = T->start())
return Err;

// Wait for setup packet to arrive.
auto EI = EIF.get();
if (!EI) {
T->disconnect();
return EI.takeError();
}

LLVM_DEBUG({
dbgs() << "SimpleRemoteEPC received setup message:\n"
<< " Triple: " << EI.TargetTriple << "\n"
<< " Page size: " << EI.PageSize << "\n"
<< " Triple: " << EI->TargetTriple << "\n"
<< " Page size: " << EI->PageSize << "\n"
<< " Bootstrap symbols:\n";
for (const auto &KV : EI.BootstrapSymbols)
for (const auto &KV : EI->BootstrapSymbols)
dbgs() << " " << KV.first() << ": "
<< formatv("{0:x16}", KV.second.getValue()) << "\n";
});
this->T = std::move(T);
TargetTriple = Triple(EI.TargetTriple);
PageSize = EI.PageSize;
BootstrapSymbols = std::move(EI.BootstrapSymbols);
TargetTriple = Triple(EI->TargetTriple);
PageSize = EI->PageSize;
BootstrapSymbols = std::move(EI->BootstrapSymbols);

if (auto Err = getBootstrapSymbols(
{{JDI.JITDispatchContext, ExecutorSessionObjectName},
Expand Down

0 comments on commit 4b37462

Please sign in to comment.