diff --git a/tests/test/scheduler/test_remote_mpi_worlds.cpp b/tests/test/scheduler/test_remote_mpi_worlds.cpp index 541025e17..056b7f2a3 100644 --- a/tests/test/scheduler/test_remote_mpi_worlds.cpp +++ b/tests/test/scheduler/test_remote_mpi_worlds.cpp @@ -7,6 +7,8 @@ #include #include +#include + #include using namespace faabric::scheduler; @@ -76,6 +78,50 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, "Test send across hosts", "[mpi]") localWorld.destroy(); } +TEST_CASE_METHOD(RemoteMpiTestFixture, + "Test sending many messages across host", + "[mpi]") +{ + // Register two ranks (one on each host) + this->setWorldsSizes(2, 1, 1); + int rankA = 0; + int rankB = 1; + int numMessages = 1000; + + // Init worlds + MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); + faabric::util::setMockMode(false); + + std::thread senderThread([this, rankA, rankB, numMessages] { + std::vector messageData = { 0, 1, 2 }; + + remoteWorld.initialiseFromMsg(msg); + + for (int i = 0; i < numMessages; i++) { + // faabric::util::getLogger()->info("{}", i); + remoteWorld.send(rankB, rankA, BYTES(&i), MPI_INT, sizeof(int)); + } + usleep(1000 * 500); + remoteWorld.destroy(); + }); + + int recv; + for (int i = 0; i < numMessages; i++) { + localWorld.recv( + rankB, rankA, BYTES(&recv), MPI_INT, sizeof(int), MPI_STATUS_IGNORE); + + // Check in-order delivery + if (i % (numMessages / 10) == 0) { + faabric::util::getLogger()->info("recv {} - i {}", recv, i); + REQUIRE(recv == i); + } + } + + // Destroy worlds + senderThread.join(); + localWorld.destroy(); +} + TEST_CASE_METHOD(RemoteMpiTestFixture, "Test collective messaging across hosts", "[mpi]")