Skip to content

Commit

Permalink
Fixes for MPI message graph (#173)
Browse files Browse the repository at this point in the history
* fixes for exec graph to work

* add regression test
  • Loading branch information
csegarragonz committed Nov 9, 2021
1 parent 832aafe commit f7a910b
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 5 deletions.
4 changes: 4 additions & 0 deletions include/faabric/scheduler/MpiWorld.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,10 @@ class MpiWorld

std::vector<bool> getInitedUMB();

/* Profiling */

void setMsgForRank(faabric::Message& msg);

private:
int id = -1;
int size = -1;
Expand Down
6 changes: 5 additions & 1 deletion src/scheduler/MpiWorld.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,6 @@ void MpiWorld::initialiseFromMsg(faabric::Message& msg)
user = msg.user();
function = msg.function();
size = msg.mpiworldsize();
thisRankMsg = &msg;

// Block until we receive
faabric::MpiHostsToRanksMessage hostRankMsg = recvMpiHostRankMsg();
Expand All @@ -324,6 +323,11 @@ void MpiWorld::initialiseFromMsg(faabric::Message& msg)
initLocalQueues();
}

void MpiWorld::setMsgForRank(faabric::Message& msg)
{
thisRankMsg = &msg;
}

std::string MpiWorld::getHostForRank(int rank)
{
assert(rankHosts.size() == size);
Expand Down
6 changes: 5 additions & 1 deletion src/scheduler/MpiWorldRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,11 @@ MpiWorld& MpiWorldRegistry::getOrInitialiseWorld(faabric::Message& msg)

{
faabric::util::SharedLock lock(registryMutex);
return worldMap[worldId];
MpiWorld& world = worldMap[worldId];
if (msg.recordexecgraph()) {
world.setMsgForRank(msg);
}
return world;
}
}

Expand Down
8 changes: 5 additions & 3 deletions src/util/json.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,11 @@ std::string messageToJson(const faabric::Message& msg)
}

std::string out = ss.str();
d.AddMember("int_exec_graph_detail",
Value(out.c_str(), out.size()).Move(),
a);

// Need to create a value (instead of move) as the string's scope
// is smaller than the document's one
Value value = Value(out.c_str(), out.size(), a);
d.AddMember("int_exec_graph_detail", value, a);
}
}

Expand Down
56 changes: 56 additions & 0 deletions tests/test/scheduler/test_mpi_exec_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include <faabric/scheduler/MpiWorld.h>
#include <faabric/util/exec_graph.h>
#include <faabric/util/json.h>
#include <faabric/util/macros.h>

namespace tests {
Expand Down Expand Up @@ -74,4 +75,59 @@ TEST_CASE_METHOD(MpiTestFixture,
REQUIRE(msg.intexecgraphdetails_size() == 0);
REQUIRE(msg.execgraphdetails_size() == 0);
}

TEST_CASE_METHOD(MpiBaseTestFixture,
"Test different threads populate the graph",
"[util][exec-graph]")
{
int rank = 0;
int otherRank = 1;
int worldSize = 2;
int worldId = 123;

faabric::Message msg = faabric::util::messageFactory("mpi", "hellompi");
msg.set_ismpi(true);
msg.set_recordexecgraph(true);
msg.set_mpiworldsize(worldSize);
msg.set_mpiworldid(worldId);

faabric::Message otherMsg = msg;
otherMsg.set_mpirank(otherRank);
msg.set_mpirank(rank);

faabric::scheduler::MpiWorld& thisWorld =
faabric::scheduler::getMpiWorldRegistry().createWorld(msg, worldId);

std::vector<int> messageData = { 0, 1, 2 };
auto buffer = new int[messageData.size()];
std::thread otherWorldThread([&messageData, &otherMsg, rank, otherRank] {
faabric::scheduler::MpiWorld& otherWorld =
faabric::scheduler::getMpiWorldRegistry().getOrInitialiseWorld(
otherMsg);

otherWorld.send(otherRank,
rank,
BYTES(messageData.data()),
MPI_INT,
messageData.size());

otherWorld.destroy();
});

thisWorld.recv(
otherRank, rank, BYTES(buffer), MPI_INT, messageData.size(), nullptr);

thisWorld.destroy();

if (otherWorldThread.joinable()) {
otherWorldThread.join();
}

std::string expectedKey =
faabric::util::exec_graph::mpiMsgCountPrefix + std::to_string(rank);
REQUIRE(otherMsg.mpirank() == otherRank);
REQUIRE(otherMsg.intexecgraphdetails_size() == 1);
REQUIRE(otherMsg.intexecgraphdetails().count(expectedKey) == 1);
REQUIRE(otherMsg.intexecgraphdetails().at(expectedKey) == 1);
}
}

0 comments on commit f7a910b

Please sign in to comment.