Skip to content

Commit

Permalink
Add in-proc PAIR sockets (#206)
Browse files Browse the repository at this point in the history
* Add in-proc pair sockets

* Misplaced macro

* PR comments

* More PR comments
  • Loading branch information
Shillaker committed Dec 28, 2021
1 parent 747def5 commit b6b3e3d
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 1 deletion.
20 changes: 20 additions & 0 deletions include/faabric/transport/MessageEndpoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,26 @@ class SyncRecvMessageEndpoint final : public RecvMessageEndpoint
void sendResponse(const uint8_t* data, int size);
};

class AsyncDirectRecvEndpoint final : public RecvMessageEndpoint
{
public:
AsyncDirectRecvEndpoint(const std::string& inprocLabel,
int timeoutMs = DEFAULT_RECV_TIMEOUT_MS);

std::optional<Message> recv(int size = 0) override;
};

class AsyncDirectSendEndpoint final : public MessageEndpoint
{
public:
AsyncDirectSendEndpoint(const std::string& inProcLabel,
int timeoutMs = DEFAULT_RECV_TIMEOUT_MS);

void send(const uint8_t* data, size_t dataSize, bool more = false);

zmq::socket_t socket;
};

class MessageTimeoutException final : public faabric::util::FaabricException
{
public:
Expand Down
48 changes: 48 additions & 0 deletions src/transport/MessageEndpoint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,13 @@ zmq::socket_t socketFactory(zmq::socket_type socketType,
CATCH_ZMQ_ERR_RETRY_ONCE(socket.bind(address), "bind")
break;
}
case zmq::socket_type::pair: {
SPDLOG_TRACE("Bind socket: pair {} (timeout {}ms)",
address,
timeoutMs);
CATCH_ZMQ_ERR_RETRY_ONCE(socket.bind(address), "bind")
break;
}
case zmq::socket_type::pub: {
SPDLOG_TRACE(
"Bind socket: pub {} (timeout {}ms)", address, timeoutMs);
Expand Down Expand Up @@ -123,6 +130,13 @@ zmq::socket_t socketFactory(zmq::socket_type socketType,
}
case (MessageEndpointConnectType::CONNECT): {
switch (socketType) {
case zmq::socket_type::pair: {
SPDLOG_TRACE("Connect socket: pair {} (timeout {}ms)",
address,
timeoutMs);
CATCH_ZMQ_ERR_RETRY_ONCE(socket.connect(address), "connect")
break;
}
case zmq::socket_type::pull: {
SPDLOG_TRACE("Connect socket: pull {} (timeout {}ms)",
address,
Expand Down Expand Up @@ -559,4 +573,38 @@ void SyncRecvMessageEndpoint::sendResponse(const uint8_t* data, int size)
SPDLOG_TRACE("REP {} ({} bytes)", address, size);
doSend(socket, data, size, false);
}

// ----------------------------------------------
// INTERNAL DIRECT MESSAGE ENDPOINTS
// ----------------------------------------------

AsyncDirectRecvEndpoint::AsyncDirectRecvEndpoint(const std::string& inprocLabel,
int timeoutMs)
: RecvMessageEndpoint(inprocLabel,
timeoutMs,
zmq::socket_type::pair,
MessageEndpointConnectType::BIND)
{}

std::optional<Message> AsyncDirectRecvEndpoint::recv(int size)
{
SPDLOG_TRACE("PAIR recv {} ({} bytes)", address, size);
return RecvMessageEndpoint::recv(size);
}

AsyncDirectSendEndpoint::AsyncDirectSendEndpoint(const std::string& inprocLabel,
int timeoutMs)
: MessageEndpoint("inproc://" + inprocLabel, timeoutMs)
{
socket =
setUpSocket(zmq::socket_type::pair, MessageEndpointConnectType::CONNECT);
}

void AsyncDirectSendEndpoint::send(const uint8_t* data,
size_t dataSize,
bool more)
{
SPDLOG_TRACE("PAIR send {} ({} bytes, more {})", address, dataSize, more);
doSend(socket, data, dataSize, more);
}
}
110 changes: 109 additions & 1 deletion tests/test/transport/test_message_endpoint_client.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
#include "faabric_utils.h"
#include <catch2/catch.hpp>

#include <atomic>
#include <thread>
#include <unistd.h>

#include <faabric/transport/MessageEndpoint.h>
#include <faabric/util/latch.h>
#include <faabric/util/macros.h>

using namespace faabric::transport;
Expand Down Expand Up @@ -224,6 +226,112 @@ TEST_CASE_METHOD(SchedulerTestFixture,
}
}

#endif
TEST_CASE_METHOD(SchedulerTestFixture, "Test direct messaging", "[transport]")
{
std::string expected = "Direct hello";
const uint8_t* msg = BYTES_CONST(expected.c_str());

std::string inprocLabel = "direct-test";

AsyncDirectSendEndpoint sender(inprocLabel);
sender.send(msg, expected.size());

AsyncDirectRecvEndpoint receiver(inprocLabel);

std::string actual;
SECTION("Recv with size")
{
faabric::transport::Message recvMsg =
receiver.recv(expected.size()).value();
actual = std::string(recvMsg.data(), recvMsg.size());
}

SECTION("Recv no size")
{
faabric::transport::Message recvMsg = receiver.recv().value();
actual = std::string(recvMsg.data(), recvMsg.size());
}

REQUIRE(actual == expected);
}

TEST_CASE_METHOD(SchedulerTestFixture,
"Stress test direct messaging",
"[transport]")
{
int nMessages = 1000;
int nPairs = 3;
std::string inprocLabel = "direct-test-";

std::shared_ptr<faabric::util::Latch> startLatch =
faabric::util::Latch::create(nPairs + 1);

std::vector<std::thread> senders;
std::vector<std::thread> receivers;

for (int i = 0; i < nPairs; i++) {
senders.emplace_back([i, nMessages, inprocLabel, &startLatch] {
std::string thisLabel = inprocLabel + std::to_string(i);
AsyncDirectSendEndpoint sender(thisLabel);

for (int m = 0; m < nMessages; m++) {
std::string expected =
"Direct hello " + std::to_string(i) + "_" + std::to_string(m);
const uint8_t* msg = BYTES_CONST(expected.c_str());
sender.send(msg, expected.size());

if (m % 100 == 0) {
SLEEP_MS(10);
}

// Make main thread wait until messages are queued (to check no
// issue with connecting before binding)
if (m == 10) {
startLatch->wait();
}
}
});
}

// Wait for queued messages
startLatch->wait();

std::atomic<bool> success = true;
for (int i = 0; i < nPairs; i++) {
receivers.emplace_back([i, nMessages, inprocLabel, &success] {
std::string thisLabel = inprocLabel + std::to_string(i);
AsyncDirectRecvEndpoint receiver(thisLabel);

// Receive messages
for (int m = 0; m < nMessages; m++) {
faabric::transport::Message recvMsg = receiver.recv().value();
std::string actual(recvMsg.data(), recvMsg.size());

std::string expected =
"Direct hello " + std::to_string(i) + "_" + std::to_string(m);

if (actual != expected) {
success.store(false);
}
}
});
}

REQUIRE(success.load(std::memory_order_acquire));

for (auto& t : senders) {
if (t.joinable()) {
t.join();
}
}

for (auto& t : receivers) {
if (t.joinable()) {
t.join();
}
}
}

#endif // End ThreadSanitizer exclusion

}

0 comments on commit b6b3e3d

Please sign in to comment.