From 149105c1d39d5d53979b860118f8db55507a006b Mon Sep 17 00:00:00 2001 From: Carlos Segarra Date: Wed, 9 Jun 2021 14:12:58 +0000 Subject: [PATCH] reusing sockets for rank-to-rank communication --- include/faabric/scheduler/MpiWorld.h | 30 +++-- .../faabric/transport/MpiMessageEndpoint.h | 25 +++- src/scheduler/MpiWorld.cpp | 121 +++++++++++++----- src/transport/MpiMessageEndpoint.cpp | 42 ++++-- tests/test/scheduler/test_mpi_world.cpp | 2 +- .../test/scheduler/test_remote_mpi_worlds.cpp | 20 ++- 6 files changed, 172 insertions(+), 68 deletions(-) diff --git a/include/faabric/scheduler/MpiWorld.h b/include/faabric/scheduler/MpiWorld.h index f397294ac..0c6a83ebf 100644 --- a/include/faabric/scheduler/MpiWorld.h +++ b/include/faabric/scheduler/MpiWorld.h @@ -6,7 +6,9 @@ #include #include #include -#include +#include +#include +#include #include #include @@ -15,8 +17,6 @@ namespace faabric::scheduler { typedef faabric::util::Queue> InMemoryMpiQueue; -std::string getWorldStateKey(int worldId); - class MpiWorld { public: @@ -43,8 +43,6 @@ class MpiWorld void shutdownThreadPool(); - void enqueueMessage(faabric::MPIMessage& msg); - void getCartesianRank(int rank, int maxDims, const int* dims, @@ -59,8 +57,6 @@ class MpiWorld int* source, int* destination); - int getMpiPort(int sendRank, int recvRank); - void send(int sendRank, int recvRank, const uint8_t* buffer, @@ -202,20 +198,28 @@ class MpiWorld std::string user; std::string function; - std::shared_ptr stateKV; - std::vector rankHosts; - - std::vector> localQueues; - std::shared_ptr threadPool; int getMpiThreadPoolSize(); std::vector cartProcsPerDim; - void closeThreadLocalClients(); + /* MPI internal messaging layer */ + // Track at which host each rank lives + std::vector rankHosts; int getIndexForRanks(int sendRank, int recvRank); + // In-memory queues for local messaging + std::vector> localQueues; void initLocalQueues(); + + // Rank-to-rank sockets for remote messaging + int getMpiPort(int sendRank, int recvRank); + void sendRemoteMpiMessage(int sendRank, + int recvRank, + const std::shared_ptr& msg); + std::shared_ptr recvRemoteMpiMessage(int sendRank, + int recvRank); + void closeMpiMessageEndpoints(); }; } diff --git a/include/faabric/transport/MpiMessageEndpoint.h b/include/faabric/transport/MpiMessageEndpoint.h index 0641b5449..a8895c53b 100644 --- a/include/faabric/transport/MpiMessageEndpoint.h +++ b/include/faabric/transport/MpiMessageEndpoint.h @@ -6,14 +6,31 @@ #include namespace faabric::transport { +/* This two abstract methods are used to broadcast the host-rank mapping at + * initialisation time. + */ faabric::MpiHostsToRanksMessage recvMpiHostRankMsg(); void sendMpiHostRankMsg(const std::string& hostIn, const faabric::MpiHostsToRanksMessage msg); -void sendMpiMessage(const std::string& hostIn, - int portIn, - const std::shared_ptr msg); +/* This class abstracts the notion of a communication channel between two MPI + * ranks. There will always be one rank local to this host, and one remote. + * Note that the port is unique per (user, function, sendRank, recvRank) tuple. + */ +class MpiMessageEndpoint +{ + public: + MpiMessageEndpoint(const std::string& hostIn, int portIn); -std::shared_ptr recvMpiMessage(int portIn); + void sendMpiMessage(const std::shared_ptr& msg); + + std::shared_ptr recvMpiMessage(); + + void close(); + + private: + SendMessageEndpoint sendMessageEndpoint; + RecvMessageEndpoint recvMessageEndpoint; +}; } diff --git a/src/scheduler/MpiWorld.cpp b/src/scheduler/MpiWorld.cpp index cdb0b13b4..c00c352f5 100644 --- a/src/scheduler/MpiWorld.cpp +++ b/src/scheduler/MpiWorld.cpp @@ -1,20 +1,14 @@ -#include - -#include #include #include -#include #include #include #include -#include #include -#include static thread_local std::unordered_map> futureMap; -static thread_local std::unordered_map - functionCallClients; +static thread_local std::vector< + std::unique_ptr> + mpiMessageEndpoints; namespace faabric::scheduler { MpiWorld::MpiWorld() @@ -26,22 +20,85 @@ MpiWorld::MpiWorld() , cartProcsPerDim(2) {} -/* -faabric::scheduler::FunctionCallClient& MpiWorld::getFunctionCallClient( - const std::string& otherHost) +void MpiWorld::sendRemoteMpiMessage( + int sendRank, + int recvRank, + const std::shared_ptr& msg) +{ + // Assert the ranks are sane + assert(0 <= sendRank && sendRank < size); + assert(0 <= recvRank && recvRank < size); + + // Initialise the endpoint vector if not initialised + if (mpiMessageEndpoints.size() == 0) { + for (int i = 0; i < size * size; i++) { + mpiMessageEndpoints.emplace_back(nullptr); + } + } + + // Get the index for the rank-host pair + int index = getIndexForRanks(sendRank, recvRank); + assert(index >= 0 && index < size * size); + + // Lazily initialise send endpoints + if (mpiMessageEndpoints[index] == nullptr) { + // Get host for recv rank + std::string host = getHostForRank(recvRank); + assert(!host.empty()); + assert(host != thisHost); + + // Get port for send-recv pair + int port = getMpiPort(sendRank, recvRank); + // TODO - finer-grained checking when multi-tenant port scheme + assert(port > 0); + + // Create MPI message endpoint + mpiMessageEndpoints.emplace( + mpiMessageEndpoints.begin() + index, + std::make_unique(host, port)); + } + + mpiMessageEndpoints[index]->sendMpiMessage(msg); +} + +std::shared_ptr MpiWorld::recvRemoteMpiMessage( + int sendRank, + int recvRank) { - auto it = functionCallClients.find(otherHost); - if (it == functionCallClients.end()) { - // The second argument is forwarded to the client's constructor - auto _it = functionCallClients.try_emplace(otherHost, otherHost); - if (!_it.second) { - throw std::runtime_error("Error inserting remote endpoint"); + // Assert the ranks are sane + assert(0 <= sendRank && sendRank < size); + assert(0 <= recvRank && recvRank < size); + + // Initialise the endpoint vector if not initialised + if (mpiMessageEndpoints.size() == 0) { + for (int i = 0; i < size * size; i++) { + mpiMessageEndpoints.emplace_back(nullptr); } - it = _it.first; } - return it->second; + + // Get the index for the rank-host pair + int index = getIndexForRanks(sendRank, recvRank); + assert(index >= 0 && index < size * size); + + // Lazily initialise send endpoints + if (mpiMessageEndpoints[index] == nullptr) { + // Get host for recv rank + std::string host = getHostForRank(sendRank); + assert(!host.empty()); + assert(host != thisHost); + + // Get port for send-recv pair + int port = getMpiPort(sendRank, recvRank); + // TODO - finer-grained checking when multi-tenant port scheme + assert(port > 0); + + mpiMessageEndpoints.emplace( + mpiMessageEndpoints.begin() + index, + std::make_unique(host, port)); + } + + return mpiMessageEndpoints[index]->recvMpiMessage(); } -*/ int MpiWorld::getMpiThreadPoolSize() { @@ -145,25 +202,29 @@ void MpiWorld::shutdownThreadPool() std::promise p; threadPool->getMpiReqQueue()->enqueue( std::make_tuple(QUEUE_SHUTDOWN, - std::bind(&MpiWorld::closeThreadLocalClients, this), + std::bind(&MpiWorld::closeMpiMessageEndpoints, this), std::move(p))); } threadPool->shutdown(); // Lastly clean the main thread as well - closeThreadLocalClients(); + closeMpiMessageEndpoints(); } // TODO - remove // Clear thread local state -void MpiWorld::closeThreadLocalClients() +void MpiWorld::closeMpiMessageEndpoints() { - // Close all open sockets - for (auto& s : functionCallClients) { - s.second.close(); + if (mpiMessageEndpoints.size() > 0) { + // Close all open sockets + for (auto& e : mpiMessageEndpoints) { + if (e != nullptr) { + e->close(); + } + } + mpiMessageEndpoints.clear(); } - functionCallClients.clear(); } void MpiWorld::initialiseFromMsg(const faabric::Message& msg, bool forceLocal) @@ -444,7 +505,7 @@ void MpiWorld::send(int sendRank, getLocalQueue(sendRank, recvRank)->enqueue(std::move(m)); } else { logger->trace("MPI - send remote {} -> {}", sendRank, recvRank); - faabric::transport::sendMpiMessage(otherHost, getMpiPort(sendRank, recvRank), m); + sendRemoteMpiMessage(sendRank, recvRank, m); } } @@ -468,7 +529,7 @@ void MpiWorld::recv(int sendRank, m = getLocalQueue(sendRank, recvRank)->dequeue(); } else { logger->trace("MPI - recv remote {} -> {}", sendRank, recvRank); - m = faabric::transport::recvMpiMessage(getMpiPort(sendRank, recvRank)); + m = recvRemoteMpiMessage(sendRank, recvRank); } assert(m != nullptr); diff --git a/src/transport/MpiMessageEndpoint.cpp b/src/transport/MpiMessageEndpoint.cpp index d8727c5cc..0d39b3d5d 100644 --- a/src/transport/MpiMessageEndpoint.cpp +++ b/src/transport/MpiMessageEndpoint.cpp @@ -28,35 +28,49 @@ void sendMpiHostRankMsg(const std::string& hostIn, } } +MpiMessageEndpoint::MpiMessageEndpoint(const std::string& hostIn, int portIn) + : sendMessageEndpoint(hostIn, portIn) + , recvMessageEndpoint(portIn) +{} -// TODO - reuse clients!! -void sendMpiMessage(const std::string& hostIn, int portIn, - const std::shared_ptr msg) +void MpiMessageEndpoint::sendMpiMessage( + const std::shared_ptr& msg) { + // TODO - is this lazy init very expensive? + if (sendMessageEndpoint.socket == nullptr) { + sendMessageEndpoint.open(faabric::transport::getGlobalMessageContext()); + } + size_t msgSize = msg->ByteSizeLong(); { uint8_t sMsg[msgSize]; if (!msg->SerializeToArray(sMsg, msgSize)) { throw std::runtime_error("Error serialising message"); } - SendMessageEndpoint endpoint(hostIn, portIn); - endpoint.open(getGlobalMessageContext()); - endpoint.send(sMsg, msgSize, false); - endpoint.close(); + sendMessageEndpoint.send(sMsg, msgSize, false); } } -std::shared_ptr recvMpiMessage(int portIn) +std::shared_ptr MpiMessageEndpoint::recvMpiMessage() { - RecvMessageEndpoint endpoint(portIn); - endpoint.open(getGlobalMessageContext()); + if (recvMessageEndpoint.socket == nullptr) { + recvMessageEndpoint.open(faabric::transport::getGlobalMessageContext()); + } + // TODO - preempt data size somehow - Message m = endpoint.recv(); + Message m = recvMessageEndpoint.recv(); PARSE_MSG(faabric::MPIMessage, m.data(), m.size()); - // Note - This may be very slow as we poll until unbound - endpoint.close(); - // TODO - send normal message, not shared_ptr return std::make_shared(msg); } + +void MpiMessageEndpoint::close() +{ + if (sendMessageEndpoint.socket != nullptr) { + sendMessageEndpoint.close(); + } + if (recvMessageEndpoint.socket != nullptr) { + recvMessageEndpoint.close(); + } +} } diff --git a/tests/test/scheduler/test_mpi_world.cpp b/tests/test/scheduler/test_mpi_world.cpp index a344421d1..3d626e10e 100644 --- a/tests/test/scheduler/test_mpi_world.cpp +++ b/tests/test/scheduler/test_mpi_world.cpp @@ -1,8 +1,8 @@ #include #include -#include #include +#include #include #include #include diff --git a/tests/test/scheduler/test_remote_mpi_worlds.cpp b/tests/test/scheduler/test_remote_mpi_worlds.cpp index c9031a7e2..541025e17 100644 --- a/tests/test/scheduler/test_remote_mpi_worlds.cpp +++ b/tests/test/scheduler/test_remote_mpi_worlds.cpp @@ -7,6 +7,8 @@ #include #include +#include + using namespace faabric::scheduler; namespace tests { @@ -39,13 +41,19 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, "Test send across hosts", "[mpi]") // Init worlds MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); - remoteWorld.initialiseFromMsg(msg); faabric::util::setMockMode(false); - // Send a message that should get sent to this host - remoteWorld.send( - rankB, rankA, BYTES(messageData.data()), MPI_INT, messageData.size()); - usleep(1000 * 100); + std::thread senderThread([this, rankA, rankB] { + std::vector messageData = { 0, 1, 2 }; + + remoteWorld.initialiseFromMsg(msg); + + // Send a message that should get sent to this host + remoteWorld.send( + rankB, rankA, BYTES(messageData.data()), MPI_INT, messageData.size()); + usleep(1000 * 500); + remoteWorld.destroy(); + }); SECTION("Check recv") { @@ -64,8 +72,8 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, "Test send across hosts", "[mpi]") } // Destroy worlds + senderThread.join(); localWorld.destroy(); - remoteWorld.destroy(); } TEST_CASE_METHOD(RemoteMpiTestFixture,