Skip to content

Commit

Permalink
switching to per-world port range
Browse files Browse the repository at this point in the history
  • Loading branch information
csegarragonz committed Jun 16, 2021
1 parent 0a7ab7c commit 88bcb7d
Show file tree
Hide file tree
Showing 9 changed files with 104 additions and 62 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ jobs:
run: inv dev.cc faabric_tests
# --- Tests ---
- name: "Run tests"
run: LOG_LEVEL=trace ./bin/faabric_tests
run: ./bin/faabric_tests
working-directory: /build/faabric/static

dist-tests:
Expand Down
9 changes: 6 additions & 3 deletions include/faabric/scheduler/MpiWorld.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class MpiWorld

std::string getHostForRank(int rank);

void setAllRankHosts(const faabric::MpiHostsToRanksMessage& msg);
void setAllRankHostsPorts(const faabric::MpiHostsToRanksMessage& msg);

std::string getUser();

Expand Down Expand Up @@ -205,8 +205,11 @@ class MpiWorld
void initLocalQueues();

// Rank-to-rank sockets for remote messaging
void initRemoteMpiEndpoint(int sendRank, int recvRank);
int getMpiPort(int sendRank, int recvRank);
std::vector<int> basePorts;
std::vector<int> initLocalBasePorts(
const std::vector<std::string>& executedAt);
void initRemoteMpiEndpoint(int localRank, int remoteRank);
std::pair<int, int> getPortForRanks(int localRank, int remoteRank);
void sendRemoteMpiMessage(int sendRank,
int recvRank,
const std::shared_ptr<faabric::MPIMessage>& msg);
Expand Down
2 changes: 0 additions & 2 deletions include/faabric/transport/MessageEndpoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,6 @@ class RecvMessageEndpoint : public MessageEndpoint
public:
RecvMessageEndpoint(int portIn);

RecvMessageEndpoint(int portIn, const std::string& overrideHost);

void open(MessageContext& context);

void close();
Expand Down
4 changes: 1 addition & 3 deletions include/faabric/transport/MpiMessageEndpoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ class MpiMessageEndpoint
public:
MpiMessageEndpoint(const std::string& hostIn, int portIn);

MpiMessageEndpoint(const std::string& hostIn,
int portIn,
const std::string& overrideRecvHost);
MpiMessageEndpoint(const std::string& hostIn, int sendPort, int recvPort);

void sendMpiMessage(const std::shared_ptr<faabric::MPIMessage>& msg);

Expand Down
1 change: 1 addition & 0 deletions src/proto/faabric.proto
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ message MPIMessage {
// fields.
message MpiHostsToRanksMessage {
repeated string hosts = 1;
repeated int32 basePorts = 2;
}

message Message {
Expand Down
131 changes: 89 additions & 42 deletions src/scheduler/MpiWorld.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,12 @@ MpiWorld::MpiWorld()
, cartProcsPerDim(2)
{}

void MpiWorld::initRemoteMpiEndpoint(int sendRank, int recvRank)
void MpiWorld::initRemoteMpiEndpoint(int localRank, int remoteRank)
{
SPDLOG_TRACE("Open MPI endpoint between ranks (local-remote) {} - {}",
localRank,
remoteRank);

// Resize the message endpoint vector and initialise to null. Note that we
// allocate size x size slots to cover all possible (sendRank, recvRank)
// pairs
Expand All @@ -37,42 +41,20 @@ void MpiWorld::initRemoteMpiEndpoint(int sendRank, int recvRank)
}
}

// Get host for recv rank
std::string otherHost;
std::string recvHost = getHostForRank(recvRank);
std::string sendHost = getHostForRank(sendRank);
if (recvHost == sendHost) {
SPDLOG_ERROR(
"Send and recv ranks in the same host: SEND {}, RECV{} in {}",
sendRank,
recvRank,
sendHost);
throw std::runtime_error("Send and recv ranks in the same host");
} else if (recvHost == thisHost) {
otherHost = sendHost;
} else if (sendHost == thisHost) {
otherHost = recvHost;
} else {
SPDLOG_ERROR("Send and recv ranks correspond to remote hosts: SEND {} "
"in {}, RECV {} in {}",
sendRank,
sendHost,
recvRank,
recvHost);
throw std::runtime_error("Send and recv ranks in remote hosts");
}
// Get host for remote rank
std::string otherHost = getHostForRank(remoteRank);

// Get the index for the rank-host pair
int index = getIndexForRanks(sendRank, recvRank);
int index = getIndexForRanks(localRank, remoteRank);

// Get port for send-recv pair
int port = getMpiPort(sendRank, recvRank);
std::pair<int, int> sendRecvPorts = getPortForRanks(localRank, remoteRank);

// Create MPI message endpoint
mpiMessageEndpoints.emplace(
mpiMessageEndpoints.begin() + index,
std::make_unique<faabric::transport::MpiMessageEndpoint>(
otherHost, port, thisHost));
otherHost, sendRecvPorts.first, sendRecvPorts.second));
}

void MpiWorld::sendRemoteMpiMessage(
Expand All @@ -81,6 +63,8 @@ void MpiWorld::sendRemoteMpiMessage(
const std::shared_ptr<faabric::MPIMessage>& msg)
{
// Get the index for the rank-host pair
// Note - message endpoints are identified by a (localRank, remoteRank)
// pair, not a (sendRank, recvRank) one
int index = getIndexForRanks(sendRank, recvRank);

if (mpiMessageEndpoints.empty() || mpiMessageEndpoints[index] == nullptr) {
Expand All @@ -95,10 +79,12 @@ std::shared_ptr<faabric::MPIMessage> MpiWorld::recvRemoteMpiMessage(
int recvRank)
{
// Get the index for the rank-host pair
int index = getIndexForRanks(sendRank, recvRank);
// Note - message endpoints are identified by a (localRank, remoteRank)
// pair, not a (sendRank, recvRank) one
int index = getIndexForRanks(recvRank, sendRank);

if (mpiMessageEndpoints.empty() || mpiMessageEndpoints[index] == nullptr) {
initRemoteMpiEndpoint(sendRank, recvRank);
initRemoteMpiEndpoint(recvRank, sendRank);
}

return mpiMessageEndpoints[index]->recvMpiMessage();
Expand Down Expand Up @@ -160,7 +146,14 @@ void MpiWorld::create(const faabric::Message& call, int newId, int newSize)
// Register hosts to rank mappings on this host
faabric::MpiHostsToRanksMessage hostRankMsg;
*hostRankMsg.mutable_hosts() = { executedAt.begin(), executedAt.end() };
setAllRankHosts(hostRankMsg);

// Prepare the base port for each rank
std::vector<int> basePortForRank = initLocalBasePorts(executedAt);
*hostRankMsg.mutable_baseports() = { basePortForRank.begin(),
basePortForRank.end() };

// Register hosts to rank mappins on this host
setAllRankHostsPorts(hostRankMsg);

// Set up a list of hosts to broadcast to (excluding this host)
std::set<std::string> hosts(executedAt.begin(), executedAt.end());
Expand Down Expand Up @@ -249,7 +242,7 @@ void MpiWorld::initialiseFromMsg(const faabric::Message& msg, bool forceLocal)
// Block until we receive
faabric::MpiHostsToRanksMessage hostRankMsg =
faabric::transport::recvMpiHostRankMsg();
setAllRankHosts(hostRankMsg);
setAllRankHostsPorts(hostRankMsg);

// Initialise the memory queues for message reception
initLocalQueues();
Expand All @@ -269,16 +262,51 @@ std::string MpiWorld::getHostForRank(int rank)
return host;
}

// Returns a pair (sendPort, recvPort)
// To assign the send and recv ports, we follow a protocol establishing:
// 1) Port range (offset) corresponding to the world that receives
// 2) Within a world's port range, port corresponding to the outcome of
// getIndexForRanks(localRank, remoteRank) Where local and remote are
// relative to the world whose port range we are in
std::pair<int, int> MpiWorld::getPortForRanks(int localRank, int remoteRank)
{
std::pair<int, int> sendRecvPortPair;

// Get base port for local and remote worlds
int localBasePort = basePorts[localRank];
int remoteBasePort = basePorts[remoteRank];
assert(localBasePort != remoteBasePort);

// Assign send port
// 1) Port range corresponding to remote world, as they are receiving
// 2) Index switching localRank and remoteRank, as remote rank is "local"
// to the remote world
sendRecvPortPair.first =
remoteBasePort + getIndexForRanks(remoteRank, localRank);

// Assign recv port
// 1) Port range corresponding to our world, as we are the one's receiving
// 2) Port using our local rank as `localRank`, as we are in the local
// offset
sendRecvPortPair.second =
localBasePort + getIndexForRanks(localRank, remoteRank);

return sendRecvPortPair;
}

// Prepare the host-rank map with a vector containing _all_ ranks
// Note - this method should be called by only one rank. This is enforced in
// the world registry
void MpiWorld::setAllRankHosts(const faabric::MpiHostsToRanksMessage& msg)
void MpiWorld::setAllRankHostsPorts(const faabric::MpiHostsToRanksMessage& msg)
{
// Assert we are only setting the values once
assert(rankHosts.size() == 0);
assert(basePorts.size() == 0);

assert(msg.hosts().size() == size);
assert(msg.baseports().size() == size);
rankHosts = { msg.hosts().begin(), msg.hosts().end() };
basePorts = { msg.baseports().begin(), msg.baseports().end() };
}

void MpiWorld::getCartesianRank(int rank,
Expand Down Expand Up @@ -448,15 +476,6 @@ int MpiWorld::irecv(int sendRank,
return requestId;
}

int MpiWorld::getMpiPort(int sendRank, int recvRank)
{
// TODO - get port in a multi-tenant-safe manner
int basePort = MPI_PORT;
int rankOffset = sendRank * size + recvRank;

return basePort + rankOffset;
}

void MpiWorld::send(int sendRank,
int recvRank,
const uint8_t* buffer,
Expand Down Expand Up @@ -1204,6 +1223,34 @@ void MpiWorld::initLocalQueues()
}
}

// Here we rely on the scheduler returning a list of hosts where equal
// hosts are always contiguous with the exception of the master host
// (thisHost) which may appear repeated at the end if the system is
// overloaded.
std::vector<int> MpiWorld::initLocalBasePorts(
const std::vector<std::string>& executedAt)
{
std::vector<int> basePortForRank;
basePortForRank.reserve(size);

std::string lastHost = thisHost;
int lastPort = MPI_PORT;
for (const auto& host : executedAt) {
if (host == thisHost) {
basePortForRank.push_back(MPI_PORT);
} else if (host == lastHost) {
basePortForRank.push_back(lastPort);
} else {
lastHost = host;
lastPort += size * size;
basePortForRank.push_back(lastPort);
}
}

assert(basePortForRank.size() == size);
return basePortForRank;
}

std::shared_ptr<faabric::MPIMessage>
MpiWorld::recvBatchReturnLast(int sendRank, int recvRank, int batchSize)
{
Expand Down
5 changes: 0 additions & 5 deletions src/transport/MessageEndpoint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,11 +279,6 @@ RecvMessageEndpoint::RecvMessageEndpoint(int portIn)
: MessageEndpoint(ANY_HOST, portIn)
{}

RecvMessageEndpoint::RecvMessageEndpoint(int portIn,
const std::string& overrideHost)
: MessageEndpoint(overrideHost, portIn)
{}

void RecvMessageEndpoint::open(MessageContext& context)
{
SPDLOG_TRACE(
Expand Down
8 changes: 4 additions & 4 deletions src/transport/MpiMessageEndpoint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ MpiMessageEndpoint::MpiMessageEndpoint(const std::string& hostIn, int portIn)
}

MpiMessageEndpoint::MpiMessageEndpoint(const std::string& hostIn,
int portIn,
const std::string& overrideRecvHost)
: sendMessageEndpoint(hostIn, portIn)
, recvMessageEndpoint(portIn, overrideRecvHost)
int sendPort,
int recvPort)
: sendMessageEndpoint(hostIn, sendPort)
, recvMessageEndpoint(recvPort)
{
sendMessageEndpoint.open(faabric::transport::getGlobalMessageContext());
recvMessageEndpoint.open(faabric::transport::getGlobalMessageContext());
Expand Down
4 changes: 2 additions & 2 deletions tests/test/transport/test_mpi_message_endpoint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ TEST_CASE_METHOD(MessageContextFixture,
"[transport]")
{
std::string thisHost = faabric::util::getSystemConfig().endpointHost;
MpiMessageEndpoint sendEndpoint(LOCALHOST, 9999, thisHost);
MpiMessageEndpoint recvEndpoint(thisHost, 9999, LOCALHOST);
MpiMessageEndpoint sendEndpoint(LOCALHOST, 9999, 9998);
MpiMessageEndpoint recvEndpoint(thisHost, 9998, 9999);

std::shared_ptr<faabric::MPIMessage> expected =
std::make_shared<faabric::MPIMessage>();
Expand Down

0 comments on commit 88bcb7d

Please sign in to comment.