From a4b668e11b9d1d23f67a9a8f12be25b701d229ef Mon Sep 17 00:00:00 2001 From: Carlos Segarra Date: Tue, 15 Jun 2021 10:00:15 +0000 Subject: [PATCH] adding tests for the mpi message buffer + formatting --- include/faabric/scheduler/MpiMessageBuffer.h | 34 +++-- src/scheduler/MpiMessageBuffer.cpp | 25 ++-- src/scheduler/MpiWorld.cpp | 18 +-- .../scheduler/test_mpi_message_buffer.cpp | 119 ++++++++++++++++++ 4 files changed, 158 insertions(+), 38 deletions(-) create mode 100644 tests/test/scheduler/test_mpi_message_buffer.cpp diff --git a/include/faabric/scheduler/MpiMessageBuffer.h b/include/faabric/scheduler/MpiMessageBuffer.h index 88f4ad7f6..06d23345f 100644 --- a/include/faabric/scheduler/MpiMessageBuffer.h +++ b/include/faabric/scheduler/MpiMessageBuffer.h @@ -10,11 +10,17 @@ namespace faabric::scheduler { * still have not waited on (acknowledged). Messages are acknowledged either * through a call to recv or a call to await. A call to recv will * acknowledge (i.e. synchronously read from transport buffers) as many - * unacknowleged messages there are, plus one. + * unacknowleged messages there are, plus one. A call to await with a request + * id as a parameter will acknowledge as many unacknowleged messages there are + * until said request id. */ class MpiMessageBuffer { public: + /* This structure holds the metadata for each Mpi message we keep in the + * buffer. Note that the message field will point to null if unacknowleged + * or to a valid message otherwise. + */ struct Arguments { int requestId; @@ -27,29 +33,39 @@ class MpiMessageBuffer faabric::MPIMessage::MPIMessageType messageType; }; - void addMessage(Arguments arg); - - void deleteMessage(const std::list::iterator& argIt); + /* Interface to query the buffer size */ bool isEmpty(); int size(); - std::list::iterator getRequestArguments(int requestId); + /* Interface to add and delete messages to the buffer */ - std::list::iterator getFirstNullMsgUntil( - const std::list::iterator& argIt); + void addMessage(Arguments arg); + + void deleteMessage(const std::list::iterator& argIt); + + /* Interface to get a pointer to a message in the MMB */ + + // Pointer to a message given its request id + std::list::iterator getRequestArguments(int requestId); + // Pointer to the first null-pointing (unacknowleged) message std::list::iterator getFirstNullMsg(); + /* Interface to ask for the number of unacknowleged messages */ + + // Unacknowledged messages until an iterator (used in await) int getTotalUnackedMessagesUntil( const std::list::iterator& argIt); + // Unacknowledged messages in the whole buffer (used in recv) int getTotalUnackedMessages(); private: - // We keep track of the request id and its arguments. Note that the message - // is part of the arguments and may be null if unacknowleged. std::list args; + + std::list::iterator getFirstNullMsgUntil( + const std::list::iterator& argIt); }; } diff --git a/src/scheduler/MpiMessageBuffer.cpp b/src/scheduler/MpiMessageBuffer.cpp index 3ddd2395a..ec13a2f8b 100644 --- a/src/scheduler/MpiMessageBuffer.cpp +++ b/src/scheduler/MpiMessageBuffer.cpp @@ -3,33 +3,30 @@ namespace faabric::scheduler { typedef std::list::iterator ArgListIterator; -void MpiMessageBuffer::addMessage(Arguments arg) +bool MpiMessageBuffer::isEmpty() { - // Ensure we are enqueueing a null message (i.e. unacknowleged) - assert(arg.msg == nullptr); - - args.push_back(arg); + return args.empty(); } -void MpiMessageBuffer::deleteMessage(const ArgListIterator& argIt) +int MpiMessageBuffer::size() { - args.erase(argIt); - return; + return args.size(); } -bool MpiMessageBuffer::isEmpty() +void MpiMessageBuffer::addMessage(Arguments arg) { - return args.empty(); + args.push_back(arg); } -int MpiMessageBuffer::size() +void MpiMessageBuffer::deleteMessage(const ArgListIterator& argIt) { - return args.size(); + args.erase(argIt); + return; } ArgListIterator MpiMessageBuffer::getRequestArguments(int requestId) { - // The request id must be in the UMB, as an irecv must happen before an + // The request id must be in the MMB, as an irecv must happen before an // await ArgListIterator argIt = std::find_if(args.begin(), args.end(), [requestId](Arguments args) { @@ -38,7 +35,7 @@ ArgListIterator MpiMessageBuffer::getRequestArguments(int requestId) // If it's not there, error out if (argIt == args.end()) { - SPDLOG_ERROR("Asynchronous request id not in UMB: {}", requestId); + SPDLOG_ERROR("Asynchronous request id not in buffer: {}", requestId); throw std::runtime_error("Async request not in buffer"); } diff --git a/src/scheduler/MpiWorld.cpp b/src/scheduler/MpiWorld.cpp index 3ade3ff44..d9eb0b607 100644 --- a/src/scheduler/MpiWorld.cpp +++ b/src/scheduler/MpiWorld.cpp @@ -5,9 +5,8 @@ #include #include -/* Each MPI rank runs in a separate thread, however they interact with faabric - * as a library. Thus, we use thread_local storage to guarantee that each rank - * sees its own version of these data structures. +/* Each MPI rank runs in a separate thread, thus we use TLS to maintain the + * per-rank data structures. */ static thread_local std::vector< std::unique_ptr> @@ -112,9 +111,7 @@ MpiWorld::getUnackedMessageBuffer(int sendRank, int recvRank) // thread local nature, we expect it to be quite sparse (i.e. filled with // nullptr). if (unackedMessageBuffers.size() == 0) { - for (int i = 0; i < size * size; i++) { - unackedMessageBuffers.emplace_back(nullptr); - } + unackedMessageBuffers.resize(size * size, nullptr); } // Get the index for the rank-host pair @@ -505,15 +502,6 @@ void MpiWorld::recv(int sendRank, { // Sanity-check input parameters checkRanksRange(sendRank, recvRank); - if (getHostForRank(recvRank) != thisHost) { - SPDLOG_ERROR("Trying to recv message into a non-local rank: {}", - recvRank); - throw std::runtime_error("Receiving message into non-local rank"); - } - - // Work out whether the message is sent locally or from another host - const std::string otherHost = getHostForRank(sendRank); - bool isLocal = otherHost == thisHost; // Recv message from underlying transport std::shared_ptr m = diff --git a/tests/test/scheduler/test_mpi_message_buffer.cpp b/tests/test/scheduler/test_mpi_message_buffer.cpp new file mode 100644 index 000000000..16665771f --- /dev/null +++ b/tests/test/scheduler/test_mpi_message_buffer.cpp @@ -0,0 +1,119 @@ +#include + +#include +#include +#include + +using namespace faabric::scheduler; + +MpiMessageBuffer::Arguments genRandomArguments(bool nullMsg = true, + int overrideRequestId = -1) +{ + int requestId; + if (overrideRequestId != -1) { + requestId = overrideRequestId; + } else { + requestId = static_cast(faabric::util::generateGid()); + } + + MpiMessageBuffer::Arguments args = { + requestId, nullptr, 0, 1, + nullptr, MPI_INT, 0, faabric::MPIMessage::NORMAL + }; + + if (!nullMsg) { + args.msg = std::make_shared(); + } + + return args; +} + +namespace tests { +TEST_CASE("Test adding message to message buffer", "[mpi]") +{ + MpiMessageBuffer mmb; + REQUIRE(mmb.isEmpty()); + + REQUIRE_NOTHROW(mmb.addMessage(genRandomArguments())); + REQUIRE(mmb.size() == 1); +} + +TEST_CASE("Test deleting message from message buffer", "[mpi]") +{ + MpiMessageBuffer mmb; + REQUIRE(mmb.isEmpty()); + + mmb.addMessage(genRandomArguments()); + REQUIRE(mmb.size() == 1); + + auto it = mmb.getFirstNullMsg(); + REQUIRE_NOTHROW(mmb.deleteMessage(it)); + + REQUIRE(mmb.isEmpty()); +} + +TEST_CASE("Test getting an iterator from a request id", "[mpi]") +{ + MpiMessageBuffer mmb; + + int requestId = 1337; + mmb.addMessage(genRandomArguments(true, requestId)); + + auto it = mmb.getRequestArguments(requestId); + REQUIRE(it->requestId == requestId); +} + +TEST_CASE("Test getting first null message", "[mpi]") +{ + MpiMessageBuffer mmb; + + // Add first a non-null message + int requestId1 = 1; + mmb.addMessage(genRandomArguments(false, requestId1)); + + // Then add a null message + int requestId2 = 2; + mmb.addMessage(genRandomArguments(true, requestId2)); + + // Query for the first non-null message + auto it = mmb.getFirstNullMsg(); + REQUIRE(it->requestId == requestId2); +} + +TEST_CASE("Test getting total unacked messages in message buffer", "[mpi]") +{ + MpiMessageBuffer mmb; + + REQUIRE(mmb.getTotalUnackedMessages() == 0); + + // Add a non-null message + mmb.addMessage(genRandomArguments(false)); + + // Then a couple of null messages + mmb.addMessage(genRandomArguments(true)); + mmb.addMessage(genRandomArguments(true)); + + // Check that we have two unacked messages + REQUIRE(mmb.getTotalUnackedMessages() == 2); +} + +TEST_CASE("Test getting total unacked messages in message buffer range", + "[mpi]") +{ + MpiMessageBuffer mmb; + + // Add a non-null message + mmb.addMessage(genRandomArguments(false)); + + // Then a couple of null messages + int requestId = 1337; + mmb.addMessage(genRandomArguments(true)); + mmb.addMessage(genRandomArguments(true, requestId)); + + // Get an iterator to our second null message + auto it = mmb.getRequestArguments(requestId); + + // Check that we have only one unacked message until the iterator + REQUIRE(mmb.getTotalUnackedMessagesUntil(it) == 1); +} +}