From 406a4e09dd7cc89395a39d54d487a7ad7e342d2f Mon Sep 17 00:00:00 2001 From: Jakub Szewczyk Date: Thu, 26 Aug 2021 20:04:35 +0100 Subject: [PATCH 1/7] Endpoint constructor fix --- include/faabric/endpoint/Endpoint.h | 2 +- src/endpoint/Endpoint.cpp | 11 +++++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/include/faabric/endpoint/Endpoint.h b/include/faabric/endpoint/Endpoint.h index 86679558b..d335b66ea 100644 --- a/include/faabric/endpoint/Endpoint.h +++ b/include/faabric/endpoint/Endpoint.h @@ -9,7 +9,7 @@ namespace faabric::endpoint { class Endpoint { public: - Endpoint() = default; + Endpoint(); Endpoint(int port, int threadCount); diff --git a/src/endpoint/Endpoint.cpp b/src/endpoint/Endpoint.cpp index da8334164..dd7c667ca 100644 --- a/src/endpoint/Endpoint.cpp +++ b/src/endpoint/Endpoint.cpp @@ -1,20 +1,27 @@ #include #include +#include #include #include #include namespace faabric::endpoint { +Endpoint::Endpoint() + : Endpoint(faabric::util::getSystemConfig().endpointPort, + faabric::util::getSystemConfig().endpointNumThreads) +{} + Endpoint::Endpoint(int portIn, int threadCountIn) : port(portIn) , threadCount(threadCountIn) - , httpEndpoint(Pistache::Address(Pistache::Ipv4::any(), Pistache::Port(port))) + , httpEndpoint( + Pistache::Address(Pistache::Ipv4::any(), Pistache::Port(portIn))) {} void Endpoint::start(bool awaitSignal) { - SPDLOG_INFO("Starting HTTP endpoint on {}", port); + SPDLOG_INFO("Starting HTTP endpoint on {}, {} threads", port, threadCount); // Set up signal handler sigset_t signals; From 42840a856f03e60559e00393e9acb2399a2216bc Mon Sep 17 00:00:00 2001 From: Jakub Szewczyk Date: Thu, 26 Aug 2021 20:16:51 +0100 Subject: [PATCH 2/7] Redis script for setFunctionResult --- include/faabric/redis/Redis.h | 37 ++++++++++++++++++++++++++++------- src/redis/Redis.cpp | 37 +++++++++++++++++++++++++++-------- src/scheduler/Scheduler.cpp | 13 +++--------- 3 files changed, 62 insertions(+), 25 deletions(-) diff --git a/include/faabric/redis/Redis.h b/include/faabric/redis/Redis.h index d5e089d54..9a952c4a3 100644 --- a/include/faabric/redis/Redis.h +++ b/include/faabric/redis/Redis.h @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -24,6 +25,7 @@ class RedisInstance explicit RedisInstance(RedisRole role); std::string delifeqSha; + std::string schedPublishSha; std::string ip; std::string hostname; @@ -35,15 +37,31 @@ class RedisInstance std::mutex scriptsLock; std::string loadScript(redisContext* context, - const std::string& scriptBody); + const std::string_view scriptBody); // Script to delete a key if it equals a given value - const std::string delifeqCmd = - "if redis.call(\"GET\", KEYS[1]) == ARGV[1] then \n" - " return redis.call(\"DEL\", KEYS[1]) \n" - "else \n" - " return 0 \n" - "end"; + const std::string_view delifeqCmd = R"---( +if redis.call('GET', KEYS[1]) == ARGV[1] then + return redis.call('DEL', KEYS[1]) +else + return 0 +end +)---"; + + // Script to push and expire function execution results avoiding extra + // copies and round-trips + const std::string_view schedPublishCmd = R"---( +local key = KEYS[1] +local status_key = KEYS[2] +local result = ARGV[1] +local result_expiry = tonumber(ARGV[2]) +local status_expiry = tonumber(ARGV[3]) +redis.call('RPUSH', key, result) +redis.call('EXPIRE', key, result_expiry) +redis.call('SET', status_key, result) +redis.call('EXPIRE', status_key, status_expiry) +return 0 +)---"; }; class Redis @@ -181,6 +199,11 @@ class Redis long buffLen, long nElems); + // Scheduler result publish + void publishSchedulerResult(const std::string& key, + const std::string& status_key, + const std::vector& result); + private: explicit Redis(const RedisInstance& instance); diff --git a/src/redis/Redis.cpp b/src/redis/Redis.cpp index 8bd7e86c3..60ca1d80f 100644 --- a/src/redis/Redis.cpp +++ b/src/redis/Redis.cpp @@ -1,11 +1,12 @@ #include +#include +#include +#include #include #include #include - -#include -#include +#include #include namespace faabric::redis { @@ -27,15 +28,16 @@ RedisInstance::RedisInstance(RedisRole roleIn) port = std::stoi(portStr); // Load scripts - if (delifeqSha.empty()) { + if (delifeqSha.empty() || schedPublishSha.empty()) { std::unique_lock lock(scriptsLock); - if (delifeqSha.empty()) { + if (delifeqSha.empty() || schedPublishSha.empty()) { printf("Loading scripts for Redis instance at %s\n", hostname.c_str()); redisContext* context = redisConnect(ip.c_str(), port); delifeqSha = this->loadScript(context, delifeqCmd); + schedPublishSha = this->loadScript(context, schedPublishCmd); redisFree(context); } @@ -43,10 +45,10 @@ RedisInstance::RedisInstance(RedisRole roleIn) } std::string RedisInstance::loadScript(redisContext* context, - const std::string& scriptBody) + const std::string_view scriptBody) { - auto reply = - (redisReply*)redisCommand(context, "SCRIPT LOAD %s", scriptBody.c_str()); + auto reply = (redisReply*)redisCommand( + context, "SCRIPT LOAD %b", scriptBody.data(), scriptBody.size()); if (reply == nullptr) { throw std::runtime_error("Error loading script from Redis"); @@ -774,4 +776,23 @@ void Redis::dequeueBytes(const std::string& queueName, freeReplyObject(reply); } + +void Redis::publishSchedulerResult(const std::string& key, + const std::string& status_key, + const std::vector& result) +{ + auto reply = (redisReply*)redisCommand(context, + "EVALSHA %s 2 %s %s %b %d %d", + instance.schedPublishSha.c_str(), + // keys + key.c_str(), + status_key.c_str(), + // argv + result.data(), + result.size(), + RESULT_KEY_EXPIRY, + STATUS_KEY_EXPIRY); + extractScriptResult(reply); +} + } diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index 85934ea4f..f09083bf3 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -324,8 +324,8 @@ std::vector Scheduler::callFunctions( if (offset < nMessages) { // Schedule first to already registered hosts for (const auto& h : thisRegisteredHosts) { - int nOnThisHost = - scheduleFunctionsOnHost(h, req, executed, offset, nullptr); + int nOnThisHost = scheduleFunctionsOnHost( + h, req, executed, offset, &snapshotData); offset += nOnThisHost; if (offset >= nMessages) { @@ -684,14 +684,7 @@ void Scheduler::setFunctionResult(faabric::Message& msg) // Write the successful result to the result queue std::vector inputData = faabric::util::messageToBytes(msg); - redis.enqueueBytes(key, inputData); - - // Set the result key to expire - redis.expire(key, RESULT_KEY_EXPIRY); - - // Set long-lived result for function too - redis.set(msg.statuskey(), inputData); - redis.expire(key, STATUS_KEY_EXPIRY); + redis.publishSchedulerResult(key, msg.statuskey(), inputData); } void Scheduler::registerThread(uint32_t msgId) From dde0f8038c836b9f87e448addcb3e0f5f2f32ec3 Mon Sep 17 00:00:00 2001 From: Jakub Szewczyk Date: Thu, 26 Aug 2021 20:23:28 +0100 Subject: [PATCH 3/7] Avoid lock on Scheduler::vacateSlot --- include/faabric/scheduler/Scheduler.h | 1 + src/scheduler/Scheduler.cpp | 13 ++++++++----- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/include/faabric/scheduler/Scheduler.h b/include/faabric/scheduler/Scheduler.h index e0d1c9147..ae6c7c3bc 100644 --- a/include/faabric/scheduler/Scheduler.h +++ b/include/faabric/scheduler/Scheduler.h @@ -203,6 +203,7 @@ class Scheduler const std::string& otherHost); faabric::HostResources thisHostResources; + std::atomic thisHostUsedSlots; std::set availableHostsCache; std::unordered_map> registeredHosts; diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index f09083bf3..deab7f863 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -156,8 +156,7 @@ void Scheduler::removeRegisteredHost(const std::string& host, void Scheduler::vacateSlot() { - faabric::util::FullLock lock(mx); - thisHostResources.set_usedslots(thisHostResources.usedslots() - 1); + thisHostUsedSlots.fetch_sub(1, std::memory_order_acq_rel); } void Scheduler::notifyExecutorShutdown(Executor* exec, @@ -302,7 +301,8 @@ std::vector Scheduler::callFunctions( int slots = thisHostResources.slots(); // Work out available cores, flooring at zero - int available = slots - thisHostResources.usedslots(); + int available = + slots - this->thisHostUsedSlots.load(std::memory_order_acquire); available = std::max(available, 0); // Claim as many as we can @@ -391,8 +391,8 @@ std::vector Scheduler::callFunctions( // executor, for anything else we want one Executor per function in flight if (!localMessageIdxs.empty()) { // Update slots - thisHostResources.set_usedslots(thisHostResources.usedslots() + - localMessageIdxs.size()); + this->thisHostUsedSlots.fetch_add((int32_t)localMessageIdxs.size(), + std::memory_order_acquire); if (isThreads) { // Threads use the existing executor. We assume there's only one @@ -779,12 +779,15 @@ faabric::Message Scheduler::getFunctionResult(unsigned int messageId, faabric::HostResources Scheduler::getThisHostResources() { + thisHostResources.set_usedslots( + this->thisHostUsedSlots.load(std::memory_order_acquire)); return thisHostResources; } void Scheduler::setThisHostResources(faabric::HostResources& res) { thisHostResources = res; + this->thisHostUsedSlots.store(res.usedslots(), std::memory_order_release); } faabric::HostResources Scheduler::getHostResources(const std::string& host) From 1e770117878d66f8fb755c81684856d053387b28 Mon Sep 17 00:00:00 2001 From: Jakub Szewczyk Date: Thu, 26 Aug 2021 20:26:31 +0100 Subject: [PATCH 4/7] Unlock scheduler while spinning up new executors --- include/faabric/scheduler/Scheduler.h | 4 +++- src/scheduler/Scheduler.cpp | 16 +++++++++++----- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/include/faabric/scheduler/Scheduler.h b/include/faabric/scheduler/Scheduler.h index ae6c7c3bc..48f7f6e5e 100644 --- a/include/faabric/scheduler/Scheduler.h +++ b/include/faabric/scheduler/Scheduler.h @@ -215,7 +215,9 @@ class Scheduler std::vector getUnregisteredHosts(const std::string& funcStr, bool noCache = false); - std::shared_ptr claimExecutor(faabric::Message& msg); + std::shared_ptr claimExecutor( + faabric::Message& msg, + faabric::util::FullLock& schedulerLock); faabric::HostResources getHostResources(const std::string& host); diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index deab7f863..be09bc4bd 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -403,7 +403,7 @@ std::vector Scheduler::callFunctions( std::shared_ptr e = nullptr; if (thisExecutors.empty()) { // Create executor if not exists - e = claimExecutor(firstMsg); + e = claimExecutor(firstMsg, lock); } else if (thisExecutors.size() == 1) { // Use existing executor if exists e = thisExecutors.back(); @@ -422,7 +422,7 @@ std::vector Scheduler::callFunctions( } else { // Non-threads require one executor per task for (auto i : localMessageIdxs) { - std::shared_ptr e = claimExecutor(firstMsg); + std::shared_ptr e = claimExecutor(firstMsg, lock); e->executeTasks({ i }, req); } } @@ -598,7 +598,9 @@ Scheduler::getRecordedMessagesShared() return recordedMessagesShared; } -std::shared_ptr Scheduler::claimExecutor(faabric::Message& msg) +std::shared_ptr Scheduler::claimExecutor( + faabric::Message& msg, + faabric::util::FullLock& schedulerLock) { std::string funcStr = faabric::util::funcToString(msg, false); @@ -622,8 +624,12 @@ std::shared_ptr Scheduler::claimExecutor(faabric::Message& msg) int nExecutors = thisExecutors.size(); SPDLOG_DEBUG( "Scaling {} from {} -> {}", funcStr, nExecutors, nExecutors + 1); - - thisExecutors.emplace_back(factory->createExecutor(msg)); + // Spinning up a new executor can be lengthy, allow other things to run + // in parallel + schedulerLock.unlock(); + auto executor = factory->createExecutor(msg); + schedulerLock.lock(); + thisExecutors.push_back(std::move(executor)); claimed = thisExecutors.back(); // Claim it From 40bf4534cf3f1faae567913fadb0de5501c43211 Mon Sep 17 00:00:00 2001 From: Jakub Szewczyk Date: Thu, 26 Aug 2021 20:38:56 +0100 Subject: [PATCH 5/7] Bypass redis in local, sync function calls --- include/faabric/scheduler/Scheduler.h | 5 +++ src/endpoint/FaabricEndpointHandler.cpp | 5 +++ src/proto/faabric.proto | 11 ++--- src/scheduler/Scheduler.cpp | 56 +++++++++++++++++++++++-- 4 files changed, 68 insertions(+), 9 deletions(-) diff --git a/include/faabric/scheduler/Scheduler.h b/include/faabric/scheduler/Scheduler.h index 48f7f6e5e..d721adf03 100644 --- a/include/faabric/scheduler/Scheduler.h +++ b/include/faabric/scheduler/Scheduler.h @@ -207,6 +207,11 @@ class Scheduler std::set availableHostsCache; std::unordered_map> registeredHosts; + std::unordered_map>> + localResults; + std::mutex localResultsMutex; + std::vector recordedMessagesAll; std::vector recordedMessagesLocal; std::vector> diff --git a/src/endpoint/FaabricEndpointHandler.cpp b/src/endpoint/FaabricEndpointHandler.cpp index debc5314e..e5669a145 100644 --- a/src/endpoint/FaabricEndpointHandler.cpp +++ b/src/endpoint/FaabricEndpointHandler.cpp @@ -110,6 +110,11 @@ std::pair FaabricEndpointHandler::executeFunction( faabric::util::setMessageId(msg); std::string thisHost = faabric::util::getSystemConfig().endpointHost; msg.set_masterhost(thisHost); + // This is set to false by the scheduler if the function ends up being sent + // elsewhere + if (!msg.isasync()) { + msg.set_executeslocally(true); + } auto tid = (pid_t)syscall(SYS_gettid); const std::string funcStr = faabric::util::funcToString(msg, true); diff --git a/src/proto/faabric.proto b/src/proto/faabric.proto index 296f4e232..07c7ef476 100644 --- a/src/proto/faabric.proto +++ b/src/proto/faabric.proto @@ -118,13 +118,14 @@ message Message { int64 timestamp = 14; string resultKey = 15; - string statusKey = 16; + bool executesLocally = 16; + string statusKey = 17; - string executedHost = 17; - int64 finishTimestamp = 18; + string executedHost = 18; + int64 finishTimestamp = 19; - bool isAsync = 19; - bool isPython = 20; + bool isAsync = 20; + bool isPython = 21; bool isStatusRequest = 22; bool isExecGraphRequest = 23; diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index be09bc4bd..a63cf9b57 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -422,7 +422,14 @@ std::vector Scheduler::callFunctions( } else { // Non-threads require one executor per task for (auto i : localMessageIdxs) { - std::shared_ptr e = claimExecutor(firstMsg, lock); + faabric::Message& localMsg = req->mutable_messages()->at(i); + if (localMsg.executeslocally()) { + faabric::util::UniqueLock resultsLock(localResultsMutex); + localResults.insert( + { localMsg.id(), + std::promise>() }); + } + std::shared_ptr e = claimExecutor(localMsg, lock); e->executeTasks({ i }, req); } } @@ -522,7 +529,9 @@ int Scheduler::scheduleFunctionsOnHost( // Add messages int nOnThisHost = std::min(available, remainder); for (int i = offset; i < (offset + nOnThisHost); i++) { - *hostRequest->add_messages() = req->messages().at(i); + auto* newMsg = hostRequest->add_messages(); + *newMsg = req->messages().at(i); + newMsg->set_executeslocally(false); records.at(i) = host; } @@ -683,6 +692,18 @@ void Scheduler::setFunctionResult(faabric::Message& msg) // Set finish timestamp msg.set_finishtimestamp(faabric::util::getGlobalClock().epochMillis()); + if (msg.executeslocally()) { + faabric::util::UniqueLock resultsLock(localResultsMutex); + auto it = localResults.find(msg.id()); + if (it != localResults.end()) { + it->second.set_value(std::make_unique(msg)); + } + // Sync messages can't have their results read twice, so skip redis + if (!msg.isasync()) { + return; + } + } + std::string key = msg.resultkey(); if (key.empty()) { throw std::runtime_error("Result key empty. Cannot publish result"); @@ -744,13 +765,40 @@ int32_t Scheduler::awaitThreadResult(uint32_t messageId) faabric::Message Scheduler::getFunctionResult(unsigned int messageId, int timeoutMs) { + bool isBlocking = timeoutMs > 0; + if (messageId == 0) { throw std::runtime_error("Must provide non-zero message ID"); } - redis::Redis& redis = redis::Redis::getQueue(); + do { + std::future> fut; + { + faabric::util::UniqueLock resultsLock(localResultsMutex); + auto it = localResults.find(messageId); + if (it == localResults.end()) { + break; // fallback to redis + } + fut = it->second.get_future(); + } + if (!isBlocking) { + auto status = fut.wait_for(std::chrono::milliseconds(timeoutMs)); + if (status == std::future_status::timeout) { + faabric::Message msgResult; + msgResult.set_type(faabric::Message_MessageType_EMPTY); + return msgResult; + } + } else { + fut.wait(); + } + { + faabric::util::UniqueLock resultsLock(localResultsMutex); + localResults.erase(messageId); + } + return *fut.get(); + } while (0); - bool isBlocking = timeoutMs > 0; + redis::Redis& redis = redis::Redis::getQueue(); std::string resultKey = faabric::util::resultKeyFromMessageId(messageId); From 0a35ed15082d7d41bfea7b72840481dd39c51eb1 Mon Sep 17 00:00:00 2001 From: Jakub Szewczyk Date: Thu, 26 Aug 2021 20:40:44 +0100 Subject: [PATCH 6/7] Significantly faster (10x) file to bytes reading --- src/util/files.cpp | 45 +++++++++++++++++++++++++++------------------ 1 file changed, 27 insertions(+), 18 deletions(-) diff --git a/src/util/files.cpp b/src/util/files.cpp index 527e5e141..1bcbcfaaf 100644 --- a/src/util/files.cpp +++ b/src/util/files.cpp @@ -8,6 +8,10 @@ #include #include +#include +#include +#include + namespace faabric::util { std::string readFileToString(const std::string& path) { @@ -21,25 +25,30 @@ std::string readFileToString(const std::string& path) std::vector readFileToBytes(const std::string& path) { - std::ifstream file(path, std::ios::binary); - - // Stop eating new lines in binary mode - file.unsetf(std::ios::skipws); - - // Reserve space - std::streampos fileSize; - file.seekg(0, std::ios::end); - fileSize = file.tellg(); - + int fd = open(path.c_str(), O_RDONLY); + if (fd < 0) { + throw std::runtime_error("Couldn't open file " + path); + } + struct stat statbuf; + int staterr = fstat(fd, &statbuf); + if (staterr < 0) { + throw std::runtime_error("Couldn't stat file " + path); + } + size_t fsize = statbuf.st_size; + posix_fadvise(fd, 0, 0, POSIX_FADV_SEQUENTIAL); std::vector result; - result.reserve(fileSize); - - // Read the data - file.seekg(0, std::ios::beg); - result.insert(result.begin(), - std::istreambuf_iterator(file), - std::istreambuf_iterator()); - + result.resize(fsize); + int cpos = 0; + while (cpos < fsize) { + int rc = read(fd, result.data(), fsize - cpos); + if (rc < 0) { + perror("Couldn't read file"); + throw std::runtime_error("Couldn't read file " + path); + } else { + cpos += rc; + } + } + close(fd); return result; } From 87f25e041b1c797ed9e6f6cc22cc39d4644d01d6 Mon Sep 17 00:00:00 2001 From: Jakub Szewczyk Date: Thu, 26 Aug 2021 20:46:51 +0100 Subject: [PATCH 7/7] Replace timeout exceptions with std::optional I found the exceptions to make debugging a bit harder when the endpoint failed, this also simplifies the main loop of the zmq server --- include/faabric/transport/MessageEndpoint.h | 13 ++- src/scheduler/MpiWorld.cpp | 2 +- src/transport/MessageEndpoint.cpp | 23 ++-- src/transport/MessageEndpointServer.cpp | 107 ++++++++---------- src/transport/MpiMessageEndpoint.cpp | 6 +- .../test_message_endpoint_client.cpp | 10 +- 6 files changed, 80 insertions(+), 81 deletions(-) diff --git a/include/faabric/transport/MessageEndpoint.h b/include/faabric/transport/MessageEndpoint.h index 0e4dbef9d..1f5e8f07e 100644 --- a/include/faabric/transport/MessageEndpoint.h +++ b/include/faabric/transport/MessageEndpoint.h @@ -3,6 +3,7 @@ #include #include +#include #include #include @@ -56,11 +57,11 @@ class MessageEndpoint size_t dataSize, bool more); - Message doRecv(zmq::socket_t& socket, int size = 0); + std::optional doRecv(zmq::socket_t& socket, int size = 0); - Message recvBuffer(zmq::socket_t& socket, int size); + std::optional recvBuffer(zmq::socket_t& socket, int size); - Message recvNoBuffer(zmq::socket_t& socket); + std::optional recvNoBuffer(zmq::socket_t& socket); }; class AsyncSendMessageEndpoint final : public MessageEndpoint @@ -104,7 +105,7 @@ class RecvMessageEndpoint : public MessageEndpoint virtual ~RecvMessageEndpoint(){}; - virtual Message recv(int size = 0); + virtual std::optional recv(int size = 0); protected: zmq::socket_t socket; @@ -116,7 +117,7 @@ class AsyncRecvMessageEndpoint final : public RecvMessageEndpoint AsyncRecvMessageEndpoint(int portIn, int timeoutMs = DEFAULT_RECV_TIMEOUT_MS); - Message recv(int size = 0) override; + std::optional recv(int size = 0) override; }; class SyncRecvMessageEndpoint final : public RecvMessageEndpoint @@ -125,7 +126,7 @@ class SyncRecvMessageEndpoint final : public RecvMessageEndpoint SyncRecvMessageEndpoint(int portIn, int timeoutMs = DEFAULT_RECV_TIMEOUT_MS); - Message recv(int size = 0) override; + std::optional recv(int size = 0) override; void sendResponse(const uint8_t* data, int size); }; diff --git a/src/scheduler/MpiWorld.cpp b/src/scheduler/MpiWorld.cpp index 33cfa1979..1be498a53 100644 --- a/src/scheduler/MpiWorld.cpp +++ b/src/scheduler/MpiWorld.cpp @@ -59,7 +59,7 @@ faabric::MpiHostsToRanksMessage MpiWorld::recvMpiHostRankMsg() } SPDLOG_TRACE("Receiving MPI host ranks on {}", basePort); - faabric::transport::Message m = ranksRecvEndpoint->recv(); + faabric::transport::Message m = ranksRecvEndpoint->recv().value(); PARSE_MSG(faabric::MpiHostsToRanksMessage, m.data(), m.size()); return msg; diff --git a/src/transport/MessageEndpoint.cpp b/src/transport/MessageEndpoint.cpp index 2ef6d92b6..a62b43d68 100644 --- a/src/transport/MessageEndpoint.cpp +++ b/src/transport/MessageEndpoint.cpp @@ -135,7 +135,7 @@ void MessageEndpoint::doSend(zmq::socket_t& socket, "send") } -Message MessageEndpoint::doRecv(zmq::socket_t& socket, int size) +std::optional MessageEndpoint::doRecv(zmq::socket_t& socket, int size) { assert(tid == std::this_thread::get_id()); assert(size >= 0); @@ -147,7 +147,8 @@ Message MessageEndpoint::doRecv(zmq::socket_t& socket, int size) return recvBuffer(socket, size); } -Message MessageEndpoint::recvBuffer(zmq::socket_t& socket, int size) +std::optional MessageEndpoint::recvBuffer(zmq::socket_t& socket, + int size) { // Pre-allocate buffer to avoid copying data Message msg(size); @@ -158,7 +159,7 @@ Message MessageEndpoint::recvBuffer(zmq::socket_t& socket, int size) if (!res.has_value()) { SPDLOG_TRACE("Timed out receiving message of size {}", size); - throw MessageTimeoutException("Timed out receiving message"); + return std::nullopt; } if (res.has_value() && (res->size != res->untruncated_size)) { @@ -181,7 +182,7 @@ Message MessageEndpoint::recvBuffer(zmq::socket_t& socket, int size) return msg; } -Message MessageEndpoint::recvNoBuffer(zmq::socket_t& socket) +std::optional MessageEndpoint::recvNoBuffer(zmq::socket_t& socket) { // Allocate a message to receive data zmq::message_t msg; @@ -190,7 +191,7 @@ Message MessageEndpoint::recvNoBuffer(zmq::socket_t& socket) auto res = socket.recv(msg); if (!res.has_value()) { SPDLOG_TRACE("Timed out receiving message with no size"); - throw MessageTimeoutException("Timed out receiving message"); + return std::nullopt; } } catch (zmq::error_t& e) { if (e.num() == ZMQ_ETERM) { @@ -274,7 +275,11 @@ Message SyncSendMessageEndpoint::sendAwaitResponse(const uint8_t* data, // Do the receive SPDLOG_TRACE("RECV (REQ) {}", port); - return recvNoBuffer(reqSocket); + auto msgMaybe = recvNoBuffer(reqSocket); + if (!msgMaybe.has_value()) { + throw MessageTimeoutException("SendAwaitResponse timeout"); + } + return msgMaybe.value(); } // ---------------------------------------------- @@ -289,7 +294,7 @@ RecvMessageEndpoint::RecvMessageEndpoint(int portIn, socket = setUpSocket(socketType, portIn); } -Message RecvMessageEndpoint::recv(int size) +std::optional RecvMessageEndpoint::recv(int size) { return doRecv(socket, size); } @@ -302,7 +307,7 @@ AsyncRecvMessageEndpoint::AsyncRecvMessageEndpoint(int portIn, int timeoutMs) : RecvMessageEndpoint(portIn, timeoutMs, zmq::socket_type::pull) {} -Message AsyncRecvMessageEndpoint::recv(int size) +std::optional AsyncRecvMessageEndpoint::recv(int size) { SPDLOG_TRACE("PULL {} ({} bytes)", port, size); return RecvMessageEndpoint::recv(size); @@ -316,7 +321,7 @@ SyncRecvMessageEndpoint::SyncRecvMessageEndpoint(int portIn, int timeoutMs) : RecvMessageEndpoint(portIn, timeoutMs, zmq::socket_type::rep) {} -Message SyncRecvMessageEndpoint::recv(int size) +std::optional SyncRecvMessageEndpoint::recv(int size) { SPDLOG_TRACE("RECV (REP) {} ({} bytes)", port, size); return RecvMessageEndpoint::recv(size); diff --git a/src/transport/MessageEndpointServer.cpp b/src/transport/MessageEndpointServer.cpp index 52a59eb57..dbcac4529 100644 --- a/src/transport/MessageEndpointServer.cpp +++ b/src/transport/MessageEndpointServer.cpp @@ -36,67 +36,59 @@ void MessageEndpointServerThread::start( latch->wait(); while (true) { - bool headerReceived = false; - bool bodyReceived = false; - try { - // Receive header and body - Message headerMessage = endpoint->recv(); - headerReceived = true; - - if (headerMessage.size() == shutdownHeader.size()) { - if (headerMessage.dataCopy() == shutdownHeader) { - SPDLOG_TRACE("Server on {} received shutdown message", - port); - break; - } - } - - if (!headerMessage.more()) { - throw std::runtime_error( - "Header sent without SNDMORE flag"); - } + // Receive header and body + std::optional headerMessageMaybe = endpoint->recv(); + if (!headerMessageMaybe.has_value()) { + SPDLOG_TRACE("Server on port {}, looping after no message", + port); + continue; + } + Message& headerMessage = headerMessageMaybe.value(); - Message body = endpoint->recv(); - if (body.more()) { - throw std::runtime_error("Body sent with SNDMORE flag"); - } - bodyReceived = true; - - assert(headerMessage.size() == sizeof(uint8_t)); - uint8_t header = static_cast(*headerMessage.data()); - - if (async) { - // Server-specific async handling - server->doAsyncRecv(header, body.udata(), body.size()); - } else { - // Server-specific sync handling - std::unique_ptr resp = - server->doSyncRecv(header, body.udata(), body.size()); - size_t respSize = resp->ByteSizeLong(); - - uint8_t buffer[respSize]; - if (!resp->SerializeToArray(buffer, respSize)) { - throw std::runtime_error("Error serialising message"); - } - - // Return the response - static_cast(endpoint.get()) - ->sendResponse(buffer, respSize); - } - } catch (MessageTimeoutException& ex) { - // If we don't get a header in the timeout, we're ok to just - // loop round and try again - if (!headerReceived) { - SPDLOG_TRACE("Server on port {}, looping after no message", + if (headerMessage.size() == shutdownHeader.size()) { + if (headerMessage.dataCopy() == shutdownHeader) { + SPDLOG_TRACE("Server on {} received shutdown message", port); - continue; + break; } + } - if (headerReceived && !bodyReceived) { - SPDLOG_ERROR( - "Server on port {}, got header, timed out on body", port); - throw; + if (!headerMessage.more()) { + throw std::runtime_error("Header sent without SNDMORE flag"); + } + + std::optional bodyMaybe = endpoint->recv(); + if (!bodyMaybe.has_value()) { + SPDLOG_ERROR("Server on port {}, got header, timed out on body", + port); + throw MessageTimeoutException( + "Server, got header, timed out on body"); + } + Message& body = bodyMaybe.value(); + if (body.more()) { + throw std::runtime_error("Body sent with SNDMORE flag"); + } + + assert(headerMessage.size() == sizeof(uint8_t)); + uint8_t header = static_cast(*headerMessage.data()); + + if (async) { + // Server-specific async handling + server->doAsyncRecv(header, body.udata(), body.size()); + } else { + // Server-specific sync handling + std::unique_ptr resp = + server->doSyncRecv(header, body.udata(), body.size()); + size_t respSize = resp->ByteSizeLong(); + + uint8_t buffer[respSize]; + if (!resp->SerializeToArray(buffer, respSize)) { + throw std::runtime_error("Error serialising message"); } + + // Return the response + static_cast(endpoint.get()) + ->sendResponse(buffer, respSize); } // Wait on the async latch if necessary @@ -104,9 +96,6 @@ void MessageEndpointServerThread::start( SPDLOG_TRACE("Server thread waiting on async latch"); server->asyncLatch->wait(); } - - headerReceived = false; - bodyReceived = false; } }); } diff --git a/src/transport/MpiMessageEndpoint.cpp b/src/transport/MpiMessageEndpoint.cpp index 136456238..d5153f0b4 100644 --- a/src/transport/MpiMessageEndpoint.cpp +++ b/src/transport/MpiMessageEndpoint.cpp @@ -21,7 +21,11 @@ void MpiMessageEndpoint::sendMpiMessage( std::shared_ptr MpiMessageEndpoint::recvMpiMessage() { - Message m = recvSocket.recv(); + std::optional mMaybe = recvSocket.recv(); + if (!mMaybe.has_value()) { + throw MessageTimeoutException("Mpi message timeout"); + } + Message& m = mMaybe.value(); PARSE_MSG(faabric::MPIMessage, m.data(), m.size()); return std::make_shared(msg); diff --git a/tests/test/transport/test_message_endpoint_client.cpp b/tests/test/transport/test_message_endpoint_client.cpp index b42ff3e6d..f8ca68b1b 100644 --- a/tests/test/transport/test_message_endpoint_client.cpp +++ b/tests/test/transport/test_message_endpoint_client.cpp @@ -26,7 +26,7 @@ TEST_CASE_METHOD(SchedulerTestFixture, src.send(msg, expectedMsg.size()); // Receive message - faabric::transport::Message recvMsg = dst.recv(); + faabric::transport::Message recvMsg = dst.recv().value(); REQUIRE(recvMsg.size() == expectedMsg.size()); std::string actualMsg(recvMsg.data(), recvMsg.size()); REQUIRE(actualMsg == expectedMsg); @@ -48,7 +48,7 @@ TEST_CASE_METHOD(SchedulerTestFixture, // Receive message AsyncRecvMessageEndpoint dst(TEST_PORT); - faabric::transport::Message recvMsg = dst.recv(); + faabric::transport::Message recvMsg = dst.recv().value(); assert(recvMsg.size() == expectedMsg.size()); std::string actualMsg(recvMsg.data(), recvMsg.size()); @@ -90,7 +90,7 @@ TEST_CASE_METHOD(SchedulerTestFixture, "Test await response", "[transport]") // Receive message SyncRecvMessageEndpoint dst(TEST_PORT); - faabric::transport::Message recvMsg = dst.recv(); + faabric::transport::Message recvMsg = dst.recv().value(); REQUIRE(recvMsg.size() == expectedMsg.size()); std::string actualMsg(recvMsg.data(), recvMsg.size()); REQUIRE(actualMsg == expectedMsg); @@ -125,7 +125,7 @@ TEST_CASE_METHOD(SchedulerTestFixture, // Receive messages AsyncRecvMessageEndpoint dst(TEST_PORT); for (int i = 0; i < numMessages; i++) { - faabric::transport::Message recvMsg = dst.recv(); + faabric::transport::Message recvMsg = dst.recv().value(); // Check just a subset of the messages // Note - this implicitly tests in-order message delivery if ((i % (numMessages / 10)) == 0) { @@ -165,7 +165,7 @@ TEST_CASE_METHOD(SchedulerTestFixture, // Receive messages AsyncRecvMessageEndpoint dst(TEST_PORT); for (int i = 0; i < numSenders * numMessages; i++) { - faabric::transport::Message recvMsg = dst.recv(); + faabric::transport::Message recvMsg = dst.recv().value(); // Check just a subset of the messages if ((i % numMessages) == 0) { REQUIRE(recvMsg.size() == expectedMsg.size());