Skip to content

Commit

Permalink
Queue-less remote send/recv (#105)
Browse files Browse the repository at this point in the history
* adding p2p send/recv methods

* not check queueing when we don't use queues

* removing sendMPImessage from the function call API

* update tests

* reusing sockets for rank-to-rank communication

* adding more tests

* pr comments
  • Loading branch information
csegarragonz committed Jun 15, 2021
1 parent 9faf426 commit e3331a1
Show file tree
Hide file tree
Showing 18 changed files with 550 additions and 361 deletions.
11 changes: 5 additions & 6 deletions include/faabric/scheduler/FunctionCallApi.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@ namespace faabric::scheduler {
enum FunctionCalls
{
NoFunctionCall = 0,
MpiMessage = 1,
ExecuteFunctions = 2,
Flush = 3,
Unregister = 4,
GetResources = 5,
SetThreadResult = 6,
ExecuteFunctions = 1,
Flush = 2,
Unregister = 3,
GetResources = 4,
SetThreadResult = 5,
};
}
4 changes: 0 additions & 4 deletions include/faabric/scheduler/FunctionCallClient.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ std::vector<
std::pair<std::string, std::shared_ptr<faabric::BatchExecuteRequest>>>
getBatchRequests();

std::vector<std::pair<std::string, faabric::MPIMessage>> getMPIMessages();

std::vector<std::pair<std::string, faabric::ResponseRequest>>
getResourceRequests();

Expand All @@ -44,8 +42,6 @@ class FunctionCallClient : public faabric::transport::MessageEndpointClient

void sendFlush();

void sendMPIMessage(const std::shared_ptr<faabric::MPIMessage> msg);

faabric::HostResources getResources();

void executeFunctions(
Expand Down
4 changes: 0 additions & 4 deletions include/faabric/scheduler/FunctionCallServer.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

#include <faabric/proto/faabric.pb.h>
#include <faabric/scheduler/FunctionCallApi.h>
#include <faabric/scheduler/MpiWorld.h>
#include <faabric/scheduler/MpiWorldRegistry.h>
#include <faabric/scheduler/Scheduler.h>
#include <faabric/transport/MessageEndpointClient.h>
#include <faabric/transport/MessageEndpointServer.h>
Expand All @@ -25,8 +23,6 @@ class FunctionCallServer final

/* Function call server API */

void recvMpiMessage(faabric::transport::Message& body);

void recvFlush(faabric::transport::Message& body);

void recvExecuteFunctions(faabric::transport::Message& body);
Expand Down
34 changes: 20 additions & 14 deletions include/faabric/scheduler/MpiWorld.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
#include <faabric/scheduler/FunctionCallClient.h>
#include <faabric/scheduler/InMemoryMessageQueue.h>
#include <faabric/scheduler/MpiThreadPool.h>
#include <faabric/state/StateKeyValue.h>
#include <faabric/transport/MpiMessageEndpoint.h>
#include <faabric/util/logging.h>
#include <faabric/util/timing.h>

#include <atomic>
#include <thread>
Expand All @@ -15,8 +17,6 @@ namespace faabric::scheduler {
typedef faabric::util::Queue<std::shared_ptr<faabric::MPIMessage>>
InMemoryMpiQueue;

std::string getWorldStateKey(int worldId);

class MpiWorld
{
public:
Expand All @@ -43,8 +43,6 @@ class MpiWorld

void shutdownThreadPool();

void enqueueMessage(faabric::MPIMessage& msg);

void getCartesianRank(int rank,
int maxDims,
const int* dims,
Expand Down Expand Up @@ -198,23 +196,31 @@ class MpiWorld
std::string user;
std::string function;

std::shared_ptr<state::StateKeyValue> stateKV;
std::vector<std::string> rankHosts;

std::vector<std::shared_ptr<InMemoryMpiQueue>> localQueues;

std::shared_ptr<faabric::scheduler::MpiAsyncThreadPool> threadPool;
int getMpiThreadPoolSize();

std::vector<int> cartProcsPerDim;

faabric::scheduler::FunctionCallClient& getFunctionCallClient(
const std::string& otherHost);

void closeThreadLocalClients();
/* MPI internal messaging layer */

// Track at which host each rank lives
std::vector<std::string> rankHosts;
int getIndexForRanks(int sendRank, int recvRank);

// In-memory queues for local messaging
std::vector<std::shared_ptr<InMemoryMpiQueue>> localQueues;
void initLocalQueues();

// Rank-to-rank sockets for remote messaging
void initRemoteMpiEndpoint(int sendRank, int recvRank);
int getMpiPort(int sendRank, int recvRank);
void sendRemoteMpiMessage(int sendRank,
int recvRank,
const std::shared_ptr<faabric::MPIMessage>& msg);
std::shared_ptr<faabric::MPIMessage> recvRemoteMpiMessage(int sendRank,
int recvRank);
void closeMpiMessageEndpoints();

void checkRanksRange(int sendRank, int recvRank);
};
}
2 changes: 2 additions & 0 deletions include/faabric/transport/MessageEndpoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ class RecvMessageEndpoint : public MessageEndpoint
public:
RecvMessageEndpoint(int portIn);

RecvMessageEndpoint(int portIn, const std::string& overrideHost);

void open(MessageContext& context);

void close();
Expand Down
27 changes: 27 additions & 0 deletions include/faabric/transport/MpiMessageEndpoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,35 @@
#include <faabric/transport/macros.h>

namespace faabric::transport {
/* These two abstract methods are used to broadcast the host-rank mapping at
* initialisation time.
*/
faabric::MpiHostsToRanksMessage recvMpiHostRankMsg();

void sendMpiHostRankMsg(const std::string& hostIn,
const faabric::MpiHostsToRanksMessage msg);

/* This class abstracts the notion of a communication channel between two MPI
* ranks. There will always be one rank local to this host, and one remote.
* Note that the port is unique per (user, function, sendRank, recvRank) tuple.
*/
class MpiMessageEndpoint
{
public:
MpiMessageEndpoint(const std::string& hostIn, int portIn);

MpiMessageEndpoint(const std::string& hostIn,
int portIn,
const std::string& overrideRecvHost);

void sendMpiMessage(const std::shared_ptr<faabric::MPIMessage>& msg);

std::shared_ptr<faabric::MPIMessage> recvMpiMessage();

void close();

private:
SendMessageEndpoint sendMessageEndpoint;
RecvMessageEndpoint recvMessageEndpoint;
};
}
3 changes: 1 addition & 2 deletions include/faabric/transport/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
#define DEFAULT_SNAPSHOT_HOST "0.0.0.0"
#define STATE_PORT 8003
#define FUNCTION_CALL_PORT 8004
#define MPI_MESSAGE_PORT 8005
#define SNAPSHOT_PORT 8006
#define SNAPSHOT_PORT 8005
#define REPLY_PORT_OFFSET 100

#define MPI_PORT 8800
19 changes: 0 additions & 19 deletions src/scheduler/FunctionCallClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ static std::vector<
std::pair<std::string, std::shared_ptr<faabric::BatchExecuteRequest>>>
batchMessages;

static std::vector<std::pair<std::string, faabric::MPIMessage>> mpiMessages;

static std::vector<std::pair<std::string, faabric::ResponseRequest>>
resourceRequests;

Expand Down Expand Up @@ -48,11 +46,6 @@ getBatchRequests()
return batchMessages;
}

std::vector<std::pair<std::string, faabric::MPIMessage>> getMPIMessages()
{
return mpiMessages;
}

std::vector<std::pair<std::string, faabric::ResponseRequest>>
getResourceRequests()
{
Expand All @@ -74,7 +67,6 @@ void clearMockRequests()
{
functionCalls.clear();
batchMessages.clear();
mpiMessages.clear();
resourceRequests.clear();
unregisterRequests.clear();

Expand Down Expand Up @@ -113,17 +105,6 @@ void FunctionCallClient::sendFlush()
}
}

void FunctionCallClient::sendMPIMessage(
const std::shared_ptr<faabric::MPIMessage> msg)
{
if (faabric::util::isMockMode()) {
faabric::util::UniqueLock lock(mockMutex);
mpiMessages.emplace_back(host, *msg);
} else {
SEND_MESSAGE_PTR(faabric::scheduler::FunctionCalls::MpiMessage, msg);
}
}

faabric::HostResources FunctionCallClient::getResources()
{
faabric::ResponseRequest request;
Expand Down
12 changes: 0 additions & 12 deletions src/scheduler/FunctionCallServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@ void FunctionCallServer::doRecv(faabric::transport::Message& header,
assert(header.size() == sizeof(uint8_t));
uint8_t call = static_cast<uint8_t>(*header.data());
switch (call) {
case faabric::scheduler::FunctionCalls::MpiMessage:
this->recvMpiMessage(body);
break;
case faabric::scheduler::FunctionCalls::Flush:
this->recvFlush(body);
break;
Expand All @@ -48,15 +45,6 @@ void FunctionCallServer::doRecv(faabric::transport::Message& header,
}
}

void FunctionCallServer::recvMpiMessage(faabric::transport::Message& body)
{
PARSE_MSG(faabric::MPIMessage, body.data(), body.size())

MpiWorldRegistry& registry = getMpiWorldRegistry();
MpiWorld& world = registry.getWorld(msg.worldid());
world.enqueueMessage(msg);
}

void FunctionCallServer::recvFlush(faabric::transport::Message& body)
{
PARSE_MSG(faabric::ResponseRequest, body.data(), body.size());
Expand Down
Loading

0 comments on commit e3331a1

Please sign in to comment.