Skip to content

Commit

Permalink
pr comments
Browse files Browse the repository at this point in the history
  • Loading branch information
csegarragonz committed Jun 14, 2021
1 parent 421234d commit 886d0f5
Show file tree
Hide file tree
Showing 10 changed files with 353 additions and 214 deletions.
3 changes: 3 additions & 0 deletions include/faabric/scheduler/MpiWorld.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,12 +212,15 @@ class MpiWorld
void initLocalQueues();

// Rank-to-rank sockets for remote messaging
void initRemoteMpiEndpoint(int sendRank, int recvRank);
int getMpiPort(int sendRank, int recvRank);
void sendRemoteMpiMessage(int sendRank,
int recvRank,
const std::shared_ptr<faabric::MPIMessage>& msg);
std::shared_ptr<faabric::MPIMessage> recvRemoteMpiMessage(int sendRank,
int recvRank);
void closeMpiMessageEndpoints();

void checkRanksRange(int sendRank, int recvRank);
};
}
2 changes: 2 additions & 0 deletions include/faabric/transport/MessageEndpoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ class RecvMessageEndpoint : public MessageEndpoint
public:
RecvMessageEndpoint(int portIn);

RecvMessageEndpoint(int portIn, const std::string& overrideHost);

void open(MessageContext& context);

void close();
Expand Down
4 changes: 4 additions & 0 deletions include/faabric/transport/MpiMessageEndpoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ class MpiMessageEndpoint
public:
MpiMessageEndpoint(const std::string& hostIn, int portIn);

MpiMessageEndpoint(const std::string& hostIn,
int portIn,
const std::string& overrideRecvHost);

void sendMpiMessage(const std::shared_ptr<faabric::MPIMessage>& msg);

std::shared_ptr<faabric::MPIMessage> recvMpiMessage();
Expand Down
146 changes: 85 additions & 61 deletions src/scheduler/MpiWorld.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,42 +19,65 @@ MpiWorld::MpiWorld()
, cartProcsPerDim(2)
{}

void MpiWorld::sendRemoteMpiMessage(
int sendRank,
int recvRank,
const std::shared_ptr<faabric::MPIMessage>& msg)
void MpiWorld::initRemoteMpiEndpoint(int sendRank, int recvRank)
{
// Assert the ranks are sane
assert(0 <= sendRank && sendRank < size);
assert(0 <= recvRank && recvRank < size);

// Initialise the endpoint vector if not initialised
if (mpiMessageEndpoints.size() == 0) {
// Resize the message endpoint vector and initialise to null. Note that we
// allocate size x size slots to cover all possible (sendRank, recvRank)
// pairs
if (mpiMessageEndpoints.empty()) {
for (int i = 0; i < size * size; i++) {
mpiMessageEndpoints.emplace_back(nullptr);
}
}

// Get host for recv rank
std::string otherHost;
std::string recvHost = getHostForRank(recvRank);
std::string sendHost = getHostForRank(sendRank);
if (recvHost == sendHost) {
SPDLOG_ERROR(
"Send and recv ranks in the same host: SEND {}, RECV{} in {}",
sendRank,
recvRank,
sendHost);
throw std::runtime_error("Send and recv ranks in the same host");
} else if (recvHost == thisHost) {
otherHost = sendHost;
} else if (sendHost == thisHost) {
otherHost = recvHost;
} else {
SPDLOG_ERROR("Send and recv ranks correspond to remote hosts: SEND {} "
"in {}, RECV {} in {}",
sendRank,
sendHost,
recvRank,
recvHost);
throw std::runtime_error("Send and recv ranks in remote hosts");
}

// Get the index for the rank-host pair
int index = getIndexForRanks(sendRank, recvRank);

// Get port for send-recv pair
int port = getMpiPort(sendRank, recvRank);

// Create MPI message endpoint
mpiMessageEndpoints.emplace(
mpiMessageEndpoints.begin() + index,
std::make_unique<faabric::transport::MpiMessageEndpoint>(
otherHost, port, thisHost));
}

void MpiWorld::sendRemoteMpiMessage(
int sendRank,
int recvRank,
const std::shared_ptr<faabric::MPIMessage>& msg)
{
// Get the index for the rank-host pair
int index = getIndexForRanks(sendRank, recvRank);
assert(index >= 0 && index < size * size);

// Lazily initialise send endpoints
if (mpiMessageEndpoints[index] == nullptr) {
// Get host for recv rank
std::string host = getHostForRank(recvRank);
assert(!host.empty());
assert(host != thisHost);

// Get port for send-recv pair
int port = getMpiPort(sendRank, recvRank);
// TODO - finer-grained checking when multi-tenant port scheme
assert(port > 0);

// Create MPI message endpoint
mpiMessageEndpoints.emplace(
mpiMessageEndpoints.begin() + index,
std::make_unique<faabric::transport::MpiMessageEndpoint>(host, port));
if (mpiMessageEndpoints.empty() || mpiMessageEndpoints[index] == nullptr) {
initRemoteMpiEndpoint(sendRank, recvRank);
}

mpiMessageEndpoints[index]->sendMpiMessage(msg);
Expand All @@ -64,36 +87,11 @@ std::shared_ptr<faabric::MPIMessage> MpiWorld::recvRemoteMpiMessage(
int sendRank,
int recvRank)
{
// Assert the ranks are sane
assert(0 <= sendRank && sendRank < size);
assert(0 <= recvRank && recvRank < size);

// Initialise the endpoint vector if not initialised
if (mpiMessageEndpoints.size() == 0) {
for (int i = 0; i < size * size; i++) {
mpiMessageEndpoints.emplace_back(nullptr);
}
}

// Get the index for the rank-host pair
int index = getIndexForRanks(sendRank, recvRank);
assert(index >= 0 && index < size * size);

// Lazily initialise send endpoints
if (mpiMessageEndpoints[index] == nullptr) {
// Get host for recv rank
std::string host = getHostForRank(sendRank);
assert(!host.empty());
assert(host != thisHost);

// Get port for send-recv pair
int port = getMpiPort(sendRank, recvRank);
// TODO - finer-grained checking when multi-tenant port scheme
assert(port > 0);

mpiMessageEndpoints.emplace(
mpiMessageEndpoints.begin() + index,
std::make_unique<faabric::transport::MpiMessageEndpoint>(host, port));
if (mpiMessageEndpoints.empty() || mpiMessageEndpoints[index] == nullptr) {
initRemoteMpiEndpoint(sendRank, recvRank);
}

return mpiMessageEndpoints[index]->recvMpiMessage();
Expand Down Expand Up @@ -254,11 +252,6 @@ std::string MpiWorld::getHostForRank(int rank)
{
assert(rankHosts.size() == size);

if (rank >= size) {
throw std::runtime_error(
fmt::format("Rank bigger than world size ({} > {})", rank, size));
}

std::string host = rankHosts[rank];
if (host.empty()) {
throw std::runtime_error(
Expand Down Expand Up @@ -476,6 +469,14 @@ void MpiWorld::send(int sendRank,
int count,
faabric::MPIMessage::MPIMessageType messageType)
{
// Sanity-check input parameters
checkRanksRange(sendRank, recvRank);
if (getHostForRank(sendRank) != thisHost) {
SPDLOG_ERROR("Trying to send message from a non-local rank: {}",
sendRank);
throw std::runtime_error("Sending message from non-local rank");
}

// Work out whether the message is sent locally or to another host
const std::string otherHost = getHostForRank(recvRank);
bool isLocal = otherHost == thisHost;
Expand Down Expand Up @@ -516,8 +517,15 @@ void MpiWorld::recv(int sendRank,
MPI_Status* status,
faabric::MPIMessage::MPIMessageType messageType)
{
// Sanity-check input parameters
checkRanksRange(sendRank, recvRank);
if (getHostForRank(recvRank) != thisHost) {
SPDLOG_ERROR("Trying to recv message into a non-local rank: {}",
recvRank);
throw std::runtime_error("Receiving message into non-local rank");
}

// Work out whether the message is sent locally or from another host
assert(thisHost == getHostForRank(recvRank));
const std::string otherHost = getHostForRank(sendRank);
bool isLocal = otherHost == thisHost;

Expand Down Expand Up @@ -1174,7 +1182,9 @@ void MpiWorld::initLocalQueues()

int MpiWorld::getIndexForRanks(int sendRank, int recvRank)
{
return sendRank * size + recvRank;
int index = sendRank * size + recvRank;
assert(index >= 0 && index < size * size);
return index;
}

long MpiWorld::getLocalQueueSize(int sendRank, int recvRank)
Expand Down Expand Up @@ -1214,4 +1224,18 @@ void MpiWorld::overrideHost(const std::string& newHost)
{
thisHost = newHost;
}

void MpiWorld::checkRanksRange(int sendRank, int recvRank)
{
if (sendRank < 0 || sendRank >= size) {
SPDLOG_ERROR(
"Send rank outside range: {} not in [0, {})", sendRank, size);
throw std::runtime_error("Send rank outside range");
}
if (recvRank < 0 || recvRank >= size) {
SPDLOG_ERROR(
"Recv rank outside range: {} not in [0, {})", recvRank, size);
throw std::runtime_error("Recv rank outside range");
}
}
}
9 changes: 7 additions & 2 deletions src/transport/MessageEndpoint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,18 +279,23 @@ RecvMessageEndpoint::RecvMessageEndpoint(int portIn)
: MessageEndpoint(ANY_HOST, portIn)
{}

RecvMessageEndpoint::RecvMessageEndpoint(int portIn,
const std::string& overrideHost)
: MessageEndpoint(overrideHost, portIn)
{}

void RecvMessageEndpoint::open(MessageContext& context)
{
SPDLOG_TRACE(
fmt::format("Opening socket: {} (RECV {}:{})", id, ANY_HOST, port));
fmt::format("Opening socket: {} (RECV {}:{})", id, host, port));

MessageEndpoint::open(context, SocketType::PULL, true);
}

void RecvMessageEndpoint::close()
{
SPDLOG_TRACE(
fmt::format("Closing socket: {} (RECV {}:{})", id, ANY_HOST, port));
fmt::format("Closing socket: {} (RECV {}:{})", id, host, port));

MessageEndpoint::close(true);
}
Expand Down
24 changes: 14 additions & 10 deletions src/transport/MpiMessageEndpoint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,24 @@ void sendMpiHostRankMsg(const std::string& hostIn,
MpiMessageEndpoint::MpiMessageEndpoint(const std::string& hostIn, int portIn)
: sendMessageEndpoint(hostIn, portIn)
, recvMessageEndpoint(portIn)
{}
{
sendMessageEndpoint.open(faabric::transport::getGlobalMessageContext());
recvMessageEndpoint.open(faabric::transport::getGlobalMessageContext());
}

MpiMessageEndpoint::MpiMessageEndpoint(const std::string& hostIn,
int portIn,
const std::string& overrideRecvHost)
: sendMessageEndpoint(hostIn, portIn)
, recvMessageEndpoint(portIn, overrideRecvHost)
{
sendMessageEndpoint.open(faabric::transport::getGlobalMessageContext());
recvMessageEndpoint.open(faabric::transport::getGlobalMessageContext());
}

void MpiMessageEndpoint::sendMpiMessage(
const std::shared_ptr<faabric::MPIMessage>& msg)
{
// TODO - is this lazy init very expensive?
if (sendMessageEndpoint.socket == nullptr) {
sendMessageEndpoint.open(faabric::transport::getGlobalMessageContext());
}

size_t msgSize = msg->ByteSizeLong();
{
uint8_t sMsg[msgSize];
Expand All @@ -53,10 +61,6 @@ void MpiMessageEndpoint::sendMpiMessage(

std::shared_ptr<faabric::MPIMessage> MpiMessageEndpoint::recvMpiMessage()
{
if (recvMessageEndpoint.socket == nullptr) {
recvMessageEndpoint.open(faabric::transport::getGlobalMessageContext());
}

Message m = recvMessageEndpoint.recv();
PARSE_MSG(faabric::MPIMessage, m.data(), m.size());

Expand Down
Loading

0 comments on commit 886d0f5

Please sign in to comment.