Skip to content

Commit

Permalink
Adding arbitrary information to execution graph (#166)
Browse files Browse the repository at this point in the history
* add call records and tests

* formatting

* disable tracing by default during tests

* re-factor to be used depending on message flag, change the message layout and rename to ExecGraphDetail

* update comments

* refactor after offline discussion

* don't log chained calls if recording exec graph is not set

* quick test fix

* add serialisation for maps

* fix tests

* self-review

* refactor thread local message's name

* move mpi exec graph tests to separate file

* add checks for serialisation/deserialisation

* cleanup
  • Loading branch information
csegarragonz committed Nov 2, 2021
1 parent dc3d150 commit 832aafe
Show file tree
Hide file tree
Showing 18 changed files with 348 additions and 15 deletions.
4 changes: 2 additions & 2 deletions include/faabric/scheduler/MpiContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ class MpiContext
public:
MpiContext();

int createWorld(const faabric::Message& msg);
int createWorld(faabric::Message& msg);

void joinWorld(const faabric::Message& msg);
void joinWorld(faabric::Message& msg);

bool getIsMpi();

Expand Down
4 changes: 2 additions & 2 deletions include/faabric/scheduler/MpiWorld.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ class MpiWorld
public:
MpiWorld();

void create(const faabric::Message& call, int newId, int newSize);
void create(faabric::Message& call, int newId, int newSize);

void broadcastHostsToRanks();

void initialiseFromMsg(const faabric::Message& msg);
void initialiseFromMsg(faabric::Message& msg);

std::string getHostForRank(int rank);

Expand Down
4 changes: 2 additions & 2 deletions include/faabric/scheduler/MpiWorldRegistry.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ class MpiWorldRegistry
public:
MpiWorldRegistry() = default;

scheduler::MpiWorld& createWorld(const faabric::Message& msg,
scheduler::MpiWorld& createWorld(faabric::Message& msg,
int worldId,
std::string hostOverride = "");

scheduler::MpiWorld& getOrInitialiseWorld(const faabric::Message& msg);
scheduler::MpiWorld& getOrInitialiseWorld(faabric::Message& msg);

scheduler::MpiWorld& getWorld(int worldId);

Expand Down
19 changes: 19 additions & 0 deletions include/faabric/util/exec_graph.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#pragma once

#include <faabric/proto/faabric.pb.h>

#include <functional>
#include <list>
#include <map>

namespace faabric::util::exec_graph {
void addDetail(faabric::Message& msg,
const std::string& key,
const std::string& value);

void incrementCounter(faabric::Message& msg,
const std::string& key,
const int valueToIncrement = 1);

static inline std::string const mpiMsgCountPrefix = "mpi-msgcount-torank-";
}
5 changes: 5 additions & 0 deletions src/proto/faabric.proto
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,11 @@ message Message {
string sgxTag = 38;
bytes sgxPolicy = 39;
bytes sgxResult = 40;

// Exec-graph utils
bool recordExecGraph = 41;
map<string, int32> intExecGraphDetails = 42;
map<string, string> execGraphDetails = 43;
}

// ---------------------------------------------
Expand Down
1 change: 1 addition & 0 deletions src/scheduler/Executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <faabric/util/clock.h>
#include <faabric/util/config.h>
#include <faabric/util/environment.h>
#include <faabric/util/exec_graph.h>
#include <faabric/util/func.h>
#include <faabric/util/gids.h>
#include <faabric/util/logging.h>
Expand Down
4 changes: 2 additions & 2 deletions src/scheduler/MpiContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ MpiContext::MpiContext()
, worldId(-1)
{}

int MpiContext::createWorld(const faabric::Message& msg)
int MpiContext::createWorld(faabric::Message& msg)
{

if (msg.mpirank() > 0) {
Expand All @@ -38,7 +38,7 @@ int MpiContext::createWorld(const faabric::Message& msg)
return worldId;
}

void MpiContext::joinWorld(const faabric::Message& msg)
void MpiContext::joinWorld(faabric::Message& msg)
{
if (!msg.ismpi()) {
// Not an MPI call
Expand Down
23 changes: 20 additions & 3 deletions src/scheduler/MpiWorld.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <faabric/scheduler/MpiWorld.h>
#include <faabric/scheduler/Scheduler.h>
#include <faabric/util/environment.h>
#include <faabric/util/exec_graph.h>
#include <faabric/util/func.h>
#include <faabric/util/gids.h>
#include <faabric/util/macros.h>
Expand Down Expand Up @@ -33,6 +34,9 @@ static thread_local std::unordered_map<
std::unique_ptr<faabric::transport::AsyncSendMessageEndpoint>>
ranksSendEndpoints;

// Id of the message that created this thread-local instance
static thread_local faabric::Message* thisRankMsg = nullptr;

// This is used for mocking in tests
static std::vector<faabric::MpiHostsToRanksMessage> rankMessages;

Expand Down Expand Up @@ -172,11 +176,12 @@ MpiWorld::getUnackedMessageBuffer(int sendRank, int recvRank)
return unackedMessageBuffers[index];
}

void MpiWorld::create(const faabric::Message& call, int newId, int newSize)
void MpiWorld::create(faabric::Message& call, int newId, int newSize)
{
id = newId;
user = call.user();
function = call.function();
thisRankMsg = &call;

size = newSize;

Expand All @@ -194,7 +199,10 @@ void MpiWorld::create(const faabric::Message& call, int newId, int newSize)
msg.set_mpirank(i + 1);
msg.set_mpiworldsize(size);
// Log chained functions to generate execution graphs
sch.logChainedFunction(call.id(), msg.id());
if (thisRankMsg != nullptr && thisRankMsg->recordexecgraph()) {
sch.logChainedFunction(call.id(), msg.id());
msg.set_recordexecgraph(true);
}
}

std::vector<std::string> executedAt;
Expand Down Expand Up @@ -286,12 +294,13 @@ void MpiWorld::destroy()
}
}

void MpiWorld::initialiseFromMsg(const faabric::Message& msg)
void MpiWorld::initialiseFromMsg(faabric::Message& msg)
{
id = msg.mpiworldid();
user = msg.user();
function = msg.function();
size = msg.mpiworldsize();
thisRankMsg = &msg;

// Block until we receive
faabric::MpiHostsToRanksMessage hostRankMsg = recvMpiHostRankMsg();
Expand Down Expand Up @@ -572,6 +581,14 @@ void MpiWorld::send(int sendRank,
SPDLOG_TRACE("MPI - send remote {} -> {}", sendRank, recvRank);
sendRemoteMpiMessage(sendRank, recvRank, m);
}

// If the message is set and recording on, track we have sent this message
if (thisRankMsg != nullptr && thisRankMsg->recordexecgraph()) {
faabric::util::exec_graph::incrementCounter(
*thisRankMsg,
faabric::util::exec_graph::mpiMsgCountPrefix +
std::to_string(recvRank));
}
}

void MpiWorld::recv(int sendRank,
Expand Down
4 changes: 2 additions & 2 deletions src/scheduler/MpiWorldRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ MpiWorldRegistry& getMpiWorldRegistry()
return r;
}

scheduler::MpiWorld& MpiWorldRegistry::createWorld(const faabric::Message& msg,
scheduler::MpiWorld& MpiWorldRegistry::createWorld(faabric::Message& msg,
int worldId,
std::string hostOverride)
{
Expand All @@ -37,7 +37,7 @@ scheduler::MpiWorld& MpiWorldRegistry::createWorld(const faabric::Message& msg,
return worldMap[worldId];
}

MpiWorld& MpiWorldRegistry::getOrInitialiseWorld(const faabric::Message& msg)
MpiWorld& MpiWorldRegistry::getOrInitialiseWorld(faabric::Message& msg)
{
// Create world locally if not exists
int worldId = msg.mpiworldid();
Expand Down
1 change: 1 addition & 0 deletions src/util/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ faabric_lib(util
crash.cpp
delta.cpp
environment.cpp
exec_graph.cpp
files.cpp
func.cpp
gids.cpp
Expand Down
31 changes: 31 additions & 0 deletions src/util/exec_graph.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#include <faabric/util/exec_graph.h>
#include <faabric/util/logging.h>
#include <faabric/util/testing.h>

namespace faabric::util::exec_graph {
void addDetail(faabric::Message& msg,
const std::string& key,
const std::string& value)
{
if (!msg.recordexecgraph()) {
return;
}

auto& stringMap = *msg.mutable_execgraphdetails();

stringMap[key] = value;
}

void incrementCounter(faabric::Message& msg,
const std::string& key,
const int valueToIncrement)
{
if (!msg.recordexecgraph()) {
return;
}

auto& stringMap = *msg.mutable_intexecgraphdetails();

stringMap[key] += valueToIncrement;
}
}
5 changes: 5 additions & 0 deletions src/util/func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ std::shared_ptr<faabric::Message> messageFactoryShared(

std::string thisHost = faabric::util::getSystemConfig().endpointHost;
ptr->set_masterhost(thisHost);

ptr->set_recordexecgraph(false);

return ptr;
}

Expand All @@ -103,6 +106,8 @@ faabric::Message messageFactory(const std::string& user,
std::string thisHost = faabric::util::getSystemConfig().endpointHost;
msg.set_masterhost(thisHost);

msg.set_recordexecgraph(false);

return msg;
}

Expand Down
111 changes: 111 additions & 0 deletions src/util/json.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

#include <cppcodec/base64_rfc4648.hpp>

#include <sstream>

using namespace rapidjson;

namespace faabric::util {
Expand Down Expand Up @@ -183,6 +185,45 @@ std::string messageToJson(const faabric::Message& msg)
a);
}

if (msg.recordexecgraph()) {
d.AddMember("record_exec_graph", msg.recordexecgraph(), a);

if (msg.execgraphdetails_size() > 0) {
std::stringstream ss;
const auto& map = msg.execgraphdetails();
auto it = map.begin();
while (it != map.end()) {
ss << fmt::format("{}:{}", it->first, it->second);
++it;
if (it != map.end()) {
ss << ",";
}
}

std::string out = ss.str();
d.AddMember(
"exec_graph_detail", Value(out.c_str(), out.size()).Move(), a);
}

if (msg.intexecgraphdetails_size() > 0) {
std::stringstream ss;
const auto& map = msg.intexecgraphdetails();
auto it = map.begin();
while (it != map.end()) {
ss << fmt::format("{}:{}", it->first, it->second);
++it;
if (it != map.end()) {
ss << ",";
}
}

std::string out = ss.str();
d.AddMember("int_exec_graph_detail",
Value(out.c_str(), out.size()).Move(),
a);
}
}

StringBuffer sb;
Writer<StringBuffer> writer(sb);
d.Accept(writer);
Expand Down Expand Up @@ -266,6 +307,54 @@ std::string getStringFromJson(Document& doc,
return std::string(valuePtr, valuePtr + it->value.GetStringLength());
}

std::map<std::string, std::string> getStringStringMapFromJson(
Document& doc,
const std::string& key)
{
std::map<std::string, std::string> map;

Value::MemberIterator it = doc.FindMember(key.c_str());
if (it == doc.MemberEnd()) {
return map;
}

const char* valuePtr = it->value.GetString();
std::stringstream ss(
std::string(valuePtr, valuePtr + it->value.GetStringLength()));
std::string keyVal;
while (std::getline(ss, keyVal, ',')) {
auto pos = keyVal.find(":");
std::string key = keyVal.substr(0, pos);
map[key] = keyVal.erase(0, pos + sizeof(char));
}

return map;
}

std::map<std::string, int> getStringIntMapFromJson(Document& doc,
const std::string& key)
{
std::map<std::string, int> map;

Value::MemberIterator it = doc.FindMember(key.c_str());
if (it == doc.MemberEnd()) {
return map;
}

const char* valuePtr = it->value.GetString();
std::stringstream ss(
std::string(valuePtr, valuePtr + it->value.GetStringLength()));
std::string keyVal;
while (std::getline(ss, keyVal, ',')) {
auto pos = keyVal.find(":");
std::string key = keyVal.substr(0, pos);
int val = std::stoi(keyVal.erase(0, pos + sizeof(char)));
map[key] = val;
}

return map;
}

faabric::Message jsonToMessage(const std::string& jsonIn)
{
PROF_START(jsonDecode)
Expand Down Expand Up @@ -324,6 +413,28 @@ faabric::Message jsonToMessage(const std::string& jsonIn)
msg.set_sgxpolicy(getStringFromJson(d, "sgxpolicy", ""));
msg.set_sgxresult(getStringFromJson(d, "sgxresult", ""));

msg.set_recordexecgraph(getBoolFromJson(d, "record_exec_graph", false));

// By default, clear the map
msg.clear_execgraphdetails();
// Fill keypairs if found
auto& msgStrMap = *msg.mutable_execgraphdetails();
std::map<std::string, std::string> strMap =
getStringStringMapFromJson(d, "exec_graph_detail");
for (auto& it : strMap) {
msgStrMap[it.first] = it.second;
}

// By default, clear the map
msg.clear_intexecgraphdetails();
// Fill keypairs if found
auto& msgIntMap = *msg.mutable_intexecgraphdetails();
std::map<std::string, int> intMap =
getStringIntMapFromJson(d, "int_exec_graph_detail");
for (auto& it : intMap) {
msgIntMap[it.first] = it.second;
}

PROF_END(jsonDecode)

return msg;
Expand Down
Loading

0 comments on commit 832aafe

Please sign in to comment.