diff --git a/include/faabric/util/queue.h b/include/faabric/util/queue.h index 9095bde36..7dc8973ef 100644 --- a/include/faabric/util/queue.h +++ b/include/faabric/util/queue.h @@ -6,7 +6,7 @@ #include -#define DEFAULT_QUEUE_TIMEOUT_MS 500 +#define DEFAULT_QUEUE_TIMEOUT_MS 5000 namespace faabric::util { class QueueTimeoutException : public faabric::util::FaabricException diff --git a/src/scheduler/MpiWorld.cpp b/src/scheduler/MpiWorld.cpp index 8f3566a9c..e0a214fdd 100644 --- a/src/scheduler/MpiWorld.cpp +++ b/src/scheduler/MpiWorld.cpp @@ -1155,6 +1155,7 @@ void MpiWorld::barrier(int thisRank) { if (thisRank == 0) { // This is the root, hence just does the waiting + SPDLOG_TRACE("MPI - barrier init {}", thisRank); // Await messages from all others for (int r = 1; r < size; r++) { diff --git a/tests/test/scheduler/test_mpi_world.cpp b/tests/test/scheduler/test_mpi_world.cpp index e696a4ecb..5144de4ba 100644 --- a/tests/test/scheduler/test_mpi_world.cpp +++ b/tests/test/scheduler/test_mpi_world.cpp @@ -8,6 +8,8 @@ #include #include +#include + using namespace faabric::scheduler; namespace tests { @@ -171,6 +173,41 @@ TEST_CASE_METHOD(MpiBaseTestFixture, "Test cartesian communicator", "[mpi]") world.destroy(); } +TEST_CASE_METHOD(MpiBaseTestFixture, "Test local barrier", "[mpi]") +{ + // Create the world + int worldSize = 2; + MpiWorld world; + world.create(msg, worldId, worldSize); + + int rankA1 = 0; + int rankA2 = 1; + std::vector sendData = { 0, 1, 2 }; + std::vector recvData = { -1, -1, -1 }; + + std::thread senderThread([&world, rankA1, rankA2, &sendData, &recvData] { + world.send( + rankA1, rankA2, BYTES(sendData.data()), MPI_INT, sendData.size()); + + world.barrier(rankA1); + assert(sendData == recvData); + }); + + world.recv(rankA1, + rankA2, + BYTES(recvData.data()), + MPI_INT, + recvData.size(), + MPI_STATUS_IGNORE); + + REQUIRE(recvData == sendData); + + world.barrier(rankA2); + + senderThread.join(); + world.destroy(); +} + void checkMessage(faabric::MPIMessage& actualMessage, int worldId, int senderRank, diff --git a/tests/test/scheduler/test_remote_mpi_worlds.cpp b/tests/test/scheduler/test_remote_mpi_worlds.cpp index cbe84d573..0669fb245 100644 --- a/tests/test/scheduler/test_remote_mpi_worlds.cpp +++ b/tests/test/scheduler/test_remote_mpi_worlds.cpp @@ -156,6 +156,49 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, localWorld.destroy(); } +TEST_CASE_METHOD(RemoteMpiTestFixture, "Test barrier across hosts", "[mpi]") +{ + // Register two ranks (one on each host) + this->setWorldsSizes(2, 1, 1); + int rankA = 0; + int rankB = 1; + std::vector sendData = { 0, 1, 2 }; + std::vector recvData = { -1, -1, -1 }; + + // Init worlds + MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); + faabric::util::setMockMode(false); + + std::thread senderThread([this, rankA, rankB, &sendData, &recvData] { + remoteWorld.initialiseFromMsg(msg); + + remoteWorld.send( + rankB, rankA, BYTES(sendData.data()), MPI_INT, sendData.size()); + + // Barrier on this rank + remoteWorld.barrier(rankB); + assert(sendData == recvData); + + remoteWorld.destroy(); + }); + + // Receive the message for the given rank + localWorld.recv(rankB, + rankA, + BYTES(recvData.data()), + MPI_INT, + recvData.size(), + MPI_STATUS_IGNORE); + REQUIRE(recvData == sendData); + + // Call barrier to synchronise remote host + localWorld.barrier(rankA); + + // Destroy worlds + senderThread.join(); + localWorld.destroy(); +} + TEST_CASE_METHOD(RemoteMpiTestFixture, "Test sending many messages across host", "[mpi]")