Skip to content

Commit

Permalink
add call records and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
csegarragonz committed Oct 27, 2021
1 parent 1837780 commit 610d1c3
Show file tree
Hide file tree
Showing 7 changed files with 233 additions and 0 deletions.
37 changes: 37 additions & 0 deletions include/faabric/util/tracing.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#pragma once

#include <faabric/proto/faabric.pb.h>

#include <list>
#include <map>

namespace faabric::util::tracing {
enum RecordType
{
MpiPerRankMessageCount
};

class CallRecords
{
public:
void startRecording(const faabric::Message& msg);

void stopRecording(faabric::Message& msg);

void addRecord(int msgId, RecordType recordType, int idToIncrement);

private:
std::shared_ptr<faabric::Message> linkedMsg = nullptr;

std::list<RecordType> onGoingRecordings;

void loadRecordsToMessage(faabric::CallRecords& callRecords,
const RecordType& recordType);

// ----- Per record type data structures -----

std::map<int, int> perRankMsgCount;
};

CallRecords& getCallRecords();
}
17 changes: 17 additions & 0 deletions src/proto/faabric.proto
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,19 @@ message MpiHostsToRanksMessage {
repeated int32 basePorts = 2;
}

// ---------------------------------------------
// PROFILING
// ---------------------------------------------

message MpiPerRankMessageCount {
repeated int32 ranks = 1;
repeated int32 numMessages = 2;
}

message CallRecords {
MpiPerRankMessageCount mpiMsgCount = 1;
}

message Message {
int32 id = 1;
int32 appId = 2;
Expand Down Expand Up @@ -148,6 +161,10 @@ message Message {
string sgxTag = 35;
bytes sgxPolicy = 36;
bytes sgxResult = 37;

// This last struct is used for tracing purposes, it should only be set in
// non-release builds
CallRecords records = 38;
}

// ---------------------------------------------
Expand Down
7 changes: 7 additions & 0 deletions src/scheduler/Executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <faabric/util/macros.h>
#include <faabric/util/memory.h>
#include <faabric/util/queue.h>
#include <faabric/util/tracing.h>
#include <faabric/util/timing.h>

#define POOL_SHUTDOWN -1
Expand Down Expand Up @@ -226,6 +227,9 @@ void Executor::threadPoolThread(int threadPoolIdx)
msg.id(),
isThreads);

// Start recording calls in non-release builds
faabric::util::tracing::getCallRecords().startRecording(msg);

int32_t returnValue;
try {
returnValue =
Expand All @@ -242,6 +246,9 @@ void Executor::threadPoolThread(int threadPoolIdx)
// Set the return value
msg.set_returnvalue(returnValue);

// Stop recording calls
faabric::util::tracing::getCallRecords().stopRecording(msg);

// Decrement the task count
int oldTaskCount = task.batchCounter->fetch_sub(1);
assert(oldTaskCount >= 0);
Expand Down
11 changes: 11 additions & 0 deletions src/scheduler/MpiWorld.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <faabric/util/macros.h>
#include <faabric/util/scheduling.h>
#include <faabric/util/testing.h>
#include <faabric/util/tracing.h>

// Each MPI rank runs in a separate thread, thus we use TLS to maintain the
// per-rank data structures
Expand Down Expand Up @@ -33,6 +34,9 @@ static thread_local std::unordered_map<
std::unique_ptr<faabric::transport::AsyncSendMessageEndpoint>>
ranksSendEndpoints;

// Id of the message that created this thread-local instance
static thread_local int thisMsgId;

// This is used for mocking in tests
static std::vector<faabric::MpiHostsToRanksMessage> rankMessages;

Expand Down Expand Up @@ -177,6 +181,7 @@ void MpiWorld::create(const faabric::Message& call, int newId, int newSize)
id = newId;
user = call.user();
function = call.function();
thisMsgId = call.id();

size = newSize;

Expand Down Expand Up @@ -245,6 +250,7 @@ void MpiWorld::destroy()
SPDLOG_TRACE("Destroying MPI world {}", id);

// Note that all ranks will call this function.
thisMsgId = 0;

// We must force the destructors for all message endpoints to run here
// rather than at the end of their global thread-local lifespan. If we
Expand Down Expand Up @@ -292,6 +298,7 @@ void MpiWorld::initialiseFromMsg(const faabric::Message& msg)
user = msg.user();
function = msg.function();
size = msg.mpiworldsize();
thisMsgId = msg.id();

// Block until we receive
faabric::MpiHostsToRanksMessage hostRankMsg = recvMpiHostRankMsg();
Expand Down Expand Up @@ -572,6 +579,10 @@ void MpiWorld::send(int sendRank,
SPDLOG_TRACE("MPI - send remote {} -> {}", sendRank, recvRank);
sendRemoteMpiMessage(sendRank, recvRank, m);
}

// In non-release builds, track that we have sent this message
faabric::util::tracing::getCallRecords().addRecord(thisMsgId,
faabric::util::tracing::RecordType::MpiPerRankMessageCount, recvRank);
}

void MpiWorld::recv(int sendRank,
Expand Down
1 change: 1 addition & 0 deletions src/util/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ faabric_lib(util
snapshot.cpp
state.cpp
string_tools.cpp
tracing.cpp
timing.cpp
testing.cpp
)
116 changes: 116 additions & 0 deletions src/util/tracing.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
#include <faabric/util/logging.h>
#include <faabric/util/tracing.h>

namespace faabric::util::tracing {
void CallRecords::startRecording(const faabric::Message& msg)
{
#ifndef NDEBUG
if (linkedMsg != nullptr && linkedMsg->id() != msg.id()) {
SPDLOG_ERROR("CallRecords already linked to a different message: (linked: {} != provided: {})",
linkedMsg->id(), msg.id());
throw std::runtime_error("CallRecords linked to a different message");
} else if (linkedMsg == nullptr) {
linkedMsg = std::make_shared<faabric::Message>(msg);
}
#else
;
#endif
}

void CallRecords::stopRecording(faabric::Message& msg)
{
#ifndef NDEBUG
if (linkedMsg == nullptr || linkedMsg->id() != msg.id()) {
SPDLOG_ERROR("CallRecords not linked to the right message: (linked: {} != provided: {})",
linkedMsg->id(), msg.id());
throw std::runtime_error("CallRecords linked to a different message");
}

linkedMsg = nullptr;

// Update the actual faabric message
faabric::CallRecords recordsMsg;
for (const auto& recordType : onGoingRecordings) {
loadRecordsToMessage(recordsMsg, recordType);
}

// Update the original message
*msg.mutable_records() = recordsMsg;
#else
;
#endif
}

void CallRecords::loadRecordsToMessage(faabric::CallRecords& callRecords,
const RecordType& recordType)
{
#ifndef NDEBUG
switch (recordType) {
case (faabric::util::tracing::RecordType::MpiPerRankMessageCount): {
faabric::MpiPerRankMessageCount msgCount;

for (const auto& it : perRankMsgCount) {
msgCount.add_ranks(it.first);
msgCount.add_nummessages(it.second);
}

*callRecords.mutable_mpimsgcount() = msgCount;
break;
}
default: {
SPDLOG_ERROR("Unsupported record type: {}", recordType);
throw std::runtime_error("Unsupported record type");
}
}
#else
;
#endif
}

void CallRecords::addRecord(int msgId, RecordType recordType, int idToIncrement)
{
#ifndef NDEBUG
// Check message id
if (linkedMsg == nullptr || linkedMsg->id() != msgId) {
SPDLOG_ERROR("CallRecords not linked to the right message: (linked: {} != provided: {})",
linkedMsg->id(), msgId);
throw std::runtime_error("CallRecords linked to a different message");
}

// Add the record to the list of on going records if it is not there
bool mustInit = false;
auto it = std::find(onGoingRecordings.begin(), onGoingRecordings.end(), recordType);
if (it == onGoingRecordings.end()) {
onGoingRecordings.push_back(recordType);
mustInit = true;
}

// Finally increment the corresponding record list
switch (recordType) {
case (faabric::util::tracing::RecordType::MpiPerRankMessageCount): {
if (mustInit) {
for (int i = 0; i < linkedMsg->mpiworldsize(); i++) {
perRankMsgCount[i] = 0;
}
}

++perRankMsgCount.at(idToIncrement);
break;
}
default: {
SPDLOG_ERROR("Unsupported record type: {}", recordType);
throw std::runtime_error("Unsupported record type");
}
}
#else
;
#endif
}


CallRecords& getCallRecords()
{
static thread_local CallRecords callRecords;
return callRecords;
}
}
44 changes: 44 additions & 0 deletions tests/test/util/test_tracing.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#include <catch.hpp>

#include "faabric_utils.h"

#include <faabric/proto/faabric.pb.h>
#include <faabric/scheduler/MpiWorld.h>
#include <faabric/util/logging.h> // DELETE MEE
#include <faabric/util/macros.h>
#include <faabric/util/tracing.h>

namespace tests {
TEST_CASE_METHOD(MpiTestFixture,
"Test tracing the number of MPI messages",
"[util][tracing]")
{
faabric::util::tracing::getCallRecords().startRecording(msg);

// Send one message
int rankA1 = 0;
int rankA2 = 1;
MPI_Status status{};

std::vector<int> messageData = { 0, 1, 2 };
auto buffer = new int[messageData.size()];

world.send(
rankA1, rankA2, BYTES(messageData.data()), MPI_INT, messageData.size());
world.recv(
rankA1, rankA2, BYTES(buffer), MPI_INT, messageData.size(), &status);

// Stop recording and check we have only recorded one message
faabric::util::tracing::getCallRecords().stopRecording(msg);
REQUIRE(msg.has_records());
REQUIRE(msg.records().has_mpimsgcount());
REQUIRE(msg.records().mpimsgcount().ranks_size() == worldSize);
for (int i = 0; i < worldSize; i++) {
if (i == rankA2) {
REQUIRE(msg.records().mpimsgcount().nummessages(i) == 1);
} else {
REQUIRE(msg.records().mpimsgcount().nummessages(i) == 0);
}
}
}
}

0 comments on commit 610d1c3

Please sign in to comment.