From 886d0f5ff3385ec399a2b327202f49704aafcbc4 Mon Sep 17 00:00:00 2001 From: Carlos Segarra Date: Mon, 14 Jun 2021 14:41:51 +0000 Subject: [PATCH] pr comments --- include/faabric/scheduler/MpiWorld.h | 3 + include/faabric/transport/MessageEndpoint.h | 2 + .../faabric/transport/MpiMessageEndpoint.h | 4 + src/scheduler/MpiWorld.cpp | 146 +++++---- src/transport/MessageEndpoint.cpp | 9 +- src/transport/MpiMessageEndpoint.cpp | 24 +- .../test/scheduler/test_remote_mpi_worlds.cpp | 302 ++++++++++-------- .../test_message_endpoint_client.cpp | 13 - .../transport/test_mpi_message_endpoint.cpp | 50 +++ tests/utils/fixtures.h | 14 + 10 files changed, 353 insertions(+), 214 deletions(-) create mode 100644 tests/test/transport/test_mpi_message_endpoint.cpp diff --git a/include/faabric/scheduler/MpiWorld.h b/include/faabric/scheduler/MpiWorld.h index baecbe33a..dfe07c99e 100644 --- a/include/faabric/scheduler/MpiWorld.h +++ b/include/faabric/scheduler/MpiWorld.h @@ -212,6 +212,7 @@ class MpiWorld void initLocalQueues(); // Rank-to-rank sockets for remote messaging + void initRemoteMpiEndpoint(int sendRank, int recvRank); int getMpiPort(int sendRank, int recvRank); void sendRemoteMpiMessage(int sendRank, int recvRank, @@ -219,5 +220,7 @@ class MpiWorld std::shared_ptr recvRemoteMpiMessage(int sendRank, int recvRank); void closeMpiMessageEndpoints(); + + void checkRanksRange(int sendRank, int recvRank); }; } diff --git a/include/faabric/transport/MessageEndpoint.h b/include/faabric/transport/MessageEndpoint.h index 4e463e07a..658448455 100644 --- a/include/faabric/transport/MessageEndpoint.h +++ b/include/faabric/transport/MessageEndpoint.h @@ -83,6 +83,8 @@ class RecvMessageEndpoint : public MessageEndpoint public: RecvMessageEndpoint(int portIn); + RecvMessageEndpoint(int portIn, const std::string& overrideHost); + void open(MessageContext& context); void close(); diff --git a/include/faabric/transport/MpiMessageEndpoint.h b/include/faabric/transport/MpiMessageEndpoint.h index 6cc1552dd..40067fc6b 100644 --- a/include/faabric/transport/MpiMessageEndpoint.h +++ b/include/faabric/transport/MpiMessageEndpoint.h @@ -23,6 +23,10 @@ class MpiMessageEndpoint public: MpiMessageEndpoint(const std::string& hostIn, int portIn); + MpiMessageEndpoint(const std::string& hostIn, + int portIn, + const std::string& overrideRecvHost); + void sendMpiMessage(const std::shared_ptr& msg); std::shared_ptr recvMpiMessage(); diff --git a/src/scheduler/MpiWorld.cpp b/src/scheduler/MpiWorld.cpp index ff223eccb..077de1deb 100644 --- a/src/scheduler/MpiWorld.cpp +++ b/src/scheduler/MpiWorld.cpp @@ -19,42 +19,65 @@ MpiWorld::MpiWorld() , cartProcsPerDim(2) {} -void MpiWorld::sendRemoteMpiMessage( - int sendRank, - int recvRank, - const std::shared_ptr& msg) +void MpiWorld::initRemoteMpiEndpoint(int sendRank, int recvRank) { - // Assert the ranks are sane - assert(0 <= sendRank && sendRank < size); - assert(0 <= recvRank && recvRank < size); - - // Initialise the endpoint vector if not initialised - if (mpiMessageEndpoints.size() == 0) { + // Resize the message endpoint vector and initialise to null. Note that we + // allocate size x size slots to cover all possible (sendRank, recvRank) + // pairs + if (mpiMessageEndpoints.empty()) { for (int i = 0; i < size * size; i++) { mpiMessageEndpoints.emplace_back(nullptr); } } + // Get host for recv rank + std::string otherHost; + std::string recvHost = getHostForRank(recvRank); + std::string sendHost = getHostForRank(sendRank); + if (recvHost == sendHost) { + SPDLOG_ERROR( + "Send and recv ranks in the same host: SEND {}, RECV{} in {}", + sendRank, + recvRank, + sendHost); + throw std::runtime_error("Send and recv ranks in the same host"); + } else if (recvHost == thisHost) { + otherHost = sendHost; + } else if (sendHost == thisHost) { + otherHost = recvHost; + } else { + SPDLOG_ERROR("Send and recv ranks correspond to remote hosts: SEND {} " + "in {}, RECV {} in {}", + sendRank, + sendHost, + recvRank, + recvHost); + throw std::runtime_error("Send and recv ranks in remote hosts"); + } + + // Get the index for the rank-host pair + int index = getIndexForRanks(sendRank, recvRank); + + // Get port for send-recv pair + int port = getMpiPort(sendRank, recvRank); + + // Create MPI message endpoint + mpiMessageEndpoints.emplace( + mpiMessageEndpoints.begin() + index, + std::make_unique( + otherHost, port, thisHost)); +} + +void MpiWorld::sendRemoteMpiMessage( + int sendRank, + int recvRank, + const std::shared_ptr& msg) +{ // Get the index for the rank-host pair int index = getIndexForRanks(sendRank, recvRank); - assert(index >= 0 && index < size * size); - // Lazily initialise send endpoints - if (mpiMessageEndpoints[index] == nullptr) { - // Get host for recv rank - std::string host = getHostForRank(recvRank); - assert(!host.empty()); - assert(host != thisHost); - - // Get port for send-recv pair - int port = getMpiPort(sendRank, recvRank); - // TODO - finer-grained checking when multi-tenant port scheme - assert(port > 0); - - // Create MPI message endpoint - mpiMessageEndpoints.emplace( - mpiMessageEndpoints.begin() + index, - std::make_unique(host, port)); + if (mpiMessageEndpoints.empty() || mpiMessageEndpoints[index] == nullptr) { + initRemoteMpiEndpoint(sendRank, recvRank); } mpiMessageEndpoints[index]->sendMpiMessage(msg); @@ -64,36 +87,11 @@ std::shared_ptr MpiWorld::recvRemoteMpiMessage( int sendRank, int recvRank) { - // Assert the ranks are sane - assert(0 <= sendRank && sendRank < size); - assert(0 <= recvRank && recvRank < size); - - // Initialise the endpoint vector if not initialised - if (mpiMessageEndpoints.size() == 0) { - for (int i = 0; i < size * size; i++) { - mpiMessageEndpoints.emplace_back(nullptr); - } - } - // Get the index for the rank-host pair int index = getIndexForRanks(sendRank, recvRank); - assert(index >= 0 && index < size * size); - // Lazily initialise send endpoints - if (mpiMessageEndpoints[index] == nullptr) { - // Get host for recv rank - std::string host = getHostForRank(sendRank); - assert(!host.empty()); - assert(host != thisHost); - - // Get port for send-recv pair - int port = getMpiPort(sendRank, recvRank); - // TODO - finer-grained checking when multi-tenant port scheme - assert(port > 0); - - mpiMessageEndpoints.emplace( - mpiMessageEndpoints.begin() + index, - std::make_unique(host, port)); + if (mpiMessageEndpoints.empty() || mpiMessageEndpoints[index] == nullptr) { + initRemoteMpiEndpoint(sendRank, recvRank); } return mpiMessageEndpoints[index]->recvMpiMessage(); @@ -254,11 +252,6 @@ std::string MpiWorld::getHostForRank(int rank) { assert(rankHosts.size() == size); - if (rank >= size) { - throw std::runtime_error( - fmt::format("Rank bigger than world size ({} > {})", rank, size)); - } - std::string host = rankHosts[rank]; if (host.empty()) { throw std::runtime_error( @@ -476,6 +469,14 @@ void MpiWorld::send(int sendRank, int count, faabric::MPIMessage::MPIMessageType messageType) { + // Sanity-check input parameters + checkRanksRange(sendRank, recvRank); + if (getHostForRank(sendRank) != thisHost) { + SPDLOG_ERROR("Trying to send message from a non-local rank: {}", + sendRank); + throw std::runtime_error("Sending message from non-local rank"); + } + // Work out whether the message is sent locally or to another host const std::string otherHost = getHostForRank(recvRank); bool isLocal = otherHost == thisHost; @@ -516,8 +517,15 @@ void MpiWorld::recv(int sendRank, MPI_Status* status, faabric::MPIMessage::MPIMessageType messageType) { + // Sanity-check input parameters + checkRanksRange(sendRank, recvRank); + if (getHostForRank(recvRank) != thisHost) { + SPDLOG_ERROR("Trying to recv message into a non-local rank: {}", + recvRank); + throw std::runtime_error("Receiving message into non-local rank"); + } + // Work out whether the message is sent locally or from another host - assert(thisHost == getHostForRank(recvRank)); const std::string otherHost = getHostForRank(sendRank); bool isLocal = otherHost == thisHost; @@ -1174,7 +1182,9 @@ void MpiWorld::initLocalQueues() int MpiWorld::getIndexForRanks(int sendRank, int recvRank) { - return sendRank * size + recvRank; + int index = sendRank * size + recvRank; + assert(index >= 0 && index < size * size); + return index; } long MpiWorld::getLocalQueueSize(int sendRank, int recvRank) @@ -1214,4 +1224,18 @@ void MpiWorld::overrideHost(const std::string& newHost) { thisHost = newHost; } + +void MpiWorld::checkRanksRange(int sendRank, int recvRank) +{ + if (sendRank < 0 || sendRank >= size) { + SPDLOG_ERROR( + "Send rank outside range: {} not in [0, {})", sendRank, size); + throw std::runtime_error("Send rank outside range"); + } + if (recvRank < 0 || recvRank >= size) { + SPDLOG_ERROR( + "Recv rank outside range: {} not in [0, {})", recvRank, size); + throw std::runtime_error("Recv rank outside range"); + } +} } diff --git a/src/transport/MessageEndpoint.cpp b/src/transport/MessageEndpoint.cpp index 9f68e9327..5f87412b3 100644 --- a/src/transport/MessageEndpoint.cpp +++ b/src/transport/MessageEndpoint.cpp @@ -279,10 +279,15 @@ RecvMessageEndpoint::RecvMessageEndpoint(int portIn) : MessageEndpoint(ANY_HOST, portIn) {} +RecvMessageEndpoint::RecvMessageEndpoint(int portIn, + const std::string& overrideHost) + : MessageEndpoint(overrideHost, portIn) +{} + void RecvMessageEndpoint::open(MessageContext& context) { SPDLOG_TRACE( - fmt::format("Opening socket: {} (RECV {}:{})", id, ANY_HOST, port)); + fmt::format("Opening socket: {} (RECV {}:{})", id, host, port)); MessageEndpoint::open(context, SocketType::PULL, true); } @@ -290,7 +295,7 @@ void RecvMessageEndpoint::open(MessageContext& context) void RecvMessageEndpoint::close() { SPDLOG_TRACE( - fmt::format("Closing socket: {} (RECV {}:{})", id, ANY_HOST, port)); + fmt::format("Closing socket: {} (RECV {}:{})", id, host, port)); MessageEndpoint::close(true); } diff --git a/src/transport/MpiMessageEndpoint.cpp b/src/transport/MpiMessageEndpoint.cpp index 897c51cb9..4aebbefcd 100644 --- a/src/transport/MpiMessageEndpoint.cpp +++ b/src/transport/MpiMessageEndpoint.cpp @@ -31,16 +31,24 @@ void sendMpiHostRankMsg(const std::string& hostIn, MpiMessageEndpoint::MpiMessageEndpoint(const std::string& hostIn, int portIn) : sendMessageEndpoint(hostIn, portIn) , recvMessageEndpoint(portIn) -{} +{ + sendMessageEndpoint.open(faabric::transport::getGlobalMessageContext()); + recvMessageEndpoint.open(faabric::transport::getGlobalMessageContext()); +} + +MpiMessageEndpoint::MpiMessageEndpoint(const std::string& hostIn, + int portIn, + const std::string& overrideRecvHost) + : sendMessageEndpoint(hostIn, portIn) + , recvMessageEndpoint(portIn, overrideRecvHost) +{ + sendMessageEndpoint.open(faabric::transport::getGlobalMessageContext()); + recvMessageEndpoint.open(faabric::transport::getGlobalMessageContext()); +} void MpiMessageEndpoint::sendMpiMessage( const std::shared_ptr& msg) { - // TODO - is this lazy init very expensive? - if (sendMessageEndpoint.socket == nullptr) { - sendMessageEndpoint.open(faabric::transport::getGlobalMessageContext()); - } - size_t msgSize = msg->ByteSizeLong(); { uint8_t sMsg[msgSize]; @@ -53,10 +61,6 @@ void MpiMessageEndpoint::sendMpiMessage( std::shared_ptr MpiMessageEndpoint::recvMpiMessage() { - if (recvMessageEndpoint.socket == nullptr) { - recvMessageEndpoint.open(faabric::transport::getGlobalMessageContext()); - } - Message m = recvMessageEndpoint.recv(); PARSE_MSG(faabric::MPIMessage, m.data(), m.size()); diff --git a/tests/test/scheduler/test_remote_mpi_worlds.cpp b/tests/test/scheduler/test_remote_mpi_worlds.cpp index 48a36ddaf..84aeaa2cf 100644 --- a/tests/test/scheduler/test_remote_mpi_worlds.cpp +++ b/tests/test/scheduler/test_remote_mpi_worlds.cpp @@ -14,6 +14,28 @@ using namespace faabric::scheduler; namespace tests { +class RemoteCollectiveTestFixture : public RemoteMpiTestFixture +{ + public: + RemoteCollectiveTestFixture() + : thisWorldSize(6) + , remoteRankA(1) + , remoteRankB(2) + , remoteRankC(3) + , localRankA(4) + , localRankB(5) + , remoteWorldRanks({ remoteRankB, remoteRankC, remoteRankA }) + , localWorldRanks({ localRankB, localRankA, 0 }) + {} + + protected: + int thisWorldSize; + int remoteRankA, remoteRankB, remoteRankC; + int localRankA, localRankB; + std::vector remoteWorldRanks; + std::vector localWorldRanks; +}; + TEST_CASE_METHOD(RemoteMpiTestFixture, "Test rank allocation", "[mpi]") { // Allocate two ranks in total, one rank per host @@ -120,35 +142,23 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, localWorld.destroy(); } -TEST_CASE_METHOD(RemoteMpiTestFixture, - "Test collective messaging across hosts", +TEST_CASE_METHOD(RemoteCollectiveTestFixture, + "Test broadcast across hosts", "[mpi]") { // Here we rely on the scheduler running out of resources, and overloading // the localWorld with ranks 4 and 5 - int thisWorldSize = 6; this->setWorldsSizes(thisWorldSize, 1, 3); - int remoteRankA = 1; - int remoteRankB = 2; - int remoteRankC = 3; - int localRankA = 4; - int localRankB = 5; + std::vector messageData = { 0, 1, 2 }; // Init worlds MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); - remoteWorld.initialiseFromMsg(msg); faabric::util::setMockMode(false); - // Note that ranks are deliberately out of order - std::vector remoteWorldRanks = { remoteRankB, - remoteRankC, - remoteRankA }; - std::vector localWorldRanks = { localRankB, localRankA, 0 }; + std::thread senderThread([this, &messageData] { + remoteWorld.initialiseFromMsg(msg); - SECTION("Broadcast") - { // Broadcast a message - std::vector messageData = { 0, 1, 2 }; remoteWorld.broadcast( remoteRankB, BYTES(messageData.data()), MPI_INT, messageData.size()); @@ -161,28 +171,46 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, std::vector actual(3, -1); remoteWorld.recv( remoteRankB, rank, BYTES(actual.data()), MPI_INT, 3, nullptr); - REQUIRE(actual == messageData); + assert(actual == messageData); } - // Check the local host - for (int rank : localWorldRanks) { - std::vector actual(3, -1); - localWorld.recv( - remoteRankB, rank, BYTES(actual.data()), MPI_INT, 3, nullptr); - REQUIRE(actual == messageData); - } + remoteWorld.destroy(); + }); + + // Check the local host + for (int rank : localWorldRanks) { + std::vector actual(3, -1); + localWorld.recv( + remoteRankB, rank, BYTES(actual.data()), MPI_INT, 3, nullptr); + REQUIRE(actual == messageData); } - SECTION("Scatter") - { - // Build the data - int nPerRank = 4; - int dataSize = nPerRank * thisWorldSize; - std::vector messageData(dataSize, 0); - for (int i = 0; i < dataSize; i++) { - messageData[i] = i; - } + senderThread.join(); + localWorld.destroy(); +} + +TEST_CASE_METHOD(RemoteCollectiveTestFixture, + "Test scatter across hosts", + "[mpi]") +{ + // Here we rely on the scheduler running out of resources, and overloading + // the localWorld with ranks 4 and 5 + this->setWorldsSizes(thisWorldSize, 1, 3); + // Init worlds + MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); + faabric::util::setMockMode(false); + + // Build the data + int nPerRank = 4; + int dataSize = nPerRank * thisWorldSize; + std::vector messageData(dataSize, 0); + for (int i = 0; i < dataSize; i++) { + messageData[i] = i; + } + + std::thread senderThread([this, nPerRank, &messageData] { + remoteWorld.initialiseFromMsg(msg); // Do the scatter std::vector actual(nPerRank, -1); remoteWorld.scatter(remoteRankB, @@ -195,7 +223,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, nPerRank); // Check for root - REQUIRE(actual == std::vector({ 8, 9, 10, 11 })); + assert(actual == std::vector({ 8, 9, 10, 11 })); // Check for other remote ranks remoteWorld.scatter(remoteRankB, @@ -206,7 +234,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, BYTES(actual.data()), MPI_INT, nPerRank); - REQUIRE(actual == std::vector({ 4, 5, 6, 7 })); + assert(actual == std::vector({ 4, 5, 6, 7 })); remoteWorld.scatter(remoteRankB, remoteRankC, @@ -216,108 +244,126 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, BYTES(actual.data()), MPI_INT, nPerRank); - REQUIRE(actual == std::vector({ 12, 13, 14, 15 })); - - // Check for local ranks - localWorld.scatter(remoteRankB, - 0, - nullptr, - MPI_INT, - nPerRank, - BYTES(actual.data()), - MPI_INT, - nPerRank); - REQUIRE(actual == std::vector({ 0, 1, 2, 3 })); - - localWorld.scatter(remoteRankB, - localRankB, - nullptr, - MPI_INT, - nPerRank, - BYTES(actual.data()), - MPI_INT, - nPerRank); - REQUIRE(actual == std::vector({ 20, 21, 22, 23 })); - - localWorld.scatter(remoteRankB, - localRankA, - nullptr, - MPI_INT, - nPerRank, - BYTES(actual.data()), - MPI_INT, - nPerRank); - REQUIRE(actual == std::vector({ 16, 17, 18, 19 })); - } + assert(actual == std::vector({ 12, 13, 14, 15 })); - SECTION("Gather and allgather") - { - // Build the data for each rank - int nPerRank = 4; - std::vector> rankData; - for (int i = 0; i < thisWorldSize; i++) { - std::vector thisRankData; - for (int j = 0; j < nPerRank; j++) { - thisRankData.push_back((i * nPerRank) + j); - } + remoteWorld.destroy(); + }); - rankData.push_back(thisRankData); - } + // Check for local ranks + std::vector actual(nPerRank, -1); + localWorld.scatter(remoteRankB, + 0, + nullptr, + MPI_INT, + nPerRank, + BYTES(actual.data()), + MPI_INT, + nPerRank); + REQUIRE(actual == std::vector({ 0, 1, 2, 3 })); + + localWorld.scatter(remoteRankB, + localRankB, + nullptr, + MPI_INT, + nPerRank, + BYTES(actual.data()), + MPI_INT, + nPerRank); + REQUIRE(actual == std::vector({ 20, 21, 22, 23 })); + + localWorld.scatter(remoteRankB, + localRankA, + nullptr, + MPI_INT, + nPerRank, + BYTES(actual.data()), + MPI_INT, + nPerRank); + REQUIRE(actual == std::vector({ 16, 17, 18, 19 })); + + senderThread.join(); + localWorld.destroy(); +} + +TEST_CASE_METHOD(RemoteCollectiveTestFixture, + "Test gather across hosts", + "[mpi]") +{ + // Here we rely on the scheduler running out of resources, and overloading + // the localWorld with ranks 4 and 5 + this->setWorldsSizes(thisWorldSize, 1, 3); + + // Init worlds + MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); + faabric::util::setMockMode(false); - // Build the expectation - std::vector expected; - for (int i = 0; i < thisWorldSize * nPerRank; i++) { - expected.push_back(i); + // Build the data for each rank + int nPerRank = 4; + std::vector> rankData; + for (int i = 0; i < thisWorldSize; i++) { + std::vector thisRankData; + for (int j = 0; j < nPerRank; j++) { + thisRankData.push_back((i * nPerRank) + j); } - SECTION("Gather") - { - std::vector actual(thisWorldSize * nPerRank, -1); - - // Call gather for each rank other than the root (out of order) - int root = localRankA; - for (int rank : remoteWorldRanks) { - remoteWorld.gather(rank, - root, - BYTES(rankData[rank].data()), - MPI_INT, - nPerRank, - nullptr, - MPI_INT, - nPerRank); - } + rankData.push_back(thisRankData); + } - for (int rank : localWorldRanks) { - if (rank == root) { - continue; - } - localWorld.gather(rank, - root, - BYTES(rankData[rank].data()), - MPI_INT, - nPerRank, - nullptr, - MPI_INT, - nPerRank); - } + // Build the expectation + std::vector expected; + for (int i = 0; i < thisWorldSize * nPerRank; i++) { + expected.push_back(i); + } + + std::vector actual(thisWorldSize * nPerRank, -1); + + // Call gather for each rank other than the root (out of order) + int root = localRankA; + std::thread senderThread([this, root, &rankData, nPerRank] { + remoteWorld.initialiseFromMsg(msg); + + for (int rank : remoteWorldRanks) { + remoteWorld.gather(rank, + root, + BYTES(rankData[rank].data()), + MPI_INT, + nPerRank, + nullptr, + MPI_INT, + nPerRank); + } - // Call gather for root - localWorld.gather(root, - root, - BYTES(rankData[root].data()), - MPI_INT, - nPerRank, - BYTES(actual.data()), - MPI_INT, - nPerRank); - - // Check data - REQUIRE(actual == expected); + remoteWorld.destroy(); + }); + + for (int rank : localWorldRanks) { + if (rank == root) { + continue; } + localWorld.gather(rank, + root, + BYTES(rankData[rank].data()), + MPI_INT, + nPerRank, + nullptr, + MPI_INT, + nPerRank); } - // Destroy worlds + // Call gather for root + localWorld.gather(root, + root, + BYTES(rankData[root].data()), + MPI_INT, + nPerRank, + BYTES(actual.data()), + MPI_INT, + nPerRank); + + // Check data + REQUIRE(actual == expected); + + senderThread.join(); localWorld.destroy(); - remoteWorld.destroy(); } } diff --git a/tests/test/transport/test_message_endpoint_client.cpp b/tests/test/transport/test_message_endpoint_client.cpp index 9fe72c518..7119602ab 100644 --- a/tests/test/transport/test_message_endpoint_client.cpp +++ b/tests/test/transport/test_message_endpoint_client.cpp @@ -13,19 +13,6 @@ const int testPort = 9999; const int testReplyPort = 9996; namespace tests { -class MessageContextFixture : public SchedulerTestFixture -{ - protected: - MessageContext& context; - - public: - MessageContextFixture() - : context(getGlobalMessageContext()) - {} - - ~MessageContextFixture() { context.close(); } -}; - TEST_CASE_METHOD(MessageContextFixture, "Test open/close one client", "[transport]") diff --git a/tests/test/transport/test_mpi_message_endpoint.cpp b/tests/test/transport/test_mpi_message_endpoint.cpp new file mode 100644 index 000000000..a5b6a4ced --- /dev/null +++ b/tests/test/transport/test_mpi_message_endpoint.cpp @@ -0,0 +1,50 @@ +#include + +#include +#include + +using namespace faabric::transport; + +namespace tests { +TEST_CASE_METHOD(MessageContextFixture, + "Test send and recv the hosts to rank message", + "[transport]") +{ + // Prepare message + std::vector expected = { "foo", "bar" }; + faabric::MpiHostsToRanksMessage sendMsg; + *sendMsg.mutable_hosts() = { expected.begin(), expected.end() }; + sendMpiHostRankMsg(LOCALHOST, sendMsg); + + // Send message + faabric::MpiHostsToRanksMessage actual = recvMpiHostRankMsg(); + + // Checks + REQUIRE(actual.hosts().size() == expected.size()); + for (int i = 0; i < actual.hosts().size(); i++) { + REQUIRE(actual.hosts().Get(i) == expected[i]); + } +} + +TEST_CASE_METHOD(MessageContextFixture, + "Test send and recv an MPI message", + "[transport]") +{ + std::string thisHost = faabric::util::getSystemConfig().endpointHost; + MpiMessageEndpoint sendEndpoint(LOCALHOST, 9999, thisHost); + MpiMessageEndpoint recvEndpoint(thisHost, 9999, LOCALHOST); + + std::shared_ptr expected = + std::make_shared(); + expected->set_id(1337); + + sendEndpoint.sendMpiMessage(expected); + std::shared_ptr actual = recvEndpoint.recvMpiMessage(); + + // Checks + REQUIRE(expected->id() == actual->id()); + + REQUIRE_NOTHROW(sendEndpoint.close()); + REQUIRE_NOTHROW(recvEndpoint.close()); +} +} diff --git a/tests/utils/fixtures.h b/tests/utils/fixtures.h index 1568b3cf4..b102db298 100644 --- a/tests/utils/fixtures.h +++ b/tests/utils/fixtures.h @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -117,6 +118,19 @@ class ConfTestFixture faabric::util::SystemConfig& conf; }; +class MessageContextFixture : public SchedulerTestFixture +{ + protected: + faabric::transport::MessageContext& context; + + public: + MessageContextFixture() + : context(faabric::transport::getGlobalMessageContext()) + {} + + ~MessageContextFixture() { context.close(); } +}; + class MpiBaseTestFixture : public SchedulerTestFixture { public: