diff --git a/include/faabric/scheduler/MpiWorld.h b/include/faabric/scheduler/MpiWorld.h index 2ce29ae1c..12da2166e 100644 --- a/include/faabric/scheduler/MpiWorld.h +++ b/include/faabric/scheduler/MpiWorld.h @@ -180,6 +180,10 @@ class MpiWorld double getWTime(); + std::vector getInitedRemoteMpiEndpoints(); + + std::vector getInitedUMB(); + private: int id = -1; int size = -1; diff --git a/include/faabric/transport/MessageEndpoint.h b/include/faabric/transport/MessageEndpoint.h index 1f5e8f07e..43f5ae8ad 100644 --- a/include/faabric/transport/MessageEndpoint.h +++ b/include/faabric/transport/MessageEndpoint.h @@ -33,7 +33,7 @@ class MessageEndpoint MessageEndpoint(const std::string& hostIn, int portIn, int timeoutMsIn); // Delete assignment and copy-constructor as we need to be very careful with - // socping and same-thread instantiation + // scoping and same-thread instantiation MessageEndpoint& operator=(const MessageEndpoint&) = delete; MessageEndpoint(const MessageEndpoint& ctx) = delete; diff --git a/src/scheduler/MpiWorld.cpp b/src/scheduler/MpiWorld.cpp index 1be498a53..fcb7ed1b4 100644 --- a/src/scheduler/MpiWorld.cpp +++ b/src/scheduler/MpiWorld.cpp @@ -110,10 +110,9 @@ void MpiWorld::initRemoteMpiEndpoint(int localRank, int remoteRank) std::pair sendRecvPorts = getPortForRanks(localRank, remoteRank); // Create MPI message endpoint - mpiMessageEndpoints.emplace( - mpiMessageEndpoints.begin() + index, + mpiMessageEndpoints.at(index) = std::make_unique( - otherHost, sendRecvPorts.first, sendRecvPorts.second)); + otherHost, sendRecvPorts.first, sendRecvPorts.second); } void MpiWorld::sendRemoteMpiMessage( @@ -164,9 +163,8 @@ MpiWorld::getUnackedMessageBuffer(int sendRank, int recvRank) assert(index >= 0 && index < size * size); if (unackedMessageBuffers[index] == nullptr) { - unackedMessageBuffers.emplace( - unackedMessageBuffers.begin() + index, - std::make_shared()); + unackedMessageBuffers.at(index) = + std::make_shared(); } return unackedMessageBuffers[index]; @@ -1379,6 +1377,26 @@ double MpiWorld::getWTime() return t / 1000.0; } +std::vector MpiWorld::getInitedRemoteMpiEndpoints() +{ + std::vector retVec(mpiMessageEndpoints.size()); + for (int i = 0; i < mpiMessageEndpoints.size(); i++) { + retVec.at(i) = mpiMessageEndpoints.at(i) != nullptr; + } + + return retVec; +} + +std::vector MpiWorld::getInitedUMB() +{ + std::vector retVec(unackedMessageBuffers.size()); + for (int i = 0; i < unackedMessageBuffers.size(); i++) { + retVec.at(i) = unackedMessageBuffers.at(i) != nullptr; + } + + return retVec; +} + std::string MpiWorld::getUser() { return user; diff --git a/tests/test/scheduler/test_remote_mpi_worlds.cpp b/tests/test/scheduler/test_remote_mpi_worlds.cpp index 1cceed404..a16febff7 100644 --- a/tests/test/scheduler/test_remote_mpi_worlds.cpp +++ b/tests/test/scheduler/test_remote_mpi_worlds.cpp @@ -723,4 +723,194 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, thisWorld.destroy(); } + +TEST_CASE_METHOD(RemoteMpiTestFixture, + "Test remote message endpoint creation", + "[mpi]") +{ + // Register two ranks (one on each host) + setWorldSizes(2, 1, 1); + int rankA = 0; + int rankB = 1; + std::vector messageData = { 0, 1, 2 }; + std::vector messageData2 = { 3, 4 }; + + // Init worlds + MpiWorld& thisWorld = getMpiWorldRegistry().createWorld(msg, worldId); + faabric::util::setMockMode(false); + thisWorld.broadcastHostsToRanks(); + + std::thread otherWorldThread( + [this, rankA, rankB, &messageData, &messageData2] { + otherWorld.initialiseFromMsg(msg); + + // Recv once + auto buffer = new int[messageData.size()]; + otherWorld.recv(rankA, + rankB, + BYTES(buffer), + MPI_INT, + messageData.size(), + MPI_STATUS_IGNORE); + std::vector actual(buffer, buffer + messageData.size()); + assert(actual == messageData); + + // Recv a second time + auto buffer2 = new int[messageData2.size()]; + otherWorld.recv(rankA, + rankB, + BYTES(buffer2), + MPI_INT, + messageData2.size(), + MPI_STATUS_IGNORE); + std::vector actual2(buffer2, buffer2 + messageData2.size()); + assert(actual2 == messageData2); + + // Send last message + otherWorld.send(rankB, + rankA, + BYTES(messageData.data()), + MPI_INT, + messageData.size()); + + testLatch->wait(); + + otherWorld.destroy(); + }); + + std::vector endpointCheck; + std::vector expectedEndpoints = { false, true, false, false }; + + // Sending a message initialises the remote endpoint + thisWorld.send( + rankA, rankB, BYTES(messageData.data()), MPI_INT, messageData.size()); + + // Check the right messaging endpoint has been created + endpointCheck = thisWorld.getInitedRemoteMpiEndpoints(); + REQUIRE(endpointCheck == expectedEndpoints); + + // Sending a second message re-uses the existing endpoint + thisWorld.send( + rankA, rankB, BYTES(messageData2.data()), MPI_INT, messageData2.size()); + + // Check that no additional endpoints have been created + endpointCheck = thisWorld.getInitedRemoteMpiEndpoints(); + REQUIRE(endpointCheck == expectedEndpoints); + + // Finally recv a messge, the same endpoint should be used again + auto buffer = new int[messageData.size()]; + thisWorld.recv(rankB, + rankA, + BYTES(buffer), + MPI_INT, + messageData.size(), + MPI_STATUS_IGNORE); + std::vector actual(buffer, buffer + messageData.size()); + assert(actual == messageData); + + // Check that no extra endpoint has been created + endpointCheck = thisWorld.getInitedRemoteMpiEndpoints(); + REQUIRE(endpointCheck == expectedEndpoints); + + testLatch->wait(); + + // Clean up + if (otherWorldThread.joinable()) { + otherWorldThread.join(); + } + + thisWorld.destroy(); +} + +TEST_CASE_METHOD(RemoteMpiTestFixture, "Test UMB creation", "[mpi]") +{ + // Register three ranks + setWorldSizes(3, 1, 2); + int thisWorldRank = 0; + int otherWorldRank1 = 1; + int otherWorldRank2 = 2; + std::vector messageData = { 0, 1, 2 }; + std::vector messageData2 = { 3, 4 }; + + // Init worlds + MpiWorld& thisWorld = getMpiWorldRegistry().createWorld(msg, worldId); + faabric::util::setMockMode(false); + thisWorld.broadcastHostsToRanks(); + + std::thread otherWorldThread([this, + thisWorldRank, + otherWorldRank1, + otherWorldRank2, + &messageData, + &messageData2] { + otherWorld.initialiseFromMsg(msg); + + // Send message from one rank + otherWorld.send(otherWorldRank1, + thisWorldRank, + BYTES(messageData.data()), + MPI_INT, + messageData.size()); + + // Send message from one rank + otherWorld.send(otherWorldRank2, + thisWorldRank, + BYTES(messageData2.data()), + MPI_INT, + messageData2.size()); + + testLatch->wait(); + + otherWorld.destroy(); + }); + + std::vector umbCheck; + std::vector expectedUmb1 = { false, false, false, true, false, + false, false, false, false }; + std::vector expectedUmb2 = { false, false, false, true, false, + false, true, false, false }; + + // Irecv a messge from one rank, another UMB should be created + auto buffer1 = new int[messageData.size()]; + int recvId1 = thisWorld.irecv(otherWorldRank1, + thisWorldRank, + BYTES(buffer1), + MPI_INT, + messageData.size()); + + // Check that an endpoint has been created + umbCheck = thisWorld.getInitedUMB(); + REQUIRE(umbCheck == expectedUmb1); + + // Irecv a messge from another rank, another UMB should be created + auto buffer2 = new int[messageData.size()]; + int recvId2 = thisWorld.irecv(otherWorldRank2, + thisWorldRank, + BYTES(buffer2), + MPI_INT, + messageData2.size()); + + // Check that an extra endpoint has been created + umbCheck = thisWorld.getInitedUMB(); + REQUIRE(umbCheck == expectedUmb2); + + // Wait for both messages + thisWorld.awaitAsyncRequest(recvId1); + thisWorld.awaitAsyncRequest(recvId2); + + // Sanity check the message content + std::vector actual1(buffer1, buffer1 + messageData.size()); + assert(actual1 == messageData); + std::vector actual2(buffer2, buffer2 + messageData2.size()); + assert(actual2 == messageData2); + + testLatch->wait(); + + // Clean up + if (otherWorldThread.joinable()) { + otherWorldThread.join(); + } + + thisWorld.destroy(); +} }