diff --git a/include/faabric/scheduler/MpiMessageBuffer.h b/include/faabric/scheduler/MpiMessageBuffer.h new file mode 100644 index 000000000..59cc39120 --- /dev/null +++ b/include/faabric/scheduler/MpiMessageBuffer.h @@ -0,0 +1,82 @@ +#include +#include + +#include +#include + +namespace faabric::scheduler { +/* The MPI message buffer (MMB) keeps track of the asyncrhonous + * messages that we must have received (i.e. through an irecv call) but we + * still have not waited on (acknowledged). Messages are acknowledged either + * through a call to recv or a call to await. A call to recv will + * acknowledge (i.e. synchronously read from transport buffers) as many + * unacknowleged messages there are. A call to await with a request + * id as a parameter will acknowledge as many unacknowleged messages there are + * until said request id. + */ +class MpiMessageBuffer +{ + public: + /* This structure holds the metadata for each Mpi message we keep in the + * buffer. Note that the message field will point to null if unacknowleged + * or to a valid message otherwise. + */ + class PendingAsyncMpiMessage + { + public: + int requestId = -1; + std::shared_ptr msg = nullptr; + int sendRank = -1; + int recvRank = -1; + uint8_t* buffer = nullptr; + faabric_datatype_t* dataType = nullptr; + int count = -1; + faabric::MPIMessage::MPIMessageType messageType = + faabric::MPIMessage::NORMAL; + + bool isAcknowledged() { return msg != nullptr; } + + void acknowledge(std::shared_ptr msgIn) + { + msg = msgIn; + } + }; + + /* Interface to query the buffer size */ + + bool isEmpty(); + + int size(); + + /* Interface to add and delete messages to the buffer */ + + void addMessage(PendingAsyncMpiMessage msg); + + void deleteMessage( + const std::list::iterator& msgIt); + + /* Interface to get a pointer to a message in the MMB */ + + // Pointer to a message given its request id + std::list::iterator getRequestPendingMsg( + int requestId); + + // Pointer to the first null-pointing (unacknowleged) message + std::list::iterator getFirstNullMsg(); + + /* Interface to ask for the number of unacknowleged messages */ + + // Unacknowledged messages until an iterator (used in await) + int getTotalUnackedMessagesUntil( + const std::list::iterator& msgIt); + + // Unacknowledged messages in the whole buffer (used in recv) + int getTotalUnackedMessages(); + + private: + std::list pendingMsgs; + + std::list::iterator getFirstNullMsgUntil( + const std::list::iterator& msgIt); +}; +} diff --git a/include/faabric/scheduler/MpiThreadPool.h b/include/faabric/scheduler/MpiThreadPool.h deleted file mode 100644 index 4483c769d..000000000 --- a/include/faabric/scheduler/MpiThreadPool.h +++ /dev/null @@ -1,40 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -#define QUEUE_SHUTDOWN -1 - -namespace faabric::scheduler { -typedef std::tuple, std::promise> - ReqQueueType; -typedef faabric::util::Queue MpiReqQueue; - -class MpiAsyncThreadPool -{ - public: - explicit MpiAsyncThreadPool(int nThreads); - - void shutdown(); - - int size; - - std::shared_ptr getMpiReqQueue(); - - private: - std::vector threadPool; - std::atomic isShutdown; - - std::shared_ptr localReqQueue; - - void entrypoint(int i); -}; -} diff --git a/include/faabric/scheduler/MpiWorld.h b/include/faabric/scheduler/MpiWorld.h index dfe07c99e..fc91d4bbb 100644 --- a/include/faabric/scheduler/MpiWorld.h +++ b/include/faabric/scheduler/MpiWorld.h @@ -3,15 +3,13 @@ #include #include -#include #include -#include +#include #include #include #include #include -#include namespace faabric::scheduler { typedef faabric::util::Queue> @@ -29,7 +27,7 @@ class MpiWorld std::string getHostForRank(int rank); - void setAllRankHosts(const faabric::MpiHostsToRanksMessage& msg); + void setAllRankHostsPorts(const faabric::MpiHostsToRanksMessage& msg); std::string getUser(); @@ -41,8 +39,6 @@ class MpiWorld void destroy(); - void shutdownThreadPool(); - void getCartesianRank(int rank, int maxDims, const int* dims, @@ -196,9 +192,6 @@ class MpiWorld std::string user; std::string function; - std::shared_ptr threadPool; - int getMpiThreadPoolSize(); - std::vector cartProcsPerDim; /* MPI internal messaging layer */ @@ -212,8 +205,11 @@ class MpiWorld void initLocalQueues(); // Rank-to-rank sockets for remote messaging - void initRemoteMpiEndpoint(int sendRank, int recvRank); - int getMpiPort(int sendRank, int recvRank); + std::vector basePorts; + std::vector initLocalBasePorts( + const std::vector& executedAt); + void initRemoteMpiEndpoint(int localRank, int remoteRank); + std::pair getPortForRanks(int localRank, int remoteRank); void sendRemoteMpiMessage(int sendRank, int recvRank, const std::shared_ptr& msg); @@ -221,6 +217,24 @@ class MpiWorld int recvRank); void closeMpiMessageEndpoints(); + // Support for asyncrhonous communications + std::shared_ptr getUnackedMessageBuffer(int sendRank, + int recvRank); + std::shared_ptr recvBatchReturnLast(int sendRank, + int recvRank, + int batchSize = 0); + + /* Helper methods */ + void checkRanksRange(int sendRank, int recvRank); + + // Abstraction of the bulk of the recv work, shared among various functions + void doRecv(std::shared_ptr m, + uint8_t* buffer, + faabric_datatype_t* dataType, + int count, + MPI_Status* status, + faabric::MPIMessage::MPIMessageType messageType = + faabric::MPIMessage::NORMAL); }; } diff --git a/include/faabric/transport/MessageEndpoint.h b/include/faabric/transport/MessageEndpoint.h index 658448455..4e463e07a 100644 --- a/include/faabric/transport/MessageEndpoint.h +++ b/include/faabric/transport/MessageEndpoint.h @@ -83,8 +83,6 @@ class RecvMessageEndpoint : public MessageEndpoint public: RecvMessageEndpoint(int portIn); - RecvMessageEndpoint(int portIn, const std::string& overrideHost); - void open(MessageContext& context); void close(); diff --git a/include/faabric/transport/MpiMessageEndpoint.h b/include/faabric/transport/MpiMessageEndpoint.h index 40067fc6b..991331bc5 100644 --- a/include/faabric/transport/MpiMessageEndpoint.h +++ b/include/faabric/transport/MpiMessageEndpoint.h @@ -14,8 +14,8 @@ faabric::MpiHostsToRanksMessage recvMpiHostRankMsg(); void sendMpiHostRankMsg(const std::string& hostIn, const faabric::MpiHostsToRanksMessage msg); -/* This class abstracts the notion of a communication channel between two MPI - * ranks. There will always be one rank local to this host, and one remote. +/* This class abstracts the notion of a communication channel between two remote + * MPI ranks. There will always be one rank local to this host, and one remote. * Note that the port is unique per (user, function, sendRank, recvRank) tuple. */ class MpiMessageEndpoint @@ -23,9 +23,7 @@ class MpiMessageEndpoint public: MpiMessageEndpoint(const std::string& hostIn, int portIn); - MpiMessageEndpoint(const std::string& hostIn, - int portIn, - const std::string& overrideRecvHost); + MpiMessageEndpoint(const std::string& hostIn, int sendPort, int recvPort); void sendMpiMessage(const std::shared_ptr& msg); diff --git a/src/proto/faabric.proto b/src/proto/faabric.proto index 80103ca1d..a44a4668c 100644 --- a/src/proto/faabric.proto +++ b/src/proto/faabric.proto @@ -88,6 +88,7 @@ message MPIMessage { // fields. message MpiHostsToRanksMessage { repeated string hosts = 1; + repeated int32 basePorts = 2; } message Message { diff --git a/src/scheduler/CMakeLists.txt b/src/scheduler/CMakeLists.txt index ce4cb75d1..4e8c271ad 100644 --- a/src/scheduler/CMakeLists.txt +++ b/src/scheduler/CMakeLists.txt @@ -10,7 +10,7 @@ set(LIB_FILES SnapshotServer.cpp SnapshotClient.cpp MpiContext.cpp - MpiThreadPool.cpp + MpiMessageBuffer.cpp MpiWorldRegistry.cpp MpiWorld.cpp ${HEADERS} diff --git a/src/scheduler/MpiMessageBuffer.cpp b/src/scheduler/MpiMessageBuffer.cpp new file mode 100644 index 000000000..8b2445e40 --- /dev/null +++ b/src/scheduler/MpiMessageBuffer.cpp @@ -0,0 +1,73 @@ +#include +#include + +namespace faabric::scheduler { +typedef std::list::iterator + MpiMessageIterator; +bool MpiMessageBuffer::isEmpty() +{ + return pendingMsgs.empty(); +} + +int MpiMessageBuffer::size() +{ + return pendingMsgs.size(); +} + +void MpiMessageBuffer::addMessage(PendingAsyncMpiMessage msg) +{ + pendingMsgs.push_back(msg); +} + +void MpiMessageBuffer::deleteMessage(const MpiMessageIterator& msgIt) +{ + pendingMsgs.erase(msgIt); +} + +MpiMessageIterator MpiMessageBuffer::getRequestPendingMsg(int requestId) +{ + // The request id must be in the MMB, as an irecv must happen before an + // await + MpiMessageIterator msgIt = + std::find_if(pendingMsgs.begin(), + pendingMsgs.end(), + [requestId](PendingAsyncMpiMessage pendingMsg) { + return pendingMsg.requestId == requestId; + }); + + // If it's not there, error out + if (msgIt == pendingMsgs.end()) { + SPDLOG_ERROR("Asynchronous request id not in buffer: {}", requestId); + throw std::runtime_error("Async request not in buffer"); + } + + return msgIt; +} + +MpiMessageIterator MpiMessageBuffer::getFirstNullMsgUntil( + const MpiMessageIterator& msgItEnd) +{ + return std::find_if( + pendingMsgs.begin(), msgItEnd, [](PendingAsyncMpiMessage pendingMsg) { + return pendingMsg.msg == nullptr; + }); +} + +MpiMessageIterator MpiMessageBuffer::getFirstNullMsg() +{ + return getFirstNullMsgUntil(pendingMsgs.end()); +} + +int MpiMessageBuffer::getTotalUnackedMessagesUntil( + const MpiMessageIterator& msgItEnd) +{ + MpiMessageIterator firstNull = getFirstNullMsgUntil(msgItEnd); + return std::distance(firstNull, msgItEnd); +} + +int MpiMessageBuffer::getTotalUnackedMessages() +{ + MpiMessageIterator firstNull = getFirstNullMsg(); + return std::distance(firstNull, pendingMsgs.end()); +} +} diff --git a/src/scheduler/MpiThreadPool.cpp b/src/scheduler/MpiThreadPool.cpp deleted file mode 100644 index 6f6f088db..000000000 --- a/src/scheduler/MpiThreadPool.cpp +++ /dev/null @@ -1,67 +0,0 @@ -#include -#include - -namespace faabric::scheduler { -MpiAsyncThreadPool::MpiAsyncThreadPool(int nThreads) - : size(nThreads) - , isShutdown(false) -{ - SPDLOG_DEBUG("Starting an MpiAsyncThreadPool of size {}", nThreads); - - // Initialize async. req queue - localReqQueue = std::make_shared(); - - // Initialize thread pool - for (int i = 0; i < nThreads; ++i) { - threadPool.emplace_back( - std::bind(&MpiAsyncThreadPool::entrypoint, this, i)); - } -} - -void MpiAsyncThreadPool::shutdown() -{ - SPDLOG_DEBUG("Shutting down MpiAsyncThreadPool"); - - for (auto& thread : threadPool) { - if (thread.joinable()) { - thread.join(); - } - } -} - -std::shared_ptr MpiAsyncThreadPool::getMpiReqQueue() -{ - return this->localReqQueue; -} - -void MpiAsyncThreadPool::entrypoint(int i) -{ - faabric::scheduler::ReqQueueType req; - - while (!this->isShutdown) { - req = getMpiReqQueue()->dequeue(); - - int id = std::get<0>(req); - std::function func = std::get<1>(req); - std::promise promise = std::move(std::get<2>(req)); - - // Detect shutdown condition - if (id == QUEUE_SHUTDOWN) { - // The shutdown tuple includes a TLS cleanup function that we run - // _once per thread_ and exit - func(); - if (!this->isShutdown) { - this->isShutdown = true; - } - SPDLOG_TRACE("Mpi thread {}/{} shut down", i + 1, size); - break; - } - - // Do the job without holding any locks - func(); - - // Notify we are done via the future - promise.set_value(); - } -} -} diff --git a/src/scheduler/MpiWorld.cpp b/src/scheduler/MpiWorld.cpp index 077de1deb..5f9e7bfc1 100644 --- a/src/scheduler/MpiWorld.cpp +++ b/src/scheduler/MpiWorld.cpp @@ -5,10 +5,17 @@ #include #include -static thread_local std::unordered_map> futureMap; +/* Each MPI rank runs in a separate thread, thus we use TLS to maintain the + * per-rank data structures. + */ static thread_local std::vector< std::unique_ptr> mpiMessageEndpoints; +static thread_local std::vector< + std::shared_ptr> + unackedMessageBuffers; +static thread_local std::set iSendRequests; +static thread_local std::map> reqIdToRanks; namespace faabric::scheduler { MpiWorld::MpiWorld() @@ -19,8 +26,12 @@ MpiWorld::MpiWorld() , cartProcsPerDim(2) {} -void MpiWorld::initRemoteMpiEndpoint(int sendRank, int recvRank) +void MpiWorld::initRemoteMpiEndpoint(int localRank, int remoteRank) { + SPDLOG_TRACE("Open MPI endpoint between ranks (local-remote) {} - {}", + localRank, + remoteRank); + // Resize the message endpoint vector and initialise to null. Note that we // allocate size x size slots to cover all possible (sendRank, recvRank) // pairs @@ -30,42 +41,20 @@ void MpiWorld::initRemoteMpiEndpoint(int sendRank, int recvRank) } } - // Get host for recv rank - std::string otherHost; - std::string recvHost = getHostForRank(recvRank); - std::string sendHost = getHostForRank(sendRank); - if (recvHost == sendHost) { - SPDLOG_ERROR( - "Send and recv ranks in the same host: SEND {}, RECV{} in {}", - sendRank, - recvRank, - sendHost); - throw std::runtime_error("Send and recv ranks in the same host"); - } else if (recvHost == thisHost) { - otherHost = sendHost; - } else if (sendHost == thisHost) { - otherHost = recvHost; - } else { - SPDLOG_ERROR("Send and recv ranks correspond to remote hosts: SEND {} " - "in {}, RECV {} in {}", - sendRank, - sendHost, - recvRank, - recvHost); - throw std::runtime_error("Send and recv ranks in remote hosts"); - } + // Get host for remote rank + std::string otherHost = getHostForRank(remoteRank); // Get the index for the rank-host pair - int index = getIndexForRanks(sendRank, recvRank); + int index = getIndexForRanks(localRank, remoteRank); // Get port for send-recv pair - int port = getMpiPort(sendRank, recvRank); + std::pair sendRecvPorts = getPortForRanks(localRank, remoteRank); // Create MPI message endpoint mpiMessageEndpoints.emplace( mpiMessageEndpoints.begin() + index, std::make_unique( - otherHost, port, thisHost)); + otherHost, sendRecvPorts.first, sendRecvPorts.second)); } void MpiWorld::sendRemoteMpiMessage( @@ -74,6 +63,8 @@ void MpiWorld::sendRemoteMpiMessage( const std::shared_ptr& msg) { // Get the index for the rank-host pair + // Note - message endpoints are identified by a (localRank, remoteRank) + // pair, not a (sendRank, recvRank) one int index = getIndexForRanks(sendRank, recvRank); if (mpiMessageEndpoints.empty() || mpiMessageEndpoints[index] == nullptr) { @@ -88,34 +79,38 @@ std::shared_ptr MpiWorld::recvRemoteMpiMessage( int recvRank) { // Get the index for the rank-host pair - int index = getIndexForRanks(sendRank, recvRank); + // Note - message endpoints are identified by a (localRank, remoteRank) + // pair, not a (sendRank, recvRank) one + int index = getIndexForRanks(recvRank, sendRank); if (mpiMessageEndpoints.empty() || mpiMessageEndpoints[index] == nullptr) { - initRemoteMpiEndpoint(sendRank, recvRank); + initRemoteMpiEndpoint(recvRank, sendRank); } return mpiMessageEndpoints[index]->recvMpiMessage(); } -int MpiWorld::getMpiThreadPoolSize() +std::shared_ptr +MpiWorld::getUnackedMessageBuffer(int sendRank, int recvRank) { - int usableCores = faabric::util::getUsableCores(); - int worldSize = size; + // We want to lazily initialise this data structure because, given its + // thread local nature, we expect it to be quite sparse (i.e. filled with + // nullptr). + if (unackedMessageBuffers.size() == 0) { + unackedMessageBuffers.resize(size * size, nullptr); + } + + // Get the index for the rank-host pair + int index = getIndexForRanks(sendRank, recvRank); + assert(index >= 0 && index < size * size); - if ((worldSize > usableCores) && (worldSize % usableCores != 0)) { - SPDLOG_WARN("Over-provisioning threads in the MPI thread pool."); - SPDLOG_WARN("To avoid this, set an MPI world size multiple of the " - "number of cores per machine."); + if (unackedMessageBuffers[index] == nullptr) { + unackedMessageBuffers.emplace( + unackedMessageBuffers.begin() + index, + std::make_shared()); } - // Note - adding one to the worldSize to prevent deadlocking in certain - // corner-cases. - // For instance, if issuing `worldSize` non-blocking recvs, followed by - // `worldSize` non-blocking sends, and nothing else, the application will - // deadlock as all worker threads will be blocking on `recv` calls. This - // scenario is remote, but feasible. We _assume_ that following the same - // pattern but doing `worldSize + 1` calls is deliberately malicious, and - // we can confidently fail and deadlock. - return std::min(worldSize + 1, usableCores); + + return unackedMessageBuffers[index]; } void MpiWorld::create(const faabric::Message& call, int newId, int newSize) @@ -125,8 +120,6 @@ void MpiWorld::create(const faabric::Message& call, int newId, int newSize) function = call.function(); size = newSize; - threadPool = std::make_shared( - getMpiThreadPoolSize()); auto& sch = faabric::scheduler::getScheduler(); @@ -153,7 +146,14 @@ void MpiWorld::create(const faabric::Message& call, int newId, int newSize) // Register hosts to rank mappings on this host faabric::MpiHostsToRanksMessage hostRankMsg; *hostRankMsg.mutable_hosts() = { executedAt.begin(), executedAt.end() }; - setAllRankHosts(hostRankMsg); + + // Prepare the base port for each rank + std::vector basePortForRank = initLocalBasePorts(executedAt); + *hostRankMsg.mutable_baseports() = { basePortForRank.begin(), + basePortForRank.end() }; + + // Register hosts to rank mappins on this host + setAllRankHostsPorts(hostRankMsg); // Set up a list of hosts to broadcast to (excluding this host) std::set hosts(executedAt.begin(), executedAt.end()); @@ -170,13 +170,51 @@ void MpiWorld::create(const faabric::Message& call, int newId, int newSize) void MpiWorld::destroy() { - // Destroy once per host - if (!isDestroyed.test_and_set()) { - shutdownThreadPool(); + // Destroy once per thread the rank-specific data structures + // Remote message endpoints + if (!mpiMessageEndpoints.empty()) { + for (auto& e : mpiMessageEndpoints) { + if (e != nullptr) { + e->close(); + } + } + mpiMessageEndpoints.clear(); + } + + // Unacked message buffers + if (!unackedMessageBuffers.empty()) { + for (auto& umb : unackedMessageBuffers) { + if (umb != nullptr) { + if (!umb->isEmpty()) { + SPDLOG_ERROR("Destroying the MPI world with outstanding {}" + " messages in the message buffer", + umb->size()); + throw std::runtime_error( + "Destroying world with a non-empty MPI message buffer"); + } + } + } + unackedMessageBuffers.clear(); + } - // Note - we are deliberately not deleting the KV in the global state - // TODO - find a way to do this only from the master client + // Request to rank map should be empty + if (!reqIdToRanks.empty()) { + SPDLOG_ERROR( + "Destroying the MPI world with {} outstanding irecv requests", + reqIdToRanks.size()); + throw std::runtime_error("Destroying world with outstanding requests"); + } + + // iSend set should be empty + if (!iSendRequests.empty()) { + SPDLOG_ERROR( + "Destroying the MPI world with {} outstanding isend requests", + iSendRequests.size()); + throw std::runtime_error("Destroying world with outstanding requests"); + } + // Destroy once per host the shared resources + if (!isDestroyed.test_and_set()) { // Wait (forever) until all ranks are done consuming their queues to // clear them. // Note - this means that an application with outstanding messages, i.e. @@ -190,40 +228,6 @@ void MpiWorld::destroy() } } -void MpiWorld::shutdownThreadPool() -{ - // When shutting down the thread pool, we also make sure we clean all thread - // local state by sending a clear message to the queue. Currently, we only - // need to close the function call clients - for (int i = 0; i < threadPool->size; i++) { - std::promise p; - threadPool->getMpiReqQueue()->enqueue( - std::make_tuple(QUEUE_SHUTDOWN, - std::bind(&MpiWorld::closeMpiMessageEndpoints, this), - std::move(p))); - } - - threadPool->shutdown(); - - // Lastly clean the main thread as well - closeMpiMessageEndpoints(); -} - -// TODO - remove -// Clear thread local state -void MpiWorld::closeMpiMessageEndpoints() -{ - if (mpiMessageEndpoints.size() > 0) { - // Close all open sockets - for (auto& e : mpiMessageEndpoints) { - if (e != nullptr) { - e->close(); - } - } - mpiMessageEndpoints.clear(); - } -} - void MpiWorld::initialiseFromMsg(const faabric::Message& msg, bool forceLocal) { id = msg.mpiworldid(); @@ -231,9 +235,6 @@ void MpiWorld::initialiseFromMsg(const faabric::Message& msg, bool forceLocal) function = msg.function(); size = msg.mpiworldsize(); - threadPool = std::make_shared( - getMpiThreadPoolSize()); - // Sometimes for testing purposes we may want to initialise a world in the // _same_ host we have created one (note that this would never happen in // reality). If so, we skip initialising resources already initialised @@ -241,7 +242,7 @@ void MpiWorld::initialiseFromMsg(const faabric::Message& msg, bool forceLocal) // Block until we receive faabric::MpiHostsToRanksMessage hostRankMsg = faabric::transport::recvMpiHostRankMsg(); - setAllRankHosts(hostRankMsg); + setAllRankHostsPorts(hostRankMsg); // Initialise the memory queues for message reception initLocalQueues(); @@ -261,16 +262,51 @@ std::string MpiWorld::getHostForRank(int rank) return host; } +// Returns a pair (sendPort, recvPort) +// To assign the send and recv ports, we follow a protocol establishing: +// 1) Port range (offset) corresponding to the world that receives +// 2) Within a world's port range, port corresponding to the outcome of +// getIndexForRanks(localRank, remoteRank) Where local and remote are +// relative to the world whose port range we are in +std::pair MpiWorld::getPortForRanks(int localRank, int remoteRank) +{ + std::pair sendRecvPortPair; + + // Get base port for local and remote worlds + int localBasePort = basePorts[localRank]; + int remoteBasePort = basePorts[remoteRank]; + assert(localBasePort != remoteBasePort); + + // Assign send port + // 1) Port range corresponding to remote world, as they are receiving + // 2) Index switching localRank and remoteRank, as remote rank is "local" + // to the remote world + sendRecvPortPair.first = + remoteBasePort + getIndexForRanks(remoteRank, localRank); + + // Assign recv port + // 1) Port range corresponding to our world, as we are the one's receiving + // 2) Port using our local rank as `localRank`, as we are in the local + // offset + sendRecvPortPair.second = + localBasePort + getIndexForRanks(localRank, remoteRank); + + return sendRecvPortPair; +} + // Prepare the host-rank map with a vector containing _all_ ranks // Note - this method should be called by only one rank. This is enforced in // the world registry -void MpiWorld::setAllRankHosts(const faabric::MpiHostsToRanksMessage& msg) +void MpiWorld::setAllRankHostsPorts(const faabric::MpiHostsToRanksMessage& msg) { // Assert we are only setting the values once assert(rankHosts.size() == 0); + assert(basePorts.size() == 0); assert(msg.hosts().size() == size); + assert(msg.baseports().size() == size); rankHosts = { msg.hosts().begin(), msg.hosts().end() }; + basePorts = { msg.baseports().begin(), msg.baseports().end() }; } void MpiWorld::getCartesianRank(int rank, @@ -394,6 +430,10 @@ void MpiWorld::shiftCartesianCoords(int rank, getRankFromCoords(source, dispCoordsBwd.data()); } +// Sending is already asynchronous in both transport layers we use: in-memory +// queues for local messages, and ZeroMQ sockets for remote messages. Thus, +// we can just send normally and return a requestId. Upon await, we'll return +// immediately. int MpiWorld::isend(int sendRank, int recvRank, const uint8_t* buffer, @@ -402,23 +442,9 @@ int MpiWorld::isend(int sendRank, faabric::MPIMessage::MPIMessageType messageType) { int requestId = (int)faabric::util::generateGid(); + iSendRequests.insert(requestId); - std::promise resultPromise; - std::future resultFuture = resultPromise.get_future(); - threadPool->getMpiReqQueue()->enqueue( - std::make_tuple(requestId, - std::bind(&MpiWorld::send, - this, - sendRank, - recvRank, - buffer, - dataType, - count, - messageType), - std::move(resultPromise))); - - // Place the promise in a map to wait for it later - futureMap.emplace(std::make_pair(requestId, std::move(resultFuture))); + send(sendRank, recvRank, buffer, dataType, count, messageType); return requestId; } @@ -431,37 +457,25 @@ int MpiWorld::irecv(int sendRank, faabric::MPIMessage::MPIMessageType messageType) { int requestId = (int)faabric::util::generateGid(); - - std::promise resultPromise; - std::future resultFuture = resultPromise.get_future(); - threadPool->getMpiReqQueue()->enqueue( - std::make_tuple(requestId, - std::bind(&MpiWorld::recv, - this, - sendRank, - recvRank, - buffer, - dataType, - count, - nullptr, - messageType), - std::move(resultPromise))); - - // Place the promise in a map to wait for it later - futureMap.emplace(std::make_pair(requestId, std::move(resultFuture))); + reqIdToRanks.try_emplace(requestId, sendRank, recvRank); + + // Enqueue an unacknowleged request (no message) + faabric::scheduler::MpiMessageBuffer::PendingAsyncMpiMessage pendingMsg; + pendingMsg.requestId = requestId; + pendingMsg.sendRank = sendRank; + pendingMsg.recvRank = recvRank; + pendingMsg.buffer = buffer; + pendingMsg.dataType = dataType; + pendingMsg.count = count; + pendingMsg.messageType = messageType; + assert(!pendingMsg.isAcknowledged()); + + auto umb = getUnackedMessageBuffer(sendRank, recvRank); + umb->addMessage(pendingMsg); return requestId; } -int MpiWorld::getMpiPort(int sendRank, int recvRank) -{ - // TODO - get port in a multi-tenant-safe manner - int basePort = MPI_PORT; - int rankOffset = sendRank * size + recvRank; - - return basePort + rankOffset; -} - void MpiWorld::send(int sendRank, int recvRank, const uint8_t* buffer, @@ -519,27 +533,22 @@ void MpiWorld::recv(int sendRank, { // Sanity-check input parameters checkRanksRange(sendRank, recvRank); - if (getHostForRank(recvRank) != thisHost) { - SPDLOG_ERROR("Trying to recv message into a non-local rank: {}", - recvRank); - throw std::runtime_error("Receiving message into non-local rank"); - } - // Work out whether the message is sent locally or from another host - const std::string otherHost = getHostForRank(sendRank); - bool isLocal = otherHost == thisHost; + // Recv message from underlying transport + std::shared_ptr m = + recvBatchReturnLast(sendRank, recvRank); - // Recv message - std::shared_ptr m; - if (isLocal) { - SPDLOG_TRACE("MPI - recv {} -> {}", sendRank, recvRank); - m = getLocalQueue(sendRank, recvRank)->dequeue(); - } else { - SPDLOG_TRACE("MPI - recv remote {} -> {}", sendRank, recvRank); - m = recvRemoteMpiMessage(sendRank, recvRank); - } - assert(m != nullptr); + // Do the processing + doRecv(m, buffer, dataType, count, status, messageType); +} +void MpiWorld::doRecv(std::shared_ptr m, + uint8_t* buffer, + faabric_datatype_t* dataType, + int count, + MPI_Status* status, + faabric::MPIMessage::MPIMessageType messageType) +{ // Assert message integrity // Note - this checks won't happen in Release builds assert(m->messagetype() == messageType); @@ -806,17 +815,51 @@ void MpiWorld::awaitAsyncRequest(int requestId) { SPDLOG_TRACE("MPI - await {}", requestId); - auto it = futureMap.find(requestId); - if (it == futureMap.end()) { - throw std::runtime_error( - fmt::format("Error: waiting for unrecognized request {}", requestId)); + auto iSendIt = iSendRequests.find(requestId); + if (iSendIt != iSendRequests.end()) { + iSendRequests.erase(iSendIt); + return; + } + + // Get the corresponding send and recv ranks + auto it = reqIdToRanks.find(requestId); + // If the request id is not in the map, the application either has issued an + // await without a previous isend/irecv, or the actual request id + // has been corrupted. In any case, we error out. + if (it == reqIdToRanks.end()) { + SPDLOG_ERROR("Asynchronous request id not recognized: {}", requestId); + throw std::runtime_error("Unrecognized async request id"); + } + int sendRank = it->second.first; + int recvRank = it->second.second; + reqIdToRanks.erase(it); + + std::shared_ptr umb = + getUnackedMessageBuffer(sendRank, recvRank); + + std::list::iterator msgIt = + umb->getRequestPendingMsg(requestId); + + std::shared_ptr m; + if (msgIt->msg != nullptr) { + // This id has already been acknowledged by a recv call, so do the recv + m = msgIt->msg; + } else { + // We need to acknowledge all messages not acknowledged from the + // begining until us + m = recvBatchReturnLast( + sendRank, recvRank, umb->getTotalUnackedMessagesUntil(msgIt) + 1); } - // This call blocks until requestId has finished. - it->second.wait(); - futureMap.erase(it); + doRecv(m, + msgIt->buffer, + msgIt->dataType, + msgIt->count, + MPI_STATUS_IGNORE, + msgIt->messageType); - SPDLOG_DEBUG("Finished awaitAsyncRequest on {}", requestId); + // Remove the acknowledged indexes from the UMB + umb->deleteMessage(msgIt); } void MpiWorld::reduce(int sendRank, @@ -1180,6 +1223,93 @@ void MpiWorld::initLocalQueues() } } +// Here we rely on the scheduler returning a list of hosts where equal +// hosts are always contiguous with the exception of the master host +// (thisHost) which may appear repeated at the end if the system is +// overloaded. +std::vector MpiWorld::initLocalBasePorts( + const std::vector& executedAt) +{ + std::vector basePortForRank; + basePortForRank.reserve(size); + + std::string lastHost = thisHost; + int lastPort = MPI_PORT; + for (const auto& host : executedAt) { + if (host == thisHost) { + basePortForRank.push_back(MPI_PORT); + } else if (host == lastHost) { + basePortForRank.push_back(lastPort); + } else { + lastHost = host; + lastPort += size * size; + basePortForRank.push_back(lastPort); + } + } + + assert(basePortForRank.size() == size); + return basePortForRank; +} + +std::shared_ptr +MpiWorld::recvBatchReturnLast(int sendRank, int recvRank, int batchSize) +{ + std::shared_ptr umb = + getUnackedMessageBuffer(sendRank, recvRank); + + // When calling from recv, we set the batch size to zero and work + // out the total here. We want to acknowledge _all_ unacknowleged messages + // _and then_ receive ours (which is not in the MMB). + if (batchSize == 0) { + batchSize = umb->getTotalUnackedMessages() + 1; + } + + // Work out whether the message is sent locally or from another host + assert(thisHost == getHostForRank(recvRank)); + const std::string otherHost = getHostForRank(sendRank); + bool isLocal = otherHost == thisHost; + + // Recv message: first we receive all messages for which there is an id + // in the unacknowleged buffer but no msg. Note that these messages + // (batchSize - 1) were `irecv`-ed before ours. + std::shared_ptr ourMsg; + auto msgIt = umb->getFirstNullMsg(); + if (isLocal) { + // First receive messages that happened before us + for (int i = 0; i < batchSize - 1; i++) { + SPDLOG_TRACE("MPI - pending recv {} -> {}", sendRank, recvRank); + auto pendingMsg = getLocalQueue(sendRank, recvRank)->dequeue(); + + // Put the unacked message in the UMB + assert(!msgIt->isAcknowledged()); + msgIt->acknowledge(pendingMsg); + msgIt++; + } + + // Finally receive the message corresponding to us + SPDLOG_TRACE("MPI - recv {} -> {}", sendRank, recvRank); + ourMsg = getLocalQueue(sendRank, recvRank)->dequeue(); + } else { + // First receive messages that happened before us + for (int i = 0; i < batchSize - 1; i++) { + SPDLOG_TRACE( + "MPI - pending remote recv {} -> {}", sendRank, recvRank); + auto pendingMsg = recvRemoteMpiMessage(sendRank, recvRank); + + // Put the unacked message in the UMB + assert(!msgIt->isAcknowledged()); + msgIt->acknowledge(pendingMsg); + msgIt++; + } + + // Finally receive the message corresponding to us + SPDLOG_TRACE("MPI - recv remote {} -> {}", sendRank, recvRank); + ourMsg = recvRemoteMpiMessage(sendRank, recvRank); + } + + return ourMsg; +} + int MpiWorld::getIndexForRanks(int sendRank, int recvRank) { int index = sendRank * size + recvRank; diff --git a/src/transport/MessageEndpoint.cpp b/src/transport/MessageEndpoint.cpp index 5f87412b3..1cb71a871 100644 --- a/src/transport/MessageEndpoint.cpp +++ b/src/transport/MessageEndpoint.cpp @@ -279,11 +279,6 @@ RecvMessageEndpoint::RecvMessageEndpoint(int portIn) : MessageEndpoint(ANY_HOST, portIn) {} -RecvMessageEndpoint::RecvMessageEndpoint(int portIn, - const std::string& overrideHost) - : MessageEndpoint(overrideHost, portIn) -{} - void RecvMessageEndpoint::open(MessageContext& context) { SPDLOG_TRACE( diff --git a/src/transport/MpiMessageEndpoint.cpp b/src/transport/MpiMessageEndpoint.cpp index 4aebbefcd..67be90801 100644 --- a/src/transport/MpiMessageEndpoint.cpp +++ b/src/transport/MpiMessageEndpoint.cpp @@ -37,10 +37,10 @@ MpiMessageEndpoint::MpiMessageEndpoint(const std::string& hostIn, int portIn) } MpiMessageEndpoint::MpiMessageEndpoint(const std::string& hostIn, - int portIn, - const std::string& overrideRecvHost) - : sendMessageEndpoint(hostIn, portIn) - , recvMessageEndpoint(portIn, overrideRecvHost) + int sendPort, + int recvPort) + : sendMessageEndpoint(hostIn, sendPort) + , recvMessageEndpoint(recvPort) { sendMessageEndpoint.open(faabric::transport::getGlobalMessageContext()); recvMessageEndpoint.open(faabric::transport::getGlobalMessageContext()); diff --git a/tests/test/scheduler/test_mpi_message_buffer.cpp b/tests/test/scheduler/test_mpi_message_buffer.cpp new file mode 100644 index 000000000..40a845c0e --- /dev/null +++ b/tests/test/scheduler/test_mpi_message_buffer.cpp @@ -0,0 +1,118 @@ +#include + +#include +#include +#include + +using namespace faabric::scheduler; + +MpiMessageBuffer::PendingAsyncMpiMessage genRandomArguments( + bool nullMsg = true, + int overrideRequestId = -1) +{ + int requestId; + if (overrideRequestId != -1) { + requestId = overrideRequestId; + } else { + requestId = static_cast(faabric::util::generateGid()); + } + + MpiMessageBuffer::PendingAsyncMpiMessage pendingMsg; + pendingMsg.requestId = requestId; + + if (!nullMsg) { + pendingMsg.msg = std::make_shared(); + } + + return pendingMsg; +} + +namespace tests { +TEST_CASE("Test adding message to message buffer", "[mpi]") +{ + MpiMessageBuffer mmb; + REQUIRE(mmb.isEmpty()); + + REQUIRE_NOTHROW(mmb.addMessage(genRandomArguments())); + REQUIRE(mmb.size() == 1); +} + +TEST_CASE("Test deleting message from message buffer", "[mpi]") +{ + MpiMessageBuffer mmb; + REQUIRE(mmb.isEmpty()); + + mmb.addMessage(genRandomArguments()); + REQUIRE(mmb.size() == 1); + + auto it = mmb.getFirstNullMsg(); + REQUIRE_NOTHROW(mmb.deleteMessage(it)); + + REQUIRE(mmb.isEmpty()); +} + +TEST_CASE("Test getting an iterator from a request id", "[mpi]") +{ + MpiMessageBuffer mmb; + + int requestId = 1337; + mmb.addMessage(genRandomArguments(true, requestId)); + + auto it = mmb.getRequestPendingMsg(requestId); + REQUIRE(it->requestId == requestId); +} + +TEST_CASE("Test getting first null message", "[mpi]") +{ + MpiMessageBuffer mmb; + + // Add first a non-null message + int requestId1 = 1; + mmb.addMessage(genRandomArguments(false, requestId1)); + + // Then add a null message + int requestId2 = 2; + mmb.addMessage(genRandomArguments(true, requestId2)); + + // Query for the first non-null message + auto it = mmb.getFirstNullMsg(); + REQUIRE(it->requestId == requestId2); +} + +TEST_CASE("Test getting total unacked messages in message buffer", "[mpi]") +{ + MpiMessageBuffer mmb; + + REQUIRE(mmb.getTotalUnackedMessages() == 0); + + // Add a non-null message + mmb.addMessage(genRandomArguments(false)); + + // Then a couple of null messages + mmb.addMessage(genRandomArguments(true)); + mmb.addMessage(genRandomArguments(true)); + + // Check that we have two unacked messages + REQUIRE(mmb.getTotalUnackedMessages() == 2); +} + +TEST_CASE("Test getting total unacked messages in message buffer range", + "[mpi]") +{ + MpiMessageBuffer mmb; + + // Add a non-null message + mmb.addMessage(genRandomArguments(false)); + + // Then a couple of null messages + int requestId = 1337; + mmb.addMessage(genRandomArguments(true)); + mmb.addMessage(genRandomArguments(true, requestId)); + + // Get an iterator to our second null message + auto it = mmb.getRequestPendingMsg(requestId); + + // Check that we have only one unacked message until the iterator + REQUIRE(mmb.getTotalUnackedMessagesUntil(it) == 1); +} +} diff --git a/tests/test/scheduler/test_mpi_world.cpp b/tests/test/scheduler/test_mpi_world.cpp index 3d626e10e..a4a3b0391 100644 --- a/tests/test/scheduler/test_mpi_world.cpp +++ b/tests/test/scheduler/test_mpi_world.cpp @@ -1202,4 +1202,61 @@ TEST_CASE_METHOD(MpiBaseTestFixture, "Test all-to-all", "[mpi]") world.destroy(); } + +TEST_CASE_METHOD(MpiTestFixture, + "Test can't destroy world with outstanding requests", + "[mpi]") +{ + int rankA = 0; + int rankB = 1; + int data = 9; + int actual = -1; + + SECTION("Outstanding irecv") + { + world.send(rankA, rankB, BYTES(&data), MPI_INT, 1); + int recvId = world.irecv(rankA, rankB, BYTES(&actual), MPI_INT, 1); + + REQUIRE_THROWS(world.destroy()); + + world.awaitAsyncRequest(recvId); + REQUIRE(actual == data); + } + + SECTION("Outstanding acknowledged irecv") + { + int data2 = 14; + int actual2 = -1; + + world.send(rankA, rankB, BYTES(&data), MPI_INT, 1); + world.send(rankA, rankB, BYTES(&data2), MPI_INT, 1); + int recvId = world.irecv(rankA, rankB, BYTES(&actual), MPI_INT, 1); + int recvId2 = world.irecv(rankA, rankB, BYTES(&actual2), MPI_INT, 1); + + REQUIRE_THROWS(world.destroy()); + + // Await for the second request, which will acknowledge the first one + // but not remove it from the pending message buffer + world.awaitAsyncRequest(recvId2); + + REQUIRE_THROWS(world.destroy()); + + // Await for the first one + world.awaitAsyncRequest(recvId); + + REQUIRE(actual == data); + REQUIRE(actual2 == data2); + } + + SECTION("Outstanding isend") + { + int sendId = world.isend(rankA, rankB, BYTES(&data), MPI_INT, 1); + world.recv(rankA, rankB, BYTES(&actual), MPI_INT, 1, MPI_STATUS_IGNORE); + + REQUIRE_THROWS(world.destroy()); + + world.awaitAsyncRequest(sendId); + REQUIRE(actual == data); + } +} } diff --git a/tests/test/scheduler/test_remote_mpi_worlds.cpp b/tests/test/scheduler/test_remote_mpi_worlds.cpp index 84aeaa2cf..cbe84d573 100644 --- a/tests/test/scheduler/test_remote_mpi_worlds.cpp +++ b/tests/test/scheduler/test_remote_mpi_worlds.cpp @@ -67,33 +67,89 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, "Test send across hosts", "[mpi]") MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); faabric::util::setMockMode(false); - std::thread senderThread([this, rankA, rankB] { - std::vector messageData = { 0, 1, 2 }; + std::thread senderThread([this, rankA, rankB, &messageData] { + remoteWorld.initialiseFromMsg(msg); + + // Send a message that should get sent to this host + remoteWorld.send( + rankB, rankA, BYTES(messageData.data()), MPI_INT, messageData.size()); + + usleep(1000 * 500); + + remoteWorld.destroy(); + }); + + // Receive the message for the given rank + MPI_Status status{}; + auto buffer = new int[messageData.size()]; + localWorld.recv( + rankB, rankA, BYTES(buffer), MPI_INT, messageData.size(), &status); + + std::vector actual(buffer, buffer + messageData.size()); + REQUIRE(actual == messageData); + + REQUIRE(status.MPI_SOURCE == rankB); + REQUIRE(status.MPI_ERROR == MPI_SUCCESS); + REQUIRE(status.bytesSize == messageData.size() * sizeof(int)); + + // Destroy worlds + senderThread.join(); + localWorld.destroy(); +} + +TEST_CASE_METHOD(RemoteMpiTestFixture, + "Test send and recv across hosts", + "[mpi]") +{ + // Register two ranks (one on each host) + this->setWorldsSizes(2, 1, 1); + int rankA = 0; + int rankB = 1; + std::vector messageData = { 0, 1, 2 }; + std::vector messageData2 = { 3, 4, 5 }; + + // Init worlds + MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); + faabric::util::setMockMode(false); + std::thread senderThread([this, rankA, rankB, &messageData, &messageData2] { remoteWorld.initialiseFromMsg(msg); // Send a message that should get sent to this host remoteWorld.send( rankB, rankA, BYTES(messageData.data()), MPI_INT, messageData.size()); + + // Now recv + auto buffer = new int[messageData2.size()]; + remoteWorld.recv(rankA, + rankB, + BYTES(buffer), + MPI_INT, + messageData2.size(), + MPI_STATUS_IGNORE); + std::vector actual(buffer, buffer + messageData2.size()); + REQUIRE(actual == messageData2); + usleep(1000 * 500); + remoteWorld.destroy(); }); - SECTION("Check recv") - { - // Receive the message for the given rank - MPI_Status status{}; - auto buffer = new int[messageData.size()]; - localWorld.recv( - rankB, rankA, BYTES(buffer), MPI_INT, messageData.size(), &status); + // Receive the message for the given rank + MPI_Status status{}; + auto buffer = new int[messageData.size()]; + localWorld.recv( + rankB, rankA, BYTES(buffer), MPI_INT, messageData.size(), &status); + std::vector actual(buffer, buffer + messageData.size()); + REQUIRE(actual == messageData); - std::vector actual(buffer, buffer + messageData.size()); - REQUIRE(actual == messageData); + // Now send a message + localWorld.send( + rankA, rankB, BYTES(messageData2.data()), MPI_INT, messageData2.size()); - REQUIRE(status.MPI_SOURCE == rankB); - REQUIRE(status.MPI_ERROR == MPI_SUCCESS); - REQUIRE(status.bytesSize == messageData.size() * sizeof(int)); - } + REQUIRE(status.MPI_SOURCE == rankB); + REQUIRE(status.MPI_ERROR == MPI_SUCCESS); + REQUIRE(status.bytesSize == messageData.size() * sizeof(int)); // Destroy worlds senderThread.join(); @@ -115,21 +171,21 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, faabric::util::setMockMode(false); std::thread senderThread([this, rankA, rankB, numMessages] { - std::vector messageData = { 0, 1, 2 }; - remoteWorld.initialiseFromMsg(msg); for (int i = 0; i < numMessages; i++) { - remoteWorld.send(rankB, rankA, BYTES(&i), MPI_INT, sizeof(int)); + remoteWorld.send(rankB, rankA, BYTES(&i), MPI_INT, 1); } + usleep(1000 * 500); + remoteWorld.destroy(); }); int recv; for (int i = 0; i < numMessages; i++) { localWorld.recv( - rankB, rankA, BYTES(&recv), MPI_INT, sizeof(int), MPI_STATUS_IGNORE); + rankB, rankA, BYTES(&recv), MPI_INT, 1, MPI_STATUS_IGNORE); // Check in-order delivery if (i % (numMessages / 10) == 0) { @@ -174,6 +230,8 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, assert(actual == messageData); } + usleep(1000 * 500); + remoteWorld.destroy(); }); @@ -246,6 +304,8 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, nPerRank); assert(actual == std::vector({ 12, 13, 14, 15 })); + usleep(1000 * 500); + remoteWorld.destroy(); }); @@ -333,6 +393,8 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, nPerRank); } + usleep(1000 * 500); + remoteWorld.destroy(); }); @@ -366,4 +428,191 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, senderThread.join(); localWorld.destroy(); } + +TEST_CASE_METHOD(RemoteMpiTestFixture, + "Test sending sync and async message to same host", + "[mpi]") +{ + // Allocate two ranks in total, one rank per host + this->setWorldsSizes(2, 1, 1); + int sendRank = 1; + int recvRank = 0; + std::vector messageData = { 0, 1, 2 }; + + // Init world + MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); + faabric::util::setMockMode(false); + + std::thread senderThread([this, sendRank, recvRank, &messageData] { + remoteWorld.initialiseFromMsg(msg); + + // Send message twice + remoteWorld.send(sendRank, + recvRank, + BYTES(messageData.data()), + MPI_INT, + messageData.size()); + remoteWorld.send(sendRank, + recvRank, + BYTES(messageData.data()), + MPI_INT, + messageData.size()); + + usleep(1000 * 500); + + remoteWorld.destroy(); + }); + + // Receive one message asynchronously + std::vector asyncMessage(messageData.size(), 0); + int recvId = localWorld.irecv(sendRank, + recvRank, + BYTES(asyncMessage.data()), + MPI_INT, + asyncMessage.size()); + + // Receive one message synchronously + std::vector syncMessage(messageData.size(), 0); + localWorld.recv(sendRank, + recvRank, + BYTES(syncMessage.data()), + MPI_INT, + syncMessage.size(), + MPI_STATUS_IGNORE); + + // Wait for the async message + localWorld.awaitAsyncRequest(recvId); + + // Checks + REQUIRE(syncMessage == messageData); + REQUIRE(asyncMessage == messageData); + + // Destroy world + senderThread.join(); + localWorld.destroy(); +} + +TEST_CASE_METHOD(RemoteMpiTestFixture, + "Test receiving remote async requests out of order", + "[mpi]") +{ + // Allocate two ranks in total, one rank per host + this->setWorldsSizes(2, 1, 1); + int sendRank = 1; + int recvRank = 0; + + // Init world + MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); + faabric::util::setMockMode(false); + + std::thread senderThread([this, sendRank, recvRank] { + remoteWorld.initialiseFromMsg(msg); + + // Send different messages + for (int i = 0; i < 3; i++) { + remoteWorld.send(sendRank, recvRank, BYTES(&i), MPI_INT, 1); + } + + usleep(1000 * 500); + + remoteWorld.destroy(); + }); + + // Receive two messages asynchronously + int recv1, recv2, recv3; + int recvId1 = + localWorld.irecv(sendRank, recvRank, BYTES(&recv1), MPI_INT, 1); + + int recvId2 = + localWorld.irecv(sendRank, recvRank, BYTES(&recv2), MPI_INT, 1); + + // Receive one message synchronously + localWorld.recv( + sendRank, recvRank, BYTES(&recv3), MPI_INT, 1, MPI_STATUS_IGNORE); + + SECTION("Wait out of order") + { + localWorld.awaitAsyncRequest(recvId2); + localWorld.awaitAsyncRequest(recvId1); + } + + SECTION("Wait in order") + { + localWorld.awaitAsyncRequest(recvId1); + localWorld.awaitAsyncRequest(recvId2); + } + + // Checks + REQUIRE(recv1 == 0); + REQUIRE(recv2 == 1); + REQUIRE(recv3 == 2); + + // Destroy world + senderThread.join(); + localWorld.destroy(); +} + +TEST_CASE_METHOD(RemoteMpiTestFixture, + "Test ring sendrecv across hosts", + "[mpi]") +{ + // Allocate two ranks in total, one rank per host + this->setWorldsSizes(3, 1, 2); + int worldSize = 3; + std::vector localRanks = { 0 }; + + // Init world + MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); + faabric::util::setMockMode(false); + + std::thread senderThread([this, worldSize] { + std::vector remoteRanks = { 1, 2 }; + remoteWorld.initialiseFromMsg(msg); + + // Send different messages + for (auto& rank : remoteRanks) { + int left = rank > 0 ? rank - 1 : worldSize - 1; + int right = (rank + 1) % worldSize; + int recvData = -1; + + remoteWorld.sendRecv(BYTES(&rank), + 1, + MPI_INT, + right, + BYTES(&recvData), + 1, + MPI_INT, + left, + rank, + MPI_STATUS_IGNORE); + } + + usleep(1000 * 500); + + remoteWorld.destroy(); + }); + + for (auto& rank : localRanks) { + int left = rank > 0 ? rank - 1 : worldSize - 1; + int right = (rank + 1) % worldSize; + int recvData = -1; + + localWorld.sendRecv(BYTES(&rank), + 1, + MPI_INT, + right, + BYTES(&recvData), + 1, + MPI_INT, + left, + rank, + MPI_STATUS_IGNORE); + + REQUIRE(recvData == left); + } + + // Destroy world + senderThread.join(); + localWorld.destroy(); +} } diff --git a/tests/test/transport/test_mpi_message_endpoint.cpp b/tests/test/transport/test_mpi_message_endpoint.cpp index a5b6a4ced..a27ae9bf7 100644 --- a/tests/test/transport/test_mpi_message_endpoint.cpp +++ b/tests/test/transport/test_mpi_message_endpoint.cpp @@ -31,8 +31,8 @@ TEST_CASE_METHOD(MessageContextFixture, "[transport]") { std::string thisHost = faabric::util::getSystemConfig().endpointHost; - MpiMessageEndpoint sendEndpoint(LOCALHOST, 9999, thisHost); - MpiMessageEndpoint recvEndpoint(thisHost, 9999, LOCALHOST); + MpiMessageEndpoint sendEndpoint(LOCALHOST, 9999, 9998); + MpiMessageEndpoint recvEndpoint(thisHost, 9998, 9999); std::shared_ptr expected = std::make_shared();