Skip to content

Commit

Permalink
pr comments
Browse files Browse the repository at this point in the history
  • Loading branch information
csegarragonz committed Jun 15, 2021
1 parent 3055782 commit f59c056
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 91 deletions.
45 changes: 28 additions & 17 deletions include/faabric/scheduler/MpiMessageBuffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,25 @@ class MpiMessageBuffer
* buffer. Note that the message field will point to null if unacknowleged
* or to a valid message otherwise.
*/
struct Arguments
class PendingAsyncMpiMessage
{
int requestId;
std::shared_ptr<faabric::MPIMessage> msg;
int sendRank;
int recvRank;
uint8_t* buffer;
faabric_datatype_t* dataType;
int count;
faabric::MPIMessage::MPIMessageType messageType;
public:
int requestId = -1;
std::shared_ptr<faabric::MPIMessage> msg = nullptr;
int sendRank = -1;
int recvRank = -1;
uint8_t* buffer = nullptr;
faabric_datatype_t* dataType = nullptr;
int count = -1;
faabric::MPIMessage::MPIMessageType messageType =
faabric::MPIMessage::NORMAL;

bool isAcknowledged() { return msg != nullptr; }

void acknowledge(std::shared_ptr<faabric::MPIMessage> msgIn)
{
msg = msgIn;
}
};

/* Interface to query the buffer size */
Expand All @@ -41,31 +50,33 @@ class MpiMessageBuffer

/* Interface to add and delete messages to the buffer */

void addMessage(Arguments arg);
void addMessage(PendingAsyncMpiMessage msg);

void deleteMessage(const std::list<Arguments>::iterator& argIt);
void deleteMessage(
const std::list<PendingAsyncMpiMessage>::iterator& msgIt);

/* Interface to get a pointer to a message in the MMB */

// Pointer to a message given its request id
std::list<Arguments>::iterator getRequestArguments(int requestId);
std::list<PendingAsyncMpiMessage>::iterator getRequestPendingMsg(
int requestId);

// Pointer to the first null-pointing (unacknowleged) message
std::list<Arguments>::iterator getFirstNullMsg();
std::list<PendingAsyncMpiMessage>::iterator getFirstNullMsg();

/* Interface to ask for the number of unacknowleged messages */

// Unacknowledged messages until an iterator (used in await)
int getTotalUnackedMessagesUntil(
const std::list<Arguments>::iterator& argIt);
const std::list<PendingAsyncMpiMessage>::iterator& msgIt);

// Unacknowledged messages in the whole buffer (used in recv)
int getTotalUnackedMessages();

private:
std::list<Arguments> args;
std::list<PendingAsyncMpiMessage> pendingMsgs;

std::list<Arguments>::iterator getFirstNullMsgUntil(
const std::list<Arguments>::iterator& argIt);
std::list<PendingAsyncMpiMessage>::iterator getFirstNullMsgUntil(
const std::list<PendingAsyncMpiMessage>::iterator& msgIt);
};
}
4 changes: 2 additions & 2 deletions include/faabric/scheduler/MpiWorld.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,8 @@ class MpiWorld
void closeMpiMessageEndpoints();

// Support for asyncrhonous communications
std::shared_ptr<faabric::scheduler::MpiMessageBuffer>
getUnackedMessageBuffer(int sendRank, int recvRank);
std::shared_ptr<MpiMessageBuffer> getUnackedMessageBuffer(int sendRank,
int recvRank);
std::shared_ptr<faabric::MPIMessage> recvBatchReturnLast(int sendRank,
int recvRank,
int batchSize = 0);
Expand Down
56 changes: 30 additions & 26 deletions src/scheduler/MpiMessageBuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,68 +2,72 @@
#include <faabric/util/logging.h>

namespace faabric::scheduler {
typedef std::list<MpiMessageBuffer::Arguments>::iterator ArgListIterator;
typedef std::list<MpiMessageBuffer::PendingAsyncMpiMessage>::iterator
MpiMessageIterator;
bool MpiMessageBuffer::isEmpty()
{
return args.empty();
return pendingMsgs.empty();
}

int MpiMessageBuffer::size()
{
return args.size();
return pendingMsgs.size();
}

void MpiMessageBuffer::addMessage(Arguments arg)
void MpiMessageBuffer::addMessage(PendingAsyncMpiMessage msg)
{
args.push_back(arg);
pendingMsgs.push_back(msg);
}

void MpiMessageBuffer::deleteMessage(const ArgListIterator& argIt)
void MpiMessageBuffer::deleteMessage(const MpiMessageIterator& msgIt)
{
args.erase(argIt);
pendingMsgs.erase(msgIt);
}

ArgListIterator MpiMessageBuffer::getRequestArguments(int requestId)
MpiMessageIterator MpiMessageBuffer::getRequestPendingMsg(int requestId)
{
// The request id must be in the MMB, as an irecv must happen before an
// await
ArgListIterator argIt =
std::find_if(args.begin(), args.end(), [requestId](Arguments args) {
return args.requestId == requestId;
});
MpiMessageIterator msgIt =
std::find_if(pendingMsgs.begin(),
pendingMsgs.end(),
[requestId](PendingAsyncMpiMessage pendingMsg) {
return pendingMsg.requestId == requestId;
});

// If it's not there, error out
if (argIt == args.end()) {
if (msgIt == pendingMsgs.end()) {
SPDLOG_ERROR("Asynchronous request id not in buffer: {}", requestId);
throw std::runtime_error("Async request not in buffer");
}

return argIt;
return msgIt;
}

ArgListIterator MpiMessageBuffer::getFirstNullMsgUntil(
const ArgListIterator& argItEnd)
MpiMessageIterator MpiMessageBuffer::getFirstNullMsgUntil(
const MpiMessageIterator& msgItEnd)
{
return std::find_if(args.begin(), argItEnd, [](Arguments args) {
return args.msg == nullptr;
});
return std::find_if(
pendingMsgs.begin(), msgItEnd, [](PendingAsyncMpiMessage pendingMsg) {
return pendingMsg.msg == nullptr;
});
}

ArgListIterator MpiMessageBuffer::getFirstNullMsg()
MpiMessageIterator MpiMessageBuffer::getFirstNullMsg()
{
return getFirstNullMsgUntil(args.end());
return getFirstNullMsgUntil(pendingMsgs.end());
}

int MpiMessageBuffer::getTotalUnackedMessagesUntil(
const ArgListIterator& argItEnd)
const MpiMessageIterator& msgItEnd)
{
ArgListIterator firstNull = getFirstNullMsgUntil(argItEnd);
return std::distance(firstNull, argItEnd);
MpiMessageIterator firstNull = getFirstNullMsgUntil(msgItEnd);
return std::distance(firstNull, msgItEnd);
}

int MpiMessageBuffer::getTotalUnackedMessages()
{
ArgListIterator firstNull = getFirstNullMsg();
return std::distance(firstNull, args.end());
MpiMessageIterator firstNull = getFirstNullMsg();
return std::distance(firstNull, pendingMsgs.end());
}
}
79 changes: 43 additions & 36 deletions src/scheduler/MpiWorld.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,11 +207,19 @@ void MpiWorld::destroy()
// Request to rank map should be empty
if (!reqIdToRanks.empty()) {
SPDLOG_ERROR(
"Destroying the MPI world with {} outstanding async requests",
"Destroying the MPI world with {} outstanding irecv requests",
reqIdToRanks.size());
throw std::runtime_error("Destroying world with outstanding requests");
}

// iSend set should be empty
if (!iSendRequests.empty()) {
SPDLOG_ERROR(
"Destroying the MPI world with {} outstanding isend requests",
iSendRequests.size());
throw std::runtime_error("Destroying world with outstanding requests");
}

// Destroy once per host the shared resources
if (!isDestroyed.test_and_set()) {
// Wait (forever) until all ranks are done consuming their queues to
Expand Down Expand Up @@ -423,15 +431,19 @@ int MpiWorld::irecv(int sendRank,
int requestId = (int)faabric::util::generateGid();
reqIdToRanks.try_emplace(requestId, sendRank, recvRank);

// Enqueue a request with a null-pointing message (i.e. unacknowleged) and
// the generated request id
faabric::scheduler::MpiMessageBuffer::Arguments args = {
requestId, nullptr, sendRank, recvRank,
buffer, dataType, count, messageType
};
// Enqueue an unacknowleged request (no message)
faabric::scheduler::MpiMessageBuffer::PendingAsyncMpiMessage pendingMsg;
pendingMsg.requestId = requestId;
pendingMsg.sendRank = sendRank;
pendingMsg.recvRank = recvRank;
pendingMsg.buffer = buffer;
pendingMsg.dataType = dataType;
pendingMsg.count = count;
pendingMsg.messageType = messageType;
assert(!pendingMsg.isAcknowledged());

auto umb = getUnackedMessageBuffer(sendRank, recvRank);
umb->addMessage(args);
umb->addMessage(pendingMsg);

return requestId;
}
Expand Down Expand Up @@ -806,29 +818,29 @@ void MpiWorld::awaitAsyncRequest(int requestId)
std::shared_ptr<faabric::scheduler::MpiMessageBuffer> umb =
getUnackedMessageBuffer(sendRank, recvRank);

std::list<MpiMessageBuffer::Arguments>::iterator argsIndex =
umb->getRequestArguments(requestId);
std::list<MpiMessageBuffer::PendingAsyncMpiMessage>::iterator msgIt =
umb->getRequestPendingMsg(requestId);

std::shared_ptr<faabric::MPIMessage> m;
if (argsIndex->msg != nullptr) {
if (msgIt->msg != nullptr) {
// This id has already been acknowledged by a recv call, so do the recv
m = argsIndex->msg;
m = msgIt->msg;
} else {
// We need to acknowledge all messages not acknowledged from the
// begining until us
m = recvBatchReturnLast(
sendRank, recvRank, umb->getTotalUnackedMessagesUntil(argsIndex) + 1);
sendRank, recvRank, umb->getTotalUnackedMessagesUntil(msgIt) + 1);
}

doRecv(m,
argsIndex->buffer,
argsIndex->dataType,
argsIndex->count,
msgIt->buffer,
msgIt->dataType,
msgIt->count,
MPI_STATUS_IGNORE,
argsIndex->messageType);
msgIt->messageType);

// Remove the acknowledged indexes from the UMB
umb->deleteMessage(argsIndex);
umb->deleteMessage(msgIt);
}

void MpiWorld::reduce(int sendRank,
Expand Down Expand Up @@ -1213,47 +1225,42 @@ MpiWorld::recvBatchReturnLast(int sendRank, int recvRank, int batchSize)
// Recv message: first we receive all messages for which there is an id
// in the unacknowleged buffer but no msg. Note that these messages
// (batchSize - 1) were `irecv`-ed before ours.
std::shared_ptr<faabric::MPIMessage> m;
auto argsIt = umb->getFirstNullMsg();
std::shared_ptr<faabric::MPIMessage> ourMsg;
auto msgIt = umb->getFirstNullMsg();
if (isLocal) {
// First receive messages that happened before us
for (int i = 0; i < batchSize - 1; i++) {
SPDLOG_TRACE("MPI - pending recv {} -> {}", sendRank, recvRank);
auto _m = getLocalQueue(sendRank, recvRank)->dequeue();

assert(_m != nullptr);
assert(argsIt->msg == nullptr);
auto pendingMsg = getLocalQueue(sendRank, recvRank)->dequeue();

// Put the unacked message in the UMB
argsIt->msg = _m;
argsIt++;
assert(!msgIt->isAcknowledged());
msgIt->acknowledge(pendingMsg);
msgIt++;
}

// Finally receive the message corresponding to us
SPDLOG_TRACE("MPI - recv {} -> {}", sendRank, recvRank);
m = getLocalQueue(sendRank, recvRank)->dequeue();
ourMsg = getLocalQueue(sendRank, recvRank)->dequeue();
} else {
// First receive messages that happened before us
for (int i = 0; i < batchSize - 1; i++) {
SPDLOG_TRACE(
"MPI - pending remote recv {} -> {}", sendRank, recvRank);
auto _m = recvRemoteMpiMessage(sendRank, recvRank);

assert(_m != nullptr);
assert(argsIt->msg == nullptr);
auto pendingMsg = recvRemoteMpiMessage(sendRank, recvRank);

// Put the unacked message in the UMB
argsIt->msg = _m;
argsIt++;
assert(!msgIt->isAcknowledged());
msgIt->acknowledge(pendingMsg);
msgIt++;
}

// Finally receive the message corresponding to us
SPDLOG_TRACE("MPI - recv remote {} -> {}", sendRank, recvRank);
m = recvRemoteMpiMessage(sendRank, recvRank);
ourMsg = recvRemoteMpiMessage(sendRank, recvRank);
}
assert(m != nullptr);

return m;
return ourMsg;
}

int MpiWorld::getIndexForRanks(int sendRank, int recvRank)
Expand Down
19 changes: 9 additions & 10 deletions tests/test/scheduler/test_mpi_message_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@

using namespace faabric::scheduler;

MpiMessageBuffer::Arguments genRandomArguments(bool nullMsg = true,
int overrideRequestId = -1)
MpiMessageBuffer::PendingAsyncMpiMessage genRandomArguments(
bool nullMsg = true,
int overrideRequestId = -1)
{
int requestId;
if (overrideRequestId != -1) {
Expand All @@ -16,16 +17,14 @@ MpiMessageBuffer::Arguments genRandomArguments(bool nullMsg = true,
requestId = static_cast<int>(faabric::util::generateGid());
}

MpiMessageBuffer::Arguments args = {
requestId, nullptr, 0, 1,
nullptr, MPI_INT, 0, faabric::MPIMessage::NORMAL
};
MpiMessageBuffer::PendingAsyncMpiMessage pendingMsg;
pendingMsg.requestId = requestId;

if (!nullMsg) {
args.msg = std::make_shared<faabric::MPIMessage>();
pendingMsg.msg = std::make_shared<faabric::MPIMessage>();
}

return args;
return pendingMsg;
}

namespace tests {
Expand Down Expand Up @@ -59,7 +58,7 @@ TEST_CASE("Test getting an iterator from a request id", "[mpi]")
int requestId = 1337;
mmb.addMessage(genRandomArguments(true, requestId));

auto it = mmb.getRequestArguments(requestId);
auto it = mmb.getRequestPendingMsg(requestId);
REQUIRE(it->requestId == requestId);
}

Expand Down Expand Up @@ -111,7 +110,7 @@ TEST_CASE("Test getting total unacked messages in message buffer range",
mmb.addMessage(genRandomArguments(true, requestId));

// Get an iterator to our second null message
auto it = mmb.getRequestArguments(requestId);
auto it = mmb.getRequestPendingMsg(requestId);

// Check that we have only one unacked message until the iterator
REQUIRE(mmb.getTotalUnackedMessagesUntil(it) == 1);
Expand Down
Loading

0 comments on commit f59c056

Please sign in to comment.