diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index bc80693a9..58b8536ba 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -68,7 +68,7 @@ jobs: run: inv dev.cc faabric_tests # --- Tests --- - name: "Run tests" - run: LOG_LEVEL=trace ./bin/faabric_tests + run: ./bin/faabric_tests working-directory: /build/faabric/static dist-tests: diff --git a/include/faabric/scheduler/MpiWorld.h b/include/faabric/scheduler/MpiWorld.h index 256d23f97..fc91d4bbb 100644 --- a/include/faabric/scheduler/MpiWorld.h +++ b/include/faabric/scheduler/MpiWorld.h @@ -27,7 +27,7 @@ class MpiWorld std::string getHostForRank(int rank); - void setAllRankHosts(const faabric::MpiHostsToRanksMessage& msg); + void setAllRankHostsPorts(const faabric::MpiHostsToRanksMessage& msg); std::string getUser(); @@ -205,8 +205,11 @@ class MpiWorld void initLocalQueues(); // Rank-to-rank sockets for remote messaging - void initRemoteMpiEndpoint(int sendRank, int recvRank); - int getMpiPort(int sendRank, int recvRank); + std::vector basePorts; + std::vector initLocalBasePorts( + const std::vector& executedAt); + void initRemoteMpiEndpoint(int localRank, int remoteRank); + std::pair getPortForRanks(int localRank, int remoteRank); void sendRemoteMpiMessage(int sendRank, int recvRank, const std::shared_ptr& msg); diff --git a/include/faabric/transport/MessageEndpoint.h b/include/faabric/transport/MessageEndpoint.h index 658448455..4e463e07a 100644 --- a/include/faabric/transport/MessageEndpoint.h +++ b/include/faabric/transport/MessageEndpoint.h @@ -83,8 +83,6 @@ class RecvMessageEndpoint : public MessageEndpoint public: RecvMessageEndpoint(int portIn); - RecvMessageEndpoint(int portIn, const std::string& overrideHost); - void open(MessageContext& context); void close(); diff --git a/include/faabric/transport/MpiMessageEndpoint.h b/include/faabric/transport/MpiMessageEndpoint.h index 79c402a12..991331bc5 100644 --- a/include/faabric/transport/MpiMessageEndpoint.h +++ b/include/faabric/transport/MpiMessageEndpoint.h @@ -23,9 +23,7 @@ class MpiMessageEndpoint public: MpiMessageEndpoint(const std::string& hostIn, int portIn); - MpiMessageEndpoint(const std::string& hostIn, - int portIn, - const std::string& overrideRecvHost); + MpiMessageEndpoint(const std::string& hostIn, int sendPort, int recvPort); void sendMpiMessage(const std::shared_ptr& msg); diff --git a/src/proto/faabric.proto b/src/proto/faabric.proto index 80103ca1d..a44a4668c 100644 --- a/src/proto/faabric.proto +++ b/src/proto/faabric.proto @@ -88,6 +88,7 @@ message MPIMessage { // fields. message MpiHostsToRanksMessage { repeated string hosts = 1; + repeated int32 basePorts = 2; } message Message { diff --git a/src/scheduler/MpiWorld.cpp b/src/scheduler/MpiWorld.cpp index f9716001a..5f9e7bfc1 100644 --- a/src/scheduler/MpiWorld.cpp +++ b/src/scheduler/MpiWorld.cpp @@ -26,8 +26,12 @@ MpiWorld::MpiWorld() , cartProcsPerDim(2) {} -void MpiWorld::initRemoteMpiEndpoint(int sendRank, int recvRank) +void MpiWorld::initRemoteMpiEndpoint(int localRank, int remoteRank) { + SPDLOG_TRACE("Open MPI endpoint between ranks (local-remote) {} - {}", + localRank, + remoteRank); + // Resize the message endpoint vector and initialise to null. Note that we // allocate size x size slots to cover all possible (sendRank, recvRank) // pairs @@ -37,42 +41,20 @@ void MpiWorld::initRemoteMpiEndpoint(int sendRank, int recvRank) } } - // Get host for recv rank - std::string otherHost; - std::string recvHost = getHostForRank(recvRank); - std::string sendHost = getHostForRank(sendRank); - if (recvHost == sendHost) { - SPDLOG_ERROR( - "Send and recv ranks in the same host: SEND {}, RECV{} in {}", - sendRank, - recvRank, - sendHost); - throw std::runtime_error("Send and recv ranks in the same host"); - } else if (recvHost == thisHost) { - otherHost = sendHost; - } else if (sendHost == thisHost) { - otherHost = recvHost; - } else { - SPDLOG_ERROR("Send and recv ranks correspond to remote hosts: SEND {} " - "in {}, RECV {} in {}", - sendRank, - sendHost, - recvRank, - recvHost); - throw std::runtime_error("Send and recv ranks in remote hosts"); - } + // Get host for remote rank + std::string otherHost = getHostForRank(remoteRank); // Get the index for the rank-host pair - int index = getIndexForRanks(sendRank, recvRank); + int index = getIndexForRanks(localRank, remoteRank); // Get port for send-recv pair - int port = getMpiPort(sendRank, recvRank); + std::pair sendRecvPorts = getPortForRanks(localRank, remoteRank); // Create MPI message endpoint mpiMessageEndpoints.emplace( mpiMessageEndpoints.begin() + index, std::make_unique( - otherHost, port, thisHost)); + otherHost, sendRecvPorts.first, sendRecvPorts.second)); } void MpiWorld::sendRemoteMpiMessage( @@ -81,6 +63,8 @@ void MpiWorld::sendRemoteMpiMessage( const std::shared_ptr& msg) { // Get the index for the rank-host pair + // Note - message endpoints are identified by a (localRank, remoteRank) + // pair, not a (sendRank, recvRank) one int index = getIndexForRanks(sendRank, recvRank); if (mpiMessageEndpoints.empty() || mpiMessageEndpoints[index] == nullptr) { @@ -95,10 +79,12 @@ std::shared_ptr MpiWorld::recvRemoteMpiMessage( int recvRank) { // Get the index for the rank-host pair - int index = getIndexForRanks(sendRank, recvRank); + // Note - message endpoints are identified by a (localRank, remoteRank) + // pair, not a (sendRank, recvRank) one + int index = getIndexForRanks(recvRank, sendRank); if (mpiMessageEndpoints.empty() || mpiMessageEndpoints[index] == nullptr) { - initRemoteMpiEndpoint(sendRank, recvRank); + initRemoteMpiEndpoint(recvRank, sendRank); } return mpiMessageEndpoints[index]->recvMpiMessage(); @@ -160,7 +146,14 @@ void MpiWorld::create(const faabric::Message& call, int newId, int newSize) // Register hosts to rank mappings on this host faabric::MpiHostsToRanksMessage hostRankMsg; *hostRankMsg.mutable_hosts() = { executedAt.begin(), executedAt.end() }; - setAllRankHosts(hostRankMsg); + + // Prepare the base port for each rank + std::vector basePortForRank = initLocalBasePorts(executedAt); + *hostRankMsg.mutable_baseports() = { basePortForRank.begin(), + basePortForRank.end() }; + + // Register hosts to rank mappins on this host + setAllRankHostsPorts(hostRankMsg); // Set up a list of hosts to broadcast to (excluding this host) std::set hosts(executedAt.begin(), executedAt.end()); @@ -249,7 +242,7 @@ void MpiWorld::initialiseFromMsg(const faabric::Message& msg, bool forceLocal) // Block until we receive faabric::MpiHostsToRanksMessage hostRankMsg = faabric::transport::recvMpiHostRankMsg(); - setAllRankHosts(hostRankMsg); + setAllRankHostsPorts(hostRankMsg); // Initialise the memory queues for message reception initLocalQueues(); @@ -269,16 +262,51 @@ std::string MpiWorld::getHostForRank(int rank) return host; } +// Returns a pair (sendPort, recvPort) +// To assign the send and recv ports, we follow a protocol establishing: +// 1) Port range (offset) corresponding to the world that receives +// 2) Within a world's port range, port corresponding to the outcome of +// getIndexForRanks(localRank, remoteRank) Where local and remote are +// relative to the world whose port range we are in +std::pair MpiWorld::getPortForRanks(int localRank, int remoteRank) +{ + std::pair sendRecvPortPair; + + // Get base port for local and remote worlds + int localBasePort = basePorts[localRank]; + int remoteBasePort = basePorts[remoteRank]; + assert(localBasePort != remoteBasePort); + + // Assign send port + // 1) Port range corresponding to remote world, as they are receiving + // 2) Index switching localRank and remoteRank, as remote rank is "local" + // to the remote world + sendRecvPortPair.first = + remoteBasePort + getIndexForRanks(remoteRank, localRank); + + // Assign recv port + // 1) Port range corresponding to our world, as we are the one's receiving + // 2) Port using our local rank as `localRank`, as we are in the local + // offset + sendRecvPortPair.second = + localBasePort + getIndexForRanks(localRank, remoteRank); + + return sendRecvPortPair; +} + // Prepare the host-rank map with a vector containing _all_ ranks // Note - this method should be called by only one rank. This is enforced in // the world registry -void MpiWorld::setAllRankHosts(const faabric::MpiHostsToRanksMessage& msg) +void MpiWorld::setAllRankHostsPorts(const faabric::MpiHostsToRanksMessage& msg) { // Assert we are only setting the values once assert(rankHosts.size() == 0); + assert(basePorts.size() == 0); assert(msg.hosts().size() == size); + assert(msg.baseports().size() == size); rankHosts = { msg.hosts().begin(), msg.hosts().end() }; + basePorts = { msg.baseports().begin(), msg.baseports().end() }; } void MpiWorld::getCartesianRank(int rank, @@ -448,15 +476,6 @@ int MpiWorld::irecv(int sendRank, return requestId; } -int MpiWorld::getMpiPort(int sendRank, int recvRank) -{ - // TODO - get port in a multi-tenant-safe manner - int basePort = MPI_PORT; - int rankOffset = sendRank * size + recvRank; - - return basePort + rankOffset; -} - void MpiWorld::send(int sendRank, int recvRank, const uint8_t* buffer, @@ -1204,6 +1223,34 @@ void MpiWorld::initLocalQueues() } } +// Here we rely on the scheduler returning a list of hosts where equal +// hosts are always contiguous with the exception of the master host +// (thisHost) which may appear repeated at the end if the system is +// overloaded. +std::vector MpiWorld::initLocalBasePorts( + const std::vector& executedAt) +{ + std::vector basePortForRank; + basePortForRank.reserve(size); + + std::string lastHost = thisHost; + int lastPort = MPI_PORT; + for (const auto& host : executedAt) { + if (host == thisHost) { + basePortForRank.push_back(MPI_PORT); + } else if (host == lastHost) { + basePortForRank.push_back(lastPort); + } else { + lastHost = host; + lastPort += size * size; + basePortForRank.push_back(lastPort); + } + } + + assert(basePortForRank.size() == size); + return basePortForRank; +} + std::shared_ptr MpiWorld::recvBatchReturnLast(int sendRank, int recvRank, int batchSize) { diff --git a/src/transport/MessageEndpoint.cpp b/src/transport/MessageEndpoint.cpp index 5f87412b3..1cb71a871 100644 --- a/src/transport/MessageEndpoint.cpp +++ b/src/transport/MessageEndpoint.cpp @@ -279,11 +279,6 @@ RecvMessageEndpoint::RecvMessageEndpoint(int portIn) : MessageEndpoint(ANY_HOST, portIn) {} -RecvMessageEndpoint::RecvMessageEndpoint(int portIn, - const std::string& overrideHost) - : MessageEndpoint(overrideHost, portIn) -{} - void RecvMessageEndpoint::open(MessageContext& context) { SPDLOG_TRACE( diff --git a/src/transport/MpiMessageEndpoint.cpp b/src/transport/MpiMessageEndpoint.cpp index 4aebbefcd..67be90801 100644 --- a/src/transport/MpiMessageEndpoint.cpp +++ b/src/transport/MpiMessageEndpoint.cpp @@ -37,10 +37,10 @@ MpiMessageEndpoint::MpiMessageEndpoint(const std::string& hostIn, int portIn) } MpiMessageEndpoint::MpiMessageEndpoint(const std::string& hostIn, - int portIn, - const std::string& overrideRecvHost) - : sendMessageEndpoint(hostIn, portIn) - , recvMessageEndpoint(portIn, overrideRecvHost) + int sendPort, + int recvPort) + : sendMessageEndpoint(hostIn, sendPort) + , recvMessageEndpoint(recvPort) { sendMessageEndpoint.open(faabric::transport::getGlobalMessageContext()); recvMessageEndpoint.open(faabric::transport::getGlobalMessageContext()); diff --git a/tests/test/transport/test_mpi_message_endpoint.cpp b/tests/test/transport/test_mpi_message_endpoint.cpp index a5b6a4ced..a27ae9bf7 100644 --- a/tests/test/transport/test_mpi_message_endpoint.cpp +++ b/tests/test/transport/test_mpi_message_endpoint.cpp @@ -31,8 +31,8 @@ TEST_CASE_METHOD(MessageContextFixture, "[transport]") { std::string thisHost = faabric::util::getSystemConfig().endpointHost; - MpiMessageEndpoint sendEndpoint(LOCALHOST, 9999, thisHost); - MpiMessageEndpoint recvEndpoint(thisHost, 9999, LOCALHOST); + MpiMessageEndpoint sendEndpoint(LOCALHOST, 9999, 9998); + MpiMessageEndpoint recvEndpoint(thisHost, 9998, 9999); std::shared_ptr expected = std::make_shared();