Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalised point-to-point messaging #151

Merged
merged 23 commits into from
Oct 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/faabric/runner/FaabricMain.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <faabric/scheduler/Scheduler.h>
#include <faabric/snapshot/SnapshotServer.h>
#include <faabric/state/StateServer.h>
#include <faabric/transport/PointToPointServer.h>
#include <faabric/util/config.h>

namespace faabric::runner {
Expand All @@ -23,11 +24,14 @@ class FaabricMain

void startSnapshotServer();

void startPointToPointServer();

void shutdown();

private:
faabric::state::StateServer stateServer;
faabric::scheduler::FunctionCallServer functionServer;
faabric::snapshot::SnapshotServer snapshotServer;
faabric::transport::PointToPointServer pointToPointServer;
};
}
22 changes: 21 additions & 1 deletion include/faabric/transport/MessageEndpoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,18 @@ class AsyncSendMessageEndpoint final : public MessageEndpoint

void send(const uint8_t* data, size_t dataSize, bool more = false);

zmq::socket_t pushSocket;
zmq::socket_t socket;
};

class AsyncInternalSendMessageEndpoint final : public MessageEndpoint
{
public:
AsyncInternalSendMessageEndpoint(const std::string& inProcLabel,
int timeoutMs = DEFAULT_RECV_TIMEOUT_MS);

void send(const uint8_t* data, size_t dataSize, bool more = false);

zmq::socket_t socket;
};

class SyncSendMessageEndpoint final : public MessageEndpoint
Expand Down Expand Up @@ -183,6 +194,15 @@ class AsyncRecvMessageEndpoint final : public RecvMessageEndpoint
std::optional<Message> recv(int size = 0) override;
};

class AsyncInternalRecvMessageEndpoint final : public RecvMessageEndpoint
{
public:
AsyncInternalRecvMessageEndpoint(const std::string& inprocLabel,
int timeoutMs = DEFAULT_RECV_TIMEOUT_MS);

std::optional<Message> recv(int size = 0) override;
};

class SyncRecvMessageEndpoint final : public RecvMessageEndpoint
{
public:
Expand Down
1 change: 0 additions & 1 deletion include/faabric/transport/MessageEndpointClient.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ class MessageEndpointClient
protected:
const std::string host;

private:
const int asyncPort;

const int syncPort;
Expand Down
2 changes: 2 additions & 0 deletions include/faabric/transport/MessageEndpointServer.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ class MessageEndpointServer

virtual void stop();

virtual void onWorkerStop();
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This hook allows servers to perform custom shutdown when workers stop.


void setRequestLatch();

void awaitRequestLatch();
Expand Down
52 changes: 52 additions & 0 deletions include/faabric/transport/PointToPointBroker.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#pragma once

#include <faabric/scheduler/Scheduler.h>
#include <faabric/transport/PointToPointClient.h>

#include <set>
#include <shared_mutex>
#include <string>
#include <unordered_map>
#include <vector>

namespace faabric::transport {
class PointToPointBroker
{
public:
PointToPointBroker();

std::string getHostForReceiver(int appId, int recvIdx);

void setHostForReceiver(int appId, int recvIdx, const std::string& host);

void broadcastMappings(int appId);

void sendMappings(int appId, const std::string& host);

std::set<int> getIdxsRegisteredForApp(int appId);

void sendMessage(int appId,
int sendIdx,
int recvIdx,
const uint8_t* buffer,
size_t bufferSize);

std::vector<uint8_t> recvMessage(int appId, int sendIdx, int recvIdx);

void clear();

void resetThreadLocalCache();

private:
std::shared_mutex brokerMutex;

std::unordered_map<int, std::set<int>> appIdxs;
std::unordered_map<std::string, std::string> mappings;

std::shared_ptr<PointToPointClient> getClient(const std::string& host);

faabric::scheduler::Scheduler& sch;
};

PointToPointBroker& getPointToPointBroker();
}
10 changes: 10 additions & 0 deletions include/faabric/transport/PointToPointCall.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#pragma once

namespace faabric::transport {

enum PointToPointCall
{
MAPPING = 0,
MESSAGE = 1
};
}
25 changes: 25 additions & 0 deletions include/faabric/transport/PointToPointClient.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#pragma once

#include <faabric/proto/faabric.pb.h>
#include <faabric/transport/MessageEndpointClient.h>

namespace faabric::transport {

std::vector<std::pair<std::string, faabric::PointToPointMappings>>
getSentMappings();

std::vector<std::pair<std::string, faabric::PointToPointMessage>>
getSentPointToPointMessages();

void clearSentMessages();

class PointToPointClient : public faabric::transport::MessageEndpointClient
{
public:
PointToPointClient(const std::string& hostIn);

void sendMappings(faabric::PointToPointMappings& mappings);

void sendMessage(faabric::PointToPointMessage& msg);
};
}
29 changes: 29 additions & 0 deletions include/faabric/transport/PointToPointServer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#pragma once

#include <faabric/transport/MessageEndpointServer.h>
#include <faabric/transport/PointToPointBroker.h>

namespace faabric::transport {

class PointToPointServer final : public MessageEndpointServer
{
public:
PointToPointServer();

private:
PointToPointBroker& reg;

void doAsyncRecv(int header,
const uint8_t* buffer,
size_t bufferSize) override;

std::unique_ptr<google::protobuf::Message>
doSyncRecv(int header, const uint8_t* buffer, size_t bufferSize) override;

void onWorkerStop() override;

std::unique_ptr<google::protobuf::Message> doRecvMappings(
const uint8_t* buffer,
size_t bufferSize);
};
}
4 changes: 4 additions & 0 deletions include/faabric/transport/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,7 @@
#define SNAPSHOT_ASYNC_PORT 8007
#define SNAPSHOT_SYNC_PORT 8008
#define SNAPSHOT_INPROC_LABEL "snapshot"

#define POINT_TO_POINT_ASYNC_PORT 8009
#define POINT_TO_POINT_SYNC_PORT 8010
#define POINT_TO_POINT_INPROC_LABEL "ptp"
2 changes: 2 additions & 0 deletions include/faabric/util/bytes.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ std::vector<uint8_t> stringToBytes(const std::string& str);

std::string bytesToString(const std::vector<uint8_t>& bytes);

std::string formatByteArrayToIntString(const std::vector<uint8_t>& bytes);

void trimTrailingZeros(std::vector<uint8_t>& vectorIn);

int safeCopyToBuffer(const std::vector<uint8_t>& dataIn,
Expand Down
1 change: 1 addition & 0 deletions include/faabric/util/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class SystemConfig
int functionServerThreads;
int stateServerThreads;
int snapshotServerThreads;
int pointToPointServerThreads;

SystemConfig();

Expand Down
21 changes: 21 additions & 0 deletions src/proto/faabric.proto
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,24 @@ message StateAppendedResponse {
string key = 2;
repeated AppendedValue values = 3;
}

// ---------------------------------------------
// POINT-TO-POINT
// ---------------------------------------------

message PointToPointMessage {
int32 appId = 1;
int32 sendIdx = 2;
int32 recvIdx = 3;

bytes data = 4;
}

message PointToPointMappings {
message PointToPointMapping {
int32 appId = 1;
int32 recvIdx = 2;
string host = 3;
}
repeated PointToPointMapping mappings = 1;
}
12 changes: 12 additions & 0 deletions src/runner/FaabricMain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ void FaabricMain::startBackground()
// Snapshots
startSnapshotServer();

// Point-to-point messaging
startPointToPointServer();

// Work sharing
startFunctionCallServer();
}
Expand Down Expand Up @@ -71,6 +74,12 @@ void FaabricMain::startSnapshotServer()
snapshotServer.start();
}

void FaabricMain::startPointToPointServer()
{
SPDLOG_INFO("Starting point-to-point server");
pointToPointServer.start();
}

void FaabricMain::startStateServer()
{
// Skip state server if not in inmemory mode
Expand Down Expand Up @@ -99,6 +108,9 @@ void FaabricMain::shutdown()
SPDLOG_INFO("Waiting for the snapshot server to finish");
snapshotServer.stop();

SPDLOG_INFO("Waiting for the point-to-point server to finish");
pointToPointServer.stop();

auto& sch = faabric::scheduler::getScheduler();
sch.shutdown();

Expand Down
6 changes: 6 additions & 0 deletions src/transport/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ set(HEADERS
"${FAABRIC_INCLUDE_DIR}/faabric/transport/MessageEndpointClient.h"
"${FAABRIC_INCLUDE_DIR}/faabric/transport/MessageEndpointServer.h"
"${FAABRIC_INCLUDE_DIR}/faabric/transport/MpiMessageEndpoint.h"
"${FAABRIC_INCLUDE_DIR}/faabric/transport/PointToPointBroker.h"
"${FAABRIC_INCLUDE_DIR}/faabric/transport/PointToPointClient.h"
"${FAABRIC_INCLUDE_DIR}/faabric/transport/PointToPointServer.h"
)

set(LIB_FILES
Expand All @@ -20,6 +23,9 @@ set(LIB_FILES
MessageEndpointClient.cpp
MessageEndpointServer.cpp
MpiMessageEndpoint.cpp
PointToPointBroker.cpp
PointToPointClient.cpp
PointToPointServer.cpp
${HEADERS}
)

Expand Down
38 changes: 35 additions & 3 deletions src/transport/MessageEndpoint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,22 +313,39 @@ AsyncSendMessageEndpoint::AsyncSendMessageEndpoint(const std::string& hostIn,
int timeoutMs)
: MessageEndpoint(hostIn, portIn, timeoutMs)
{
pushSocket =
socket =
setUpSocket(zmq::socket_type::push, MessageEndpointConnectType::CONNECT);
}

void AsyncSendMessageEndpoint::sendHeader(int header)
{
uint8_t headerBytes = static_cast<uint8_t>(header);
doSend(pushSocket, &headerBytes, sizeof(headerBytes), true);
doSend(socket, &headerBytes, sizeof(headerBytes), true);
}

void AsyncSendMessageEndpoint::send(const uint8_t* data,
size_t dataSize,
bool more)
{
SPDLOG_TRACE("PUSH {} ({} bytes, more {})", address, dataSize, more);
doSend(pushSocket, data, dataSize, more);
doSend(socket, data, dataSize, more);
}

AsyncInternalSendMessageEndpoint::AsyncInternalSendMessageEndpoint(
const std::string& inprocLabel,
int timeoutMs)
: MessageEndpoint("inproc://" + inprocLabel, timeoutMs)
{
socket =
setUpSocket(zmq::socket_type::push, MessageEndpointConnectType::CONNECT);
}

void AsyncInternalSendMessageEndpoint::send(const uint8_t* data,
size_t dataSize,
bool more)
{
SPDLOG_TRACE("PUSH {} ({} bytes, more {})", address, dataSize, more);
doSend(socket, data, dataSize, more);
}

// ----------------------------------------------
Expand Down Expand Up @@ -495,6 +512,21 @@ std::optional<Message> AsyncRecvMessageEndpoint::recv(int size)
return RecvMessageEndpoint::recv(size);
}

AsyncInternalRecvMessageEndpoint::AsyncInternalRecvMessageEndpoint(
const std::string& inprocLabel,
int timeoutMs)
: RecvMessageEndpoint(inprocLabel,
timeoutMs,
zmq::socket_type::pull,
MessageEndpointConnectType::BIND)
{}

std::optional<Message> AsyncInternalRecvMessageEndpoint::recv(int size)
{
SPDLOG_TRACE("PULL {} ({} bytes)", address, size);
return RecvMessageEndpoint::recv(size);
}

// ----------------------------------------------
// SYNC RECV ENDPOINT
// ----------------------------------------------
Expand Down
8 changes: 8 additions & 0 deletions src/transport/MessageEndpointServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,9 @@ void MessageEndpointServerHandler::start(
}
}

// Perform the tidy-up
server->onWorkerStop();

// Just before the thread dies, check if there's something
// waiting on the shutdown latch
if (server->shutdownLatch != nullptr) {
Expand Down Expand Up @@ -286,6 +289,11 @@ void MessageEndpointServer::stop()
syncHandler.join();
}

void MessageEndpointServer::onWorkerStop()
{
// Nothing to do by default
}

void MessageEndpointServer::setRequestLatch()
{
requestLatch = faabric::util::Latch::create(2);
Expand Down
Loading