Skip to content

Commit

Permalink
reusing sockets for rank-to-rank communication
Browse files Browse the repository at this point in the history
  • Loading branch information
csegarragonz committed Jun 9, 2021
1 parent a283e83 commit 149105c
Show file tree
Hide file tree
Showing 6 changed files with 172 additions and 68 deletions.
30 changes: 17 additions & 13 deletions include/faabric/scheduler/MpiWorld.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
#include <faabric/scheduler/FunctionCallClient.h>
#include <faabric/scheduler/InMemoryMessageQueue.h>
#include <faabric/scheduler/MpiThreadPool.h>
#include <faabric/state/StateKeyValue.h>
#include <faabric/transport/MpiMessageEndpoint.h>
#include <faabric/util/logging.h>
#include <faabric/util/timing.h>

#include <atomic>
#include <thread>
Expand All @@ -15,8 +17,6 @@ namespace faabric::scheduler {
typedef faabric::util::Queue<std::shared_ptr<faabric::MPIMessage>>
InMemoryMpiQueue;

std::string getWorldStateKey(int worldId);

class MpiWorld
{
public:
Expand All @@ -43,8 +43,6 @@ class MpiWorld

void shutdownThreadPool();

void enqueueMessage(faabric::MPIMessage& msg);

void getCartesianRank(int rank,
int maxDims,
const int* dims,
Expand All @@ -59,8 +57,6 @@ class MpiWorld
int* source,
int* destination);

int getMpiPort(int sendRank, int recvRank);

void send(int sendRank,
int recvRank,
const uint8_t* buffer,
Expand Down Expand Up @@ -202,20 +198,28 @@ class MpiWorld
std::string user;
std::string function;

std::shared_ptr<state::StateKeyValue> stateKV;
std::vector<std::string> rankHosts;

std::vector<std::shared_ptr<InMemoryMpiQueue>> localQueues;

std::shared_ptr<faabric::scheduler::MpiAsyncThreadPool> threadPool;
int getMpiThreadPoolSize();

std::vector<int> cartProcsPerDim;

void closeThreadLocalClients();
/* MPI internal messaging layer */

// Track at which host each rank lives
std::vector<std::string> rankHosts;
int getIndexForRanks(int sendRank, int recvRank);

// In-memory queues for local messaging
std::vector<std::shared_ptr<InMemoryMpiQueue>> localQueues;
void initLocalQueues();

// Rank-to-rank sockets for remote messaging
int getMpiPort(int sendRank, int recvRank);
void sendRemoteMpiMessage(int sendRank,
int recvRank,
const std::shared_ptr<faabric::MPIMessage>& msg);
std::shared_ptr<faabric::MPIMessage> recvRemoteMpiMessage(int sendRank,
int recvRank);
void closeMpiMessageEndpoints();
};
}
25 changes: 21 additions & 4 deletions include/faabric/transport/MpiMessageEndpoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,31 @@
#include <faabric/transport/macros.h>

namespace faabric::transport {
/* This two abstract methods are used to broadcast the host-rank mapping at
* initialisation time.
*/
faabric::MpiHostsToRanksMessage recvMpiHostRankMsg();

void sendMpiHostRankMsg(const std::string& hostIn,
const faabric::MpiHostsToRanksMessage msg);

void sendMpiMessage(const std::string& hostIn,
int portIn,
const std::shared_ptr<faabric::MPIMessage> msg);
/* This class abstracts the notion of a communication channel between two MPI
* ranks. There will always be one rank local to this host, and one remote.
* Note that the port is unique per (user, function, sendRank, recvRank) tuple.
*/
class MpiMessageEndpoint
{
public:
MpiMessageEndpoint(const std::string& hostIn, int portIn);

std::shared_ptr<faabric::MPIMessage> recvMpiMessage(int portIn);
void sendMpiMessage(const std::shared_ptr<faabric::MPIMessage>& msg);

std::shared_ptr<faabric::MPIMessage> recvMpiMessage();

void close();

private:
SendMessageEndpoint sendMessageEndpoint;
RecvMessageEndpoint recvMessageEndpoint;
};
}
121 changes: 91 additions & 30 deletions src/scheduler/MpiWorld.cpp
Original file line number Diff line number Diff line change
@@ -1,20 +1,14 @@
#include <faabric/mpi/mpi.h>

#include <faabric/scheduler/MpiThreadPool.h>
#include <faabric/scheduler/MpiWorld.h>
#include <faabric/scheduler/Scheduler.h>
#include <faabric/transport/MpiMessageEndpoint.h>
#include <faabric/util/environment.h>
#include <faabric/util/func.h>
#include <faabric/util/gids.h>
#include <faabric/util/logging.h>
#include <faabric/util/macros.h>
#include <faabric/util/timing.h>

static thread_local std::unordered_map<int, std::future<void>> futureMap;
static thread_local std::unordered_map<std::string,
faabric::scheduler::FunctionCallClient>
functionCallClients;
static thread_local std::vector<
std::unique_ptr<faabric::transport::MpiMessageEndpoint>>
mpiMessageEndpoints;

namespace faabric::scheduler {
MpiWorld::MpiWorld()
Expand All @@ -26,22 +20,85 @@ MpiWorld::MpiWorld()
, cartProcsPerDim(2)
{}

/*
faabric::scheduler::FunctionCallClient& MpiWorld::getFunctionCallClient(
const std::string& otherHost)
void MpiWorld::sendRemoteMpiMessage(
int sendRank,
int recvRank,
const std::shared_ptr<faabric::MPIMessage>& msg)
{
// Assert the ranks are sane
assert(0 <= sendRank && sendRank < size);
assert(0 <= recvRank && recvRank < size);

// Initialise the endpoint vector if not initialised
if (mpiMessageEndpoints.size() == 0) {
for (int i = 0; i < size * size; i++) {
mpiMessageEndpoints.emplace_back(nullptr);
}
}

// Get the index for the rank-host pair
int index = getIndexForRanks(sendRank, recvRank);
assert(index >= 0 && index < size * size);

// Lazily initialise send endpoints
if (mpiMessageEndpoints[index] == nullptr) {
// Get host for recv rank
std::string host = getHostForRank(recvRank);
assert(!host.empty());
assert(host != thisHost);

// Get port for send-recv pair
int port = getMpiPort(sendRank, recvRank);
// TODO - finer-grained checking when multi-tenant port scheme
assert(port > 0);

// Create MPI message endpoint
mpiMessageEndpoints.emplace(
mpiMessageEndpoints.begin() + index,
std::make_unique<faabric::transport::MpiMessageEndpoint>(host, port));
}

mpiMessageEndpoints[index]->sendMpiMessage(msg);
}

std::shared_ptr<faabric::MPIMessage> MpiWorld::recvRemoteMpiMessage(
int sendRank,
int recvRank)
{
auto it = functionCallClients.find(otherHost);
if (it == functionCallClients.end()) {
// The second argument is forwarded to the client's constructor
auto _it = functionCallClients.try_emplace(otherHost, otherHost);
if (!_it.second) {
throw std::runtime_error("Error inserting remote endpoint");
// Assert the ranks are sane
assert(0 <= sendRank && sendRank < size);
assert(0 <= recvRank && recvRank < size);

// Initialise the endpoint vector if not initialised
if (mpiMessageEndpoints.size() == 0) {
for (int i = 0; i < size * size; i++) {
mpiMessageEndpoints.emplace_back(nullptr);
}
it = _it.first;
}
return it->second;

// Get the index for the rank-host pair
int index = getIndexForRanks(sendRank, recvRank);
assert(index >= 0 && index < size * size);

// Lazily initialise send endpoints
if (mpiMessageEndpoints[index] == nullptr) {
// Get host for recv rank
std::string host = getHostForRank(sendRank);
assert(!host.empty());
assert(host != thisHost);

// Get port for send-recv pair
int port = getMpiPort(sendRank, recvRank);
// TODO - finer-grained checking when multi-tenant port scheme
assert(port > 0);

mpiMessageEndpoints.emplace(
mpiMessageEndpoints.begin() + index,
std::make_unique<faabric::transport::MpiMessageEndpoint>(host, port));
}

return mpiMessageEndpoints[index]->recvMpiMessage();
}
*/

int MpiWorld::getMpiThreadPoolSize()
{
Expand Down Expand Up @@ -145,25 +202,29 @@ void MpiWorld::shutdownThreadPool()
std::promise<void> p;
threadPool->getMpiReqQueue()->enqueue(
std::make_tuple(QUEUE_SHUTDOWN,
std::bind(&MpiWorld::closeThreadLocalClients, this),
std::bind(&MpiWorld::closeMpiMessageEndpoints, this),
std::move(p)));
}

threadPool->shutdown();

// Lastly clean the main thread as well
closeThreadLocalClients();
closeMpiMessageEndpoints();
}

// TODO - remove
// Clear thread local state
void MpiWorld::closeThreadLocalClients()
void MpiWorld::closeMpiMessageEndpoints()
{
// Close all open sockets
for (auto& s : functionCallClients) {
s.second.close();
if (mpiMessageEndpoints.size() > 0) {
// Close all open sockets
for (auto& e : mpiMessageEndpoints) {
if (e != nullptr) {
e->close();
}
}
mpiMessageEndpoints.clear();
}
functionCallClients.clear();
}

void MpiWorld::initialiseFromMsg(const faabric::Message& msg, bool forceLocal)
Expand Down Expand Up @@ -444,7 +505,7 @@ void MpiWorld::send(int sendRank,
getLocalQueue(sendRank, recvRank)->enqueue(std::move(m));
} else {
logger->trace("MPI - send remote {} -> {}", sendRank, recvRank);
faabric::transport::sendMpiMessage(otherHost, getMpiPort(sendRank, recvRank), m);
sendRemoteMpiMessage(sendRank, recvRank, m);
}
}

Expand All @@ -468,7 +529,7 @@ void MpiWorld::recv(int sendRank,
m = getLocalQueue(sendRank, recvRank)->dequeue();
} else {
logger->trace("MPI - recv remote {} -> {}", sendRank, recvRank);
m = faabric::transport::recvMpiMessage(getMpiPort(sendRank, recvRank));
m = recvRemoteMpiMessage(sendRank, recvRank);
}
assert(m != nullptr);

Expand Down
42 changes: 28 additions & 14 deletions src/transport/MpiMessageEndpoint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,35 +28,49 @@ void sendMpiHostRankMsg(const std::string& hostIn,
}
}

MpiMessageEndpoint::MpiMessageEndpoint(const std::string& hostIn, int portIn)
: sendMessageEndpoint(hostIn, portIn)
, recvMessageEndpoint(portIn)
{}

// TODO - reuse clients!!
void sendMpiMessage(const std::string& hostIn, int portIn,
const std::shared_ptr<faabric::MPIMessage> msg)
void MpiMessageEndpoint::sendMpiMessage(
const std::shared_ptr<faabric::MPIMessage>& msg)
{
// TODO - is this lazy init very expensive?
if (sendMessageEndpoint.socket == nullptr) {
sendMessageEndpoint.open(faabric::transport::getGlobalMessageContext());
}

size_t msgSize = msg->ByteSizeLong();
{
uint8_t sMsg[msgSize];
if (!msg->SerializeToArray(sMsg, msgSize)) {
throw std::runtime_error("Error serialising message");
}
SendMessageEndpoint endpoint(hostIn, portIn);
endpoint.open(getGlobalMessageContext());
endpoint.send(sMsg, msgSize, false);
endpoint.close();
sendMessageEndpoint.send(sMsg, msgSize, false);
}
}

std::shared_ptr<faabric::MPIMessage> recvMpiMessage(int portIn)
std::shared_ptr<faabric::MPIMessage> MpiMessageEndpoint::recvMpiMessage()
{
RecvMessageEndpoint endpoint(portIn);
endpoint.open(getGlobalMessageContext());
if (recvMessageEndpoint.socket == nullptr) {
recvMessageEndpoint.open(faabric::transport::getGlobalMessageContext());
}

// TODO - preempt data size somehow
Message m = endpoint.recv();
Message m = recvMessageEndpoint.recv();
PARSE_MSG(faabric::MPIMessage, m.data(), m.size());
// Note - This may be very slow as we poll until unbound
endpoint.close();

// TODO - send normal message, not shared_ptr
return std::make_shared<faabric::MPIMessage>(msg);
}

void MpiMessageEndpoint::close()
{
if (sendMessageEndpoint.socket != nullptr) {
sendMessageEndpoint.close();
}
if (recvMessageEndpoint.socket != nullptr) {
recvMessageEndpoint.close();
}
}
}
2 changes: 1 addition & 1 deletion tests/test/scheduler/test_mpi_world.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#include <catch.hpp>

#include <faabric/mpi/mpi.h>
#include <faabric/scheduler/Scheduler.h>
#include <faabric/scheduler/MpiWorld.h>
#include <faabric/scheduler/Scheduler.h>
#include <faabric/util/bytes.h>
#include <faabric/util/macros.h>
#include <faabric/util/random.h>
Expand Down
Loading

0 comments on commit 149105c

Please sign in to comment.