Skip to content

Commit

Permalink
adding tests for the mpi message buffer + formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
csegarragonz committed Jun 15, 2021
1 parent e986079 commit a4b668e
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 38 deletions.
34 changes: 25 additions & 9 deletions include/faabric/scheduler/MpiMessageBuffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,17 @@ namespace faabric::scheduler {
* still have not waited on (acknowledged). Messages are acknowledged either
* through a call to recv or a call to await. A call to recv will
* acknowledge (i.e. synchronously read from transport buffers) as many
* unacknowleged messages there are, plus one.
* unacknowleged messages there are, plus one. A call to await with a request
* id as a parameter will acknowledge as many unacknowleged messages there are
* until said request id.
*/
class MpiMessageBuffer
{
public:
/* This structure holds the metadata for each Mpi message we keep in the
* buffer. Note that the message field will point to null if unacknowleged
* or to a valid message otherwise.
*/
struct Arguments
{
int requestId;
Expand All @@ -27,29 +33,39 @@ class MpiMessageBuffer
faabric::MPIMessage::MPIMessageType messageType;
};

void addMessage(Arguments arg);

void deleteMessage(const std::list<Arguments>::iterator& argIt);
/* Interface to query the buffer size */

bool isEmpty();

int size();

std::list<Arguments>::iterator getRequestArguments(int requestId);
/* Interface to add and delete messages to the buffer */

std::list<Arguments>::iterator getFirstNullMsgUntil(
const std::list<Arguments>::iterator& argIt);
void addMessage(Arguments arg);

void deleteMessage(const std::list<Arguments>::iterator& argIt);

/* Interface to get a pointer to a message in the MMB */

// Pointer to a message given its request id
std::list<Arguments>::iterator getRequestArguments(int requestId);

// Pointer to the first null-pointing (unacknowleged) message
std::list<Arguments>::iterator getFirstNullMsg();

/* Interface to ask for the number of unacknowleged messages */

// Unacknowledged messages until an iterator (used in await)
int getTotalUnackedMessagesUntil(
const std::list<Arguments>::iterator& argIt);

// Unacknowledged messages in the whole buffer (used in recv)
int getTotalUnackedMessages();

private:
// We keep track of the request id and its arguments. Note that the message
// is part of the arguments and may be null if unacknowleged.
std::list<Arguments> args;

std::list<Arguments>::iterator getFirstNullMsgUntil(
const std::list<Arguments>::iterator& argIt);
};
}
25 changes: 11 additions & 14 deletions src/scheduler/MpiMessageBuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,30 @@

namespace faabric::scheduler {
typedef std::list<MpiMessageBuffer::Arguments>::iterator ArgListIterator;
void MpiMessageBuffer::addMessage(Arguments arg)
bool MpiMessageBuffer::isEmpty()
{
// Ensure we are enqueueing a null message (i.e. unacknowleged)
assert(arg.msg == nullptr);

args.push_back(arg);
return args.empty();
}

void MpiMessageBuffer::deleteMessage(const ArgListIterator& argIt)
int MpiMessageBuffer::size()
{
args.erase(argIt);
return;
return args.size();
}

bool MpiMessageBuffer::isEmpty()
void MpiMessageBuffer::addMessage(Arguments arg)
{
return args.empty();
args.push_back(arg);
}

int MpiMessageBuffer::size()
void MpiMessageBuffer::deleteMessage(const ArgListIterator& argIt)
{
return args.size();
args.erase(argIt);
return;
}

ArgListIterator MpiMessageBuffer::getRequestArguments(int requestId)
{
// The request id must be in the UMB, as an irecv must happen before an
// The request id must be in the MMB, as an irecv must happen before an
// await
ArgListIterator argIt =
std::find_if(args.begin(), args.end(), [requestId](Arguments args) {
Expand All @@ -38,7 +35,7 @@ ArgListIterator MpiMessageBuffer::getRequestArguments(int requestId)

// If it's not there, error out
if (argIt == args.end()) {
SPDLOG_ERROR("Asynchronous request id not in UMB: {}", requestId);
SPDLOG_ERROR("Asynchronous request id not in buffer: {}", requestId);
throw std::runtime_error("Async request not in buffer");
}

Expand Down
18 changes: 3 additions & 15 deletions src/scheduler/MpiWorld.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
#include <faabric/util/gids.h>
#include <faabric/util/macros.h>

/* Each MPI rank runs in a separate thread, however they interact with faabric
* as a library. Thus, we use thread_local storage to guarantee that each rank
* sees its own version of these data structures.
/* Each MPI rank runs in a separate thread, thus we use TLS to maintain the
* per-rank data structures.
*/
static thread_local std::vector<
std::unique_ptr<faabric::transport::MpiMessageEndpoint>>
Expand Down Expand Up @@ -112,9 +111,7 @@ MpiWorld::getUnackedMessageBuffer(int sendRank, int recvRank)
// thread local nature, we expect it to be quite sparse (i.e. filled with
// nullptr).
if (unackedMessageBuffers.size() == 0) {
for (int i = 0; i < size * size; i++) {
unackedMessageBuffers.emplace_back(nullptr);
}
unackedMessageBuffers.resize(size * size, nullptr);
}

// Get the index for the rank-host pair
Expand Down Expand Up @@ -505,15 +502,6 @@ void MpiWorld::recv(int sendRank,
{
// Sanity-check input parameters
checkRanksRange(sendRank, recvRank);
if (getHostForRank(recvRank) != thisHost) {
SPDLOG_ERROR("Trying to recv message into a non-local rank: {}",
recvRank);
throw std::runtime_error("Receiving message into non-local rank");
}

// Work out whether the message is sent locally or from another host
const std::string otherHost = getHostForRank(sendRank);
bool isLocal = otherHost == thisHost;

// Recv message from underlying transport
std::shared_ptr<faabric::MPIMessage> m =
Expand Down
119 changes: 119 additions & 0 deletions tests/test/scheduler/test_mpi_message_buffer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
#include <catch.hpp>

#include <faabric/mpi/mpi.h>
#include <faabric/scheduler/MpiMessageBuffer.h>
#include <faabric/util/gids.h>

using namespace faabric::scheduler;

MpiMessageBuffer::Arguments genRandomArguments(bool nullMsg = true,
int overrideRequestId = -1)
{
int requestId;
if (overrideRequestId != -1) {
requestId = overrideRequestId;
} else {
requestId = static_cast<int>(faabric::util::generateGid());
}

MpiMessageBuffer::Arguments args = {
requestId, nullptr, 0, 1,
nullptr, MPI_INT, 0, faabric::MPIMessage::NORMAL
};

if (!nullMsg) {
args.msg = std::make_shared<faabric::MPIMessage>();
}

return args;
}

namespace tests {
TEST_CASE("Test adding message to message buffer", "[mpi]")
{
MpiMessageBuffer mmb;
REQUIRE(mmb.isEmpty());

REQUIRE_NOTHROW(mmb.addMessage(genRandomArguments()));
REQUIRE(mmb.size() == 1);
}

TEST_CASE("Test deleting message from message buffer", "[mpi]")
{
MpiMessageBuffer mmb;
REQUIRE(mmb.isEmpty());

mmb.addMessage(genRandomArguments());
REQUIRE(mmb.size() == 1);

auto it = mmb.getFirstNullMsg();
REQUIRE_NOTHROW(mmb.deleteMessage(it));

REQUIRE(mmb.isEmpty());
}

TEST_CASE("Test getting an iterator from a request id", "[mpi]")
{
MpiMessageBuffer mmb;

int requestId = 1337;
mmb.addMessage(genRandomArguments(true, requestId));

auto it = mmb.getRequestArguments(requestId);
REQUIRE(it->requestId == requestId);
}

TEST_CASE("Test getting first null message", "[mpi]")
{
MpiMessageBuffer mmb;

// Add first a non-null message
int requestId1 = 1;
mmb.addMessage(genRandomArguments(false, requestId1));

// Then add a null message
int requestId2 = 2;
mmb.addMessage(genRandomArguments(true, requestId2));

// Query for the first non-null message
auto it = mmb.getFirstNullMsg();
REQUIRE(it->requestId == requestId2);
}

TEST_CASE("Test getting total unacked messages in message buffer", "[mpi]")
{
MpiMessageBuffer mmb;

REQUIRE(mmb.getTotalUnackedMessages() == 0);

// Add a non-null message
mmb.addMessage(genRandomArguments(false));

// Then a couple of null messages
mmb.addMessage(genRandomArguments(true));
mmb.addMessage(genRandomArguments(true));

// Check that we have two unacked messages
REQUIRE(mmb.getTotalUnackedMessages() == 2);
}

TEST_CASE("Test getting total unacked messages in message buffer range",
"[mpi]")
{
MpiMessageBuffer mmb;

// Add a non-null message
mmb.addMessage(genRandomArguments(false));

// Then a couple of null messages
int requestId = 1337;
mmb.addMessage(genRandomArguments(true));
mmb.addMessage(genRandomArguments(true, requestId));

// Get an iterator to our second null message
auto it = mmb.getRequestArguments(requestId);

// Check that we have only one unacked message until the iterator
REQUIRE(mmb.getTotalUnackedMessagesUntil(it) == 1);
}
}

0 comments on commit a4b668e

Please sign in to comment.