From 832aafe84bb62e7fe12dcc4eddab81030b64a765 Mon Sep 17 00:00:00 2001 From: Carlos Date: Tue, 2 Nov 2021 09:44:43 +0000 Subject: [PATCH] Adding arbitrary information to execution graph (#166) * add call records and tests * formatting * disable tracing by default during tests * re-factor to be used depending on message flag, change the message layout and rename to ExecGraphDetail * update comments * refactor after offline discussion * don't log chained calls if recording exec graph is not set * quick test fix * add serialisation for maps * fix tests * self-review * refactor thread local message's name * move mpi exec graph tests to separate file * add checks for serialisation/deserialisation * cleanup --- include/faabric/scheduler/MpiContext.h | 4 +- include/faabric/scheduler/MpiWorld.h | 4 +- include/faabric/scheduler/MpiWorldRegistry.h | 4 +- include/faabric/util/exec_graph.h | 19 ++++ src/proto/faabric.proto | 5 + src/scheduler/Executor.cpp | 1 + src/scheduler/MpiContext.cpp | 4 +- src/scheduler/MpiWorld.cpp | 23 +++- src/scheduler/MpiWorldRegistry.cpp | 4 +- src/util/CMakeLists.txt | 1 + src/util/exec_graph.cpp | 31 ++++++ src/util/func.cpp | 5 + src/util/json.cpp | 111 +++++++++++++++++++ tests/test/scheduler/test_exec_graph.cpp | 41 ++++++- tests/test/scheduler/test_mpi_exec_graph.cpp | 77 +++++++++++++ tests/test/util/test_json.cpp | 8 ++ tests/utils/faabric_utils.h | 16 +++ tests/utils/message_utils.cpp | 5 + 18 files changed, 348 insertions(+), 15 deletions(-) create mode 100644 include/faabric/util/exec_graph.h create mode 100644 src/util/exec_graph.cpp create mode 100644 tests/test/scheduler/test_mpi_exec_graph.cpp diff --git a/include/faabric/scheduler/MpiContext.h b/include/faabric/scheduler/MpiContext.h index d898793f1..7fedc3de7 100644 --- a/include/faabric/scheduler/MpiContext.h +++ b/include/faabric/scheduler/MpiContext.h @@ -9,9 +9,9 @@ class MpiContext public: MpiContext(); - int createWorld(const faabric::Message& msg); + int createWorld(faabric::Message& msg); - void joinWorld(const faabric::Message& msg); + void joinWorld(faabric::Message& msg); bool getIsMpi(); diff --git a/include/faabric/scheduler/MpiWorld.h b/include/faabric/scheduler/MpiWorld.h index d927bb5c6..119abf28e 100644 --- a/include/faabric/scheduler/MpiWorld.h +++ b/include/faabric/scheduler/MpiWorld.h @@ -21,11 +21,11 @@ class MpiWorld public: MpiWorld(); - void create(const faabric::Message& call, int newId, int newSize); + void create(faabric::Message& call, int newId, int newSize); void broadcastHostsToRanks(); - void initialiseFromMsg(const faabric::Message& msg); + void initialiseFromMsg(faabric::Message& msg); std::string getHostForRank(int rank); diff --git a/include/faabric/scheduler/MpiWorldRegistry.h b/include/faabric/scheduler/MpiWorldRegistry.h index 303c099df..73ae0788a 100644 --- a/include/faabric/scheduler/MpiWorldRegistry.h +++ b/include/faabric/scheduler/MpiWorldRegistry.h @@ -8,11 +8,11 @@ class MpiWorldRegistry public: MpiWorldRegistry() = default; - scheduler::MpiWorld& createWorld(const faabric::Message& msg, + scheduler::MpiWorld& createWorld(faabric::Message& msg, int worldId, std::string hostOverride = ""); - scheduler::MpiWorld& getOrInitialiseWorld(const faabric::Message& msg); + scheduler::MpiWorld& getOrInitialiseWorld(faabric::Message& msg); scheduler::MpiWorld& getWorld(int worldId); diff --git a/include/faabric/util/exec_graph.h b/include/faabric/util/exec_graph.h new file mode 100644 index 000000000..bc000a3ad --- /dev/null +++ b/include/faabric/util/exec_graph.h @@ -0,0 +1,19 @@ +#pragma once + +#include + +#include +#include +#include + +namespace faabric::util::exec_graph { +void addDetail(faabric::Message& msg, + const std::string& key, + const std::string& value); + +void incrementCounter(faabric::Message& msg, + const std::string& key, + const int valueToIncrement = 1); + +static inline std::string const mpiMsgCountPrefix = "mpi-msgcount-torank-"; +} diff --git a/src/proto/faabric.proto b/src/proto/faabric.proto index 65b2d800c..20e89e65c 100644 --- a/src/proto/faabric.proto +++ b/src/proto/faabric.proto @@ -158,6 +158,11 @@ message Message { string sgxTag = 38; bytes sgxPolicy = 39; bytes sgxResult = 40; + + // Exec-graph utils + bool recordExecGraph = 41; + map intExecGraphDetails = 42; + map execGraphDetails = 43; } // --------------------------------------------- diff --git a/src/scheduler/Executor.cpp b/src/scheduler/Executor.cpp index 0a2c77c67..875a62eb1 100644 --- a/src/scheduler/Executor.cpp +++ b/src/scheduler/Executor.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include diff --git a/src/scheduler/MpiContext.cpp b/src/scheduler/MpiContext.cpp index 4674da0b1..ba895f3c9 100644 --- a/src/scheduler/MpiContext.cpp +++ b/src/scheduler/MpiContext.cpp @@ -11,7 +11,7 @@ MpiContext::MpiContext() , worldId(-1) {} -int MpiContext::createWorld(const faabric::Message& msg) +int MpiContext::createWorld(faabric::Message& msg) { if (msg.mpirank() > 0) { @@ -38,7 +38,7 @@ int MpiContext::createWorld(const faabric::Message& msg) return worldId; } -void MpiContext::joinWorld(const faabric::Message& msg) +void MpiContext::joinWorld(faabric::Message& msg) { if (!msg.ismpi()) { // Not an MPI call diff --git a/src/scheduler/MpiWorld.cpp b/src/scheduler/MpiWorld.cpp index 79908a3ae..ab814f4d5 100644 --- a/src/scheduler/MpiWorld.cpp +++ b/src/scheduler/MpiWorld.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #include @@ -33,6 +34,9 @@ static thread_local std::unordered_map< std::unique_ptr> ranksSendEndpoints; +// Id of the message that created this thread-local instance +static thread_local faabric::Message* thisRankMsg = nullptr; + // This is used for mocking in tests static std::vector rankMessages; @@ -172,11 +176,12 @@ MpiWorld::getUnackedMessageBuffer(int sendRank, int recvRank) return unackedMessageBuffers[index]; } -void MpiWorld::create(const faabric::Message& call, int newId, int newSize) +void MpiWorld::create(faabric::Message& call, int newId, int newSize) { id = newId; user = call.user(); function = call.function(); + thisRankMsg = &call; size = newSize; @@ -194,7 +199,10 @@ void MpiWorld::create(const faabric::Message& call, int newId, int newSize) msg.set_mpirank(i + 1); msg.set_mpiworldsize(size); // Log chained functions to generate execution graphs - sch.logChainedFunction(call.id(), msg.id()); + if (thisRankMsg != nullptr && thisRankMsg->recordexecgraph()) { + sch.logChainedFunction(call.id(), msg.id()); + msg.set_recordexecgraph(true); + } } std::vector executedAt; @@ -286,12 +294,13 @@ void MpiWorld::destroy() } } -void MpiWorld::initialiseFromMsg(const faabric::Message& msg) +void MpiWorld::initialiseFromMsg(faabric::Message& msg) { id = msg.mpiworldid(); user = msg.user(); function = msg.function(); size = msg.mpiworldsize(); + thisRankMsg = &msg; // Block until we receive faabric::MpiHostsToRanksMessage hostRankMsg = recvMpiHostRankMsg(); @@ -572,6 +581,14 @@ void MpiWorld::send(int sendRank, SPDLOG_TRACE("MPI - send remote {} -> {}", sendRank, recvRank); sendRemoteMpiMessage(sendRank, recvRank, m); } + + // If the message is set and recording on, track we have sent this message + if (thisRankMsg != nullptr && thisRankMsg->recordexecgraph()) { + faabric::util::exec_graph::incrementCounter( + *thisRankMsg, + faabric::util::exec_graph::mpiMsgCountPrefix + + std::to_string(recvRank)); + } } void MpiWorld::recv(int sendRank, diff --git a/src/scheduler/MpiWorldRegistry.cpp b/src/scheduler/MpiWorldRegistry.cpp index 37a93866a..2a89647a8 100644 --- a/src/scheduler/MpiWorldRegistry.cpp +++ b/src/scheduler/MpiWorldRegistry.cpp @@ -10,7 +10,7 @@ MpiWorldRegistry& getMpiWorldRegistry() return r; } -scheduler::MpiWorld& MpiWorldRegistry::createWorld(const faabric::Message& msg, +scheduler::MpiWorld& MpiWorldRegistry::createWorld(faabric::Message& msg, int worldId, std::string hostOverride) { @@ -37,7 +37,7 @@ scheduler::MpiWorld& MpiWorldRegistry::createWorld(const faabric::Message& msg, return worldMap[worldId]; } -MpiWorld& MpiWorldRegistry::getOrInitialiseWorld(const faabric::Message& msg) +MpiWorld& MpiWorldRegistry::getOrInitialiseWorld(faabric::Message& msg) { // Create world locally if not exists int worldId = msg.mpiworldid(); diff --git a/src/util/CMakeLists.txt b/src/util/CMakeLists.txt index bee41ed71..d155313b4 100644 --- a/src/util/CMakeLists.txt +++ b/src/util/CMakeLists.txt @@ -7,6 +7,7 @@ faabric_lib(util crash.cpp delta.cpp environment.cpp + exec_graph.cpp files.cpp func.cpp gids.cpp diff --git a/src/util/exec_graph.cpp b/src/util/exec_graph.cpp new file mode 100644 index 000000000..6f9774f74 --- /dev/null +++ b/src/util/exec_graph.cpp @@ -0,0 +1,31 @@ +#include +#include +#include + +namespace faabric::util::exec_graph { +void addDetail(faabric::Message& msg, + const std::string& key, + const std::string& value) +{ + if (!msg.recordexecgraph()) { + return; + } + + auto& stringMap = *msg.mutable_execgraphdetails(); + + stringMap[key] = value; +} + +void incrementCounter(faabric::Message& msg, + const std::string& key, + const int valueToIncrement) +{ + if (!msg.recordexecgraph()) { + return; + } + + auto& stringMap = *msg.mutable_intexecgraphdetails(); + + stringMap[key] += valueToIncrement; +} +} diff --git a/src/util/func.cpp b/src/util/func.cpp index 0bef897d7..a08c220a0 100644 --- a/src/util/func.cpp +++ b/src/util/func.cpp @@ -88,6 +88,9 @@ std::shared_ptr messageFactoryShared( std::string thisHost = faabric::util::getSystemConfig().endpointHost; ptr->set_masterhost(thisHost); + + ptr->set_recordexecgraph(false); + return ptr; } @@ -103,6 +106,8 @@ faabric::Message messageFactory(const std::string& user, std::string thisHost = faabric::util::getSystemConfig().endpointHost; msg.set_masterhost(thisHost); + msg.set_recordexecgraph(false); + return msg; } diff --git a/src/util/json.cpp b/src/util/json.cpp index 605f933b6..bdcb06676 100644 --- a/src/util/json.cpp +++ b/src/util/json.cpp @@ -8,6 +8,8 @@ #include +#include + using namespace rapidjson; namespace faabric::util { @@ -183,6 +185,45 @@ std::string messageToJson(const faabric::Message& msg) a); } + if (msg.recordexecgraph()) { + d.AddMember("record_exec_graph", msg.recordexecgraph(), a); + + if (msg.execgraphdetails_size() > 0) { + std::stringstream ss; + const auto& map = msg.execgraphdetails(); + auto it = map.begin(); + while (it != map.end()) { + ss << fmt::format("{}:{}", it->first, it->second); + ++it; + if (it != map.end()) { + ss << ","; + } + } + + std::string out = ss.str(); + d.AddMember( + "exec_graph_detail", Value(out.c_str(), out.size()).Move(), a); + } + + if (msg.intexecgraphdetails_size() > 0) { + std::stringstream ss; + const auto& map = msg.intexecgraphdetails(); + auto it = map.begin(); + while (it != map.end()) { + ss << fmt::format("{}:{}", it->first, it->second); + ++it; + if (it != map.end()) { + ss << ","; + } + } + + std::string out = ss.str(); + d.AddMember("int_exec_graph_detail", + Value(out.c_str(), out.size()).Move(), + a); + } + } + StringBuffer sb; Writer writer(sb); d.Accept(writer); @@ -266,6 +307,54 @@ std::string getStringFromJson(Document& doc, return std::string(valuePtr, valuePtr + it->value.GetStringLength()); } +std::map getStringStringMapFromJson( + Document& doc, + const std::string& key) +{ + std::map map; + + Value::MemberIterator it = doc.FindMember(key.c_str()); + if (it == doc.MemberEnd()) { + return map; + } + + const char* valuePtr = it->value.GetString(); + std::stringstream ss( + std::string(valuePtr, valuePtr + it->value.GetStringLength())); + std::string keyVal; + while (std::getline(ss, keyVal, ',')) { + auto pos = keyVal.find(":"); + std::string key = keyVal.substr(0, pos); + map[key] = keyVal.erase(0, pos + sizeof(char)); + } + + return map; +} + +std::map getStringIntMapFromJson(Document& doc, + const std::string& key) +{ + std::map map; + + Value::MemberIterator it = doc.FindMember(key.c_str()); + if (it == doc.MemberEnd()) { + return map; + } + + const char* valuePtr = it->value.GetString(); + std::stringstream ss( + std::string(valuePtr, valuePtr + it->value.GetStringLength())); + std::string keyVal; + while (std::getline(ss, keyVal, ',')) { + auto pos = keyVal.find(":"); + std::string key = keyVal.substr(0, pos); + int val = std::stoi(keyVal.erase(0, pos + sizeof(char))); + map[key] = val; + } + + return map; +} + faabric::Message jsonToMessage(const std::string& jsonIn) { PROF_START(jsonDecode) @@ -324,6 +413,28 @@ faabric::Message jsonToMessage(const std::string& jsonIn) msg.set_sgxpolicy(getStringFromJson(d, "sgxpolicy", "")); msg.set_sgxresult(getStringFromJson(d, "sgxresult", "")); + msg.set_recordexecgraph(getBoolFromJson(d, "record_exec_graph", false)); + + // By default, clear the map + msg.clear_execgraphdetails(); + // Fill keypairs if found + auto& msgStrMap = *msg.mutable_execgraphdetails(); + std::map strMap = + getStringStringMapFromJson(d, "exec_graph_detail"); + for (auto& it : strMap) { + msgStrMap[it.first] = it.second; + } + + // By default, clear the map + msg.clear_intexecgraphdetails(); + // Fill keypairs if found + auto& msgIntMap = *msg.mutable_intexecgraphdetails(); + std::map intMap = + getStringIntMapFromJson(d, "int_exec_graph_detail"); + for (auto& it : intMap) { + msgIntMap[it.first] = it.second; + } + PROF_END(jsonDecode) return msg; diff --git a/tests/test/scheduler/test_exec_graph.cpp b/tests/test/scheduler/test_exec_graph.cpp index 755053ce6..b1619b171 100644 --- a/tests/test/scheduler/test_exec_graph.cpp +++ b/tests/test/scheduler/test_exec_graph.cpp @@ -7,12 +7,13 @@ #include #include #include +#include #include using namespace scheduler; namespace tests { -TEST_CASE("Test execution graph", "[scheduler]") +TEST_CASE("Test execution graph", "[scheduler][exec-graph]") { faabric::Message msgA = faabric::util::messageFactory("demo", "echo"); faabric::Message msgB1 = faabric::util::messageFactory("demo", "echo"); @@ -73,10 +74,11 @@ TEST_CASE("Test execution graph", "[scheduler]") TEST_CASE_METHOD(MpiBaseTestFixture, "Test MPI execution graph", - "[mpi][scheduler]") + "[mpi][scheduler][exec-graph]") { faabric::scheduler::MpiWorld world; msg.set_ismpi(true); + msg.set_recordexecgraph(true); // Update the result for the master message sch.setFunctionResult(msg); @@ -96,6 +98,7 @@ TEST_CASE_METHOD(MpiBaseTestFixture, messages.at(rank).set_mpiworldid(worldId); messages.at(rank).set_mpirank(rank); messages.at(rank).set_mpiworldsize(worldSize); + messages.at(rank).set_recordexecgraph(true); } world.create(msg, worldId, worldSize); @@ -142,4 +145,38 @@ TEST_CASE_METHOD(MpiBaseTestFixture, checkExecGraphEquality(expected, actual); } + +TEST_CASE("Test exec graph details", "[util][exec-graph]") +{ + faabric::Message msg = faabric::util::messageFactory("foo", "bar"); + std::string expectedKey = "foo"; + std::string expectedStringValue = "bar"; + int expectedIntValue = 1; + + // By default, recording is disabled + REQUIRE(msg.recordexecgraph() == false); + + // If we add a recording while disabled, nothing changes + faabric::util::exec_graph::incrementCounter( + msg, expectedKey, expectedIntValue); + faabric::util::exec_graph::addDetail(msg, expectedKey, expectedStringValue); + REQUIRE(msg.intexecgraphdetails_size() == 0); + REQUIRE(msg.execgraphdetails_size() == 0); + + // We can turn it on + msg.set_recordexecgraph(true); + + // We can add records either to a string or to an int map + faabric::util::exec_graph::incrementCounter( + msg, expectedKey, expectedIntValue); + faabric::util::exec_graph::addDetail(msg, expectedKey, expectedStringValue); + + // Both change the behaviour of the underlying message + REQUIRE(msg.intexecgraphdetails_size() == 1); + REQUIRE(msg.execgraphdetails_size() == 1); + REQUIRE(msg.intexecgraphdetails().count(expectedKey) == 1); + REQUIRE(msg.intexecgraphdetails().at(expectedKey) == expectedIntValue); + REQUIRE(msg.execgraphdetails().count(expectedKey) == 1); + REQUIRE(msg.execgraphdetails().at(expectedKey) == expectedStringValue); +} } diff --git a/tests/test/scheduler/test_mpi_exec_graph.cpp b/tests/test/scheduler/test_mpi_exec_graph.cpp new file mode 100644 index 000000000..4bfe17b0b --- /dev/null +++ b/tests/test/scheduler/test_mpi_exec_graph.cpp @@ -0,0 +1,77 @@ +#include + +#include "faabric_utils.h" + +#include +#include +#include + +namespace tests { +TEST_CASE_METHOD(MpiTestFixture, + "Test tracing the number of MPI messages", + "[util][exec-graph]") +{ + msg.set_recordexecgraph(true); + + // Send one message + int rankA1 = 0; + int rankA2 = 1; + MPI_Status status{}; + + std::vector messageData = { 0, 1, 2 }; + auto buffer = new int[messageData.size()]; + + int numToSend = 10; + std::string expectedKey = + faabric::util::exec_graph::mpiMsgCountPrefix + std::to_string(rankA2); + + for (int i = 0; i < numToSend; i++) { + world.send(rankA1, + rankA2, + BYTES(messageData.data()), + MPI_INT, + messageData.size()); + world.recv( + rankA1, rankA2, BYTES(buffer), MPI_INT, messageData.size(), &status); + } + + REQUIRE(msg.intexecgraphdetails_size() == 1); + REQUIRE(msg.execgraphdetails_size() == 0); + REQUIRE(msg.intexecgraphdetails().count(expectedKey) == 1); + REQUIRE(msg.intexecgraphdetails().at(expectedKey) == numToSend); +} + +TEST_CASE_METHOD(MpiTestFixture, + "Test tracing is disabled if flag in message not set", + "[util][exec-graph]") +{ + // Disable test mode and set message flag to true + msg.set_recordexecgraph(false); + + // Send one message + int rankA1 = 0; + int rankA2 = 1; + MPI_Status status{}; + + std::vector messageData = { 0, 1, 2 }; + auto buffer = new int[messageData.size()]; + + int numToSend = 10; + std::string expectedKey = + faabric::util::exec_graph::mpiMsgCountPrefix + std::to_string(rankA2); + + for (int i = 0; i < numToSend; i++) { + world.send(rankA1, + rankA2, + BYTES(messageData.data()), + MPI_INT, + messageData.size()); + world.recv( + rankA1, rankA2, BYTES(buffer), MPI_INT, messageData.size(), &status); + } + + // Stop recording and check we have recorded no message + REQUIRE(msg.intexecgraphdetails_size() == 0); + REQUIRE(msg.execgraphdetails_size() == 0); +} +} diff --git a/tests/test/util/test_json.cpp b/tests/test/util/test_json.cpp index 2e2ecfa5a..f8ac1766d 100644 --- a/tests/test/util/test_json.cpp +++ b/tests/test/util/test_json.cpp @@ -4,6 +4,8 @@ #include +#include + using namespace faabric::util; namespace tests { @@ -39,6 +41,12 @@ TEST_CASE("Test message to JSON round trip", "[util]") msg.set_sgxpolicy("test policy string"); msg.set_sgxresult("test result string"); + msg.set_recordexecgraph(true); + auto& map = *msg.mutable_execgraphdetails(); + map["foo"] = "bar"; + auto& intMap = *msg.mutable_intexecgraphdetails(); + intMap["foo"] = 0; + SECTION("Dodgy characters") { msg.set_inputdata("[0], %$ 2233 9"); } SECTION("Bytes") diff --git a/tests/utils/faabric_utils.h b/tests/utils/faabric_utils.h index 3a9ea939c..b6c50c5a2 100644 --- a/tests/utils/faabric_utils.h +++ b/tests/utils/faabric_utils.h @@ -1,5 +1,7 @@ #pragma once +#include + #include "fixtures.h" #include @@ -63,6 +65,20 @@ using namespace faabric; namespace tests { void cleanFaabric(); +template +void checkMessageMapEquality(T mapA, T mapB) +{ + REQUIRE(mapA.size() == mapB.size()); + auto itA = mapA.begin(); + auto itB = mapB.begin(); + while (itA != mapA.end() && itB != mapB.end()) { + REQUIRE(itA->first == itB->first); + REQUIRE(itA->second == itB->second); + itA++; + itB++; + } +} + void checkMessageEquality(const faabric::Message& msgA, const faabric::Message& msgB); diff --git a/tests/utils/message_utils.cpp b/tests/utils/message_utils.cpp index d441d26b8..58924958c 100644 --- a/tests/utils/message_utils.cpp +++ b/tests/utils/message_utils.cpp @@ -47,5 +47,10 @@ void checkMessageEquality(const faabric::Message& msgA, REQUIRE(msgA.sgxtag() == msgB.sgxtag()); REQUIRE(msgA.sgxpolicy() == msgB.sgxpolicy()); REQUIRE(msgA.sgxresult() == msgB.sgxresult()); + + REQUIRE(msgA.recordexecgraph() == msgB.recordexecgraph()); + checkMessageMapEquality(msgA.execgraphdetails(), msgB.execgraphdetails()); + checkMessageMapEquality(msgA.intexecgraphdetails(), + msgB.intexecgraphdetails()); } }