Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Queue-less remote send/recv #105

Merged
merged 7 commits into from
Jun 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions include/faabric/scheduler/FunctionCallApi.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@ namespace faabric::scheduler {
enum FunctionCalls
{
NoFunctionCall = 0,
MpiMessage = 1,
ExecuteFunctions = 2,
Flush = 3,
Unregister = 4,
GetResources = 5,
SetThreadResult = 6,
ExecuteFunctions = 1,
Flush = 2,
Unregister = 3,
GetResources = 4,
SetThreadResult = 5,
};
}
4 changes: 0 additions & 4 deletions include/faabric/scheduler/FunctionCallClient.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ std::vector<
std::pair<std::string, std::shared_ptr<faabric::BatchExecuteRequest>>>
getBatchRequests();

std::vector<std::pair<std::string, faabric::MPIMessage>> getMPIMessages();

std::vector<std::pair<std::string, faabric::ResponseRequest>>
getResourceRequests();

Expand All @@ -44,8 +42,6 @@ class FunctionCallClient : public faabric::transport::MessageEndpointClient

void sendFlush();

void sendMPIMessage(const std::shared_ptr<faabric::MPIMessage> msg);

faabric::HostResources getResources();

void executeFunctions(
Expand Down
4 changes: 0 additions & 4 deletions include/faabric/scheduler/FunctionCallServer.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

#include <faabric/proto/faabric.pb.h>
#include <faabric/scheduler/FunctionCallApi.h>
#include <faabric/scheduler/MpiWorld.h>
#include <faabric/scheduler/MpiWorldRegistry.h>
#include <faabric/scheduler/Scheduler.h>
#include <faabric/transport/MessageEndpointClient.h>
#include <faabric/transport/MessageEndpointServer.h>
Expand All @@ -25,8 +23,6 @@ class FunctionCallServer final

/* Function call server API */

void recvMpiMessage(faabric::transport::Message& body);

void recvFlush(faabric::transport::Message& body);

void recvExecuteFunctions(faabric::transport::Message& body);
Expand Down
34 changes: 20 additions & 14 deletions include/faabric/scheduler/MpiWorld.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
#include <faabric/scheduler/FunctionCallClient.h>
#include <faabric/scheduler/InMemoryMessageQueue.h>
#include <faabric/scheduler/MpiThreadPool.h>
#include <faabric/state/StateKeyValue.h>
#include <faabric/transport/MpiMessageEndpoint.h>
#include <faabric/util/logging.h>
#include <faabric/util/timing.h>

#include <atomic>
#include <thread>
Expand All @@ -15,8 +17,6 @@ namespace faabric::scheduler {
typedef faabric::util::Queue<std::shared_ptr<faabric::MPIMessage>>
InMemoryMpiQueue;

std::string getWorldStateKey(int worldId);

class MpiWorld
{
public:
Expand All @@ -43,8 +43,6 @@ class MpiWorld

void shutdownThreadPool();

void enqueueMessage(faabric::MPIMessage& msg);

void getCartesianRank(int rank,
int maxDims,
const int* dims,
Expand Down Expand Up @@ -198,23 +196,31 @@ class MpiWorld
std::string user;
std::string function;

std::shared_ptr<state::StateKeyValue> stateKV;
std::vector<std::string> rankHosts;

std::vector<std::shared_ptr<InMemoryMpiQueue>> localQueues;

std::shared_ptr<faabric::scheduler::MpiAsyncThreadPool> threadPool;
int getMpiThreadPoolSize();

std::vector<int> cartProcsPerDim;

faabric::scheduler::FunctionCallClient& getFunctionCallClient(
const std::string& otherHost);

void closeThreadLocalClients();
/* MPI internal messaging layer */

// Track at which host each rank lives
std::vector<std::string> rankHosts;
int getIndexForRanks(int sendRank, int recvRank);

// In-memory queues for local messaging
std::vector<std::shared_ptr<InMemoryMpiQueue>> localQueues;
void initLocalQueues();

// Rank-to-rank sockets for remote messaging
void initRemoteMpiEndpoint(int sendRank, int recvRank);
int getMpiPort(int sendRank, int recvRank);
void sendRemoteMpiMessage(int sendRank,
int recvRank,
const std::shared_ptr<faabric::MPIMessage>& msg);
std::shared_ptr<faabric::MPIMessage> recvRemoteMpiMessage(int sendRank,
int recvRank);
void closeMpiMessageEndpoints();

void checkRanksRange(int sendRank, int recvRank);
};
}
2 changes: 2 additions & 0 deletions include/faabric/transport/MessageEndpoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ class RecvMessageEndpoint : public MessageEndpoint
public:
RecvMessageEndpoint(int portIn);

RecvMessageEndpoint(int portIn, const std::string& overrideHost);

void open(MessageContext& context);

void close();
Expand Down
27 changes: 27 additions & 0 deletions include/faabric/transport/MpiMessageEndpoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,35 @@
#include <faabric/transport/macros.h>

namespace faabric::transport {
/* These two abstract methods are used to broadcast the host-rank mapping at
* initialisation time.
*/
faabric::MpiHostsToRanksMessage recvMpiHostRankMsg();

void sendMpiHostRankMsg(const std::string& hostIn,
const faabric::MpiHostsToRanksMessage msg);

/* This class abstracts the notion of a communication channel between two MPI
* ranks. There will always be one rank local to this host, and one remote.
* Note that the port is unique per (user, function, sendRank, recvRank) tuple.
*/
class MpiMessageEndpoint
csegarragonz marked this conversation as resolved.
Show resolved Hide resolved
{
public:
MpiMessageEndpoint(const std::string& hostIn, int portIn);

MpiMessageEndpoint(const std::string& hostIn,
int portIn,
const std::string& overrideRecvHost);

void sendMpiMessage(const std::shared_ptr<faabric::MPIMessage>& msg);

std::shared_ptr<faabric::MPIMessage> recvMpiMessage();

void close();

private:
SendMessageEndpoint sendMessageEndpoint;
RecvMessageEndpoint recvMessageEndpoint;
};
}
3 changes: 1 addition & 2 deletions include/faabric/transport/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
#define DEFAULT_SNAPSHOT_HOST "0.0.0.0"
#define STATE_PORT 8003
#define FUNCTION_CALL_PORT 8004
#define MPI_MESSAGE_PORT 8005
#define SNAPSHOT_PORT 8006
#define SNAPSHOT_PORT 8005
#define REPLY_PORT_OFFSET 100

#define MPI_PORT 8800
19 changes: 0 additions & 19 deletions src/scheduler/FunctionCallClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ static std::vector<
std::pair<std::string, std::shared_ptr<faabric::BatchExecuteRequest>>>
batchMessages;

static std::vector<std::pair<std::string, faabric::MPIMessage>> mpiMessages;

static std::vector<std::pair<std::string, faabric::ResponseRequest>>
resourceRequests;

Expand Down Expand Up @@ -48,11 +46,6 @@ getBatchRequests()
return batchMessages;
}

std::vector<std::pair<std::string, faabric::MPIMessage>> getMPIMessages()
{
return mpiMessages;
}

std::vector<std::pair<std::string, faabric::ResponseRequest>>
getResourceRequests()
{
Expand All @@ -74,7 +67,6 @@ void clearMockRequests()
{
functionCalls.clear();
batchMessages.clear();
mpiMessages.clear();
resourceRequests.clear();
unregisterRequests.clear();

Expand Down Expand Up @@ -113,17 +105,6 @@ void FunctionCallClient::sendFlush()
}
}

void FunctionCallClient::sendMPIMessage(
const std::shared_ptr<faabric::MPIMessage> msg)
{
if (faabric::util::isMockMode()) {
faabric::util::UniqueLock lock(mockMutex);
mpiMessages.emplace_back(host, *msg);
} else {
SEND_MESSAGE_PTR(faabric::scheduler::FunctionCalls::MpiMessage, msg);
}
}

faabric::HostResources FunctionCallClient::getResources()
{
faabric::ResponseRequest request;
Expand Down
12 changes: 0 additions & 12 deletions src/scheduler/FunctionCallServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@ void FunctionCallServer::doRecv(faabric::transport::Message& header,
assert(header.size() == sizeof(uint8_t));
uint8_t call = static_cast<uint8_t>(*header.data());
switch (call) {
case faabric::scheduler::FunctionCalls::MpiMessage:
this->recvMpiMessage(body);
break;
case faabric::scheduler::FunctionCalls::Flush:
this->recvFlush(body);
break;
Expand All @@ -48,15 +45,6 @@ void FunctionCallServer::doRecv(faabric::transport::Message& header,
}
}

void FunctionCallServer::recvMpiMessage(faabric::transport::Message& body)
{
PARSE_MSG(faabric::MPIMessage, body.data(), body.size())

MpiWorldRegistry& registry = getMpiWorldRegistry();
MpiWorld& world = registry.getWorld(msg.worldid());
world.enqueueMessage(msg);
}

void FunctionCallServer::recvFlush(faabric::transport::Message& body)
{
PARSE_MSG(faabric::ResponseRequest, body.data(), body.size());
Expand Down
Loading