Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Not use emplace for editing vector elements #143

Merged
merged 3 commits into from
Sep 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/faabric/scheduler/MpiWorld.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,10 @@ class MpiWorld

double getWTime();

std::vector<bool> getInitedRemoteMpiEndpoints();

std::vector<bool> getInitedUMB();

private:
int id = -1;
int size = -1;
Expand Down
2 changes: 1 addition & 1 deletion include/faabric/transport/MessageEndpoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class MessageEndpoint
MessageEndpoint(const std::string& hostIn, int portIn, int timeoutMsIn);

// Delete assignment and copy-constructor as we need to be very careful with
// socping and same-thread instantiation
// scoping and same-thread instantiation
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caught this re-visiting the sources so figured out i'd just change it.

MessageEndpoint& operator=(const MessageEndpoint&) = delete;

MessageEndpoint(const MessageEndpoint& ctx) = delete;
Expand Down
30 changes: 24 additions & 6 deletions src/scheduler/MpiWorld.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,9 @@ void MpiWorld::initRemoteMpiEndpoint(int localRank, int remoteRank)
std::pair<int, int> sendRecvPorts = getPortForRanks(localRank, remoteRank);

// Create MPI message endpoint
mpiMessageEndpoints.emplace(
mpiMessageEndpoints.begin() + index,
mpiMessageEndpoints.at(index) =
std::make_unique<faabric::transport::MpiMessageEndpoint>(
otherHost, sendRecvPorts.first, sendRecvPorts.second));
otherHost, sendRecvPorts.first, sendRecvPorts.second);
}

void MpiWorld::sendRemoteMpiMessage(
Expand Down Expand Up @@ -164,9 +163,8 @@ MpiWorld::getUnackedMessageBuffer(int sendRank, int recvRank)
assert(index >= 0 && index < size * size);

if (unackedMessageBuffers[index] == nullptr) {
unackedMessageBuffers.emplace(
unackedMessageBuffers.begin() + index,
std::make_shared<faabric::scheduler::MpiMessageBuffer>());
unackedMessageBuffers.at(index) =
std::make_shared<faabric::scheduler::MpiMessageBuffer>();
}

return unackedMessageBuffers[index];
Expand Down Expand Up @@ -1379,6 +1377,26 @@ double MpiWorld::getWTime()
return t / 1000.0;
}

std::vector<bool> MpiWorld::getInitedRemoteMpiEndpoints()
{
std::vector<bool> retVec(mpiMessageEndpoints.size());
for (int i = 0; i < mpiMessageEndpoints.size(); i++) {
retVec.at(i) = mpiMessageEndpoints.at(i) != nullptr;
}

return retVec;
}

std::vector<bool> MpiWorld::getInitedUMB()
{
std::vector<bool> retVec(unackedMessageBuffers.size());
for (int i = 0; i < unackedMessageBuffers.size(); i++) {
retVec.at(i) = unackedMessageBuffers.at(i) != nullptr;
}

return retVec;
}

std::string MpiWorld::getUser()
{
return user;
Expand Down
190 changes: 190 additions & 0 deletions tests/test/scheduler/test_remote_mpi_worlds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -723,4 +723,194 @@ TEST_CASE_METHOD(RemoteMpiTestFixture,

thisWorld.destroy();
}

TEST_CASE_METHOD(RemoteMpiTestFixture,
"Test remote message endpoint creation",
"[mpi]")
{
// Register two ranks (one on each host)
setWorldSizes(2, 1, 1);
int rankA = 0;
int rankB = 1;
std::vector<int> messageData = { 0, 1, 2 };
std::vector<int> messageData2 = { 3, 4 };

// Init worlds
MpiWorld& thisWorld = getMpiWorldRegistry().createWorld(msg, worldId);
faabric::util::setMockMode(false);
thisWorld.broadcastHostsToRanks();

std::thread otherWorldThread(
[this, rankA, rankB, &messageData, &messageData2] {
otherWorld.initialiseFromMsg(msg);

// Recv once
auto buffer = new int[messageData.size()];
otherWorld.recv(rankA,
rankB,
BYTES(buffer),
MPI_INT,
messageData.size(),
MPI_STATUS_IGNORE);
std::vector<int> actual(buffer, buffer + messageData.size());
assert(actual == messageData);

// Recv a second time
auto buffer2 = new int[messageData2.size()];
otherWorld.recv(rankA,
rankB,
BYTES(buffer2),
MPI_INT,
messageData2.size(),
MPI_STATUS_IGNORE);
std::vector<int> actual2(buffer2, buffer2 + messageData2.size());
assert(actual2 == messageData2);

// Send last message
otherWorld.send(rankB,
rankA,
BYTES(messageData.data()),
MPI_INT,
messageData.size());

testLatch->wait();

otherWorld.destroy();
});

std::vector<bool> endpointCheck;
std::vector<bool> expectedEndpoints = { false, true, false, false };

// Sending a message initialises the remote endpoint
thisWorld.send(
rankA, rankB, BYTES(messageData.data()), MPI_INT, messageData.size());

// Check the right messaging endpoint has been created
endpointCheck = thisWorld.getInitedRemoteMpiEndpoints();
REQUIRE(endpointCheck == expectedEndpoints);

// Sending a second message re-uses the existing endpoint
thisWorld.send(
rankA, rankB, BYTES(messageData2.data()), MPI_INT, messageData2.size());

// Check that no additional endpoints have been created
endpointCheck = thisWorld.getInitedRemoteMpiEndpoints();
REQUIRE(endpointCheck == expectedEndpoints);

// Finally recv a messge, the same endpoint should be used again
auto buffer = new int[messageData.size()];
thisWorld.recv(rankB,
rankA,
BYTES(buffer),
MPI_INT,
messageData.size(),
MPI_STATUS_IGNORE);
std::vector<int> actual(buffer, buffer + messageData.size());
assert(actual == messageData);

// Check that no extra endpoint has been created
endpointCheck = thisWorld.getInitedRemoteMpiEndpoints();
REQUIRE(endpointCheck == expectedEndpoints);

testLatch->wait();

// Clean up
if (otherWorldThread.joinable()) {
otherWorldThread.join();
}

thisWorld.destroy();
}

TEST_CASE_METHOD(RemoteMpiTestFixture, "Test UMB creation", "[mpi]")
{
// Register three ranks
setWorldSizes(3, 1, 2);
int thisWorldRank = 0;
int otherWorldRank1 = 1;
int otherWorldRank2 = 2;
std::vector<int> messageData = { 0, 1, 2 };
std::vector<int> messageData2 = { 3, 4 };

// Init worlds
MpiWorld& thisWorld = getMpiWorldRegistry().createWorld(msg, worldId);
faabric::util::setMockMode(false);
thisWorld.broadcastHostsToRanks();

std::thread otherWorldThread([this,
thisWorldRank,
otherWorldRank1,
otherWorldRank2,
&messageData,
&messageData2] {
otherWorld.initialiseFromMsg(msg);

// Send message from one rank
otherWorld.send(otherWorldRank1,
thisWorldRank,
BYTES(messageData.data()),
MPI_INT,
messageData.size());

// Send message from one rank
otherWorld.send(otherWorldRank2,
thisWorldRank,
BYTES(messageData2.data()),
MPI_INT,
messageData2.size());

testLatch->wait();

otherWorld.destroy();
});

std::vector<bool> umbCheck;
std::vector<bool> expectedUmb1 = { false, false, false, true, false,
false, false, false, false };
std::vector<bool> expectedUmb2 = { false, false, false, true, false,
false, true, false, false };

// Irecv a messge from one rank, another UMB should be created
auto buffer1 = new int[messageData.size()];
int recvId1 = thisWorld.irecv(otherWorldRank1,
thisWorldRank,
BYTES(buffer1),
MPI_INT,
messageData.size());

// Check that an endpoint has been created
umbCheck = thisWorld.getInitedUMB();
REQUIRE(umbCheck == expectedUmb1);

// Irecv a messge from another rank, another UMB should be created
auto buffer2 = new int[messageData.size()];
int recvId2 = thisWorld.irecv(otherWorldRank2,
thisWorldRank,
BYTES(buffer2),
MPI_INT,
messageData2.size());

// Check that an extra endpoint has been created
umbCheck = thisWorld.getInitedUMB();
REQUIRE(umbCheck == expectedUmb2);

// Wait for both messages
thisWorld.awaitAsyncRequest(recvId1);
thisWorld.awaitAsyncRequest(recvId2);

// Sanity check the message content
std::vector<int> actual1(buffer1, buffer1 + messageData.size());
assert(actual1 == messageData);
std::vector<int> actual2(buffer2, buffer2 + messageData2.size());
assert(actual2 == messageData2);

testLatch->wait();

// Clean up
if (otherWorldThread.joinable()) {
otherWorldThread.join();
}

thisWorld.destroy();
}
}