From 1a34a0077937449a803358437ddc0b46596c2a6c Mon Sep 17 00:00:00 2001 From: Carlos Segarra Date: Tue, 8 Jun 2021 08:15:51 +0000 Subject: [PATCH 1/7] adding p2p send/recv methods --- include/faabric/scheduler/MpiWorld.h | 2 ++ .../faabric/transport/MpiMessageEndpoint.h | 6 ++++ src/scheduler/MpiWorld.cpp | 29 ++++++++++++++--- src/transport/MpiMessageEndpoint.cpp | 32 +++++++++++++++++++ 4 files changed, 65 insertions(+), 4 deletions(-) diff --git a/include/faabric/scheduler/MpiWorld.h b/include/faabric/scheduler/MpiWorld.h index efe6ad1b9..3514be218 100644 --- a/include/faabric/scheduler/MpiWorld.h +++ b/include/faabric/scheduler/MpiWorld.h @@ -59,6 +59,8 @@ class MpiWorld int* source, int* destination); + int getMpiPort(int sendRank, int recvRank); + void send(int sendRank, int recvRank, const uint8_t* buffer, diff --git a/include/faabric/transport/MpiMessageEndpoint.h b/include/faabric/transport/MpiMessageEndpoint.h index 252a75bcc..0641b5449 100644 --- a/include/faabric/transport/MpiMessageEndpoint.h +++ b/include/faabric/transport/MpiMessageEndpoint.h @@ -10,4 +10,10 @@ faabric::MpiHostsToRanksMessage recvMpiHostRankMsg(); void sendMpiHostRankMsg(const std::string& hostIn, const faabric::MpiHostsToRanksMessage msg); + +void sendMpiMessage(const std::string& hostIn, + int portIn, + const std::shared_ptr msg); + +std::shared_ptr recvMpiMessage(int portIn); } diff --git a/src/scheduler/MpiWorld.cpp b/src/scheduler/MpiWorld.cpp index 61af649ad..389d8df28 100644 --- a/src/scheduler/MpiWorld.cpp +++ b/src/scheduler/MpiWorld.cpp @@ -396,6 +396,15 @@ int MpiWorld::irecv(int sendRank, return requestId; } +int MpiWorld::getMpiPort(int sendRank, int recvRank) +{ + // TODO - get port in a multi-tenant-safe manner + int basePort = MPI_PORT; + int rankOffset = sendRank * size + recvRank; + + return basePort + rankOffset; +} + void MpiWorld::send(int sendRank, int recvRank, const uint8_t* buffer, @@ -443,10 +452,22 @@ void MpiWorld::recv(int sendRank, MPI_Status* status, faabric::MPIMessage::MPIMessageType messageType) { - // Listen to the in-memory queue for this rank and message type - SPDLOG_TRACE("MPI - recv {} -> {}", sendRank, recvRank); - std::shared_ptr m = - getLocalQueue(sendRank, recvRank)->dequeue(); + // Work out whether the message is sent locally or from another host + assert(thisHost == getHostForRank(recvRank)); + const std::string otherHost = getHostForRank(sendRank); + bool isLocal = otherHost == thisHost; + + // Recv message + std::shared_ptr m; + if (isLocal) { + SPDLOG_TRACE("MPI - recv {} -> {}", sendRank, recvRank); + // TODO - change to ipc sockets + m = getLocalQueue(sendRank, recvRank)->dequeue(); + } else { + SPDLOG_TRACE("MPI - recv remote {} -> {}", sendRank, recvRank); + m = faabric::transport::recvMpiMessage(getMpiPort(sendRank, recvRank)); + } + assert(m != nullptr); // Assert message integrity // Note - this checks won't happen in Release builds diff --git a/src/transport/MpiMessageEndpoint.cpp b/src/transport/MpiMessageEndpoint.cpp index 0076ccc65..d8727c5cc 100644 --- a/src/transport/MpiMessageEndpoint.cpp +++ b/src/transport/MpiMessageEndpoint.cpp @@ -27,4 +27,36 @@ void sendMpiHostRankMsg(const std::string& hostIn, endpoint.close(); } } + + +// TODO - reuse clients!! +void sendMpiMessage(const std::string& hostIn, int portIn, + const std::shared_ptr msg) +{ + size_t msgSize = msg->ByteSizeLong(); + { + uint8_t sMsg[msgSize]; + if (!msg->SerializeToArray(sMsg, msgSize)) { + throw std::runtime_error("Error serialising message"); + } + SendMessageEndpoint endpoint(hostIn, portIn); + endpoint.open(getGlobalMessageContext()); + endpoint.send(sMsg, msgSize, false); + endpoint.close(); + } +} + +std::shared_ptr recvMpiMessage(int portIn) +{ + RecvMessageEndpoint endpoint(portIn); + endpoint.open(getGlobalMessageContext()); + // TODO - preempt data size somehow + Message m = endpoint.recv(); + PARSE_MSG(faabric::MPIMessage, m.data(), m.size()); + // Note - This may be very slow as we poll until unbound + endpoint.close(); + + // TODO - send normal message, not shared_ptr + return std::make_shared(msg); +} } From f4c0b5ebcf8da4f52731bef31386283943291384 Mon Sep 17 00:00:00 2001 From: Carlos Segarra Date: Tue, 8 Jun 2021 08:19:02 +0000 Subject: [PATCH 2/7] not check queueing when we don't use queues --- src/scheduler/MpiWorld.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/scheduler/MpiWorld.cpp b/src/scheduler/MpiWorld.cpp index 389d8df28..9327bca77 100644 --- a/src/scheduler/MpiWorld.cpp +++ b/src/scheduler/MpiWorld.cpp @@ -461,7 +461,6 @@ void MpiWorld::recv(int sendRank, std::shared_ptr m; if (isLocal) { SPDLOG_TRACE("MPI - recv {} -> {}", sendRank, recvRank); - // TODO - change to ipc sockets m = getLocalQueue(sendRank, recvRank)->dequeue(); } else { SPDLOG_TRACE("MPI - recv remote {} -> {}", sendRank, recvRank); From 185658de8e0ad167ebb09c23338a03dc1d97c747 Mon Sep 17 00:00:00 2001 From: Carlos Segarra Date: Tue, 8 Jun 2021 14:16:48 +0000 Subject: [PATCH 3/7] removing sendMPImessage from the function call API --- include/faabric/scheduler/FunctionCallApi.h | 11 ++- .../faabric/scheduler/FunctionCallClient.h | 4 -- .../faabric/scheduler/FunctionCallServer.h | 4 -- include/faabric/scheduler/MpiWorld.h | 3 - include/faabric/transport/common.h | 3 +- src/scheduler/FunctionCallClient.cpp | 19 ----- src/scheduler/FunctionCallServer.cpp | 12 ---- src/scheduler/MpiWorld.cpp | 21 ++---- .../scheduler/test_function_client_server.cpp | 71 ------------------- 9 files changed, 10 insertions(+), 138 deletions(-) diff --git a/include/faabric/scheduler/FunctionCallApi.h b/include/faabric/scheduler/FunctionCallApi.h index 04e18725d..0da78587c 100644 --- a/include/faabric/scheduler/FunctionCallApi.h +++ b/include/faabric/scheduler/FunctionCallApi.h @@ -4,11 +4,10 @@ namespace faabric::scheduler { enum FunctionCalls { NoFunctionCall = 0, - MpiMessage = 1, - ExecuteFunctions = 2, - Flush = 3, - Unregister = 4, - GetResources = 5, - SetThreadResult = 6, + ExecuteFunctions = 1, + Flush = 2, + Unregister = 3, + GetResources = 4, + SetThreadResult = 5, }; } diff --git a/include/faabric/scheduler/FunctionCallClient.h b/include/faabric/scheduler/FunctionCallClient.h index 9df02f79d..55e105231 100644 --- a/include/faabric/scheduler/FunctionCallClient.h +++ b/include/faabric/scheduler/FunctionCallClient.h @@ -19,8 +19,6 @@ std::vector< std::pair>> getBatchRequests(); -std::vector> getMPIMessages(); - std::vector> getResourceRequests(); @@ -44,8 +42,6 @@ class FunctionCallClient : public faabric::transport::MessageEndpointClient void sendFlush(); - void sendMPIMessage(const std::shared_ptr msg); - faabric::HostResources getResources(); void executeFunctions( diff --git a/include/faabric/scheduler/FunctionCallServer.h b/include/faabric/scheduler/FunctionCallServer.h index 71778a618..962cc7439 100644 --- a/include/faabric/scheduler/FunctionCallServer.h +++ b/include/faabric/scheduler/FunctionCallServer.h @@ -2,8 +2,6 @@ #include #include -#include -#include #include #include #include @@ -25,8 +23,6 @@ class FunctionCallServer final /* Function call server API */ - void recvMpiMessage(faabric::transport::Message& body); - void recvFlush(faabric::transport::Message& body); void recvExecuteFunctions(faabric::transport::Message& body); diff --git a/include/faabric/scheduler/MpiWorld.h b/include/faabric/scheduler/MpiWorld.h index 3514be218..5a0bb9dee 100644 --- a/include/faabric/scheduler/MpiWorld.h +++ b/include/faabric/scheduler/MpiWorld.h @@ -210,9 +210,6 @@ class MpiWorld std::vector cartProcsPerDim; - faabric::scheduler::FunctionCallClient& getFunctionCallClient( - const std::string& otherHost); - void closeThreadLocalClients(); int getIndexForRanks(int sendRank, int recvRank); diff --git a/include/faabric/transport/common.h b/include/faabric/transport/common.h index ad8b84f40..b7957801e 100644 --- a/include/faabric/transport/common.h +++ b/include/faabric/transport/common.h @@ -5,8 +5,7 @@ #define DEFAULT_SNAPSHOT_HOST "0.0.0.0" #define STATE_PORT 8003 #define FUNCTION_CALL_PORT 8004 -#define MPI_MESSAGE_PORT 8005 -#define SNAPSHOT_PORT 8006 +#define SNAPSHOT_PORT 8005 #define REPLY_PORT_OFFSET 100 #define MPI_PORT 8800 diff --git a/src/scheduler/FunctionCallClient.cpp b/src/scheduler/FunctionCallClient.cpp index 74056b89c..a32701e8e 100644 --- a/src/scheduler/FunctionCallClient.cpp +++ b/src/scheduler/FunctionCallClient.cpp @@ -19,8 +19,6 @@ static std::vector< std::pair>> batchMessages; -static std::vector> mpiMessages; - static std::vector> resourceRequests; @@ -48,11 +46,6 @@ getBatchRequests() return batchMessages; } -std::vector> getMPIMessages() -{ - return mpiMessages; -} - std::vector> getResourceRequests() { @@ -74,7 +67,6 @@ void clearMockRequests() { functionCalls.clear(); batchMessages.clear(); - mpiMessages.clear(); resourceRequests.clear(); unregisterRequests.clear(); @@ -113,17 +105,6 @@ void FunctionCallClient::sendFlush() } } -void FunctionCallClient::sendMPIMessage( - const std::shared_ptr msg) -{ - if (faabric::util::isMockMode()) { - faabric::util::UniqueLock lock(mockMutex); - mpiMessages.emplace_back(host, *msg); - } else { - SEND_MESSAGE_PTR(faabric::scheduler::FunctionCalls::MpiMessage, msg); - } -} - faabric::HostResources FunctionCallClient::getResources() { faabric::ResponseRequest request; diff --git a/src/scheduler/FunctionCallServer.cpp b/src/scheduler/FunctionCallServer.cpp index 0f2526b89..55a2416d9 100644 --- a/src/scheduler/FunctionCallServer.cpp +++ b/src/scheduler/FunctionCallServer.cpp @@ -27,9 +27,6 @@ void FunctionCallServer::doRecv(faabric::transport::Message& header, assert(header.size() == sizeof(uint8_t)); uint8_t call = static_cast(*header.data()); switch (call) { - case faabric::scheduler::FunctionCalls::MpiMessage: - this->recvMpiMessage(body); - break; case faabric::scheduler::FunctionCalls::Flush: this->recvFlush(body); break; @@ -48,15 +45,6 @@ void FunctionCallServer::doRecv(faabric::transport::Message& header, } } -void FunctionCallServer::recvMpiMessage(faabric::transport::Message& body) -{ - PARSE_MSG(faabric::MPIMessage, body.data(), body.size()) - - MpiWorldRegistry& registry = getMpiWorldRegistry(); - MpiWorld& world = registry.getWorld(msg.worldid()); - world.enqueueMessage(msg); -} - void FunctionCallServer::recvFlush(faabric::transport::Message& body) { PARSE_MSG(faabric::ResponseRequest, body.data(), body.size()); diff --git a/src/scheduler/MpiWorld.cpp b/src/scheduler/MpiWorld.cpp index 9327bca77..51e2bffb1 100644 --- a/src/scheduler/MpiWorld.cpp +++ b/src/scheduler/MpiWorld.cpp @@ -25,6 +25,7 @@ MpiWorld::MpiWorld() , cartProcsPerDim(2) {} +/* faabric::scheduler::FunctionCallClient& MpiWorld::getFunctionCallClient( const std::string& otherHost) { @@ -39,6 +40,7 @@ faabric::scheduler::FunctionCallClient& MpiWorld::getFunctionCallClient( } return it->second; } +*/ int MpiWorld::getMpiThreadPoolSize() { @@ -152,6 +154,7 @@ void MpiWorld::shutdownThreadPool() closeThreadLocalClients(); } +// TODO - remove // Clear thread local state void MpiWorld::closeThreadLocalClients() { @@ -440,7 +443,7 @@ void MpiWorld::send(int sendRank, getLocalQueue(sendRank, recvRank)->enqueue(std::move(m)); } else { SPDLOG_TRACE("MPI - send remote {} -> {}", sendRank, recvRank); - getFunctionCallClient(otherHost).sendMPIMessage(m); + faabric::transport::sendMpiMessage(otherHost, getMpiPort(sendRank, recvRank), m); } } @@ -1080,22 +1083,6 @@ void MpiWorld::barrier(int thisRank) } } -void MpiWorld::enqueueMessage(faabric::MPIMessage& msg) -{ - if (msg.worldid() != id) { - SPDLOG_ERROR( - "Queueing message not meant for this world (msg={}, this={})", - msg.worldid(), - id); - throw std::runtime_error("Queueing message not for this world"); - } - - SPDLOG_TRACE( - "Queueing message locally {} -> {}", msg.sender(), msg.destination()); - getLocalQueue(msg.sender(), msg.destination()) - ->enqueue(std::make_shared(msg)); -} - std::shared_ptr MpiWorld::getLocalQueue(int sendRank, int recvRank) { diff --git a/tests/test/scheduler/test_function_client_server.cpp b/tests/test/scheduler/test_function_client_server.cpp index f90684f77..6019ae191 100644 --- a/tests/test/scheduler/test_function_client_server.cpp +++ b/tests/test/scheduler/test_function_client_server.cpp @@ -51,77 +51,6 @@ class ClientServerFixture } }; -TEST_CASE_METHOD(ClientServerFixture, "Test sending MPI message", "[scheduler]") -{ - auto& sch = faabric::scheduler::getScheduler(); - - // Force the scheduler to initialise a world in the remote host by setting - // a world size bigger than the slots available locally - int worldSize = 2; - faabric::HostResources localResources; - localResources.set_slots(1); - localResources.set_usedslots(1); - faabric::HostResources otherResources; - otherResources.set_slots(1); - - // Set up a remote host - std::string otherHost = LOCALHOST; - sch.addHostToGlobalSet(otherHost); - - // Mock everything to make sure the other host has resources as well - faabric::util::setMockMode(true); - sch.setThisHostResources(localResources); - faabric::scheduler::queueResourceResponse(otherHost, otherResources); - - // Create an MPI world on this host and one on a "remote" host - const char* user = "mpi"; - const char* func = "hellompi"; - int worldId = 123; - faabric::Message msg; - msg.set_user(user); - msg.set_function(func); - msg.set_mpiworldid(worldId); - msg.set_mpiworldsize(worldSize); - faabric::util::messageFactory(user, func); - - scheduler::MpiWorldRegistry& registry = getMpiWorldRegistry(); - scheduler::MpiWorld& localWorld = registry.createWorld(msg, worldId); - - scheduler::MpiWorld remoteWorld; - remoteWorld.overrideHost(otherHost); - remoteWorld.initialiseFromMsg(msg); - - // Register a rank on each - int rankLocal = 0; - int rankRemote = 1; - - // Undo the mocking, so we actually send the MPI message - faabric::util::setMockMode(false); - - // Create a message - faabric::MPIMessage mpiMsg; - mpiMsg.set_worldid(worldId); - mpiMsg.set_sender(rankRemote); - mpiMsg.set_destination(rankLocal); - - // Send the message - cli.sendMPIMessage(std::make_shared(mpiMsg)); - usleep(1000 * TEST_TIMEOUT_MS); - - // Make sure the message has been put on the right queue locally - std::shared_ptr queue = - localWorld.getLocalQueue(rankRemote, rankLocal); - REQUIRE(queue->size() == 1); - const std::shared_ptr actualMessage = queue->dequeue(); - - REQUIRE(actualMessage->worldid() == worldId); - REQUIRE(actualMessage->sender() == rankRemote); - - localWorld.destroy(); - remoteWorld.destroy(); - registry.clear(); -} - TEST_CASE_METHOD(ClientServerFixture, "Test sending flush message", "[scheduler]") From 3a291698d5e34731bc5ce848a775a560aa950106 Mon Sep 17 00:00:00 2001 From: Carlos Segarra Date: Wed, 9 Jun 2021 09:56:30 +0000 Subject: [PATCH 4/7] update tests --- tests/test/scheduler/test_mpi_world.cpp | 4 +-- .../test/scheduler/test_remote_mpi_worlds.cpp | 27 ------------------- 2 files changed, 1 insertion(+), 30 deletions(-) diff --git a/tests/test/scheduler/test_mpi_world.cpp b/tests/test/scheduler/test_mpi_world.cpp index 402ebec8e..a344421d1 100644 --- a/tests/test/scheduler/test_mpi_world.cpp +++ b/tests/test/scheduler/test_mpi_world.cpp @@ -1,12 +1,10 @@ #include #include -#include -#include #include +#include #include #include -#include #include #include diff --git a/tests/test/scheduler/test_remote_mpi_worlds.cpp b/tests/test/scheduler/test_remote_mpi_worlds.cpp index 947311eb1..c9031a7e2 100644 --- a/tests/test/scheduler/test_remote_mpi_worlds.cpp +++ b/tests/test/scheduler/test_remote_mpi_worlds.cpp @@ -1,7 +1,6 @@ #include #include -#include #include #include #include @@ -32,11 +31,6 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, "Test rank allocation", "[mpi]") TEST_CASE_METHOD(RemoteMpiTestFixture, "Test send across hosts", "[mpi]") { - // Start a server on this host - FunctionCallServer server; - server.start(); - usleep(1000 * 100); - // Register two ranks (one on each host) this->setWorldsSizes(2, 1, 1); int rankA = 0; @@ -53,19 +47,6 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, "Test send across hosts", "[mpi]") rankB, rankA, BYTES(messageData.data()), MPI_INT, messageData.size()); usleep(1000 * 100); - SECTION("Check queueing") - { - REQUIRE(localWorld.getLocalQueueSize(rankB, rankA) == 1); - - // Check message content - faabric::MPIMessage actualMessage = - *(localWorld.getLocalQueue(rankB, rankA)->dequeue()); - REQUIRE(actualMessage.worldid() == worldId); - REQUIRE(actualMessage.count() == messageData.size()); - REQUIRE(actualMessage.sender() == rankB); - REQUIRE(actualMessage.destination() == rankA); - } - SECTION("Check recv") { // Receive the message for the given rank @@ -85,18 +66,12 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, "Test send across hosts", "[mpi]") // Destroy worlds localWorld.destroy(); remoteWorld.destroy(); - - server.stop(); } TEST_CASE_METHOD(RemoteMpiTestFixture, "Test collective messaging across hosts", "[mpi]") { - FunctionCallServer server; - server.start(); - usleep(1000 * 100); - // Here we rely on the scheduler running out of resources, and overloading // the localWorld with ranks 4 and 5 int thisWorldSize = 6; @@ -292,7 +267,5 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, // Destroy worlds localWorld.destroy(); remoteWorld.destroy(); - - server.stop(); } } From 13200da7a5c3b2161cfafd50750a4589030a61d7 Mon Sep 17 00:00:00 2001 From: Carlos Segarra Date: Wed, 9 Jun 2021 14:12:58 +0000 Subject: [PATCH 5/7] reusing sockets for rank-to-rank communication --- include/faabric/scheduler/MpiWorld.h | 30 +++-- .../faabric/transport/MpiMessageEndpoint.h | 25 +++- src/scheduler/MpiWorld.cpp | 121 +++++++++++++----- src/transport/MpiMessageEndpoint.cpp | 43 ++++--- tests/test/scheduler/test_mpi_world.cpp | 2 +- .../test/scheduler/test_remote_mpi_worlds.cpp | 20 ++- 6 files changed, 172 insertions(+), 69 deletions(-) diff --git a/include/faabric/scheduler/MpiWorld.h b/include/faabric/scheduler/MpiWorld.h index 5a0bb9dee..baecbe33a 100644 --- a/include/faabric/scheduler/MpiWorld.h +++ b/include/faabric/scheduler/MpiWorld.h @@ -6,7 +6,9 @@ #include #include #include -#include +#include +#include +#include #include #include @@ -15,8 +17,6 @@ namespace faabric::scheduler { typedef faabric::util::Queue> InMemoryMpiQueue; -std::string getWorldStateKey(int worldId); - class MpiWorld { public: @@ -43,8 +43,6 @@ class MpiWorld void shutdownThreadPool(); - void enqueueMessage(faabric::MPIMessage& msg); - void getCartesianRank(int rank, int maxDims, const int* dims, @@ -59,8 +57,6 @@ class MpiWorld int* source, int* destination); - int getMpiPort(int sendRank, int recvRank); - void send(int sendRank, int recvRank, const uint8_t* buffer, @@ -200,20 +196,28 @@ class MpiWorld std::string user; std::string function; - std::shared_ptr stateKV; - std::vector rankHosts; - - std::vector> localQueues; - std::shared_ptr threadPool; int getMpiThreadPoolSize(); std::vector cartProcsPerDim; - void closeThreadLocalClients(); + /* MPI internal messaging layer */ + // Track at which host each rank lives + std::vector rankHosts; int getIndexForRanks(int sendRank, int recvRank); + // In-memory queues for local messaging + std::vector> localQueues; void initLocalQueues(); + + // Rank-to-rank sockets for remote messaging + int getMpiPort(int sendRank, int recvRank); + void sendRemoteMpiMessage(int sendRank, + int recvRank, + const std::shared_ptr& msg); + std::shared_ptr recvRemoteMpiMessage(int sendRank, + int recvRank); + void closeMpiMessageEndpoints(); }; } diff --git a/include/faabric/transport/MpiMessageEndpoint.h b/include/faabric/transport/MpiMessageEndpoint.h index 0641b5449..6cc1552dd 100644 --- a/include/faabric/transport/MpiMessageEndpoint.h +++ b/include/faabric/transport/MpiMessageEndpoint.h @@ -6,14 +6,31 @@ #include namespace faabric::transport { +/* These two abstract methods are used to broadcast the host-rank mapping at + * initialisation time. + */ faabric::MpiHostsToRanksMessage recvMpiHostRankMsg(); void sendMpiHostRankMsg(const std::string& hostIn, const faabric::MpiHostsToRanksMessage msg); -void sendMpiMessage(const std::string& hostIn, - int portIn, - const std::shared_ptr msg); +/* This class abstracts the notion of a communication channel between two MPI + * ranks. There will always be one rank local to this host, and one remote. + * Note that the port is unique per (user, function, sendRank, recvRank) tuple. + */ +class MpiMessageEndpoint +{ + public: + MpiMessageEndpoint(const std::string& hostIn, int portIn); -std::shared_ptr recvMpiMessage(int portIn); + void sendMpiMessage(const std::shared_ptr& msg); + + std::shared_ptr recvMpiMessage(); + + void close(); + + private: + SendMessageEndpoint sendMessageEndpoint; + RecvMessageEndpoint recvMessageEndpoint; +}; } diff --git a/src/scheduler/MpiWorld.cpp b/src/scheduler/MpiWorld.cpp index 51e2bffb1..ff223eccb 100644 --- a/src/scheduler/MpiWorld.cpp +++ b/src/scheduler/MpiWorld.cpp @@ -1,20 +1,14 @@ -#include - -#include #include #include -#include #include #include #include -#include #include -#include static thread_local std::unordered_map> futureMap; -static thread_local std::unordered_map - functionCallClients; +static thread_local std::vector< + std::unique_ptr> + mpiMessageEndpoints; namespace faabric::scheduler { MpiWorld::MpiWorld() @@ -25,22 +19,85 @@ MpiWorld::MpiWorld() , cartProcsPerDim(2) {} -/* -faabric::scheduler::FunctionCallClient& MpiWorld::getFunctionCallClient( - const std::string& otherHost) +void MpiWorld::sendRemoteMpiMessage( + int sendRank, + int recvRank, + const std::shared_ptr& msg) +{ + // Assert the ranks are sane + assert(0 <= sendRank && sendRank < size); + assert(0 <= recvRank && recvRank < size); + + // Initialise the endpoint vector if not initialised + if (mpiMessageEndpoints.size() == 0) { + for (int i = 0; i < size * size; i++) { + mpiMessageEndpoints.emplace_back(nullptr); + } + } + + // Get the index for the rank-host pair + int index = getIndexForRanks(sendRank, recvRank); + assert(index >= 0 && index < size * size); + + // Lazily initialise send endpoints + if (mpiMessageEndpoints[index] == nullptr) { + // Get host for recv rank + std::string host = getHostForRank(recvRank); + assert(!host.empty()); + assert(host != thisHost); + + // Get port for send-recv pair + int port = getMpiPort(sendRank, recvRank); + // TODO - finer-grained checking when multi-tenant port scheme + assert(port > 0); + + // Create MPI message endpoint + mpiMessageEndpoints.emplace( + mpiMessageEndpoints.begin() + index, + std::make_unique(host, port)); + } + + mpiMessageEndpoints[index]->sendMpiMessage(msg); +} + +std::shared_ptr MpiWorld::recvRemoteMpiMessage( + int sendRank, + int recvRank) { - auto it = functionCallClients.find(otherHost); - if (it == functionCallClients.end()) { - // The second argument is forwarded to the client's constructor - auto _it = functionCallClients.try_emplace(otherHost, otherHost); - if (!_it.second) { - throw std::runtime_error("Error inserting remote endpoint"); + // Assert the ranks are sane + assert(0 <= sendRank && sendRank < size); + assert(0 <= recvRank && recvRank < size); + + // Initialise the endpoint vector if not initialised + if (mpiMessageEndpoints.size() == 0) { + for (int i = 0; i < size * size; i++) { + mpiMessageEndpoints.emplace_back(nullptr); } - it = _it.first; } - return it->second; + + // Get the index for the rank-host pair + int index = getIndexForRanks(sendRank, recvRank); + assert(index >= 0 && index < size * size); + + // Lazily initialise send endpoints + if (mpiMessageEndpoints[index] == nullptr) { + // Get host for recv rank + std::string host = getHostForRank(sendRank); + assert(!host.empty()); + assert(host != thisHost); + + // Get port for send-recv pair + int port = getMpiPort(sendRank, recvRank); + // TODO - finer-grained checking when multi-tenant port scheme + assert(port > 0); + + mpiMessageEndpoints.emplace( + mpiMessageEndpoints.begin() + index, + std::make_unique(host, port)); + } + + return mpiMessageEndpoints[index]->recvMpiMessage(); } -*/ int MpiWorld::getMpiThreadPoolSize() { @@ -144,25 +201,29 @@ void MpiWorld::shutdownThreadPool() std::promise p; threadPool->getMpiReqQueue()->enqueue( std::make_tuple(QUEUE_SHUTDOWN, - std::bind(&MpiWorld::closeThreadLocalClients, this), + std::bind(&MpiWorld::closeMpiMessageEndpoints, this), std::move(p))); } threadPool->shutdown(); // Lastly clean the main thread as well - closeThreadLocalClients(); + closeMpiMessageEndpoints(); } // TODO - remove // Clear thread local state -void MpiWorld::closeThreadLocalClients() +void MpiWorld::closeMpiMessageEndpoints() { - // Close all open sockets - for (auto& s : functionCallClients) { - s.second.close(); + if (mpiMessageEndpoints.size() > 0) { + // Close all open sockets + for (auto& e : mpiMessageEndpoints) { + if (e != nullptr) { + e->close(); + } + } + mpiMessageEndpoints.clear(); } - functionCallClients.clear(); } void MpiWorld::initialiseFromMsg(const faabric::Message& msg, bool forceLocal) @@ -443,7 +504,7 @@ void MpiWorld::send(int sendRank, getLocalQueue(sendRank, recvRank)->enqueue(std::move(m)); } else { SPDLOG_TRACE("MPI - send remote {} -> {}", sendRank, recvRank); - faabric::transport::sendMpiMessage(otherHost, getMpiPort(sendRank, recvRank), m); + sendRemoteMpiMessage(sendRank, recvRank, m); } } @@ -467,7 +528,7 @@ void MpiWorld::recv(int sendRank, m = getLocalQueue(sendRank, recvRank)->dequeue(); } else { SPDLOG_TRACE("MPI - recv remote {} -> {}", sendRank, recvRank); - m = faabric::transport::recvMpiMessage(getMpiPort(sendRank, recvRank)); + m = recvRemoteMpiMessage(sendRank, recvRank); } assert(m != nullptr); diff --git a/src/transport/MpiMessageEndpoint.cpp b/src/transport/MpiMessageEndpoint.cpp index d8727c5cc..897c51cb9 100644 --- a/src/transport/MpiMessageEndpoint.cpp +++ b/src/transport/MpiMessageEndpoint.cpp @@ -28,35 +28,48 @@ void sendMpiHostRankMsg(const std::string& hostIn, } } +MpiMessageEndpoint::MpiMessageEndpoint(const std::string& hostIn, int portIn) + : sendMessageEndpoint(hostIn, portIn) + , recvMessageEndpoint(portIn) +{} -// TODO - reuse clients!! -void sendMpiMessage(const std::string& hostIn, int portIn, - const std::shared_ptr msg) +void MpiMessageEndpoint::sendMpiMessage( + const std::shared_ptr& msg) { + // TODO - is this lazy init very expensive? + if (sendMessageEndpoint.socket == nullptr) { + sendMessageEndpoint.open(faabric::transport::getGlobalMessageContext()); + } + size_t msgSize = msg->ByteSizeLong(); { uint8_t sMsg[msgSize]; if (!msg->SerializeToArray(sMsg, msgSize)) { throw std::runtime_error("Error serialising message"); } - SendMessageEndpoint endpoint(hostIn, portIn); - endpoint.open(getGlobalMessageContext()); - endpoint.send(sMsg, msgSize, false); - endpoint.close(); + sendMessageEndpoint.send(sMsg, msgSize, false); } } -std::shared_ptr recvMpiMessage(int portIn) +std::shared_ptr MpiMessageEndpoint::recvMpiMessage() { - RecvMessageEndpoint endpoint(portIn); - endpoint.open(getGlobalMessageContext()); - // TODO - preempt data size somehow - Message m = endpoint.recv(); + if (recvMessageEndpoint.socket == nullptr) { + recvMessageEndpoint.open(faabric::transport::getGlobalMessageContext()); + } + + Message m = recvMessageEndpoint.recv(); PARSE_MSG(faabric::MPIMessage, m.data(), m.size()); - // Note - This may be very slow as we poll until unbound - endpoint.close(); - // TODO - send normal message, not shared_ptr return std::make_shared(msg); } + +void MpiMessageEndpoint::close() +{ + if (sendMessageEndpoint.socket != nullptr) { + sendMessageEndpoint.close(); + } + if (recvMessageEndpoint.socket != nullptr) { + recvMessageEndpoint.close(); + } +} } diff --git a/tests/test/scheduler/test_mpi_world.cpp b/tests/test/scheduler/test_mpi_world.cpp index a344421d1..3d626e10e 100644 --- a/tests/test/scheduler/test_mpi_world.cpp +++ b/tests/test/scheduler/test_mpi_world.cpp @@ -1,8 +1,8 @@ #include #include -#include #include +#include #include #include #include diff --git a/tests/test/scheduler/test_remote_mpi_worlds.cpp b/tests/test/scheduler/test_remote_mpi_worlds.cpp index c9031a7e2..541025e17 100644 --- a/tests/test/scheduler/test_remote_mpi_worlds.cpp +++ b/tests/test/scheduler/test_remote_mpi_worlds.cpp @@ -7,6 +7,8 @@ #include #include +#include + using namespace faabric::scheduler; namespace tests { @@ -39,13 +41,19 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, "Test send across hosts", "[mpi]") // Init worlds MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); - remoteWorld.initialiseFromMsg(msg); faabric::util::setMockMode(false); - // Send a message that should get sent to this host - remoteWorld.send( - rankB, rankA, BYTES(messageData.data()), MPI_INT, messageData.size()); - usleep(1000 * 100); + std::thread senderThread([this, rankA, rankB] { + std::vector messageData = { 0, 1, 2 }; + + remoteWorld.initialiseFromMsg(msg); + + // Send a message that should get sent to this host + remoteWorld.send( + rankB, rankA, BYTES(messageData.data()), MPI_INT, messageData.size()); + usleep(1000 * 500); + remoteWorld.destroy(); + }); SECTION("Check recv") { @@ -64,8 +72,8 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, "Test send across hosts", "[mpi]") } // Destroy worlds + senderThread.join(); localWorld.destroy(); - remoteWorld.destroy(); } TEST_CASE_METHOD(RemoteMpiTestFixture, From 421234d21a7041b81f308dea43448965dff91359 Mon Sep 17 00:00:00 2001 From: Carlos Segarra Date: Wed, 9 Jun 2021 15:48:27 +0000 Subject: [PATCH 6/7] adding more tests --- .../test/scheduler/test_remote_mpi_worlds.cpp | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/tests/test/scheduler/test_remote_mpi_worlds.cpp b/tests/test/scheduler/test_remote_mpi_worlds.cpp index 541025e17..48a36ddaf 100644 --- a/tests/test/scheduler/test_remote_mpi_worlds.cpp +++ b/tests/test/scheduler/test_remote_mpi_worlds.cpp @@ -7,6 +7,8 @@ #include #include +#include + #include using namespace faabric::scheduler; @@ -76,6 +78,48 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, "Test send across hosts", "[mpi]") localWorld.destroy(); } +TEST_CASE_METHOD(RemoteMpiTestFixture, + "Test sending many messages across host", + "[mpi]") +{ + // Register two ranks (one on each host) + this->setWorldsSizes(2, 1, 1); + int rankA = 0; + int rankB = 1; + int numMessages = 1000; + + // Init worlds + MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); + faabric::util::setMockMode(false); + + std::thread senderThread([this, rankA, rankB, numMessages] { + std::vector messageData = { 0, 1, 2 }; + + remoteWorld.initialiseFromMsg(msg); + + for (int i = 0; i < numMessages; i++) { + remoteWorld.send(rankB, rankA, BYTES(&i), MPI_INT, sizeof(int)); + } + usleep(1000 * 500); + remoteWorld.destroy(); + }); + + int recv; + for (int i = 0; i < numMessages; i++) { + localWorld.recv( + rankB, rankA, BYTES(&recv), MPI_INT, sizeof(int), MPI_STATUS_IGNORE); + + // Check in-order delivery + if (i % (numMessages / 10) == 0) { + REQUIRE(recv == i); + } + } + + // Destroy worlds + senderThread.join(); + localWorld.destroy(); +} + TEST_CASE_METHOD(RemoteMpiTestFixture, "Test collective messaging across hosts", "[mpi]") From 886d0f5ff3385ec399a2b327202f49704aafcbc4 Mon Sep 17 00:00:00 2001 From: Carlos Segarra Date: Mon, 14 Jun 2021 14:41:51 +0000 Subject: [PATCH 7/7] pr comments --- include/faabric/scheduler/MpiWorld.h | 3 + include/faabric/transport/MessageEndpoint.h | 2 + .../faabric/transport/MpiMessageEndpoint.h | 4 + src/scheduler/MpiWorld.cpp | 146 +++++---- src/transport/MessageEndpoint.cpp | 9 +- src/transport/MpiMessageEndpoint.cpp | 24 +- .../test/scheduler/test_remote_mpi_worlds.cpp | 302 ++++++++++-------- .../test_message_endpoint_client.cpp | 13 - .../transport/test_mpi_message_endpoint.cpp | 50 +++ tests/utils/fixtures.h | 14 + 10 files changed, 353 insertions(+), 214 deletions(-) create mode 100644 tests/test/transport/test_mpi_message_endpoint.cpp diff --git a/include/faabric/scheduler/MpiWorld.h b/include/faabric/scheduler/MpiWorld.h index baecbe33a..dfe07c99e 100644 --- a/include/faabric/scheduler/MpiWorld.h +++ b/include/faabric/scheduler/MpiWorld.h @@ -212,6 +212,7 @@ class MpiWorld void initLocalQueues(); // Rank-to-rank sockets for remote messaging + void initRemoteMpiEndpoint(int sendRank, int recvRank); int getMpiPort(int sendRank, int recvRank); void sendRemoteMpiMessage(int sendRank, int recvRank, @@ -219,5 +220,7 @@ class MpiWorld std::shared_ptr recvRemoteMpiMessage(int sendRank, int recvRank); void closeMpiMessageEndpoints(); + + void checkRanksRange(int sendRank, int recvRank); }; } diff --git a/include/faabric/transport/MessageEndpoint.h b/include/faabric/transport/MessageEndpoint.h index 4e463e07a..658448455 100644 --- a/include/faabric/transport/MessageEndpoint.h +++ b/include/faabric/transport/MessageEndpoint.h @@ -83,6 +83,8 @@ class RecvMessageEndpoint : public MessageEndpoint public: RecvMessageEndpoint(int portIn); + RecvMessageEndpoint(int portIn, const std::string& overrideHost); + void open(MessageContext& context); void close(); diff --git a/include/faabric/transport/MpiMessageEndpoint.h b/include/faabric/transport/MpiMessageEndpoint.h index 6cc1552dd..40067fc6b 100644 --- a/include/faabric/transport/MpiMessageEndpoint.h +++ b/include/faabric/transport/MpiMessageEndpoint.h @@ -23,6 +23,10 @@ class MpiMessageEndpoint public: MpiMessageEndpoint(const std::string& hostIn, int portIn); + MpiMessageEndpoint(const std::string& hostIn, + int portIn, + const std::string& overrideRecvHost); + void sendMpiMessage(const std::shared_ptr& msg); std::shared_ptr recvMpiMessage(); diff --git a/src/scheduler/MpiWorld.cpp b/src/scheduler/MpiWorld.cpp index ff223eccb..077de1deb 100644 --- a/src/scheduler/MpiWorld.cpp +++ b/src/scheduler/MpiWorld.cpp @@ -19,42 +19,65 @@ MpiWorld::MpiWorld() , cartProcsPerDim(2) {} -void MpiWorld::sendRemoteMpiMessage( - int sendRank, - int recvRank, - const std::shared_ptr& msg) +void MpiWorld::initRemoteMpiEndpoint(int sendRank, int recvRank) { - // Assert the ranks are sane - assert(0 <= sendRank && sendRank < size); - assert(0 <= recvRank && recvRank < size); - - // Initialise the endpoint vector if not initialised - if (mpiMessageEndpoints.size() == 0) { + // Resize the message endpoint vector and initialise to null. Note that we + // allocate size x size slots to cover all possible (sendRank, recvRank) + // pairs + if (mpiMessageEndpoints.empty()) { for (int i = 0; i < size * size; i++) { mpiMessageEndpoints.emplace_back(nullptr); } } + // Get host for recv rank + std::string otherHost; + std::string recvHost = getHostForRank(recvRank); + std::string sendHost = getHostForRank(sendRank); + if (recvHost == sendHost) { + SPDLOG_ERROR( + "Send and recv ranks in the same host: SEND {}, RECV{} in {}", + sendRank, + recvRank, + sendHost); + throw std::runtime_error("Send and recv ranks in the same host"); + } else if (recvHost == thisHost) { + otherHost = sendHost; + } else if (sendHost == thisHost) { + otherHost = recvHost; + } else { + SPDLOG_ERROR("Send and recv ranks correspond to remote hosts: SEND {} " + "in {}, RECV {} in {}", + sendRank, + sendHost, + recvRank, + recvHost); + throw std::runtime_error("Send and recv ranks in remote hosts"); + } + + // Get the index for the rank-host pair + int index = getIndexForRanks(sendRank, recvRank); + + // Get port for send-recv pair + int port = getMpiPort(sendRank, recvRank); + + // Create MPI message endpoint + mpiMessageEndpoints.emplace( + mpiMessageEndpoints.begin() + index, + std::make_unique( + otherHost, port, thisHost)); +} + +void MpiWorld::sendRemoteMpiMessage( + int sendRank, + int recvRank, + const std::shared_ptr& msg) +{ // Get the index for the rank-host pair int index = getIndexForRanks(sendRank, recvRank); - assert(index >= 0 && index < size * size); - // Lazily initialise send endpoints - if (mpiMessageEndpoints[index] == nullptr) { - // Get host for recv rank - std::string host = getHostForRank(recvRank); - assert(!host.empty()); - assert(host != thisHost); - - // Get port for send-recv pair - int port = getMpiPort(sendRank, recvRank); - // TODO - finer-grained checking when multi-tenant port scheme - assert(port > 0); - - // Create MPI message endpoint - mpiMessageEndpoints.emplace( - mpiMessageEndpoints.begin() + index, - std::make_unique(host, port)); + if (mpiMessageEndpoints.empty() || mpiMessageEndpoints[index] == nullptr) { + initRemoteMpiEndpoint(sendRank, recvRank); } mpiMessageEndpoints[index]->sendMpiMessage(msg); @@ -64,36 +87,11 @@ std::shared_ptr MpiWorld::recvRemoteMpiMessage( int sendRank, int recvRank) { - // Assert the ranks are sane - assert(0 <= sendRank && sendRank < size); - assert(0 <= recvRank && recvRank < size); - - // Initialise the endpoint vector if not initialised - if (mpiMessageEndpoints.size() == 0) { - for (int i = 0; i < size * size; i++) { - mpiMessageEndpoints.emplace_back(nullptr); - } - } - // Get the index for the rank-host pair int index = getIndexForRanks(sendRank, recvRank); - assert(index >= 0 && index < size * size); - // Lazily initialise send endpoints - if (mpiMessageEndpoints[index] == nullptr) { - // Get host for recv rank - std::string host = getHostForRank(sendRank); - assert(!host.empty()); - assert(host != thisHost); - - // Get port for send-recv pair - int port = getMpiPort(sendRank, recvRank); - // TODO - finer-grained checking when multi-tenant port scheme - assert(port > 0); - - mpiMessageEndpoints.emplace( - mpiMessageEndpoints.begin() + index, - std::make_unique(host, port)); + if (mpiMessageEndpoints.empty() || mpiMessageEndpoints[index] == nullptr) { + initRemoteMpiEndpoint(sendRank, recvRank); } return mpiMessageEndpoints[index]->recvMpiMessage(); @@ -254,11 +252,6 @@ std::string MpiWorld::getHostForRank(int rank) { assert(rankHosts.size() == size); - if (rank >= size) { - throw std::runtime_error( - fmt::format("Rank bigger than world size ({} > {})", rank, size)); - } - std::string host = rankHosts[rank]; if (host.empty()) { throw std::runtime_error( @@ -476,6 +469,14 @@ void MpiWorld::send(int sendRank, int count, faabric::MPIMessage::MPIMessageType messageType) { + // Sanity-check input parameters + checkRanksRange(sendRank, recvRank); + if (getHostForRank(sendRank) != thisHost) { + SPDLOG_ERROR("Trying to send message from a non-local rank: {}", + sendRank); + throw std::runtime_error("Sending message from non-local rank"); + } + // Work out whether the message is sent locally or to another host const std::string otherHost = getHostForRank(recvRank); bool isLocal = otherHost == thisHost; @@ -516,8 +517,15 @@ void MpiWorld::recv(int sendRank, MPI_Status* status, faabric::MPIMessage::MPIMessageType messageType) { + // Sanity-check input parameters + checkRanksRange(sendRank, recvRank); + if (getHostForRank(recvRank) != thisHost) { + SPDLOG_ERROR("Trying to recv message into a non-local rank: {}", + recvRank); + throw std::runtime_error("Receiving message into non-local rank"); + } + // Work out whether the message is sent locally or from another host - assert(thisHost == getHostForRank(recvRank)); const std::string otherHost = getHostForRank(sendRank); bool isLocal = otherHost == thisHost; @@ -1174,7 +1182,9 @@ void MpiWorld::initLocalQueues() int MpiWorld::getIndexForRanks(int sendRank, int recvRank) { - return sendRank * size + recvRank; + int index = sendRank * size + recvRank; + assert(index >= 0 && index < size * size); + return index; } long MpiWorld::getLocalQueueSize(int sendRank, int recvRank) @@ -1214,4 +1224,18 @@ void MpiWorld::overrideHost(const std::string& newHost) { thisHost = newHost; } + +void MpiWorld::checkRanksRange(int sendRank, int recvRank) +{ + if (sendRank < 0 || sendRank >= size) { + SPDLOG_ERROR( + "Send rank outside range: {} not in [0, {})", sendRank, size); + throw std::runtime_error("Send rank outside range"); + } + if (recvRank < 0 || recvRank >= size) { + SPDLOG_ERROR( + "Recv rank outside range: {} not in [0, {})", recvRank, size); + throw std::runtime_error("Recv rank outside range"); + } +} } diff --git a/src/transport/MessageEndpoint.cpp b/src/transport/MessageEndpoint.cpp index 9f68e9327..5f87412b3 100644 --- a/src/transport/MessageEndpoint.cpp +++ b/src/transport/MessageEndpoint.cpp @@ -279,10 +279,15 @@ RecvMessageEndpoint::RecvMessageEndpoint(int portIn) : MessageEndpoint(ANY_HOST, portIn) {} +RecvMessageEndpoint::RecvMessageEndpoint(int portIn, + const std::string& overrideHost) + : MessageEndpoint(overrideHost, portIn) +{} + void RecvMessageEndpoint::open(MessageContext& context) { SPDLOG_TRACE( - fmt::format("Opening socket: {} (RECV {}:{})", id, ANY_HOST, port)); + fmt::format("Opening socket: {} (RECV {}:{})", id, host, port)); MessageEndpoint::open(context, SocketType::PULL, true); } @@ -290,7 +295,7 @@ void RecvMessageEndpoint::open(MessageContext& context) void RecvMessageEndpoint::close() { SPDLOG_TRACE( - fmt::format("Closing socket: {} (RECV {}:{})", id, ANY_HOST, port)); + fmt::format("Closing socket: {} (RECV {}:{})", id, host, port)); MessageEndpoint::close(true); } diff --git a/src/transport/MpiMessageEndpoint.cpp b/src/transport/MpiMessageEndpoint.cpp index 897c51cb9..4aebbefcd 100644 --- a/src/transport/MpiMessageEndpoint.cpp +++ b/src/transport/MpiMessageEndpoint.cpp @@ -31,16 +31,24 @@ void sendMpiHostRankMsg(const std::string& hostIn, MpiMessageEndpoint::MpiMessageEndpoint(const std::string& hostIn, int portIn) : sendMessageEndpoint(hostIn, portIn) , recvMessageEndpoint(portIn) -{} +{ + sendMessageEndpoint.open(faabric::transport::getGlobalMessageContext()); + recvMessageEndpoint.open(faabric::transport::getGlobalMessageContext()); +} + +MpiMessageEndpoint::MpiMessageEndpoint(const std::string& hostIn, + int portIn, + const std::string& overrideRecvHost) + : sendMessageEndpoint(hostIn, portIn) + , recvMessageEndpoint(portIn, overrideRecvHost) +{ + sendMessageEndpoint.open(faabric::transport::getGlobalMessageContext()); + recvMessageEndpoint.open(faabric::transport::getGlobalMessageContext()); +} void MpiMessageEndpoint::sendMpiMessage( const std::shared_ptr& msg) { - // TODO - is this lazy init very expensive? - if (sendMessageEndpoint.socket == nullptr) { - sendMessageEndpoint.open(faabric::transport::getGlobalMessageContext()); - } - size_t msgSize = msg->ByteSizeLong(); { uint8_t sMsg[msgSize]; @@ -53,10 +61,6 @@ void MpiMessageEndpoint::sendMpiMessage( std::shared_ptr MpiMessageEndpoint::recvMpiMessage() { - if (recvMessageEndpoint.socket == nullptr) { - recvMessageEndpoint.open(faabric::transport::getGlobalMessageContext()); - } - Message m = recvMessageEndpoint.recv(); PARSE_MSG(faabric::MPIMessage, m.data(), m.size()); diff --git a/tests/test/scheduler/test_remote_mpi_worlds.cpp b/tests/test/scheduler/test_remote_mpi_worlds.cpp index 48a36ddaf..84aeaa2cf 100644 --- a/tests/test/scheduler/test_remote_mpi_worlds.cpp +++ b/tests/test/scheduler/test_remote_mpi_worlds.cpp @@ -14,6 +14,28 @@ using namespace faabric::scheduler; namespace tests { +class RemoteCollectiveTestFixture : public RemoteMpiTestFixture +{ + public: + RemoteCollectiveTestFixture() + : thisWorldSize(6) + , remoteRankA(1) + , remoteRankB(2) + , remoteRankC(3) + , localRankA(4) + , localRankB(5) + , remoteWorldRanks({ remoteRankB, remoteRankC, remoteRankA }) + , localWorldRanks({ localRankB, localRankA, 0 }) + {} + + protected: + int thisWorldSize; + int remoteRankA, remoteRankB, remoteRankC; + int localRankA, localRankB; + std::vector remoteWorldRanks; + std::vector localWorldRanks; +}; + TEST_CASE_METHOD(RemoteMpiTestFixture, "Test rank allocation", "[mpi]") { // Allocate two ranks in total, one rank per host @@ -120,35 +142,23 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, localWorld.destroy(); } -TEST_CASE_METHOD(RemoteMpiTestFixture, - "Test collective messaging across hosts", +TEST_CASE_METHOD(RemoteCollectiveTestFixture, + "Test broadcast across hosts", "[mpi]") { // Here we rely on the scheduler running out of resources, and overloading // the localWorld with ranks 4 and 5 - int thisWorldSize = 6; this->setWorldsSizes(thisWorldSize, 1, 3); - int remoteRankA = 1; - int remoteRankB = 2; - int remoteRankC = 3; - int localRankA = 4; - int localRankB = 5; + std::vector messageData = { 0, 1, 2 }; // Init worlds MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); - remoteWorld.initialiseFromMsg(msg); faabric::util::setMockMode(false); - // Note that ranks are deliberately out of order - std::vector remoteWorldRanks = { remoteRankB, - remoteRankC, - remoteRankA }; - std::vector localWorldRanks = { localRankB, localRankA, 0 }; + std::thread senderThread([this, &messageData] { + remoteWorld.initialiseFromMsg(msg); - SECTION("Broadcast") - { // Broadcast a message - std::vector messageData = { 0, 1, 2 }; remoteWorld.broadcast( remoteRankB, BYTES(messageData.data()), MPI_INT, messageData.size()); @@ -161,28 +171,46 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, std::vector actual(3, -1); remoteWorld.recv( remoteRankB, rank, BYTES(actual.data()), MPI_INT, 3, nullptr); - REQUIRE(actual == messageData); + assert(actual == messageData); } - // Check the local host - for (int rank : localWorldRanks) { - std::vector actual(3, -1); - localWorld.recv( - remoteRankB, rank, BYTES(actual.data()), MPI_INT, 3, nullptr); - REQUIRE(actual == messageData); - } + remoteWorld.destroy(); + }); + + // Check the local host + for (int rank : localWorldRanks) { + std::vector actual(3, -1); + localWorld.recv( + remoteRankB, rank, BYTES(actual.data()), MPI_INT, 3, nullptr); + REQUIRE(actual == messageData); } - SECTION("Scatter") - { - // Build the data - int nPerRank = 4; - int dataSize = nPerRank * thisWorldSize; - std::vector messageData(dataSize, 0); - for (int i = 0; i < dataSize; i++) { - messageData[i] = i; - } + senderThread.join(); + localWorld.destroy(); +} + +TEST_CASE_METHOD(RemoteCollectiveTestFixture, + "Test scatter across hosts", + "[mpi]") +{ + // Here we rely on the scheduler running out of resources, and overloading + // the localWorld with ranks 4 and 5 + this->setWorldsSizes(thisWorldSize, 1, 3); + // Init worlds + MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); + faabric::util::setMockMode(false); + + // Build the data + int nPerRank = 4; + int dataSize = nPerRank * thisWorldSize; + std::vector messageData(dataSize, 0); + for (int i = 0; i < dataSize; i++) { + messageData[i] = i; + } + + std::thread senderThread([this, nPerRank, &messageData] { + remoteWorld.initialiseFromMsg(msg); // Do the scatter std::vector actual(nPerRank, -1); remoteWorld.scatter(remoteRankB, @@ -195,7 +223,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, nPerRank); // Check for root - REQUIRE(actual == std::vector({ 8, 9, 10, 11 })); + assert(actual == std::vector({ 8, 9, 10, 11 })); // Check for other remote ranks remoteWorld.scatter(remoteRankB, @@ -206,7 +234,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, BYTES(actual.data()), MPI_INT, nPerRank); - REQUIRE(actual == std::vector({ 4, 5, 6, 7 })); + assert(actual == std::vector({ 4, 5, 6, 7 })); remoteWorld.scatter(remoteRankB, remoteRankC, @@ -216,108 +244,126 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, BYTES(actual.data()), MPI_INT, nPerRank); - REQUIRE(actual == std::vector({ 12, 13, 14, 15 })); - - // Check for local ranks - localWorld.scatter(remoteRankB, - 0, - nullptr, - MPI_INT, - nPerRank, - BYTES(actual.data()), - MPI_INT, - nPerRank); - REQUIRE(actual == std::vector({ 0, 1, 2, 3 })); - - localWorld.scatter(remoteRankB, - localRankB, - nullptr, - MPI_INT, - nPerRank, - BYTES(actual.data()), - MPI_INT, - nPerRank); - REQUIRE(actual == std::vector({ 20, 21, 22, 23 })); - - localWorld.scatter(remoteRankB, - localRankA, - nullptr, - MPI_INT, - nPerRank, - BYTES(actual.data()), - MPI_INT, - nPerRank); - REQUIRE(actual == std::vector({ 16, 17, 18, 19 })); - } + assert(actual == std::vector({ 12, 13, 14, 15 })); - SECTION("Gather and allgather") - { - // Build the data for each rank - int nPerRank = 4; - std::vector> rankData; - for (int i = 0; i < thisWorldSize; i++) { - std::vector thisRankData; - for (int j = 0; j < nPerRank; j++) { - thisRankData.push_back((i * nPerRank) + j); - } + remoteWorld.destroy(); + }); - rankData.push_back(thisRankData); - } + // Check for local ranks + std::vector actual(nPerRank, -1); + localWorld.scatter(remoteRankB, + 0, + nullptr, + MPI_INT, + nPerRank, + BYTES(actual.data()), + MPI_INT, + nPerRank); + REQUIRE(actual == std::vector({ 0, 1, 2, 3 })); + + localWorld.scatter(remoteRankB, + localRankB, + nullptr, + MPI_INT, + nPerRank, + BYTES(actual.data()), + MPI_INT, + nPerRank); + REQUIRE(actual == std::vector({ 20, 21, 22, 23 })); + + localWorld.scatter(remoteRankB, + localRankA, + nullptr, + MPI_INT, + nPerRank, + BYTES(actual.data()), + MPI_INT, + nPerRank); + REQUIRE(actual == std::vector({ 16, 17, 18, 19 })); + + senderThread.join(); + localWorld.destroy(); +} + +TEST_CASE_METHOD(RemoteCollectiveTestFixture, + "Test gather across hosts", + "[mpi]") +{ + // Here we rely on the scheduler running out of resources, and overloading + // the localWorld with ranks 4 and 5 + this->setWorldsSizes(thisWorldSize, 1, 3); + + // Init worlds + MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); + faabric::util::setMockMode(false); - // Build the expectation - std::vector expected; - for (int i = 0; i < thisWorldSize * nPerRank; i++) { - expected.push_back(i); + // Build the data for each rank + int nPerRank = 4; + std::vector> rankData; + for (int i = 0; i < thisWorldSize; i++) { + std::vector thisRankData; + for (int j = 0; j < nPerRank; j++) { + thisRankData.push_back((i * nPerRank) + j); } - SECTION("Gather") - { - std::vector actual(thisWorldSize * nPerRank, -1); - - // Call gather for each rank other than the root (out of order) - int root = localRankA; - for (int rank : remoteWorldRanks) { - remoteWorld.gather(rank, - root, - BYTES(rankData[rank].data()), - MPI_INT, - nPerRank, - nullptr, - MPI_INT, - nPerRank); - } + rankData.push_back(thisRankData); + } - for (int rank : localWorldRanks) { - if (rank == root) { - continue; - } - localWorld.gather(rank, - root, - BYTES(rankData[rank].data()), - MPI_INT, - nPerRank, - nullptr, - MPI_INT, - nPerRank); - } + // Build the expectation + std::vector expected; + for (int i = 0; i < thisWorldSize * nPerRank; i++) { + expected.push_back(i); + } + + std::vector actual(thisWorldSize * nPerRank, -1); + + // Call gather for each rank other than the root (out of order) + int root = localRankA; + std::thread senderThread([this, root, &rankData, nPerRank] { + remoteWorld.initialiseFromMsg(msg); + + for (int rank : remoteWorldRanks) { + remoteWorld.gather(rank, + root, + BYTES(rankData[rank].data()), + MPI_INT, + nPerRank, + nullptr, + MPI_INT, + nPerRank); + } - // Call gather for root - localWorld.gather(root, - root, - BYTES(rankData[root].data()), - MPI_INT, - nPerRank, - BYTES(actual.data()), - MPI_INT, - nPerRank); - - // Check data - REQUIRE(actual == expected); + remoteWorld.destroy(); + }); + + for (int rank : localWorldRanks) { + if (rank == root) { + continue; } + localWorld.gather(rank, + root, + BYTES(rankData[rank].data()), + MPI_INT, + nPerRank, + nullptr, + MPI_INT, + nPerRank); } - // Destroy worlds + // Call gather for root + localWorld.gather(root, + root, + BYTES(rankData[root].data()), + MPI_INT, + nPerRank, + BYTES(actual.data()), + MPI_INT, + nPerRank); + + // Check data + REQUIRE(actual == expected); + + senderThread.join(); localWorld.destroy(); - remoteWorld.destroy(); } } diff --git a/tests/test/transport/test_message_endpoint_client.cpp b/tests/test/transport/test_message_endpoint_client.cpp index 9fe72c518..7119602ab 100644 --- a/tests/test/transport/test_message_endpoint_client.cpp +++ b/tests/test/transport/test_message_endpoint_client.cpp @@ -13,19 +13,6 @@ const int testPort = 9999; const int testReplyPort = 9996; namespace tests { -class MessageContextFixture : public SchedulerTestFixture -{ - protected: - MessageContext& context; - - public: - MessageContextFixture() - : context(getGlobalMessageContext()) - {} - - ~MessageContextFixture() { context.close(); } -}; - TEST_CASE_METHOD(MessageContextFixture, "Test open/close one client", "[transport]") diff --git a/tests/test/transport/test_mpi_message_endpoint.cpp b/tests/test/transport/test_mpi_message_endpoint.cpp new file mode 100644 index 000000000..a5b6a4ced --- /dev/null +++ b/tests/test/transport/test_mpi_message_endpoint.cpp @@ -0,0 +1,50 @@ +#include + +#include +#include + +using namespace faabric::transport; + +namespace tests { +TEST_CASE_METHOD(MessageContextFixture, + "Test send and recv the hosts to rank message", + "[transport]") +{ + // Prepare message + std::vector expected = { "foo", "bar" }; + faabric::MpiHostsToRanksMessage sendMsg; + *sendMsg.mutable_hosts() = { expected.begin(), expected.end() }; + sendMpiHostRankMsg(LOCALHOST, sendMsg); + + // Send message + faabric::MpiHostsToRanksMessage actual = recvMpiHostRankMsg(); + + // Checks + REQUIRE(actual.hosts().size() == expected.size()); + for (int i = 0; i < actual.hosts().size(); i++) { + REQUIRE(actual.hosts().Get(i) == expected[i]); + } +} + +TEST_CASE_METHOD(MessageContextFixture, + "Test send and recv an MPI message", + "[transport]") +{ + std::string thisHost = faabric::util::getSystemConfig().endpointHost; + MpiMessageEndpoint sendEndpoint(LOCALHOST, 9999, thisHost); + MpiMessageEndpoint recvEndpoint(thisHost, 9999, LOCALHOST); + + std::shared_ptr expected = + std::make_shared(); + expected->set_id(1337); + + sendEndpoint.sendMpiMessage(expected); + std::shared_ptr actual = recvEndpoint.recvMpiMessage(); + + // Checks + REQUIRE(expected->id() == actual->id()); + + REQUIRE_NOTHROW(sendEndpoint.close()); + REQUIRE_NOTHROW(recvEndpoint.close()); +} +} diff --git a/tests/utils/fixtures.h b/tests/utils/fixtures.h index 1568b3cf4..b102db298 100644 --- a/tests/utils/fixtures.h +++ b/tests/utils/fixtures.h @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -117,6 +118,19 @@ class ConfTestFixture faabric::util::SystemConfig& conf; }; +class MessageContextFixture : public SchedulerTestFixture +{ + protected: + faabric::transport::MessageContext& context; + + public: + MessageContextFixture() + : context(faabric::transport::getGlobalMessageContext()) + {} + + ~MessageContextFixture() { context.close(); } +}; + class MpiBaseTestFixture : public SchedulerTestFixture { public: