From 3101386d2793b74f4f16fff6190ce17bc0a0dca5 Mon Sep 17 00:00:00 2001 From: Carlos Segarra Date: Mon, 14 Jun 2021 08:28:39 +0000 Subject: [PATCH] adding more tests --- .../test/scheduler/test_remote_mpi_worlds.cpp | 145 +++++++++++++++++- 1 file changed, 137 insertions(+), 8 deletions(-) diff --git a/tests/test/scheduler/test_remote_mpi_worlds.cpp b/tests/test/scheduler/test_remote_mpi_worlds.cpp index a0ff83363..fe1e42dfd 100644 --- a/tests/test/scheduler/test_remote_mpi_worlds.cpp +++ b/tests/test/scheduler/test_remote_mpi_worlds.cpp @@ -53,7 +53,6 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, "Test send across hosts", "[mpi]") // Send a message that should get sent to this host remoteWorld.send( rankB, rankA, BYTES(messageData.data()), MPI_INT, messageData.size()); - usleep(1000 * 500); remoteWorld.destroy(); }); @@ -93,21 +92,18 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, faabric::util::setMockMode(false); std::thread senderThread([this, rankA, rankB, numMessages] { - std::vector messageData = { 0, 1, 2 }; - remoteWorld.initialiseFromMsg(msg); for (int i = 0; i < numMessages; i++) { - remoteWorld.send(rankB, rankA, BYTES(&i), MPI_INT, sizeof(int)); + remoteWorld.send(rankB, rankA, BYTES(&i), MPI_INT, 1); } - usleep(1000 * 500); remoteWorld.destroy(); }); int recv; for (int i = 0; i < numMessages; i++) { localWorld.recv( - rankB, rankA, BYTES(&recv), MPI_INT, sizeof(int), MPI_STATUS_IGNORE); + rankB, rankA, BYTES(&recv), MPI_INT, 1, MPI_STATUS_IGNORE); // Check in-order delivery if (i % (numMessages / 10) == 0) { @@ -331,7 +327,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, int recvRank = 0; std::vector messageData = { 0, 1, 2 }; - // Initi world + // Init world MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); faabric::util::setMockMode(false); @@ -352,7 +348,6 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, MPI_INT, messageData.size()); - usleep(1000 * 500); remoteWorld.destroy(); }); @@ -384,4 +379,138 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, senderThread.join(); localWorld.destroy(); } + +TEST_CASE_METHOD(RemoteMpiTestFixture, + "Test receiving remote async requests out of order", + "[mpi]") +{ + // Allocate two ranks in total, one rank per host + this->setWorldsSizes(2, 1, 1); + int sendRank = 1; + int recvRank = 0; + + // Init world + MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); + faabric::util::setMockMode(false); + + std::thread senderThread([this, sendRank, recvRank] { + remoteWorld.initialiseFromMsg(msg); + + // Send different messages + for (int i = 0; i < 3; i++) { + remoteWorld.send(sendRank, + recvRank, + BYTES(&i), + MPI_INT, + 1); + } + + remoteWorld.destroy(); + }); + + // Receive two messages asynchronously + int recv1, recv2, recv3; + int recvId1 = localWorld.irecv(sendRank, + recvRank, + BYTES(&recv1), + MPI_INT, + 1); + + int recvId2 = localWorld.irecv(sendRank, + recvRank, + BYTES(&recv2), + MPI_INT, + 1); + + // Receive one message synchronously + localWorld.recv(sendRank, + recvRank, + BYTES(&recv3), + MPI_INT, + 1, + MPI_STATUS_IGNORE); + + SECTION("Wait out of order") + { + localWorld.awaitAsyncRequest(recvId2); + localWorld.awaitAsyncRequest(recvId1); + } + + SECTION("Wait in order") + { + localWorld.awaitAsyncRequest(recvId1); + localWorld.awaitAsyncRequest(recvId2); + } + + // Checks + REQUIRE(recv1 == 0); + REQUIRE(recv2 == 1); + REQUIRE(recv3 == 2); + + // Destroy world + senderThread.join(); + localWorld.destroy(); +} + +TEST_CASE_METHOD(RemoteMpiTestFixture, + "Test ring sendrecv across hosts", + "[mpi]") +{ + // Allocate two ranks in total, one rank per host + this->setWorldsSizes(3, 1, 2); + int worldSize = 3; + std::vector localRanks = {0}; + + // Init world + MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); + faabric::util::setMockMode(false); + + std::thread senderThread([this, worldSize] { + std::vector remoteRanks = {1, 2}; + remoteWorld.initialiseFromMsg(msg); + + // Send different messages + for (auto& rank : remoteRanks) { + int left = rank > 0 ? rank - 1 : worldSize - 1; + int right = (rank + 1) % worldSize; + int recvData = -1; + + remoteWorld.sendRecv(BYTES(&rank), + 1, + MPI_INT, + right, + BYTES(&recvData), + 1, + MPI_INT, + left, + rank, + MPI_STATUS_IGNORE); + } + + remoteWorld.destroy(); + }); + + for (auto& rank : localRanks) { + int left = rank > 0 ? rank - 1 : worldSize - 1; + int right = (rank + 1) % worldSize; + int recvData = -1; + + localWorld.sendRecv(BYTES(&rank), + 1, + MPI_INT, + right, + BYTES(&recvData), + 1, + MPI_INT, + left, + rank, + MPI_STATUS_IGNORE); + + REQUIRE(recvData == left); + } + + // Destroy world + senderThread.join(); + localWorld.destroy(); +} }