diff --git a/include/faabric/scheduler/Scheduler.h b/include/faabric/scheduler/Scheduler.h index f95219959..48da69b86 100644 --- a/include/faabric/scheduler/Scheduler.h +++ b/include/faabric/scheduler/Scheduler.h @@ -81,13 +81,12 @@ class Executor uint32_t threadPoolSize = 0; private: - std::string lastSnapshot; - std::atomic claimed = false; std::mutex threadsMutex; std::vector> threadPoolThreads; std::vector> deadThreads; + std::set availablePoolThreads; std::vector> threadTaskQueues; @@ -105,6 +104,10 @@ class Scheduler std::shared_ptr req, bool forceLocal = false); + faabric::util::SchedulingDecision callFunctions( + std::shared_ptr req, + faabric::util::SchedulingDecision& hint); + void reset(); void resetThreadLocalCache(); @@ -204,6 +207,8 @@ class Scheduler std::promise>> localResults; + std::unordered_map> pushedSnapshotsMap; + std::mutex localResultsMutex; // ---- Clients ---- @@ -226,6 +231,15 @@ class Scheduler std::unordered_map> registeredHosts; + faabric::util::SchedulingDecision makeSchedulingDecision( + std::shared_ptr req, + bool forceLocal); + + faabric::util::SchedulingDecision doCallFunctions( + std::shared_ptr req, + faabric::util::SchedulingDecision& decision, + faabric::util::FullLock& lock); + std::shared_ptr claimExecutor( faabric::Message& msg, faabric::util::FullLock& schedulerLock); @@ -233,13 +247,6 @@ class Scheduler std::vector getUnregisteredHosts(const std::string& funcStr, bool noCache = false); - int scheduleFunctionsOnHost( - const std::string& host, - std::shared_ptr req, - faabric::util::SchedulingDecision& decision, - int offset, - faabric::util::SnapshotData* snapshot); - // ---- Accounting and debugging ---- std::vector recordedMessagesAll; std::vector recordedMessagesLocal; diff --git a/include/faabric/transport/PointToPointBroker.h b/include/faabric/transport/PointToPointBroker.h index cb4bd33c8..a596a1fb5 100644 --- a/include/faabric/transport/PointToPointBroker.h +++ b/include/faabric/transport/PointToPointBroker.h @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -26,10 +27,16 @@ class PointToPointGroup public: static std::shared_ptr getGroup(int groupId); + static std::shared_ptr getOrAwaitGroup(int groupId); + static bool groupExists(int groupId); static void addGroup(int appId, int groupId, int groupSize); + static void addGroupIfNotExists(int appId, int groupId, int groupSize); + + static void clearGroup(int groupId); + static void clear(); PointToPointGroup(int appId, int groupIdIn, int groupSizeIn); @@ -77,10 +84,6 @@ class PointToPointGroup std::queue lockWaiters; void notifyLocked(int groupIdx); - - void masterLock(int groupIdx, bool recursive); - - void masterUnlock(int groupIdx, bool recursive); }; class PointToPointBroker @@ -108,21 +111,23 @@ class PointToPointBroker std::vector recvMessage(int groupId, int sendIdx, int recvIdx); + void clearGroup(int groupId); + void clear(); void resetThreadLocalCache(); private: + faabric::util::SystemConfig& conf; + std::shared_mutex brokerMutex; std::unordered_map> groupIdIdxsMap; std::unordered_map mappings; - std::unordered_map groupMappingsFlags; - std::unordered_map groupMappingMutexes; - std::unordered_map groupMappingCvs; + std::unordered_map groupFlags; - faabric::util::SystemConfig& conf; + faabric::util::FlagWaiter& getGroupFlag(int groupId); }; PointToPointBroker& getPointToPointBroker(); diff --git a/include/faabric/transport/PointToPointServer.h b/include/faabric/transport/PointToPointServer.h index 29a15e77c..d62fe6b76 100644 --- a/include/faabric/transport/PointToPointServer.h +++ b/include/faabric/transport/PointToPointServer.h @@ -11,7 +11,7 @@ class PointToPointServer final : public MessageEndpointServer PointToPointServer(); private: - PointToPointBroker& reg; + PointToPointBroker& broker; void doAsyncRecv(int header, const uint8_t* buffer, diff --git a/include/faabric/util/locks.h b/include/faabric/util/locks.h index ed9037763..1ab7e9fad 100644 --- a/include/faabric/util/locks.h +++ b/include/faabric/util/locks.h @@ -1,10 +1,33 @@ #pragma once +#include + +#include +#include #include #include +#define DEFAULT_FLAG_WAIT_MS 10000 + namespace faabric::util { typedef std::unique_lock UniqueLock; typedef std::unique_lock FullLock; typedef std::shared_lock SharedLock; + +class FlagWaiter +{ + public: + FlagWaiter(int timeoutMsIn = DEFAULT_FLAG_WAIT_MS); + + void waitOnFlag(); + + void setFlag(bool value); + + private: + int timeoutMs; + + std::mutex flagMx; + std::condition_variable cv; + std::atomic flag; +}; } diff --git a/include/faabric/util/snapshot.h b/include/faabric/util/snapshot.h index e0d195f6f..6d0c2748e 100644 --- a/include/faabric/util/snapshot.h +++ b/include/faabric/util/snapshot.h @@ -7,6 +7,7 @@ #include #include +#include namespace faabric::util { @@ -19,7 +20,6 @@ enum SnapshotDataType enum SnapshotMergeOperation { Overwrite, - Ignore, Sum, Product, Subtract, @@ -27,14 +27,6 @@ enum SnapshotMergeOperation Min }; -struct SnapshotMergeRegion -{ - uint32_t offset = 0; - size_t length = 0; - SnapshotDataType dataType = SnapshotDataType::Raw; - SnapshotMergeOperation operation = SnapshotMergeOperation::Overwrite; -}; - class SnapshotDiff { public: @@ -44,6 +36,8 @@ class SnapshotDiff size_t size = 0; const uint8_t* data = nullptr; + bool noChange = false; + SnapshotDiff() = default; SnapshotDiff(SnapshotDataType dataTypeIn, @@ -58,13 +52,19 @@ class SnapshotDiff data = dataIn; size = sizeIn; } +}; - SnapshotDiff(uint32_t offsetIn, const uint8_t* dataIn, size_t sizeIn) - { - offset = offsetIn; - data = dataIn; - size = sizeIn; - } +class SnapshotMergeRegion +{ + public: + uint32_t offset = 0; + size_t length = 0; + SnapshotDataType dataType = SnapshotDataType::Raw; + SnapshotMergeOperation operation = SnapshotMergeOperation::Overwrite; + + void addDiffs(std::vector& diffs, + const uint8_t* original, + const uint8_t* updated); }; class SnapshotData @@ -84,11 +84,12 @@ class SnapshotData void addMergeRegion(uint32_t offset, size_t length, SnapshotDataType dataType, - SnapshotMergeOperation operation); + SnapshotMergeOperation operation, + bool overwrite = false); private: - // Note - we care about the order of this map, as we iterate through it in - // order of offsets + // Note - we care about the order of this map, as we iterate through it + // in order of offsets std::map mergeRegions; }; diff --git a/src/scheduler/Executor.cpp b/src/scheduler/Executor.cpp index 875a62eb1..c084047c5 100644 --- a/src/scheduler/Executor.cpp +++ b/src/scheduler/Executor.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -45,6 +46,11 @@ Executor::Executor(faabric::Message& msg) // Set an ID for this Executor id = conf.endpointHost + "_" + std::to_string(faabric::util::generateGid()); SPDLOG_DEBUG("Starting executor {}", id); + + // Mark all thread pool threads as available + for (int i = 0; i < threadPoolSize; i++) { + availablePoolThreads.insert(i); + } } Executor::~Executor() {} @@ -83,8 +89,6 @@ void Executor::finish() // Reset variables boundMessage.Clear(); - lastSnapshot = ""; - claimed = false; threadPoolThreads.clear(); @@ -108,8 +112,7 @@ void Executor::executeTasks(std::vector msgIdxs, faabric::util::UniqueLock lock(threadsMutex); // Restore if necessary. If we're executing threads on the master host we - // assume we don't need to restore, but for everything else we do. If we've - // already restored from this snapshot, we don't do so again. + // assume we don't need to restore, but for everything else we do. faabric::Message& firstMsg = req->mutable_messages()->at(0); std::string snapshotKey = firstMsg.snapshotkey(); std::string thisHost = faabric::util::getSystemConfig().endpointHost; @@ -117,21 +120,10 @@ void Executor::executeTasks(std::vector msgIdxs, bool isMaster = firstMsg.masterhost() == thisHost; bool isThreads = req->type() == faabric::BatchExecuteRequest::THREADS; bool isSnapshot = !snapshotKey.empty(); - bool alreadyRestored = snapshotKey == lastSnapshot; - if (isSnapshot && !alreadyRestored) { - if ((!isMaster && isThreads) || !isThreads) { - SPDLOG_DEBUG("Restoring {} from snapshot {}", funcStr, snapshotKey); - lastSnapshot = snapshotKey; - restore(firstMsg); - } else { - SPDLOG_DEBUG("Skipping snapshot restore on master {} [{}]", - funcStr, - snapshotKey); - } - } else if (isSnapshot) { - SPDLOG_DEBUG( - "Skipping already restored snapshot {} [{}]", funcStr, snapshotKey); + if (isSnapshot && !isMaster) { + SPDLOG_DEBUG("Restoring {} from snapshot {}", funcStr, snapshotKey); + restore(firstMsg); } // Reset dirty page tracking if we're executing threads. @@ -143,31 +135,58 @@ void Executor::executeTasks(std::vector msgIdxs, } // Set up shared counter for this batch of tasks - auto batchCounter = - std::make_shared>(req->messages_size()); + auto batchCounter = std::make_shared>(msgIdxs.size()); // Work out if we should skip the reset after this batch. This only needs to // happen when we're executing threads on the master host, in which case the // original function call will cause a reset bool skipReset = isMaster && isThreads; - // Iterate through and invoke tasks + // Iterate through and invoke tasks. By default, we allocate tasks + // one-to-one with thread pool threads. Only once the pool is exhausted do + // we start overloading for (int msgIdx : msgIdxs) { const faabric::Message& msg = req->messages().at(msgIdx); - // If executing threads, we must always keep thread pool index zero - // free, as this may be executing the function that spawned them - int threadPoolIdx; - if (isThreads) { - assert(threadPoolSize > 1); - threadPoolIdx = (msg.appidx() % (threadPoolSize - 1)) + 1; + int threadPoolIdx = -1; + if (availablePoolThreads.empty()) { + // Here all threads are still executing, so we have to overload. + // If any tasks are blocking we risk a deadlock, and can no longer + // guarantee the application will finish. + // In general if we're on the master host and this is a thread, we + // should avoid the zeroth and first pool threads as they are likely + // to be the main thread and the zeroth in the communication group, + // so will be blocking. + if (isThreads && isMaster) { + if (threadPoolSize <= 2) { + SPDLOG_ERROR( + "Insufficient pool threads ({}) to overload {} idx {}", + threadPoolSize, + funcStr, + msg.appidx()); + + throw std::runtime_error("Insufficient pool threads"); + } + + threadPoolIdx = (msg.appidx() % (threadPoolSize - 2)) + 2; + } else { + threadPoolIdx = msg.appidx() % threadPoolSize; + } + + SPDLOG_DEBUG("Overloaded app index {} to thread {}", + msg.appidx(), + threadPoolIdx); } else { - threadPoolIdx = msg.appidx() % threadPoolSize; + // Take next from those that are available + threadPoolIdx = *availablePoolThreads.begin(); + availablePoolThreads.erase(threadPoolIdx); + + SPDLOG_TRACE("Assigned app index {} to thread {}", + msg.appidx(), + threadPoolIdx); } // Enqueue the task - SPDLOG_TRACE( - "Assigning app index {} to thread {}", msg.appidx(), threadPoolIdx); threadTaskQueues[threadPoolIdx].enqueue(ExecutorTask( msgIdx, req, batchCounter, needsSnapshotPush, skipReset)); @@ -184,6 +203,8 @@ void Executor::threadPoolThread(int threadPoolIdx) SPDLOG_DEBUG("Thread pool thread {}:{} starting up", id, threadPoolIdx); auto& sch = faabric::scheduler::getScheduler(); + faabric::transport::PointToPointBroker& broker = + faabric::transport::getPointToPointBroker(); const auto& conf = faabric::util::getSystemConfig(); bool selfShutdown = false; @@ -263,6 +284,9 @@ void Executor::threadPoolThread(int threadPoolIdx) faabric::snapshot::getSnapshotRegistry().getSnapshot( msg.snapshotkey()); + SPDLOG_TRACE("Diffing pre and post execution snapshots for {}", + msg.snapshotkey()); + std::vector diffs = snapshotPreExecution.getChangeDiffs(snapshotPostExecution.data, snapshotPostExecution.size); @@ -287,6 +311,12 @@ void Executor::threadPoolThread(int threadPoolIdx) releaseClaim(); } + // Return this thread index to the pool available for scheduling + { + faabric::util::UniqueLock lock(threadsMutex); + availablePoolThreads.insert(threadPoolIdx); + } + // Vacate the slot occupied by this task. This must be done after // releasing the claim on this executor, otherwise the scheduler may try // to schedule another function and be unable to reuse this executor. @@ -334,8 +364,9 @@ void Executor::threadPoolThread(int threadPoolIdx) } // We have to clean up TLS here as this should be the last use of the - // scheduler from this thread + // scheduler and point-to-point broker from this thread sch.resetThreadLocalCache(); + broker.resetThreadLocalCache(); } bool Executor::tryClaim() diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index 263ac7d87..946700919 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -8,11 +8,13 @@ #include #include #include +#include #include #include #include #include #include +#include #include #include @@ -117,6 +119,7 @@ void Scheduler::reset() availableHostsCache.clear(); registeredHosts.clear(); threadResults.clear(); + pushedSnapshotsMap.clear(); // Records recordedMessagesAll.clear(); @@ -209,14 +212,9 @@ faabric::util::SchedulingDecision Scheduler::callFunctions( std::shared_ptr req, bool forceLocal) { - // Extract properties of the request - int nMessages = req->messages_size(); - bool isThreads = req->type() == faabric::BatchExecuteRequest::THREADS; - // Note, we assume all the messages are for the same function and have the // same master host faabric::Message& firstMsg = req->mutable_messages()->at(0); - std::string funcStr = faabric::util::funcToString(firstMsg, false); std::string masterHost = firstMsg.masterhost(); if (masterHost.empty()) { std::string funcStrWithId = faabric::util::funcToString(firstMsg, true); @@ -224,169 +222,199 @@ faabric::util::SchedulingDecision Scheduler::callFunctions( throw std::runtime_error("Message with no master host"); } - // Set up scheduling decision - SchedulingDecision decision(firstMsg.appid(), firstMsg.groupid()); - - // TODO - more granular locking, this is conservative - faabric::util::FullLock lock(mx); - // If we're not the master host, we need to forward the request back to the // master host. This will only happen if a nested batch execution happens. - std::vector localMessageIdxs; if (!forceLocal && masterHost != thisHost) { - SPDLOG_DEBUG( - "Forwarding {} {} back to master {}", nMessages, funcStr, masterHost); + std::string funcStr = faabric::util::funcToString(firstMsg, false); + SPDLOG_DEBUG("Forwarding {} back to master {}", funcStr, masterHost); getFunctionCallClient(masterHost).executeFunctions(req); + SchedulingDecision decision(firstMsg.appid(), firstMsg.groupid()); decision.returnHost = masterHost; return decision; } + faabric::util::FullLock lock(mx); + + SchedulingDecision decision = makeSchedulingDecision(req, forceLocal); + + // Send out point-to-point mappings if necessary (unless being forced to + // execute locally, in which case they will be transmitted from the + // master) + if (!forceLocal && (firstMsg.groupid() > 0)) { + broker.setAndSendMappingsFromSchedulingDecision(decision); + } + + // Pass decision as hint + return doCallFunctions(req, decision, lock); +} + +faabric::util::SchedulingDecision Scheduler::makeSchedulingDecision( + std::shared_ptr req, + bool forceLocal) +{ + int nMessages = req->messages_size(); + faabric::Message& firstMsg = req->mutable_messages()->at(0); + std::string funcStr = faabric::util::funcToString(firstMsg, false); + + std::vector hosts; if (forceLocal) { // We're forced to execute locally here so we do all the messages for (int i = 0; i < nMessages; i++) { - localMessageIdxs.emplace_back(i); - decision.addMessage(thisHost, req->messages().at(i)); + hosts.push_back(thisHost); } } else { // At this point we know we're the master host, and we've not been // asked to force full local execution. - // Get a list of other registered hosts - std::set& thisRegisteredHosts = registeredHosts[funcStr]; - - // For threads/ processes we need to have a snapshot key and be - // ready to push the snapshot to other hosts. - // We also have to broadcast the latest snapshots to all registered - // hosts, regardless of whether they're going to execute a function. - // This ensures everything is up to date, and we don't have to - // maintain different records of which hosts hold which updates. - faabric::util::SnapshotData snapshotData; - std::string snapshotKey = firstMsg.snapshotkey(); - bool snapshotNeeded = - req->type() == req->THREADS || req->type() == req->PROCESSES; - - if (snapshotNeeded) { - if (snapshotKey.empty()) { - SPDLOG_ERROR("No snapshot provided for {}", funcStr); - throw std::runtime_error( - "Empty snapshot for distributed threads/ processes"); - } - - snapshotData = - faabric::snapshot::getSnapshotRegistry().getSnapshot(snapshotKey); - - if (!thisRegisteredHosts.empty()) { - std::vector snapshotDiffs = - snapshotData.getDirtyPages(); - - // Do the snapshot diff pushing - if (!snapshotDiffs.empty()) { - for (const auto& h : thisRegisteredHosts) { - SPDLOG_DEBUG("Pushing {} snapshot diffs for {} to {}", - snapshotDiffs.size(), - funcStr, - h); - SnapshotClient& c = getSnapshotClient(h); - c.pushSnapshotDiffs( - snapshotKey, firstMsg.groupid(), snapshotDiffs); - } - } - - // Now reset the dirty page tracking, as we want the next batch - // of diffs to contain everything from now on (including the - // updates sent back from all the threads) - SPDLOG_DEBUG("Resetting dirty tracking after pushing diffs {}", - funcStr); - faabric::util::resetDirtyTracking(); - } - } - // Work out how many we can handle locally - int nLocally; - { - int slots = thisHostResources.slots(); + int slots = thisHostResources.slots(); - // Work out available cores, flooring at zero - int available = - slots - this->thisHostUsedSlots.load(std::memory_order_acquire); - available = std::max(available, 0); + // Work out available cores, flooring at zero + int available = + slots - this->thisHostUsedSlots.load(std::memory_order_acquire); + available = std::max(available, 0); - // Claim as many as we can - nLocally = std::min(available, nMessages); - } + // Claim as many as we can + int nLocally = std::min(available, nMessages); // Add those that can be executed locally - if (nLocally > 0) { - SPDLOG_DEBUG( - "Executing {}/{} {} locally", nLocally, nMessages, funcStr); - for (int i = 0; i < nLocally; i++) { - localMessageIdxs.emplace_back(i); - decision.addMessage(thisHost, req->messages().at(i)); - } + for (int i = 0; i < nLocally; i++) { + hosts.push_back(thisHost); } - // If some are left, we need to distribute - int offset = nLocally; - if (offset < nMessages) { - // Schedule first to already registered hosts + // If some are left, we need to distribute. + // First try and do so on already registered hosts. + int remainder = nMessages - nLocally; + if (remainder > 0) { + std::set& thisRegisteredHosts = + registeredHosts[funcStr]; + for (const auto& h : thisRegisteredHosts) { - int nOnThisHost = scheduleFunctionsOnHost( - h, req, decision, offset, &snapshotData); + // Work out resources on this host + faabric::HostResources r = getHostResources(h); + int available = r.slots() - r.usedslots(); + int nOnThisHost = std::min(available, remainder); + + for (int i = 0; i < nOnThisHost; i++) { + hosts.push_back(h); + } - offset += nOnThisHost; - if (offset >= nMessages) { + remainder -= nOnThisHost; + if (remainder <= 0) { break; } } } - // Now schedule to unregistered hosts if there are some left - if (offset < nMessages) { + // Now schedule to unregistered hosts if there are messages left + if (remainder > 0) { std::vector unregisteredHosts = getUnregisteredHosts(funcStr); - for (auto& h : unregisteredHosts) { + for (const auto& h : unregisteredHosts) { // Skip if this host if (h == thisHost) { continue; } - // Schedule functions on the host - int nOnThisHost = scheduleFunctionsOnHost( - h, req, decision, offset, &snapshotData); + // Work out resources on this host + faabric::HostResources r = getHostResources(h); + int available = r.slots() - r.usedslots(); + int nOnThisHost = std::min(available, remainder); // Register the host if it's exected a function if (nOnThisHost > 0) { - SPDLOG_DEBUG("Registering {} for {}", h, funcStr); registeredHosts[funcStr].insert(h); } - offset += nOnThisHost; - if (offset >= nMessages) { + for (int i = 0; i < nOnThisHost; i++) { + hosts.push_back(h); + } + + remainder -= nOnThisHost; + if (remainder <= 0) { break; } } } // At this point there's no more capacity in the system, so we - // just need to execute locally - if (offset < nMessages) { - SPDLOG_DEBUG("Overloading {}/{} {} locally", - nMessages - offset, - nMessages, - funcStr); + // just need to overload locally + if (remainder > 0) { + SPDLOG_DEBUG( + "Overloading {}/{} {} locally", remainder, nMessages, funcStr); - for (; offset < nMessages; offset++) { - localMessageIdxs.emplace_back(offset); - decision.addMessage(thisHost, req->messages().at(offset)); + for (int i = 0; i < remainder; i++) { + hosts.push_back(thisHost); } } + } + + // Sanity check + assert(hosts.size() == nMessages); - // Sanity check - assert(offset == nMessages); + // Set up decision + SchedulingDecision decision(firstMsg.appid(), firstMsg.groupid()); + for (int i = 0; i < hosts.size(); i++) { + decision.addMessage(hosts.at(i), req->messages().at(i)); } + return decision; +} + +faabric::util::SchedulingDecision Scheduler::callFunctions( + std::shared_ptr req, + faabric::util::SchedulingDecision& hint) +{ + faabric::util::FullLock lock(mx); + return doCallFunctions(req, hint, lock); +} + +faabric::util::SchedulingDecision Scheduler::doCallFunctions( + std::shared_ptr req, + faabric::util::SchedulingDecision& decision, + faabric::util::FullLock& lock) +{ + faabric::Message& firstMsg = req->mutable_messages()->at(0); + std::string funcStr = faabric::util::funcToString(firstMsg, false); + int nMessages = req->messages_size(); + + if (decision.hosts.size() != nMessages) { + SPDLOG_ERROR( + "Passed decision for {} with {} messages, but request has {}", + funcStr, + decision.hosts.size(), + nMessages); + throw std::runtime_error("Invalid scheduler hint for messages"); + } + + // NOTE: we want to schedule things on this host _last_, otherwise functions + // may start executing before all messages have been dispatched, thus + // slowing the remaining scheduling. + // Therefore we want to create a list of unique hosts, with this host last. + std::vector orderedHosts; + { + std::set uniqueHosts(decision.hosts.begin(), + decision.hosts.end()); + bool hasFunctionsOnThisHost = uniqueHosts.contains(thisHost); + + if (hasFunctionsOnThisHost) { + uniqueHosts.erase(thisHost); + } + + orderedHosts = std::vector(uniqueHosts.begin(), uniqueHosts.end()); + + if (hasFunctionsOnThisHost) { + orderedHosts.push_back(thisHost); + } + } + + // ------------------------------------------- + // THREADS + // ------------------------------------------- + bool isThreads = req->type() == faabric::BatchExecuteRequest::THREADS; + // Register thread results if necessary if (isThreads) { for (const auto& m : req->messages()) { @@ -394,58 +422,148 @@ faabric::util::SchedulingDecision Scheduler::callFunctions( } } - // Schedule messages locally if necessary. For threads we only need one - // executor, for anything else we want one Executor per function in flight - if (!localMessageIdxs.empty()) { - // Update slots - this->thisHostUsedSlots.fetch_add((int32_t)localMessageIdxs.size(), - std::memory_order_acquire); - - if (isThreads) { - // Threads use the existing executor. We assume there's only one - // running at a time. - std::vector>& thisExecutors = - executors[funcStr]; - - std::shared_ptr e = nullptr; - if (thisExecutors.empty()) { - // Create executor if not exists - e = claimExecutor(firstMsg, lock); - } else if (thisExecutors.size() == 1) { - // Use existing executor if exists - e = thisExecutors.back(); + // ------------------------------------------- + // SNAPSHOTS + // ------------------------------------------- + + // Push out snapshot diffs to registered hosts. We have to do this to + // *all* hosts, regardless of whether they will be executing functions. + // This greatly simplifies the reasoning about which hosts hold which + // diffs. + std::string snapshotKey = firstMsg.snapshotkey(); + if (!snapshotKey.empty()) { + for (const auto& host : getFunctionRegisteredHosts(firstMsg)) { + SnapshotClient& c = getSnapshotClient(host); + faabric::util::SnapshotData snapshotData = + faabric::snapshot::getSnapshotRegistry().getSnapshot(snapshotKey); + + // See if we've already pushed this snapshot to the given host, + // if so, just push the diffs + if (pushedSnapshotsMap[snapshotKey].contains(host)) { + std::vector snapshotDiffs = + snapshotData.getDirtyPages(); + c.pushSnapshotDiffs( + snapshotKey, firstMsg.groupid(), snapshotDiffs); } else { - SPDLOG_ERROR("Found {} executors for threaded function {}", - thisExecutors.size(), - funcStr); - throw std::runtime_error( - "Expected only one executor for threaded function"); + c.pushSnapshot(snapshotKey, firstMsg.groupid(), snapshotData); + pushedSnapshotsMap[snapshotKey].insert(host); + } + } + } + + // Now reset the dirty page tracking just before we start executing + SPDLOG_DEBUG("Resetting dirty tracking after pushing diffs {}", funcStr); + faabric::util::resetDirtyTracking(); + + // ------------------------------------------- + // EXECTUION + // ------------------------------------------- + + // Iterate through unique hosts and dispatch messages + for (const std::string& host : orderedHosts) { + // Work out which indexes are scheduled on this host + std::vector thisHostIdxs; + for (int i = 0; i < decision.hosts.size(); i++) { + if (decision.hosts.at(i) == host) { + thisHostIdxs.push_back(i); } + } - assert(e != nullptr); + if (host == thisHost) { + // ------------------------------------------- + // LOCAL EXECTUION + // ------------------------------------------- + // For threads we only need one executor, for anything else we want + // one Executor per function in flight. + + if (thisHostIdxs.empty()) { + SPDLOG_DEBUG("Not scheduling any calls to {} out of {} locally", + funcStr, + nMessages); + continue; + } - // Execute the tasks - e->executeTasks(localMessageIdxs, req); - } else { - // Non-threads require one executor per task - for (auto i : localMessageIdxs) { - faabric::Message& localMsg = req->mutable_messages()->at(i); - if (localMsg.executeslocally()) { - faabric::util::UniqueLock resultsLock(localResultsMutex); - localResults.insert( - { localMsg.id(), - std::promise>() }); + SPDLOG_DEBUG("Scheduling {}/{} calls to {} locally", + thisHostIdxs.size(), + nMessages, + funcStr); + + // Update slots + this->thisHostUsedSlots.fetch_add(thisHostIdxs.size(), + std::memory_order_acquire); + + if (isThreads) { + // Threads use the existing executor. We assume there's only + // one running at a time. + std::vector>& thisExecutors = + executors[funcStr]; + + std::shared_ptr e = nullptr; + if (thisExecutors.empty()) { + // Create executor if not exists + e = claimExecutor(firstMsg, lock); + } else if (thisExecutors.size() == 1) { + // Use existing executor if exists + e = thisExecutors.back(); + } else { + SPDLOG_ERROR("Found {} executors for threaded function {}", + thisExecutors.size(), + funcStr); + throw std::runtime_error( + "Expected only one executor for threaded function"); + } + + assert(e != nullptr); + + // Execute the tasks + e->executeTasks(thisHostIdxs, req); + } else { + // Non-threads require one executor per task + for (auto i : thisHostIdxs) { + faabric::Message& localMsg = req->mutable_messages()->at(i); + + if (localMsg.executeslocally()) { + faabric::util::UniqueLock resultsLock( + localResultsMutex); + localResults.insert( + { localMsg.id(), + std::promise< + std::unique_ptr>() }); + } + + std::shared_ptr e = claimExecutor(localMsg, lock); + e->executeTasks({ i }, req); } - std::shared_ptr e = claimExecutor(localMsg, lock); - e->executeTasks({ i }, req); } - } - } + } else { + // ------------------------------------------- + // REMOTE EXECTUION + // ------------------------------------------- - // Send out point-to-point mappings if necessary (unless being forced to - // execute locally, in which case they will be transmitted from the master) - if (!forceLocal && (firstMsg.groupid() > 0)) { - broker.setAndSendMappingsFromSchedulingDecision(decision); + SPDLOG_DEBUG("Scheduling {}/{} calls to {} on {}", + thisHostIdxs.size(), + nMessages, + funcStr, + host); + + // Set up new request + std::shared_ptr hostRequest = + faabric::util::batchExecFactory(); + hostRequest->set_snapshotkey(req->snapshotkey()); + hostRequest->set_type(req->type()); + hostRequest->set_subtype(req->subtype()); + hostRequest->set_contextdata(req->contextdata()); + + // Add messages + for (auto msgIdx : thisHostIdxs) { + auto* newMsg = hostRequest->add_messages(); + *newMsg = req->messages().at(msgIdx); + newMsg->set_executeslocally(false); + } + + // Dispatch the calls + getFunctionCallClient(host).executeFunctions(hostRequest); + } } // Records for tests @@ -508,61 +626,6 @@ void Scheduler::broadcastSnapshotDelete(const faabric::Message& msg, } } -int Scheduler::scheduleFunctionsOnHost( - const std::string& host, - std::shared_ptr req, - SchedulingDecision& decision, - int offset, - faabric::util::SnapshotData* snapshot) -{ - const faabric::Message& firstMsg = req->messages().at(0); - std::string funcStr = faabric::util::funcToString(firstMsg, false); - - int nMessages = req->messages_size(); - int remainder = nMessages - offset; - - // Work out how many we can put on the host - faabric::HostResources r = getHostResources(host); - int available = r.slots() - r.usedslots(); - - // Drop out if none available - if (available <= 0) { - SPDLOG_DEBUG("Not scheduling {} on {}, no resources", funcStr, host); - return 0; - } - - // Set up new request - std::shared_ptr hostRequest = - faabric::util::batchExecFactory(); - hostRequest->set_snapshotkey(req->snapshotkey()); - hostRequest->set_type(req->type()); - hostRequest->set_subtype(req->subtype()); - hostRequest->set_contextdata(req->contextdata()); - - // Add messages - int nOnThisHost = std::min(available, remainder); - for (int i = offset; i < (offset + nOnThisHost); i++) { - auto* newMsg = hostRequest->add_messages(); - *newMsg = req->messages().at(i); - newMsg->set_executeslocally(false); - decision.addMessage(host, req->messages().at(i)); - } - - SPDLOG_DEBUG( - "Sending {}/{} {} to {}", nOnThisHost, nMessages, funcStr, host); - - // Handle snapshots - std::string snapshotKey = firstMsg.snapshotkey(); - if (snapshot != nullptr && !snapshotKey.empty()) { - SnapshotClient& c = getSnapshotClient(host); - c.pushSnapshot(snapshotKey, firstMsg.groupid(), *snapshot); - } - - getFunctionCallClient(host).executeFunctions(hostRequest); - - return nOnThisHost; -} - void Scheduler::callFunction(faabric::Message& msg, bool forceLocal) { // TODO - avoid this copy @@ -646,8 +709,8 @@ std::shared_ptr Scheduler::claimExecutor( int nExecutors = thisExecutors.size(); SPDLOG_DEBUG( "Scaling {} from {} -> {}", funcStr, nExecutors, nExecutors + 1); - // Spinning up a new executor can be lengthy, allow other things to run - // in parallel + // Spinning up a new executor can be lengthy, allow other things + // to run in parallel schedulerLock.unlock(); auto executor = factory->createExecutor(msg); schedulerLock.lock(); @@ -707,11 +770,14 @@ void Scheduler::setFunctionResult(faabric::Message& msg) if (msg.executeslocally()) { faabric::util::UniqueLock resultsLock(localResultsMutex); + auto it = localResults.find(msg.id()); if (it != localResults.end()) { it->second.set_value(std::make_unique(msg)); } - // Sync messages can't have their results read twice, so skip redis + + // Sync messages can't have their results read twice, so skip + // redis if (!msg.isasync()) { return; } @@ -729,8 +795,8 @@ void Scheduler::setFunctionResult(faabric::Message& msg) void Scheduler::registerThread(uint32_t msgId) { - // Here we need to ensure the promise is registered locally so callers can - // start waiting + // Here we need to ensure the promise is registered locally so + // callers can start waiting threadResults[msgId]; } @@ -818,13 +884,13 @@ faabric::Message Scheduler::getFunctionResult(unsigned int messageId, faabric::Message msgResult; if (isBlocking) { - // Blocking version will throw an exception when timing out which is - // handled by the caller. + // Blocking version will throw an exception when timing out + // which is handled by the caller. std::vector result = redis.dequeueBytes(resultKey, timeoutMs); msgResult.ParseFromArray(result.data(), (int)result.size()); } else { - // Non-blocking version will tolerate empty responses, therefore we - // handle the exception here + // Non-blocking version will tolerate empty responses, therefore + // we handle the exception here std::vector result; try { result = redis.dequeueBytes(resultKey, timeoutMs); diff --git a/src/snapshot/SnapshotServer.cpp b/src/snapshot/SnapshotServer.cpp index 7cb778be7..108e3c5d6 100644 --- a/src/snapshot/SnapshotServer.cpp +++ b/src/snapshot/SnapshotServer.cpp @@ -71,10 +71,9 @@ std::unique_ptr SnapshotServer::recvPushSnapshot( throw std::runtime_error("Received snapshot with zero size"); } - SPDLOG_DEBUG("Receiving snapshot {} (size {}, lock {})", + SPDLOG_DEBUG("Receiving snapshot {} (size {})", r->key()->c_str(), - r->contents()->size(), - r->groupid()); + r->contents()->size()); faabric::snapshot::SnapshotRegistry& reg = faabric::snapshot::getSnapshotRegistry(); @@ -83,12 +82,6 @@ std::unique_ptr SnapshotServer::recvPushSnapshot( faabric::util::SnapshotData data; data.size = r->contents()->size(); - // Lock the function group if necessary - if (r->groupid() > 0) { - faabric::transport::PointToPointGroup::getGroup(r->groupid()) - ->localLock(); - } - // TODO - avoid this copy by changing server superclass to allow subclasses // to provide a buffer to receive data. // TODO - work out snapshot ownership here, how do we know when to delete @@ -99,12 +92,6 @@ std::unique_ptr SnapshotServer::recvPushSnapshot( reg.takeSnapshot(r->key()->str(), data, true); - // Unlock the application - if (r->groupid() > 0) { - faabric::transport::PointToPointGroup::getGroup(r->groupid()) - ->localUnlock(); - } - // Send response return std::make_unique(); } @@ -127,6 +114,7 @@ SnapshotServer::recvPushSnapshotDiffs(const uint8_t* buffer, size_t bufferSize) { const SnapshotDiffPushRequest* r = flatbuffers::GetRoot(buffer); + int groupId = r->groupid(); SPDLOG_DEBUG( "Applying {} diffs to snapshot {}", r->chunks()->size(), r->key()->str()); @@ -136,25 +124,33 @@ SnapshotServer::recvPushSnapshotDiffs(const uint8_t* buffer, size_t bufferSize) faabric::snapshot::getSnapshotRegistry(); faabric::util::SnapshotData& snap = reg.getSnapshot(r->key()->str()); - // Lock the function group - if (r->groupid() > 0) { + // Lock the function group if it exists + if (groupId > 0 && + faabric::transport::PointToPointGroup::groupExists(groupId)) { faabric::transport::PointToPointGroup::getGroup(r->groupid()) ->localLock(); } - // Apply diffs to snapshot - for (const auto* r : *r->chunks()) { - uint8_t* dest = snap.data + r->offset(); - switch (r->dataType()) { + // Iterate through the chunks passed in the request + for (const auto* chunk : *r->chunks()) { + uint8_t* dest = snap.data + chunk->offset(); + + SPDLOG_TRACE("Applying snapshot diff to {} at {}-{}", + r->key()->str(), + chunk->offset(), + chunk->offset() + chunk->data()->size()); + + switch (chunk->dataType()) { case (faabric::util::SnapshotDataType::Raw): { - switch (r->mergeOp()) { + switch (chunk->mergeOp()) { case (faabric::util::SnapshotMergeOperation::Overwrite): { - std::memcpy(dest, r->data()->data(), r->data()->size()); + std::memcpy( + dest, chunk->data()->data(), chunk->data()->size()); break; } default: { SPDLOG_ERROR("Unsupported raw merge operation: {}", - r->mergeOp()); + chunk->mergeOp()); throw std::runtime_error( "Unsupported raw merge operation"); } @@ -163,9 +159,9 @@ SnapshotServer::recvPushSnapshotDiffs(const uint8_t* buffer, size_t bufferSize) } case (faabric::util::SnapshotDataType::Int): { const auto* value = - reinterpret_cast(r->data()->data()); + reinterpret_cast(chunk->data()->data()); auto* destValue = reinterpret_cast(dest); - switch (r->mergeOp()) { + switch (chunk->mergeOp()) { case (faabric::util::SnapshotMergeOperation::Sum): { *destValue += *value; break; @@ -188,7 +184,7 @@ SnapshotServer::recvPushSnapshotDiffs(const uint8_t* buffer, size_t bufferSize) } default: { SPDLOG_ERROR("Unsupported int merge operation: {}", - r->mergeOp()); + chunk->mergeOp()); throw std::runtime_error( "Unsupported int merge operation"); } @@ -196,18 +192,22 @@ SnapshotServer::recvPushSnapshotDiffs(const uint8_t* buffer, size_t bufferSize) break; } default: { - SPDLOG_ERROR("Unsupported data type: {}", r->dataType()); + SPDLOG_ERROR("Unsupported data type: {}", chunk->dataType()); throw std::runtime_error("Unsupported merge data type"); } } } - // Unlock - if (r->groupid() > 0) { + // Unlock group if exists + if (groupId > 0 && + faabric::transport::PointToPointGroup::groupExists(groupId)) { faabric::transport::PointToPointGroup::getGroup(r->groupid()) ->localUnlock(); } + // Reset dirty tracking having applied diffs + SPDLOG_DEBUG("Resetting dirty page tracking having applied diffs"); + // Send response return std::make_unique(); } diff --git a/src/transport/PointToPointBroker.cpp b/src/transport/PointToPointBroker.cpp index 38e9e4921..4d9b23fad 100644 --- a/src/transport/PointToPointBroker.cpp +++ b/src/transport/PointToPointBroker.cpp @@ -22,6 +22,8 @@ namespace faabric::transport { static std::unordered_map> groups; +static std::shared_mutex groupsMutex; + // NOTE: Keeping 0MQ sockets in TLS is usually a bad idea, as they _must_ be // closed before the global context. However, in this case it's worth it // to cache the sockets across messages, as otherwise we'd be creating and @@ -74,19 +76,53 @@ std::shared_ptr PointToPointGroup::getGroup(int groupId) return groups.at(groupId); } +std::shared_ptr PointToPointGroup::getOrAwaitGroup( + int groupId) +{ + getPointToPointBroker().waitForMappingsOnThisHost(groupId); + + return getGroup(groupId); +} + bool PointToPointGroup::groupExists(int groupId) { + faabric::util::SharedLock lock(groupsMutex); return groups.find(groupId) != groups.end(); } void PointToPointGroup::addGroup(int appId, int groupId, int groupSize) { - groups.emplace(std::make_pair( - groupId, std::make_shared(appId, groupId, groupSize))); + faabric::util::FullLock lock(groupsMutex); + + if (groups.find(groupId) == groups.end()) { + groups.emplace(std::make_pair( + groupId, + std::make_shared(appId, groupId, groupSize))); + } +} + +void PointToPointGroup::addGroupIfNotExists(int appId, + int groupId, + int groupSize) +{ + if (groupExists(groupId)) { + return; + } + + addGroup(appId, groupId, groupSize); +} + +void PointToPointGroup::clearGroup(int groupId) +{ + faabric::util::FullLock lock(groupsMutex); + + groups.erase(groupId); } void PointToPointGroup::clear() { + faabric::util::FullLock lock(groupsMutex); + groups.clear(); } @@ -102,13 +138,78 @@ PointToPointGroup::PointToPointGroup(int appIdIn, void PointToPointGroup::lock(int groupIdx, bool recursive) { - std::string host = + std::string masterHost = ptpBroker.getHostForReceiver(groupId, POINT_TO_POINT_MASTER_IDX); + std::string lockerHost = ptpBroker.getHostForReceiver(groupId, groupIdx); - if (host == conf.endpointHost) { - masterLock(groupIdx, recursive); + bool masterIsLocal = masterHost == conf.endpointHost; + bool lockerIsLocal = lockerHost == conf.endpointHost; + + // If we're on the master, we need to try and acquire the lock, otherwise we + // send a remote request + if (masterIsLocal) { + bool acquiredLock = false; + { + faabric::util::UniqueLock lock(mx); + + if (recursive && (recursiveLockOwners.empty() || + recursiveLockOwners.top() == groupIdx)) { + // Recursive and either free, or already locked by this idx + recursiveLockOwners.push(groupIdx); + acquiredLock = true; + } else if (!recursive && (lockOwnerIdx == NO_LOCK_OWNER_IDX)) { + // Non-recursive and free + lockOwnerIdx = groupIdx; + acquiredLock = true; + } + } + + if (acquiredLock && lockerIsLocal) { + // Nothing to do now + SPDLOG_TRACE("Group idx {} ({}), locally locked {} (recursive {})", + groupIdx, + lockerHost, + groupId, + recursive); + + } else if (acquiredLock) { + SPDLOG_TRACE("Group idx {} ({}), remotely locked {} (recursive {})", + groupIdx, + lockerHost, + groupId, + recursive); + + // Notify remote locker that they've acquired the lock + notifyLocked(groupIdx); + } else { + // Need to wait to get the lock + lockWaiters.push(groupIdx); + + // Wait here if local, otherwise the remote end will pick up the + // message + if (lockerIsLocal) { + SPDLOG_TRACE( + "Group idx {} ({}), locally awaiting lock {} (recursive {})", + groupIdx, + lockerHost, + groupId, + recursive); + + ptpBroker.recvMessage( + groupId, POINT_TO_POINT_MASTER_IDX, groupIdx); + } else { + // Notify remote locker that they've acquired the lock + SPDLOG_TRACE( + "Group idx {} ({}), remotely awaiting lock {} (recursive {})", + groupIdx, + lockerHost, + groupId, + masterHost, + recursive); + } + } } else { - auto cli = getClient(host); + auto cli = getClient(masterHost); faabric::PointToPointMessage msg; msg.set_groupid(groupId); msg.set_sendidx(groupIdx); @@ -118,59 +219,16 @@ void PointToPointGroup::lock(int groupIdx, bool recursive) groupId, groupIdx, POINT_TO_POINT_MASTER_IDX, - host); + masterHost); + // Send the remote request and await the message saying it's been + // acquired cli->groupLock(appId, groupId, groupIdx, recursive); - // Await ptp response ptpBroker.recvMessage(groupId, POINT_TO_POINT_MASTER_IDX, groupIdx); } } -void PointToPointGroup::masterLock(int groupIdx, bool recursive) -{ - SPDLOG_TRACE("Master lock {}:{}", groupId, groupIdx); - - bool success = false; - { - faabric::util::UniqueLock lock(mx); - if (recursive) { - bool isFree = recursiveLockOwners.empty(); - - bool lockOwnedByThisIdx = - !isFree && (recursiveLockOwners.top() == groupIdx); - - if (isFree || lockOwnedByThisIdx) { - // Recursive and either free, or already locked by this idx - SPDLOG_TRACE("Group idx {} recursively locked {} ({})", - groupIdx, - groupId, - lockWaiters.size()); - recursiveLockOwners.push(groupIdx); - success = true; - } else { - SPDLOG_TRACE("Group idx {} unable to recursively lock {} ({})", - groupIdx, - groupId, - lockWaiters.size()); - } - } else if (lockOwnerIdx == NO_LOCK_OWNER_IDX) { - // Non-recursive and free - SPDLOG_TRACE("Group idx {} locked {}", groupIdx, groupId); - lockOwnerIdx = groupIdx; - success = true; - } else { - // Unable to lock, wait in queue - SPDLOG_TRACE("Group idx {} unable to lock {}", groupIdx, groupId); - lockWaiters.push(groupIdx); - } - } - - if (success) { - notifyLocked(groupIdx); - } -} - void PointToPointGroup::localLock() { LOCK_TIMEOUT(localMx, timeoutMs); @@ -188,7 +246,35 @@ void PointToPointGroup::unlock(int groupIdx, bool recursive) ptpBroker.getHostForReceiver(groupId, POINT_TO_POINT_MASTER_IDX); if (host == conf.endpointHost) { - masterUnlock(groupIdx, recursive); + SPDLOG_TRACE("Group idx {} unlocking {} ({} waiters, recursive {})", + groupIdx, + groupId, + lockWaiters.size(), + recursive); + + faabric::util::UniqueLock lock(mx); + + if (recursive) { + recursiveLockOwners.pop(); + + if (!recursiveLockOwners.empty()) { + return; + } + + if (!lockWaiters.empty()) { + recursiveLockOwners.push(lockWaiters.front()); + notifyLocked(lockWaiters.front()); + lockWaiters.pop(); + } + } else { + lockOwnerIdx = NO_LOCK_OWNER_IDX; + + if (!lockWaiters.empty()) { + lockOwnerIdx = lockWaiters.front(); + notifyLocked(lockWaiters.front()); + lockWaiters.pop(); + } + } } else { auto cli = getClient(host); faabric::PointToPointMessage msg; @@ -196,7 +282,7 @@ void PointToPointGroup::unlock(int groupIdx, bool recursive) msg.set_sendidx(groupIdx); msg.set_recvidx(POINT_TO_POINT_MASTER_IDX); - SPDLOG_TRACE("Remote lock {}:{}:{} to {}", + SPDLOG_TRACE("Remote unlock {}:{}:{} to {}", groupId, groupIdx, POINT_TO_POINT_MASTER_IDX, @@ -206,33 +292,6 @@ void PointToPointGroup::unlock(int groupIdx, bool recursive) } } -void PointToPointGroup::masterUnlock(int groupIdx, bool recursive) -{ - faabric::util::UniqueLock lock(mx); - - if (recursive) { - recursiveLockOwners.pop(); - - if (!recursiveLockOwners.empty()) { - return; - } - - if (!lockWaiters.empty()) { - recursiveLockOwners.push(lockWaiters.front()); - notifyLocked(lockWaiters.front()); - lockWaiters.pop(); - } - } else { - lockOwnerIdx = NO_LOCK_OWNER_IDX; - - if (!lockWaiters.empty()) { - lockOwnerIdx = lockWaiters.front(); - notifyLocked(lockWaiters.front()); - lockWaiters.pop(); - } - } -} - void PointToPointGroup::localUnlock() { localMx.unlock(); @@ -279,10 +338,16 @@ void PointToPointGroup::notify(int groupIdx) { if (groupIdx == POINT_TO_POINT_MASTER_IDX) { for (int i = 1; i < groupSize; i++) { + SPDLOG_TRACE( + "Master group {} waiting for notify from index {}", groupId, i); + ptpBroker.recvMessage(groupId, i, POINT_TO_POINT_MASTER_IDX); + + SPDLOG_TRACE("Master group {} notified by index {}", groupId, i); } } else { std::vector data(1, 0); + SPDLOG_TRACE("Notifying group {} from index {}", groupId, groupIdx); ptpBroker.sendMessage(groupId, groupIdx, POINT_TO_POINT_MASTER_IDX, @@ -364,19 +429,10 @@ PointToPointBroker::setUpLocalMappingsFromSchedulingDecision( decision.appId, groupId, decision.nFunctions); } - { - // Lock this group - faabric::util::UniqueLock lock(groupMappingMutexes[groupId]); - - SPDLOG_TRACE( - "Enabling point-to-point mapping for {}:{}", decision.appId, groupId); + SPDLOG_TRACE( + "Enabling point-to-point mapping for {}:{}", decision.appId, groupId); - // Enable the group - groupMappingsFlags[groupId] = true; - - // Notify waiters - groupMappingCvs[groupId].notify_all(); - } + getGroupFlag(groupId).setFlag(true); return hosts; } @@ -414,34 +470,27 @@ void PointToPointBroker::setAndSendMappingsFromSchedulingDecision( } } -void PointToPointBroker::waitForMappingsOnThisHost(int groupId) +faabric::util::FlagWaiter& PointToPointBroker::getGroupFlag(int groupId) { - // Check if it's been enabled - if (!groupMappingsFlags[groupId]) { - - // Lock this group - faabric::util::UniqueLock lock(groupMappingMutexes[groupId]); - - // Check again - if (!groupMappingsFlags[groupId]) { - // Wait for group to be enabled - auto timePoint = std::chrono::system_clock::now() + - std::chrono::milliseconds(MAPPING_TIMEOUT_MS); - - if (!groupMappingCvs[groupId].wait_until( - lock, timePoint, [this, groupId] { - return groupMappingsFlags[groupId]; - })) { - - SPDLOG_ERROR("Timed out waiting for group mappings {}", - groupId); - throw std::runtime_error( - "Timed out waiting for group mappings"); - } - - SPDLOG_TRACE("Point-to-point mappings for {} ready", groupId); + if (groupFlags.find(groupId) == groupFlags.end()) { + faabric::util::FullLock lock(brokerMutex); + if (groupFlags.find(groupId) == groupFlags.end()) { + return groupFlags[groupId]; } } + + { + faabric::util::SharedLock lock(brokerMutex); + return groupFlags.at(groupId); + } +} + +void PointToPointBroker::waitForMappingsOnThisHost(int groupId) +{ + faabric::util::FlagWaiter& waiter = getGroupFlag(groupId); + + // Check if it's been enabled + waiter.waitOnFlag(); } std::set PointToPointBroker::getIdxsRegisteredForGroup(int groupId) @@ -520,24 +569,42 @@ std::vector PointToPointBroker::recvMessage(int groupId, return messageData.dataCopy(); } +void PointToPointBroker::clearGroup(int groupId) +{ + SPDLOG_TRACE("Clearing point-to-point group {}", groupId); + + faabric::util::FullLock lock(brokerMutex); + + std::set idxs = getIdxsRegisteredForGroup(groupId); + for (auto idxA : idxs) { + for (auto idxB : idxs) { + std::string label = getPointToPointKey(groupId, idxA, idxB); + mappings.erase(label); + } + } + + groupIdIdxsMap.erase(groupId); + + PointToPointGroup::clearGroup(groupId); + + groupFlags.erase(groupId); +} + void PointToPointBroker::clear() { - faabric::util::SharedLock lock(brokerMutex); + faabric::util::FullLock lock(brokerMutex); groupIdIdxsMap.clear(); mappings.clear(); PointToPointGroup::clear(); - groupMappingMutexes.clear(); - groupMappingsFlags.clear(); - groupMappingCvs.clear(); + groupFlags.clear(); } void PointToPointBroker::resetThreadLocalCache() { SPDLOG_TRACE("Resetting point-to-point thread-local cache"); - sendEndpoints.clear(); recvEndpoints.clear(); clients.clear(); diff --git a/src/transport/PointToPointServer.cpp b/src/transport/PointToPointServer.cpp index 7791f0b0f..32db8cac7 100644 --- a/src/transport/PointToPointServer.cpp +++ b/src/transport/PointToPointServer.cpp @@ -16,7 +16,7 @@ PointToPointServer::PointToPointServer() POINT_TO_POINT_SYNC_PORT, POINT_TO_POINT_INPROC_LABEL, faabric::util::getSystemConfig().pointToPointServerThreads) - , reg(getPointToPointBroker()) + , broker(getPointToPointBroker()) {} void PointToPointServer::doAsyncRecv(int header, @@ -28,11 +28,11 @@ void PointToPointServer::doAsyncRecv(int header, PARSE_MSG(faabric::PointToPointMessage, buffer, bufferSize) // Send the message locally to the downstream socket - reg.sendMessage(msg.groupid(), - msg.sendidx(), - msg.recvidx(), - BYTES_CONST(msg.data().c_str()), - msg.data().size()); + broker.sendMessage(msg.groupid(), + msg.sendidx(), + msg.recvidx(), + BYTES_CONST(msg.data().c_str()), + msg.data().size()); break; } case faabric::transport::PointToPointCall::LOCK_GROUP: { @@ -85,7 +85,7 @@ std::unique_ptr PointToPointServer::doRecvMappings( SPDLOG_DEBUG("Receiving {} point-to-point mappings", decision.nFunctions); - reg.setUpLocalMappingsFromSchedulingDecision(decision); + broker.setUpLocalMappingsFromSchedulingDecision(decision); return std::make_unique(); } @@ -121,6 +121,7 @@ void PointToPointServer::recvGroupUnlock(const uint8_t* buffer, void PointToPointServer::onWorkerStop() { // Clear any thread-local cached sockets - reg.resetThreadLocalCache(); + broker.resetThreadLocalCache(); + broker.clear(); } } diff --git a/src/util/CMakeLists.txt b/src/util/CMakeLists.txt index d155313b4..7c7a40a29 100644 --- a/src/util/CMakeLists.txt +++ b/src/util/CMakeLists.txt @@ -13,6 +13,7 @@ faabric_lib(util gids.cpp json.cpp latch.cpp + locks.cpp logging.cpp memory.cpp network.cpp diff --git a/src/util/locks.cpp b/src/util/locks.cpp new file mode 100644 index 000000000..14064effa --- /dev/null +++ b/src/util/locks.cpp @@ -0,0 +1,33 @@ +#include + +namespace faabric::util { + +FlagWaiter::FlagWaiter(int timeoutMsIn) + : timeoutMs(timeoutMsIn) +{} + +void FlagWaiter::waitOnFlag() +{ + // Check + if (flag.load()) { + return; + } + + // Wait for flag to be set + UniqueLock lock(flagMx); + if (!cv.wait_for(lock, std::chrono::milliseconds(timeoutMs), [this] { + return flag.load(); + })) { + + SPDLOG_ERROR("Timed out waiting for flag"); + throw std::runtime_error("Timed out waiting for flag"); + } +} + +void FlagWaiter::setFlag(bool value) +{ + UniqueLock lock(flagMx); + flag.store(value); + cv.notify_all(); +} +} diff --git a/src/util/memory.cpp b/src/util/memory.cpp index 3855aac23..de3b9c95f 100644 --- a/src/util/memory.cpp +++ b/src/util/memory.cpp @@ -82,6 +82,8 @@ AlignedChunk getPageAlignedChunk(long offset, long length) void resetDirtyTracking() { + SPDLOG_DEBUG("Resetting dirty tracking"); + FILE* fd = fopen(CLEAR_REFS, "w"); if (fd == nullptr) { SPDLOG_ERROR("Could not open clear_refs ({})", strerror(errno)); diff --git a/src/util/snapshot.cpp b/src/util/snapshot.cpp index f49e99a55..595061220 100644 --- a/src/util/snapshot.cpp +++ b/src/util/snapshot.cpp @@ -26,7 +26,11 @@ std::vector SnapshotData::getDirtyPages() std::vector diffs; for (int i : dirtyPageNumbers) { uint32_t offset = i * HOST_PAGE_SIZE; - diffs.emplace_back(offset, data + offset, HOST_PAGE_SIZE); + diffs.emplace_back(SnapshotDataType::Raw, + SnapshotMergeOperation::Overwrite, + offset, + data + offset, + HOST_PAGE_SIZE); } SPDLOG_DEBUG("Snapshot has {}/{} dirty pages", diffs.size(), nPages); @@ -37,286 +41,127 @@ std::vector SnapshotData::getDirtyPages() std::vector SnapshotData::getChangeDiffs(const uint8_t* updated, size_t updatedSize) { - // Work out which pages have changed - size_t nThisPages = getRequiredHostPages(size); - std::vector dirtyPageNumbers = - getDirtyPageNumbers(updated, nThisPages); + std::vector diffs; + if (mergeRegions.empty()) { + SPDLOG_DEBUG("No merge regions set, thus no diffs"); + return diffs; + } - SPDLOG_TRACE("Diffing {} pages with {} changed pages and {} merge regions", - nThisPages, - dirtyPageNumbers.size(), - mergeRegions.size()); - - for (auto& m : mergeRegions) { - SPDLOG_TRACE("{} {} merge region at {} {}-{}", - snapshotDataTypeStr(m.second.dataType), - snapshotMergeOpStr(m.second.operation), - m.first, - m.second.offset, - m.second.offset + m.second.length); + for (const auto& mr : mergeRegions) { + SPDLOG_TRACE("Merge region {} {} at {}-{}", + snapshotDataTypeStr(mr.second.dataType), + snapshotMergeOpStr(mr.second.operation), + mr.second.offset, + mr.second.offset + mr.second.length); } - // Get iterator over merge regions + // Work out which pages have changed (these will be sorted) + size_t nThisPages = getRequiredHostPages(updatedSize); + std::vector dirtyPageNumbers = + getDirtyPageNumbers(updated, nThisPages); + + // Iterate through each dirty page, work out if there's an overlapping merge + // region, tell that region to add their diffs to the list std::map::iterator mergeIt = mergeRegions.begin(); - // Get byte-wise diffs _within_ the dirty pages - // - // NOTE - if raw diffs cover page boundaries, they will be split into - // multiple diffs, each of which is page-aligned. - // We can be relatively confident that variables will be page-aligned so - // this shouldn't be a problem. - // - // Merge regions support crossing page boundaries. - // - // For each byte we encounter have the following possible scenarios: - // - // 1. the byte is dirty, and is the start of a new diff - // 2. the byte is dirty, but the byte before was also dirty, so we - // are inside a diff - // 3. the byte is not dirty but the previous one was, so we've reached the - // end of a diff - // 4. the last byte of the page is dirty, so we've also come to the end of - // a diff - // 5. the byte is dirty, but is within a special merge region, in which - // case we need to add a diff for that whole region, then skip - // to the next byte after that region - std::vector diffs; for (int i : dirtyPageNumbers) { - int pageOffset = i * HOST_PAGE_SIZE; - - bool diffInProgress = false; - int diffStart = 0; - int offset = pageOffset; - for (int b = 0; b < HOST_PAGE_SIZE; b++) { - offset = pageOffset + b; - bool isDirtyByte = *(data + offset) != *(updated + offset); - - // Skip any merge regions we've passed - while (mergeIt != mergeRegions.end() && - offset >= - (mergeIt->second.offset + mergeIt->second.length)) { - SnapshotMergeRegion region = mergeIt->second; - SPDLOG_TRACE("At offset {}, past region {} {} {}-{}", - offset, - snapshotDataTypeStr(region.dataType), - snapshotMergeOpStr(region.operation), - region.offset, - region.offset + region.length); - - ++mergeIt; - } - - // Check if we're in a merge region - bool isInMergeRegion = - mergeIt != mergeRegions.end() && - offset >= mergeIt->second.offset && - offset < (mergeIt->second.offset + mergeIt->second.length); - - if (isDirtyByte && isInMergeRegion) { - // If we've entered a merge region with a diff in progress, we - // need to close it off - if (diffInProgress) { - diffs.emplace_back( - diffStart, updated + diffStart, offset - diffStart); - - SPDLOG_TRACE( - "Finished {} {} diff between {}-{} before merge region", - snapshotDataTypeStr(diffs.back().dataType), - snapshotMergeOpStr(diffs.back().operation), - diffs.back().offset, - diffs.back().offset + diffs.back().size); - - diffInProgress = false; - } - - SnapshotMergeRegion region = mergeIt->second; - - // Set up the diff - const uint8_t* updatedValue = updated + region.offset; - const uint8_t* originalValue = data + region.offset; - - SnapshotDiff diff(region.offset, updatedValue, region.length); - diff.dataType = region.dataType; - diff.operation = region.operation; - - // Modify diff data for certain operations - switch (region.dataType) { - case (SnapshotDataType::Int): { - int originalInt = - *(reinterpret_cast(originalValue)); - int updatedInt = - *(reinterpret_cast(updatedValue)); - - switch (region.operation) { - case (SnapshotMergeOperation::Sum): { - // Sums must send the value to be _added_, and - // not the final result - updatedInt -= originalInt; - break; - } - case (SnapshotMergeOperation::Subtract): { - // Subtractions must send the value to be - // subtracted, not the result - updatedInt = originalInt - updatedInt; - break; - } - case (SnapshotMergeOperation::Product): { - // Products must send the value to be - // multiplied, not the result - updatedInt /= originalInt; - break; - } - case (SnapshotMergeOperation::Max): - case (SnapshotMergeOperation::Min): - // Min and max don't need to change - break; - default: { - SPDLOG_ERROR( - "Unhandled integer merge operation: {}", - region.operation); - throw std::runtime_error( - "Unhandled integer merge operation"); - } - } - - // TODO - somehow avoid casting away the const here? - // Modify the memory in-place here - std::memcpy((uint8_t*)updatedValue, - BYTES(&updatedInt), - sizeof(int32_t)); - - break; - } - case (SnapshotDataType::Raw): { - switch (region.operation) { - case (SnapshotMergeOperation::Ignore): { - break; - } - case (SnapshotMergeOperation::Overwrite): { - // Default behaviour - break; - } - default: { - SPDLOG_ERROR( - "Unhandled raw merge operation: {}", - region.operation); - throw std::runtime_error( - "Unhandled raw merge operation"); - } - } + int pageStart = i * HOST_PAGE_SIZE; + int pageEnd = pageStart + HOST_PAGE_SIZE; - break; - } - default: { - SPDLOG_ERROR("Merge region for unhandled data type: {}", - region.dataType); - throw std::runtime_error( - "Merge region for unhandled data type"); - } - } + SPDLOG_TRACE("Checking dirty page {} at {}-{}", i, pageStart, pageEnd); - SPDLOG_TRACE("Diff at {} falls in {} {} merge region {}-{}", - pageOffset + b, - snapshotDataTypeStr(region.dataType), - snapshotMergeOpStr(region.operation), - region.offset, - region.offset + region.length); + // Skip any merge regions we've passed + while (mergeIt != mergeRegions.end() && + (mergeIt->second.offset < pageStart)) { + SPDLOG_TRACE("Gone past {} {} merge region at {}-{}", + snapshotDataTypeStr(mergeIt->second.dataType), + snapshotMergeOpStr(mergeIt->second.operation), + mergeIt->second.offset, + mergeIt->second.offset + mergeIt->second.length); - // Add the diff to the list - if (diff.operation != SnapshotMergeOperation::Ignore) { - diffs.emplace_back(diff); - } - - // Work out the offset where this region ends - int regionEndOffset = - (region.offset - pageOffset) + region.length; - - if (regionEndOffset < HOST_PAGE_SIZE) { - // Skip over this region, still more offsets left in this - // page - SPDLOG_TRACE( - "{} {} merge region {}-{} finished. Skipping to {}", - snapshotDataTypeStr(region.dataType), - snapshotMergeOpStr(region.operation), - region.offset, - region.offset + region.length, - pageOffset + regionEndOffset); - - // Bump the loop variable to the end of this region (note - // that the loop itself will increment onto the next). - b = regionEndOffset - 1; - } else { - // Merge region extends over this page, move onto next - SPDLOG_TRACE( - "{} {} merge region {}-{} over page boundary {} ({}-{})", - snapshotDataTypeStr(region.dataType), - snapshotMergeOpStr(region.operation), - region.offset, - region.offset + region.length, - i, - pageOffset, - pageOffset + HOST_PAGE_SIZE); - - break; - } - } else if (isDirtyByte && !diffInProgress) { - // Diff starts here if it's different and diff not in progress - diffInProgress = true; - diffStart = offset; - - SPDLOG_TRACE("Started Raw Overwrite diff at {}", diffStart); - } else if (!isDirtyByte && diffInProgress) { - // Diff ends if it's not different and diff is in progress - diffInProgress = false; - diffs.emplace_back( - diffStart, updated + diffStart, offset - diffStart); - - SPDLOG_TRACE("Finished {} {} diff between {}-{}", - snapshotDataTypeStr(diffs.back().dataType), - snapshotMergeOpStr(diffs.back().operation), - diffs.back().offset, - diffs.back().offset + diffs.back().size); - } + ++mergeIt; } - // If we've reached the end of this page with a diff in progress, we - // need to close it off - if (diffInProgress) { - offset++; + if (mergeIt == mergeRegions.end()) { + // Done if no more merge regions left + SPDLOG_TRACE("No more merge regions left"); + break; + } - diffs.emplace_back( - diffStart, updated + diffStart, offset - diffStart); + // For each merge region that overlaps this dirty page, get it to add + // its diffs, and move onto the next one + // TODO - make this more efficient by passing in dirty pages to merge + // regions so that they avoid unnecessary work if they're large. + while (mergeIt != mergeRegions.end() && + (mergeIt->second.offset >= pageStart && + mergeIt->second.offset < pageEnd)) { + + uint8_t* original = data; + + // If we're outside the range of the original data, pass a nullptr + if (mergeIt->second.offset > size) { + SPDLOG_TRACE( + "Checking {} {} merge region {}-{} outside original snapshot", + snapshotDataTypeStr(mergeIt->second.dataType), + snapshotMergeOpStr(mergeIt->second.operation), + mergeIt->second.offset, + mergeIt->second.offset + mergeIt->second.length); + + original = nullptr; + } - SPDLOG_TRACE("Found {} {} diff between {}-{} at end of page", - snapshotDataTypeStr(diffs.back().dataType), - snapshotMergeOpStr(diffs.back().operation), - diffs.back().offset, - diffs.back().offset + diffs.back().size); + mergeIt->second.addDiffs(diffs, original, updated); + mergeIt++; } } - // If comparison has more pages than the original, add another diff - // containing all the new pages - if (updatedSize > size) { - diffs.emplace_back(size, updated + size, updatedSize - size); - } - return diffs; } void SnapshotData::addMergeRegion(uint32_t offset, size_t length, SnapshotDataType dataType, - SnapshotMergeOperation operation) + SnapshotMergeOperation operation, + bool overwrite) { SnapshotMergeRegion region{ .offset = offset, .length = length, .dataType = dataType, .operation = operation }; + // Locking as this may be called in bursts by multiple threads faabric::util::UniqueLock lock(snapMx); - mergeRegions[offset] = region; + + if (mergeRegions.find(region.offset) != mergeRegions.end()) { + if (!overwrite) { + SPDLOG_ERROR("Attempting to overwrite existing merge region at {} " + "with {} {} at {}-{}", + region.offset, + snapshotDataTypeStr(dataType), + snapshotMergeOpStr(operation), + region.offset, + region.offset + length); + + throw std::runtime_error("Not able to overwrite merge region"); + } + + SPDLOG_TRACE( + "Overwriting existing merge region at {} with {} {} at {}-{}", + region.offset, + snapshotDataTypeStr(dataType), + snapshotMergeOpStr(operation), + region.offset, + region.offset + length); + } else { + SPDLOG_DEBUG("Adding new {} {} merge region at {}-{}", + snapshotDataTypeStr(dataType), + snapshotMergeOpStr(operation), + region.offset, + region.offset + length); + } + + mergeRegions[region.offset] = region; } std::string snapshotDataTypeStr(SnapshotDataType dt) @@ -338,9 +183,6 @@ std::string snapshotDataTypeStr(SnapshotDataType dt) std::string snapshotMergeOpStr(SnapshotMergeOperation op) { switch (op) { - case (SnapshotMergeOperation::Ignore): { - return "Ignore"; - } case (SnapshotMergeOperation::Max): { return "Max"; } @@ -365,4 +207,159 @@ std::string snapshotMergeOpStr(SnapshotMergeOperation op) } } } + +void SnapshotMergeRegion::addDiffs(std::vector& diffs, + const uint8_t* original, + const uint8_t* updated) +{ + SPDLOG_TRACE("Checking for {} {} merge region at {}-{}", + snapshotDataTypeStr(dataType), + snapshotMergeOpStr(operation), + offset, + offset + length); + + switch (dataType) { + case (SnapshotDataType::Int): { + // Check if the value has changed + const uint8_t* updatedValue = updated + offset; + int updatedInt = *(reinterpret_cast(updatedValue)); + + if (original == nullptr) { + throw std::runtime_error( + "Do not support int operations outside original snapshot"); + } + + const uint8_t* originalValue = original + offset; + int originalInt = *(reinterpret_cast(originalValue)); + + // Skip if no change + if (originalInt == updatedInt) { + return; + } + + // Add the diff + diffs.emplace_back( + dataType, operation, offset, updatedValue, length); + + SPDLOG_TRACE("Adding {} {} diff at {}-{}", + snapshotDataTypeStr(dataType), + snapshotMergeOpStr(operation), + offset, + offset + length); + + // Potentially modify the original in place depending on the + // operation + switch (operation) { + case (SnapshotMergeOperation::Sum): { + // Sums must send the value to be _added_, and + // not the final result + updatedInt -= originalInt; + break; + } + case (SnapshotMergeOperation::Subtract): { + // Subtractions must send the value to be + // subtracted, not the result + updatedInt = originalInt - updatedInt; + break; + } + case (SnapshotMergeOperation::Product): { + // Products must send the value to be + // multiplied, not the result + updatedInt /= originalInt; + break; + } + case (SnapshotMergeOperation::Max): + case (SnapshotMergeOperation::Min): + // Min and max don't need to change + break; + default: { + SPDLOG_ERROR("Unhandled integer merge operation: {}", + operation); + throw std::runtime_error( + "Unhandled integer merge operation"); + } + } + + // TODO - somehow avoid casting away the const here? + // Modify the memory in-place here + std::memcpy( + (uint8_t*)updatedValue, BYTES(&updatedInt), sizeof(int32_t)); + + break; + } + case (SnapshotDataType::Raw): { + switch (operation) { + case (SnapshotMergeOperation::Overwrite): { + // Add subsections of diffs only for the bytes that + // have changed + bool diffInProgress = false; + int diffStart = 0; + for (int b = offset; b <= offset + length; b++) { + bool isDirtyByte = false; + + if (original == nullptr) { + isDirtyByte = true; + } else { + isDirtyByte = *(original + b) != *(updated + b); + } + + SPDLOG_TRACE("BYTE {} dirty {}", b, isDirtyByte); + if (isDirtyByte && !diffInProgress) { + // Diff starts here if it's different and diff + // not in progress + diffInProgress = true; + diffStart = b; + } else if (!isDirtyByte && diffInProgress) { + // Diff ends if it's not different and diff is + // in progress + int diffLength = b - diffStart; + SPDLOG_TRACE("Adding {} {} diff at {}-{}", + snapshotDataTypeStr(dataType), + snapshotMergeOpStr(operation), + diffStart, + diffStart + diffLength); + + diffInProgress = false; + diffs.emplace_back(dataType, + operation, + diffStart, + updated + diffStart, + diffLength); + } + } + + // If we've reached the end of this region with a diff + // in progress, we need to close it off + if (diffInProgress) { + int finalDiffLength = (offset + length) - diffStart + 1; + SPDLOG_TRACE( + "Adding {} {} diff at {}-{} (end of region)", + snapshotDataTypeStr(dataType), + snapshotMergeOpStr(operation), + diffStart, + diffStart + finalDiffLength); + + diffs.emplace_back(dataType, + operation, + diffStart, + updated + diffStart, + finalDiffLength); + } + break; + } + default: { + SPDLOG_ERROR("Unhandled raw merge operation: {}", + operation); + throw std::runtime_error("Unhandled raw merge operation"); + } + } + + break; + } + default: { + SPDLOG_ERROR("Merge region for unhandled data type: {}", dataType); + throw std::runtime_error("Merge region for unhandled data type"); + } + } +} } diff --git a/tests/dist/scheduler/functions.cpp b/tests/dist/scheduler/functions.cpp index 28436f7f4..a91ad0405 100644 --- a/tests/dist/scheduler/functions.cpp +++ b/tests/dist/scheduler/functions.cpp @@ -14,6 +14,7 @@ #include #include #include +#include namespace tests { @@ -60,20 +61,31 @@ int handleFakeDiffsFunction(faabric::scheduler::Executor* exec, { faabric::Message& msg = req->mutable_messages()->at(msgIdx); - faabric::util::SnapshotData snap = exec->snapshot(); - std::string msgInput = msg.inputdata(); std::string snapshotKey = msg.snapshotkey(); - // Modify the executor's memory + faabric::snapshot::SnapshotRegistry& reg = + faabric::snapshot::getSnapshotRegistry(); + + faabric::util::SnapshotData& originalSnap = reg.getSnapshot(snapshotKey); + faabric::util::SnapshotData updatedSnap = exec->snapshot(); + + // Add a single merge region to catch both diffs + int offsetA = 10; + int offsetB = 100; std::vector inputBytes = faabric::util::stringToBytes(msgInput); - std::vector keyBytes = faabric::util::stringToBytes(snapshotKey); - uint32_t offsetA = 10; - uint32_t offsetB = 100; + originalSnap.addMergeRegion( + 0, + offsetB + inputBytes.size() + 10, + faabric::util::SnapshotDataType::Raw, + faabric::util::SnapshotMergeOperation::Overwrite); - std::memcpy(snap.data + offsetA, keyBytes.data(), keyBytes.size()); - std::memcpy(snap.data + offsetB, inputBytes.data(), inputBytes.size()); + // Modify the executor's memory + std::vector keyBytes = faabric::util::stringToBytes(snapshotKey); + std::memcpy(updatedSnap.data + offsetA, keyBytes.data(), keyBytes.size()); + std::memcpy( + updatedSnap.data + offsetB, inputBytes.data(), inputBytes.size()); return 123; } @@ -89,6 +101,9 @@ int handleFakeDiffsThreadedFunction( std::string snapshotKey = "fake-diffs-threaded-snap"; std::string msgInput = msg.inputdata(); + faabric::snapshot::SnapshotRegistry& reg = + faabric::snapshot::getSnapshotRegistry(); + // This function creates a snapshot, then spawns some child threads that // will modify the shared memory. It then awaits the results and checks that // the modifications are synced back to the original host. @@ -104,8 +119,6 @@ int handleFakeDiffsThreadedFunction( snap.data = snapMemory; snap.size = snapSize; - faabric::snapshot::SnapshotRegistry& reg = - faabric::snapshot::getSnapshotRegistry(); reg.takeSnapshot(snapshotKey, snap); auto req = @@ -121,7 +134,7 @@ int handleFakeDiffsThreadedFunction( // Make a small modification to a page that will also be edited by // the child thread to make sure it's not overwritten std::vector localChange(3, i); - uint32_t offset = 2 * i * faabric::util::HOST_PAGE_SIZE; + int offset = 2 * i * faabric::util::HOST_PAGE_SIZE; std::memcpy( snapMemory + offset, localChange.data(), localChange.size()); } @@ -157,7 +170,7 @@ int handleFakeDiffsThreadedFunction( for (int i = 0; i < nThreads; i++) { // Check local modifications std::vector expectedLocal(3, i); - uint32_t localOffset = 2 * i * faabric::util::HOST_PAGE_SIZE; + int localOffset = 2 * i * faabric::util::HOST_PAGE_SIZE; std::vector actualLocal(snapMemory + localOffset, snapMemory + localOffset + expectedLocal.size()); @@ -168,7 +181,7 @@ int handleFakeDiffsThreadedFunction( } // Check remote modifications - uint32_t offset = 2 * i * faabric::util::HOST_PAGE_SIZE + 10; + int offset = 2 * i * faabric::util::HOST_PAGE_SIZE + 10; std::string expectedData("thread_" + std::to_string(i)); auto* charPtr = reinterpret_cast(snapMemory + offset); std::string actual(charPtr); @@ -185,15 +198,33 @@ int handleFakeDiffsThreadedFunction( } } else { + // This is the code that will be executed by the remote threads. + // Add a merge region to catch the modification int idx = msg.appidx(); - uint32_t offset = 2 * idx * faabric::util::HOST_PAGE_SIZE + 10; - // Modify the executor's memory + int regionOffset = 2 * idx * faabric::util::HOST_PAGE_SIZE; + int changeOffset = regionOffset + 10; + + // Get the input data std::vector inputBytes = faabric::util::stringToBytes(msgInput); - faabric::util::SnapshotData snap = exec->snapshot(); - std::memcpy(snap.data + offset, inputBytes.data(), inputBytes.size()); + faabric::util::SnapshotData& originalSnap = + reg.getSnapshot(snapshotKey); + faabric::util::SnapshotData updatedSnap = exec->snapshot(); + + // Make sure it's captured by the region + int regionLength = 20 + inputBytes.size(); + originalSnap.addMergeRegion( + regionOffset, + regionLength, + faabric::util::SnapshotDataType::Raw, + faabric::util::SnapshotMergeOperation::Overwrite); + + // Now modify the memory + std::memcpy(updatedSnap.data + changeOffset, + inputBytes.data(), + inputBytes.size()); return 0; } diff --git a/tests/dist/transport/functions.cpp b/tests/dist/transport/functions.cpp index 559fc75cd..f145e7d14 100644 --- a/tests/dist/transport/functions.cpp +++ b/tests/dist/transport/functions.cpp @@ -1,6 +1,10 @@ #include #include "DistTestExecutor.h" +#include "faabric/scheduler/Scheduler.h" +#include "faabric/util/func.h" +#include "faabric/util/gids.h" +#include "faabric/util/scheduling.h" #include "faabric_utils.h" #include "init.h" @@ -9,6 +13,7 @@ #include #include +using namespace faabric::transport; using namespace faabric::util; namespace tests { @@ -57,6 +62,94 @@ int handlePointToPointFunction( return 0; } +int handleDistributedLock(faabric::scheduler::Executor* exec, + int threadPoolIdx, + int msgIdx, + std::shared_ptr req) +{ + // We need sufficient concurrency here to show up bugs every time + int nWorkers = 10; + int nLoops = 30; + + std::string sharedStateKey = "dist-lock-test"; + + faabric::Message& msg = req->mutable_messages()->at(msgIdx); + + faabric::state::State& state = state::getGlobalState(); + std::shared_ptr stateKv = + state.getKV(msg.user(), sharedStateKey, sizeof(int32_t)); + + if (msg.function() == "lock") { + int initialValue = 0; + int groupId = faabric::util::generateGid(); + + stateKv->set(BYTES(&initialValue)); + + std::shared_ptr nestedReq = + faabric::util::batchExecFactory("ptp", "lock-worker", nWorkers); + for (int i = 0; i < nWorkers; i++) { + faabric::Message& m = nestedReq->mutable_messages()->at(i); + m.set_groupid(groupId); + m.set_groupidx(i); + } + + faabric::scheduler::Scheduler& sch = faabric::scheduler::getScheduler(); + faabric::util::SchedulingDecision decision = + sch.callFunctions(nestedReq); + + // Await results + bool success = true; + for (int msgId : decision.messageIds) { + faabric::Message res = sch.getFunctionResult(msgId, 30000); + if (res.returnvalue() != 0) { + success = false; + } + } + + int finalValue = *(int*)stateKv->get(); + int expectedValue = nWorkers * nLoops; + if (finalValue != expectedValue) { + SPDLOG_ERROR("Distributed lock test failed: {} != {}", + finalValue, + expectedValue); + success = false; + } else { + SPDLOG_ERROR("Distributed lock succeeded, result {}", finalValue); + } + + return success ? 0 : 1; + } + + // Here we want to do something that will mess up if the locking isn't + // working properly, so we perform incremental updates to a bit of shared + // state using a global lock in a tight loop. + std::shared_ptr group = + faabric::transport::PointToPointGroup::getGroup(msg.groupid()); + + for (int i = 0; i < nLoops; i++) { + // Get the lock + group->lock(msg.groupidx(), false); + + // Pull the value + stateKv->pull(); + int* originalValue = (int*)stateKv->get(); + + // Short sleep + int sleepTimeMs = std::rand() % 50; + SLEEP_MS(sleepTimeMs); + + // Increment and push + int newValue = *originalValue + 1; + stateKv->set(BYTES(&newValue)); + stateKv->pushFull(); + + // Unlock + group->unlock(msg.groupidx(), false); + } + + return 0; +} + class DistributedCoordinationTestRunner { public: @@ -322,6 +415,11 @@ void registerTransportTestFunctions() registerDistTestExecutorCallback( "ptp", "barrier-worker", handleDistributedBarrierWorker); + registerDistTestExecutorCallback("ptp", "lock", handleDistributedLock); + + registerDistTestExecutorCallback( + "ptp", "lock-worker", handleDistributedLock); + registerDistTestExecutorCallback("ptp", "notify", handleDistributedNotify); registerDistTestExecutorCallback( diff --git a/tests/dist/transport/test_coordination.cpp b/tests/dist/transport/test_coordination.cpp new file mode 100644 index 000000000..f25608883 --- /dev/null +++ b/tests/dist/transport/test_coordination.cpp @@ -0,0 +1,34 @@ +#include + +#include "faabric_utils.h" +#include "fixtures.h" +#include "init.h" + +#include +#include +#include +#include +#include +#include + +namespace tests { + +TEST_CASE_METHOD(DistTestsFixture, "Test distributed lock", "[ptp][transport]") +{ + // Set up this host's resources + int nLocalSlots = 5; + faabric::HostResources res; + res.set_slots(nLocalSlots); + sch.setThisHostResources(res); + + // Set up the request + std::shared_ptr req = + faabric::util::batchExecFactory("ptp", "lock", 1); + + sch.callFunctions(req); + + faabric::Message& m = req->mutable_messages()->at(0); + faabric::Message result = sch.getFunctionResult(m.id(), 30000); + REQUIRE(result.returnvalue() == 0); +} +} diff --git a/tests/test/scheduler/test_executor.cpp b/tests/test/scheduler/test_executor.cpp index 4dbf2983b..573a08522 100644 --- a/tests/test/scheduler/test_executor.cpp +++ b/tests/test/scheduler/test_executor.cpp @@ -19,6 +19,7 @@ #include using namespace faabric::scheduler; +using namespace faabric::util; namespace tests { @@ -203,13 +204,30 @@ class TestExecutor final : public Executor if (msg.function() == "snap-check") { // Modify a page of the dummy memory uint8_t pageIdx = threadPoolIdx; - SPDLOG_DEBUG("TestExecutor modifying page {} of memory", pageIdx); - uint8_t* offsetPtr = - dummyMemory + (pageIdx * faabric::util::HOST_PAGE_SIZE); - std::vector data = { pageIdx, - (uint8_t)(pageIdx + 1), - (uint8_t)(pageIdx + 2) }; + faabric::util::SnapshotData& snapData = + faabric::snapshot::getSnapshotRegistry().getSnapshot( + msg.snapshotkey()); + + // Avoid writing a zero here as the memory is already zeroed hence + // it's not a change + std::vector data = { (uint8_t)(pageIdx + 1), + (uint8_t)(pageIdx + 2), + (uint8_t)(pageIdx + 3) }; + + // Set up a merge region that should catch the diff + size_t offset = (pageIdx * faabric::util::HOST_PAGE_SIZE); + snapData.addMergeRegion(offset, + data.size() + 10, + SnapshotDataType::Raw, + SnapshotMergeOperation::Overwrite); + + SPDLOG_DEBUG("TestExecutor modifying page {} of memory ({}-{})", + pageIdx, + offset, + offset + data.size()); + + uint8_t* offsetPtr = dummyMemory + offset; std::memcpy(offsetPtr, data.data(), data.size()); } @@ -430,94 +448,6 @@ TEST_CASE_METHOD(TestExecutorFixture, } } -TEST_CASE_METHOD(TestExecutorFixture, - "Test executing remote chained threads", - "[executor]") -{ - faabric::util::setMockMode(true); - - std::string thisHost = conf.endpointHost; - - // Add other host to available hosts - std::string otherHost = "other"; - sch.addHostToGlobalSet(otherHost); - - // Make sure we have only enough resources to execute the initial function - faabric::HostResources res; - res.set_slots(1); - sch.setThisHostResources(res); - - // Set up other host to have some resources - faabric::HostResources resOther; - resOther.set_slots(20); - faabric::scheduler::queueResourceResponse(otherHost, resOther); - - // Background thread to execute main function and await results - int nThreads = 8; - auto latch = faabric::util::Latch::create(2); - std::thread t([&latch, nThreads] { - std::shared_ptr req = - faabric::util::batchExecFactory("dummy", "thread-check", 1); - faabric::Message& msg = req->mutable_messages()->at(0); - msg.set_inputdata(std::to_string(nThreads)); - - auto& sch = faabric::scheduler::getScheduler(); - sch.callFunctions(req, false); - - latch->wait(); - - faabric::Message res = sch.getFunctionResult(msg.id(), 2000); - assert(res.returnvalue() == 0); - }); - - // Wait for the remote thread request to have been submitted - auto reqs = faabric::scheduler::getBatchRequests(); - REQUIRE_RETRY(reqs = faabric::scheduler::getBatchRequests(), - reqs.size() == 1); - std::string actualHost = reqs.at(0).first; - REQUIRE(actualHost == otherHost); - - std::shared_ptr distReq = reqs.at(0).second; - REQUIRE(distReq->messages().size() == nThreads); - faabric::Message firstMsg = distReq->messages().at(0); - - // Check restore hasn't been called (as we're the master) - REQUIRE(restoreCount == 0); - - // Check the snapshot has been pushed to the other host - auto snapPushes = faabric::snapshot::getSnapshotPushes(); - REQUIRE(snapPushes.size() == 1); - REQUIRE(snapPushes.at(0).first == otherHost); - - // Now execute request on this host as if we were on the other host - conf.endpointHost = otherHost; - sch.callFunctions(distReq, true); - - // Check restore has been called as we're no longer master - REQUIRE(restoreCount == 1); - - // Wait for the results - auto results = faabric::snapshot::getThreadResults(); - REQUIRE_RETRY(results = faabric::snapshot::getThreadResults(), - results.size() == nThreads); - - // Reset the host config for this host - conf.endpointHost = thisHost; - - // Process the thread results - for (auto& r : results) { - REQUIRE(r.first == thisHost); - auto args = r.second; - sch.setThreadResultLocally(args.first, args.second); - } - - // Rejoin the background thread - latch->wait(); - if (t.joinable()) { - t.join(); - } -} - TEST_CASE_METHOD(TestExecutorFixture, "Test thread results returned on non-master", "[executor]") @@ -718,7 +648,7 @@ TEST_CASE_METHOD(TestExecutorFixture, faabric::util::setMockMode(true); std::string otherHost = "other"; - // Set up a load of messages executing with a different master host + // Set up some messages executing with a different master host std::vector messageIds; for (int i = 0; i < nThreads; i++) { faabric::Message& msg = req->mutable_messages()->at(i); @@ -750,13 +680,11 @@ TEST_CASE_METHOD(TestExecutorFixture, for (int i = 0; i < diffList.size(); i++) { // Check offset and data (according to logic defined in the dummy // executor) - uint8_t pageIndex = i + 1; - REQUIRE(diffList.at(i).offset == - pageIndex * faabric::util::HOST_PAGE_SIZE); + REQUIRE(diffList.at(i).offset == i * faabric::util::HOST_PAGE_SIZE); - std::vector expected = { pageIndex, - (uint8_t)(pageIndex + 1), - (uint8_t)(pageIndex + 2) }; + std::vector expected = { (uint8_t)(i + 1), + (uint8_t)(i + 2), + (uint8_t)(i + 3) }; std::vector actual(diffList.at(i).data, diffList.at(i).data + 3); @@ -878,6 +806,8 @@ TEST_CASE_METHOD(TestExecutorFixture, { faabric::util::setMockMode(true); + conf.overrideCpuCount = 4; + std::string hostOverride = conf.endpointHost; int nMessages = 1; faabric::BatchExecuteRequest::BatchExecuteType requestType = @@ -918,8 +848,6 @@ TEST_CASE_METHOD(TestExecutorFixture, // As we're faking a non-master execution results will be sent back to // the fake master so we can't wait on them, thus have to sleep - SLEEP_MS(1000); - - REQUIRE(resetCount == expectedResets); + REQUIRE_RETRY({}, resetCount == expectedResets); } } diff --git a/tests/test/scheduler/test_scheduler.cpp b/tests/test/scheduler/test_scheduler.cpp index edf9f83b8..69c683119 100644 --- a/tests/test/scheduler/test_scheduler.cpp +++ b/tests/test/scheduler/test_scheduler.cpp @@ -191,6 +191,10 @@ TEST_CASE_METHOD(SlowExecutorFixture, "Test batch scheduling", "[scheduler]") int32_t expectedSubType; std::string expectedContextData; + int thisCores = 5; + faabric::util::SystemConfig& conf = faabric::util::getSystemConfig(); + conf.overrideCpuCount = thisCores; + SECTION("Threads") { execMode = faabric::BatchExecuteRequest::THREADS; @@ -232,7 +236,7 @@ TEST_CASE_METHOD(SlowExecutorFixture, "Test batch scheduling", "[scheduler]") // Mock everything faabric::util::setMockMode(true); - std::string thisHost = faabric::util::getSystemConfig().endpointHost; + std::string thisHost = conf.endpointHost; // Set up another host std::string otherHost = "beta"; @@ -240,7 +244,6 @@ TEST_CASE_METHOD(SlowExecutorFixture, "Test batch scheduling", "[scheduler]") int nCallsOne = 10; int nCallsTwo = 5; - int thisCores = 5; int otherCores = 11; int nCallsOffloadedOne = nCallsOne - thisCores; @@ -390,6 +393,9 @@ TEST_CASE_METHOD(SlowExecutorFixture, "Test overloaded scheduler", "[scheduler]") { + faabric::util::SystemConfig& conf = faabric::util::getSystemConfig(); + conf.overrideCpuCount = 5; + faabric::util::setMockMode(true); faabric::BatchExecuteRequest::BatchExecuteType execMode; @@ -811,28 +817,36 @@ TEST_CASE_METHOD(SlowExecutorFixture, TEST_CASE_METHOD(DummyExecutorFixture, "Test executor reuse", "[scheduler]") { - faabric::Message msgA = faabric::util::messageFactory("foo", "bar"); - faabric::Message msgB = faabric::util::messageFactory("foo", "bar"); - faabric::Message msgC = faabric::util::messageFactory("foo", "bar"); - faabric::Message msgD = faabric::util::messageFactory("foo", "bar"); + std::shared_ptr reqA = + faabric::util::batchExecFactory("foo", "bar", 2); + std::shared_ptr reqB = + faabric::util::batchExecFactory("foo", "bar", 2); + + faabric::Message& msgA = reqA->mutable_messages()->at(0); + faabric::Message& msgB = reqB->mutable_messages()->at(0); // Execute a couple of functions - sch.callFunction(msgA); - sch.callFunction(msgB); - sch.getFunctionResult(msgA.id(), SHORT_TEST_TIMEOUT_MS); - sch.getFunctionResult(msgB.id(), SHORT_TEST_TIMEOUT_MS); + sch.callFunctions(reqA); + for (const auto& m : reqA->messages()) { + faabric::Message res = + sch.getFunctionResult(m.id(), SHORT_TEST_TIMEOUT_MS); + REQUIRE(res.returnvalue() == 0); + } // Check executor count REQUIRE(sch.getFunctionExecutorCount(msgA) == 2); - // Submit a couple more functions - sch.callFunction(msgC); - sch.callFunction(msgD); - sch.getFunctionResult(msgC.id(), SHORT_TEST_TIMEOUT_MS); - sch.getFunctionResult(msgD.id(), SHORT_TEST_TIMEOUT_MS); + // Execute a couple more functions + sch.callFunctions(reqB); + for (const auto& m : reqB->messages()) { + faabric::Message res = + sch.getFunctionResult(m.id(), SHORT_TEST_TIMEOUT_MS); + REQUIRE(res.returnvalue() == 0); + } // Check executor count is still the same REQUIRE(sch.getFunctionExecutorCount(msgA) == 2); + REQUIRE(sch.getFunctionExecutorCount(msgB) == 2); } TEST_CASE_METHOD(DummyExecutorFixture, diff --git a/tests/test/snapshot/test_snapshot_client_server.cpp b/tests/test/snapshot/test_snapshot_client_server.cpp index fa86b703e..bd0764f6e 100644 --- a/tests/test/snapshot/test_snapshot_client_server.cpp +++ b/tests/test/snapshot/test_snapshot_client_server.cpp @@ -17,6 +17,8 @@ #include #include +using namespace faabric::util; + namespace tests { class SnapshotClientServerFixture @@ -40,8 +42,8 @@ class SnapshotClientServerFixture void setUpFunctionGroup(int appId, int groupId) { - faabric::util::SchedulingDecision decision(appId, groupId); - faabric::Message msg = faabric::util::messageFactory("foo", "bar"); + SchedulingDecision decision(appId, groupId); + faabric::Message msg = messageFactory("foo", "bar"); msg.set_appid(appId); msg.set_groupid(groupId); @@ -71,8 +73,8 @@ TEST_CASE_METHOD(SnapshotClientServerFixture, // Prepare some snapshot data std::string snapKeyA = "foo"; std::string snapKeyB = "bar"; - faabric::util::SnapshotData snapA; - faabric::util::SnapshotData snapB; + SnapshotData snapA; + SnapshotData snapB; size_t snapSizeA = 1024; size_t snapSizeB = 500; snapA.size = snapSizeA; @@ -97,8 +99,8 @@ TEST_CASE_METHOD(SnapshotClientServerFixture, // Check snapshots created in registry REQUIRE(reg.getSnapshotCount() == 2); - const faabric::util::SnapshotData& actualA = reg.getSnapshot(snapKeyA); - const faabric::util::SnapshotData& actualB = reg.getSnapshot(snapKeyB); + const SnapshotData& actualA = reg.getSnapshot(snapKeyA); + const SnapshotData& actualB = reg.getSnapshot(snapKeyB); REQUIRE(actualA.size == snapA.size); REQUIRE(actualB.size == snapB.size); @@ -110,8 +112,7 @@ TEST_CASE_METHOD(SnapshotClientServerFixture, REQUIRE(actualDataB == dataB); } -void checkDiffsApplied(const uint8_t* snapBase, - std::vector diffs) +void checkDiffsApplied(const uint8_t* snapBase, std::vector diffs) { for (const auto& d : diffs) { std::vector actual(snapBase + d.offset, @@ -127,7 +128,7 @@ TEST_CASE_METHOD(SnapshotClientServerFixture, "Test push snapshot diffs", "[snapshot]") { - std::string thisHost = faabric::util::getSystemConfig().endpointHost; + std::string thisHost = getSystemConfig().endpointHost; // One request with no group, another with a group we must initialise int appId = 111; @@ -137,25 +138,37 @@ TEST_CASE_METHOD(SnapshotClientServerFixture, setUpFunctionGroup(appId, groupIdB); // Set up a snapshot - std::string snapKey = std::to_string(faabric::util::generateGid()); - faabric::util::SnapshotData snap = takeSnapshot(snapKey, 5, true); + std::string snapKey = std::to_string(generateGid()); + SnapshotData snap = takeSnapshot(snapKey, 5, true); // Set up some diffs std::vector diffDataA1 = { 0, 1, 2, 3 }; std::vector diffDataA2 = { 4, 5, 6 }; std::vector diffDataB = { 7, 7, 8, 8, 8 }; - std::vector diffsA; - std::vector diffsB; + std::vector diffsA; + std::vector diffsB; + + SnapshotDiff diffA1(SnapshotDataType::Raw, + SnapshotMergeOperation::Overwrite, + 5, + diffDataA1.data(), + diffDataA1.size()); + + SnapshotDiff diffA2(SnapshotDataType::Raw, + SnapshotMergeOperation::Overwrite, + 2 * HOST_PAGE_SIZE, + diffDataA2.data(), + diffDataA2.size()); - faabric::util::SnapshotDiff diffA1(5, diffDataA1.data(), diffDataA1.size()); - faabric::util::SnapshotDiff diffA2( - 2 * faabric::util::HOST_PAGE_SIZE, diffDataA2.data(), diffDataA2.size()); diffsA = { diffA1, diffA2 }; cli.pushSnapshotDiffs(snapKey, groupIdA, diffsA); - faabric::util::SnapshotDiff diffB( - 3 * faabric::util::HOST_PAGE_SIZE, diffDataB.data(), diffDataB.size()); + SnapshotDiff diffB(SnapshotDataType::Raw, + SnapshotMergeOperation::Overwrite, + 3 * HOST_PAGE_SIZE, + diffDataB.data(), + diffDataB.size()); diffsB = { diffB }; cli.pushSnapshotDiffs(snapKey, groupIdB, diffsB); @@ -171,12 +184,12 @@ TEST_CASE_METHOD(SnapshotClientServerFixture, "[snapshot]") { // Set up a snapshot - std::string snapKey = std::to_string(faabric::util::generateGid()); - faabric::util::SnapshotData snap = takeSnapshot(snapKey, 5, false); + std::string snapKey = std::to_string(generateGid()); + SnapshotData snap = takeSnapshot(snapKey, 5, false); // Set up a couple of ints in the snapshot int offsetA1 = 5; - int offsetA2 = 2 * faabric::util::HOST_PAGE_SIZE; + int offsetA2 = 2 * HOST_PAGE_SIZE; int baseA1 = 25; int baseA2 = 60; @@ -189,22 +202,26 @@ TEST_CASE_METHOD(SnapshotClientServerFixture, int diffIntA1 = 123; int diffIntA2 = 345; - std::vector intDataA1 = - faabric::util::valueToBytes(diffIntA1); - std::vector intDataA2 = - faabric::util::valueToBytes(diffIntA2); + std::vector intDataA1 = valueToBytes(diffIntA1); + std::vector intDataA2 = valueToBytes(diffIntA2); - std::vector diffs; + std::vector diffs; - faabric::util::SnapshotDiff diffA1( - offsetA1, intDataA1.data(), intDataA1.size()); - diffA1.operation = faabric::util::SnapshotMergeOperation::Sum; - diffA1.dataType = faabric::util::SnapshotDataType::Int; + SnapshotDiff diffA1(SnapshotDataType::Raw, + SnapshotMergeOperation::Overwrite, + offsetA1, + intDataA1.data(), + intDataA1.size()); + diffA1.operation = SnapshotMergeOperation::Sum; + diffA1.dataType = SnapshotDataType::Int; - faabric::util::SnapshotDiff diffA2( - offsetA2, intDataA2.data(), intDataA2.size()); - diffA2.operation = faabric::util::SnapshotMergeOperation::Sum; - diffA2.dataType = faabric::util::SnapshotDataType::Int; + SnapshotDiff diffA2(SnapshotDataType::Raw, + SnapshotMergeOperation::Overwrite, + offsetA2, + intDataA2.data(), + intDataA2.size()); + diffA2.operation = SnapshotMergeOperation::Sum; + diffA2.dataType = SnapshotDataType::Int; diffs = { diffA1, diffA2 }; cli.pushSnapshotDiffs(snapKey, 0, diffs); @@ -221,22 +238,20 @@ TEST_CASE_METHOD(SnapshotClientServerFixture, "[snapshot]") { // Set up a snapshot - std::string snapKey = std::to_string(faabric::util::generateGid()); - faabric::util::SnapshotData snap = takeSnapshot(snapKey, 5, false); + std::string snapKey = std::to_string(generateGid()); + SnapshotData snap = takeSnapshot(snapKey, 5, false); int offset = 5; std::vector originalData; std::vector diffData; std::vector expectedData; - faabric::util::SnapshotMergeOperation operation = - faabric::util::SnapshotMergeOperation::Overwrite; - faabric::util::SnapshotDataType dataType = - faabric::util::SnapshotDataType::Raw; + SnapshotMergeOperation operation = SnapshotMergeOperation::Overwrite; + SnapshotDataType dataType = SnapshotDataType::Raw; SECTION("Integer") { - dataType = faabric::util::SnapshotDataType::Int; + dataType = SnapshotDataType::Int; int original = 0; int diff = 0; int expected = 0; @@ -247,7 +262,7 @@ TEST_CASE_METHOD(SnapshotClientServerFixture, diff = 10; expected = 110; - operation = faabric::util::SnapshotMergeOperation::Sum; + operation = SnapshotMergeOperation::Sum; } SECTION("Subtract") @@ -256,7 +271,7 @@ TEST_CASE_METHOD(SnapshotClientServerFixture, diff = 10; expected = 90; - operation = faabric::util::SnapshotMergeOperation::Subtract; + operation = SnapshotMergeOperation::Subtract; } SECTION("Product") @@ -265,7 +280,7 @@ TEST_CASE_METHOD(SnapshotClientServerFixture, diff = 20; expected = 200; - operation = faabric::util::SnapshotMergeOperation::Product; + operation = SnapshotMergeOperation::Product; } SECTION("Min") @@ -284,7 +299,7 @@ TEST_CASE_METHOD(SnapshotClientServerFixture, expected = 10; } - operation = faabric::util::SnapshotMergeOperation::Min; + operation = SnapshotMergeOperation::Min; } SECTION("Max") @@ -303,22 +318,26 @@ TEST_CASE_METHOD(SnapshotClientServerFixture, expected = 20; } - operation = faabric::util::SnapshotMergeOperation::Max; + operation = SnapshotMergeOperation::Max; } - originalData = faabric::util::valueToBytes(original); - diffData = faabric::util::valueToBytes(diff); - expectedData = faabric::util::valueToBytes(expected); + originalData = valueToBytes(original); + diffData = valueToBytes(diff); + expectedData = valueToBytes(expected); } // Put original data in place std::memcpy(snap.data + offset, originalData.data(), originalData.size()); - faabric::util::SnapshotDiff diff(offset, diffData.data(), diffData.size()); + SnapshotDiff diff(SnapshotDataType::Raw, + SnapshotMergeOperation::Overwrite, + offset, + diffData.data(), + diffData.size()); diff.operation = operation; diff.dataType = dataType; - std::vector diffs = { diff }; + std::vector diffs = { diff }; cli.pushSnapshotDiffs(snapKey, 0, diffs); // Check data is as expected diff --git a/tests/test/snapshot/test_snapshot_diffs.cpp b/tests/test/snapshot/test_snapshot_diffs.cpp index 2269d8ba7..6b4a056e3 100644 --- a/tests/test/snapshot/test_snapshot_diffs.cpp +++ b/tests/test/snapshot/test_snapshot_diffs.cpp @@ -15,10 +15,39 @@ void checkSnapshotDiff(int offset, SnapshotDiff& actual) { REQUIRE(offset == actual.offset); + REQUIRE(actual.size > 0); + REQUIRE(actual.data != nullptr); + std::vector actualData(actual.data, actual.data + actual.size); REQUIRE(data == actualData); } +TEST_CASE_METHOD(SnapshotTestFixture, + "Test no snapshot diffs if no merge regions", + "[snapshot]") +{ + std::string snapKey = "foobar123"; + int snapPages = 5; + SnapshotData snap = takeSnapshot(snapKey, snapPages, true); + + int sharedMemPages = 8; + size_t sharedMemSize = sharedMemPages * HOST_PAGE_SIZE; + uint8_t* sharedMem = allocatePages(sharedMemPages); + + reg.mapSnapshot(snapKey, sharedMem); + + // Make various changes + sharedMem[0] = 1; + sharedMem[2 * HOST_PAGE_SIZE] = 1; + sharedMem[3 * HOST_PAGE_SIZE + 10] = 1; + sharedMem[8 * HOST_PAGE_SIZE - 20] = 1; + + // Check there are no diffs + std::vector changeDiffs = + snap.getChangeDiffs(sharedMem, sharedMemSize); + REQUIRE(changeDiffs.empty()); +} + TEST_CASE_METHOD(SnapshotTestFixture, "Test snapshot diffs", "[snapshot]") { std::string snapKey = "foobar123"; @@ -37,28 +66,69 @@ TEST_CASE_METHOD(SnapshotTestFixture, "Test snapshot diffs", "[snapshot]") // Reset dirty tracking faabric::util::resetDirtyTracking(); - // Set up some chunks of data to write into the memory + // Single change, single merge region std::vector dataA = { 1, 2, 3, 4 }; - std::vector dataB = { 4, 5, 6 }; - std::vector dataC = { 7, 6, 5, 4, 3 }; - std::vector dataD = { 1, 1, 1, 1 }; - - // Set up some offsets, both on and over page boundaries int offsetA = HOST_PAGE_SIZE; - int offsetB = HOST_PAGE_SIZE + 20; - int offsetC = 2 * HOST_PAGE_SIZE - 2; - int offsetD = 3 * HOST_PAGE_SIZE - dataD.size(); - - // Write the data std::memcpy(sharedMem + offsetA, dataA.data(), dataA.size()); - std::memcpy(sharedMem + offsetB, dataB.data(), dataB.size()); + + snap.addMergeRegion(offsetA, + dataA.size(), + SnapshotDataType::Raw, + SnapshotMergeOperation::Overwrite); + + // NOTE - deliberately add merge regions out of order + // Diff starting in merge region and overlapping the end + std::vector dataC = { 7, 6, 5, 4, 3, 2, 1 }; + std::vector expectedDataC = { 7, 6, 5, 4, 3 }; + int offsetC = 2 * HOST_PAGE_SIZE; std::memcpy(sharedMem + offsetC, dataC.data(), dataC.size()); + + int regionOffsetC = offsetC - 3; + snap.addMergeRegion(regionOffsetC, + dataC.size(), + SnapshotDataType::Raw, + SnapshotMergeOperation::Overwrite); + + // Two changes in single merge region + std::vector dataB1 = { 4, 5, 6 }; + std::vector dataB2 = { 7, 6, 5 }; + int offsetB1 = HOST_PAGE_SIZE + 10; + int offsetB2 = HOST_PAGE_SIZE + 16; + std::memcpy(sharedMem + offsetB1, dataB1.data(), dataB1.size()); + std::memcpy(sharedMem + offsetB2, dataB2.data(), dataB2.size()); + + snap.addMergeRegion(offsetB1, + (offsetB2 - offsetB1) + dataB2.size() + 10, + SnapshotDataType::Raw, + SnapshotMergeOperation::Overwrite); + + // Merge region within a change + std::vector dataD = { 1, 1, 2, 2, 3, 3, 4 }; + std::vector expectedDataD = { 2, 2, 3, 3 }; + int offsetD = 3 * HOST_PAGE_SIZE - dataD.size(); std::memcpy(sharedMem + offsetD, dataD.data(), dataD.size()); - // Write the data to the region that exceeds the size of the original - std::vector dataExtra( - (sharedMemPages - snapPages) * HOST_PAGE_SIZE, 5); - std::memcpy(sharedMem + snapSize, dataExtra.data(), dataExtra.size()); + int regionOffsetD = offsetD + 2; + int regionSizeD = dataD.size() - 4; + snap.addMergeRegion(regionOffsetD, + regionSizeD, + SnapshotDataType::Raw, + SnapshotMergeOperation::Overwrite); + + // Write some data to the region that exceeds the size of the original, then + // add a merge region larger than it. Anything outside the original snapshot + // should be marked as changed. + std::vector dataExtra = { 2, 2, 2 }; + std::vector expectedDataExtra = { 0, 0, 2, 2, 2, 0, 0, 0 }; + int extraOffset = snapSize + HOST_PAGE_SIZE + 10; + std::memcpy(sharedMem + extraOffset, dataExtra.data(), dataExtra.size()); + + int extraRegionOffset = extraOffset - 2; + int extraRegionSize = dataExtra.size() + 4; + snap.addMergeRegion(extraRegionOffset, + extraRegionSize, + SnapshotDataType::Raw, + SnapshotMergeOperation::Overwrite); // Include an offset which doesn't change the data, but will register a // dirty page @@ -73,25 +143,20 @@ TEST_CASE_METHOD(SnapshotTestFixture, "Test snapshot diffs", "[snapshot]") // Check shared memory does have dirty pages (including the non-change) std::vector sharedDirtyPages = getDirtyPageNumbers(sharedMem, sharedMemPages); - std::vector expected = { 1, 2, 3, 5, 6, 7 }; + std::vector expected = { 1, 2, 3, 6 }; REQUIRE(sharedDirtyPages == expected); - // Check change diffs note that diffs across page boundaries will be split - // into two + // Check we have the right number of diffs std::vector changeDiffs = snap.getChangeDiffs(sharedMem, sharedMemSize); - REQUIRE(changeDiffs.size() == 6); - // One chunk will be split over 2 pages - std::vector dataCPart1 = { 7, 6 }; - std::vector dataCPart2 = { 5, 4, 3 }; - int offsetC2 = 2 * HOST_PAGE_SIZE; + REQUIRE(changeDiffs.size() == 6); checkSnapshotDiff(offsetA, dataA, changeDiffs.at(0)); - checkSnapshotDiff(offsetB, dataB, changeDiffs.at(1)); - checkSnapshotDiff(offsetC, dataCPart1, changeDiffs.at(2)); - checkSnapshotDiff(offsetC2, dataCPart2, changeDiffs.at(3)); - checkSnapshotDiff(offsetD, dataD, changeDiffs.at(4)); - checkSnapshotDiff(snapSize, dataExtra, changeDiffs.at(5)); + checkSnapshotDiff(offsetB1, dataB1, changeDiffs.at(1)); + checkSnapshotDiff(offsetB2, dataB2, changeDiffs.at(2)); + checkSnapshotDiff(offsetC, expectedDataC, changeDiffs.at(3)); + checkSnapshotDiff(regionOffsetD, expectedDataD, changeDiffs.at(4)); + checkSnapshotDiff(extraRegionOffset, expectedDataExtra, changeDiffs.at(5)); } } diff --git a/tests/test/transport/test_point_to_point.cpp b/tests/test/transport/test_point_to_point.cpp index 01d58340d..2782cdfac 100644 --- a/tests/test/transport/test_point_to_point.cpp +++ b/tests/test/transport/test_point_to_point.cpp @@ -353,8 +353,6 @@ TEST_CASE_METHOD(PointToPointClientServerFixture, server.setRequestLatch(); cli.groupLock(appId, groupId, groupIdx, recursive); server.awaitRequestLatch(); - - broker.recvMessage(groupId, POINT_TO_POINT_MASTER_IDX, groupIdx); } REQUIRE(group->getLockOwner(recursive) == groupIdx); diff --git a/tests/test/transport/test_point_to_point_groups.cpp b/tests/test/transport/test_point_to_point_groups.cpp index d0c4bfb33..757120b17 100644 --- a/tests/test/transport/test_point_to_point_groups.cpp +++ b/tests/test/transport/test_point_to_point_groups.cpp @@ -164,41 +164,85 @@ TEST_CASE_METHOD(PointToPointGroupFixture, } TEST_CASE_METHOD(PointToPointGroupFixture, - "Test local locking and unlocking", + "Test locking and unlocking", "[ptp][transport]") { - std::atomic sharedInt = 0; int appId = 123; int groupId = 234; - // Arbitrary group size, local locks don't care - auto group = setUpGroup(appId, groupId, 3); + int nThreads = 4; + int nLoops = 50; - group->localLock(); + auto group = setUpGroup(appId, groupId, nThreads); - std::thread tA([&group, &sharedInt] { - group->localLock(); + std::atomic success = true; - assert(sharedInt == 99); - sharedInt = 88; + int criticalVar = 1; - group->localUnlock(); - }); + bool useLocal = false; + bool recursive = false; + SECTION("Local-only") { useLocal = true; } - // Main thread sleep for a while, make sure the other can't run and update - // the counter - SLEEP_MS(1000); + SECTION("Distributed version non-recursive") + { + useLocal = false; + recursive = false; + } - REQUIRE(sharedInt == 0); - sharedInt.store(99); + SECTION("Distributed version recursive") + { + useLocal = false; + recursive = true; + } - group->localUnlock(); + // Create high contention on the critical var that will be detected if + // locking isn't working. + std::vector threads; + for (int i = 0; i < nThreads; i++) { + threads.emplace_back( + [useLocal, recursive, i, nLoops, &group, &criticalVar, &success] { + if (useLocal) { + group->localLock(); + } else { + group->lock(i, recursive); + } + + // Check that while in this critical section, no changes from + // other threads are visible + + criticalVar = 2; + for (int j = 0; j < nLoops; j++) { + // Set the var + criticalVar = i; + + // Sleep a bit + int sleepTimeMs = std::rand() % 30; + SLEEP_MS(sleepTimeMs); + + // Check the var is unchanged by others + if (criticalVar != i) { + SPDLOG_ERROR("Inner loop testing locking got {} != {}", + criticalVar, + i); + success = false; + } + } + + if (useLocal) { + group->localUnlock(); + } else { + group->unlock(i, recursive); + } + }); + } - if (tA.joinable()) { - tA.join(); + for (auto& t : threads) { + if (t.joinable()) { + t.join(); + } } - REQUIRE(sharedInt == 88); + REQUIRE(success); } TEST_CASE_METHOD(PointToPointGroupFixture, diff --git a/tests/test/util/test_locks.cpp b/tests/test/util/test_locks.cpp new file mode 100644 index 000000000..839adbf85 --- /dev/null +++ b/tests/test/util/test_locks.cpp @@ -0,0 +1,82 @@ +#include +#include + +#include +#include +#include + +using namespace faabric::util; + +namespace tests { + +TEST_CASE("Test wait flag", "[util]") +{ + int nThreads = 10; + + FlagWaiter flagA; + FlagWaiter flagB; + + std::shared_ptr latchA1 = Latch::create(nThreads + 1); + std::shared_ptr latchB1 = Latch::create(nThreads + 1); + + std::shared_ptr latchA2 = Latch::create(nThreads + 1); + std::shared_ptr latchB2 = Latch::create(nThreads + 1); + + std::vector expectedUnset(nThreads, 0); + std::vector expectedSet; + + std::vector threadsA; + std::vector threadsB; + + std::vector resultsA(nThreads, 0); + std::vector resultsB(nThreads, 0); + + for (int i = 0; i < nThreads; i++) { + expectedSet.push_back(i); + + threadsA.emplace_back([&flagA, &resultsA, &latchA1, &latchA2, i] { + latchA1->wait(); + flagA.waitOnFlag(); + resultsA[i] = i; + latchA2->wait(); + }); + + threadsB.emplace_back([&flagB, &resultsB, &latchB1, &latchB2, i] { + latchB1->wait(); + flagB.waitOnFlag(); + resultsB[i] = i; + latchB2->wait(); + }); + } + + // Check no results are set initially + latchA1->wait(); + latchB1->wait(); + REQUIRE(resultsA == expectedUnset); + REQUIRE(resultsB == expectedUnset); + + // Set one flag, await latch and check + flagA.setFlag(true); + latchA2->wait(); + REQUIRE(resultsA == expectedSet); + REQUIRE(resultsB == expectedUnset); + + // Set other flag, await latch and check + flagB.setFlag(true); + latchB2->wait(); + REQUIRE(resultsA == expectedSet); + REQUIRE(resultsB == expectedSet); + + for (auto& t : threadsA) { + if (t.joinable()) { + t.join(); + } + } + + for (auto& t : threadsB) { + if (t.joinable()) { + t.join(); + } + } +} +} diff --git a/tests/test/util/test_snapshot.cpp b/tests/test/util/test_snapshot.cpp index 2a3701ae0..2f5e06007 100644 --- a/tests/test/util/test_snapshot.cpp +++ b/tests/test/util/test_snapshot.cpp @@ -66,7 +66,7 @@ class SnapshotMergeTestFixture : public SnapshotTestFixture TEST_CASE_METHOD(SnapshotMergeTestFixture, "Detailed test snapshot merge regions with ints", - "[util]") + "[snapshot][util]") { std::string snapKey = "foobar123"; int snapPages = 5; @@ -141,43 +141,34 @@ TEST_CASE_METHOD(SnapshotMergeTestFixture, REQUIRE(*intBOriginal == originalValueB); // Check diffs themselves - REQUIRE(actualDiffs.size() == 3); + REQUIRE(actualDiffs.size() == 2); SnapshotDiff diffA = actualDiffs.at(0); SnapshotDiff diffB = actualDiffs.at(1); - SnapshotDiff diffOther = actualDiffs.at(2); REQUIRE(diffA.offset == intAOffset); REQUIRE(diffB.offset == intBOffset); - REQUIRE(diffOther.offset == otherOffset); REQUIRE(diffA.operation == SnapshotMergeOperation::Sum); REQUIRE(diffB.operation == SnapshotMergeOperation::Sum); - REQUIRE(diffOther.operation == SnapshotMergeOperation::Overwrite); REQUIRE(diffA.dataType == SnapshotDataType::Int); REQUIRE(diffB.dataType == SnapshotDataType::Int); - REQUIRE(diffOther.dataType == SnapshotDataType::Raw); REQUIRE(diffA.size == sizeof(int32_t)); REQUIRE(diffB.size == sizeof(int32_t)); - REQUIRE(diffOther.size == otherData.size()); // Check that original values have been subtracted from final values for // sums REQUIRE(*(int*)diffA.data == sumValueA); REQUIRE(*(int*)diffB.data == sumValueB); - std::vector actualOtherData(diffOther.data, - diffOther.data + diffOther.size); - REQUIRE(actualOtherData == otherData); - deallocatePages(snap.data, snapPages); } TEST_CASE_METHOD(SnapshotMergeTestFixture, "Test edge-cases of snapshot merge regions", - "[util]") + "[snapshot][util]") { // Region edge cases: // - start @@ -286,7 +277,7 @@ TEST_CASE_METHOD(SnapshotMergeTestFixture, TEST_CASE_METHOD(SnapshotMergeTestFixture, "Test snapshot merge regions", - "[util]") + "[snapshot][util]") { std::string snapKey = "foobar123"; int snapPages = 5; @@ -307,8 +298,6 @@ TEST_CASE_METHOD(SnapshotMergeTestFixture, faabric::util::SnapshotMergeOperation::Overwrite; size_t dataLength = 0; - int expectedNumDiffs = 1; - SECTION("Integer") { int originalValue = 0; @@ -370,25 +359,16 @@ TEST_CASE_METHOD(SnapshotMergeTestFixture, SECTION("Raw") { - dataLength = 2 * sizeof(int32_t); + dataLength = 100; originalData = std::vector(dataLength, 3); - updatedData = originalData; - expectedData = originalData; + updatedData = std::vector(dataLength, 4); + expectedData = updatedData; dataType = faabric::util::SnapshotDataType::Raw; - operation = faabric::util::SnapshotMergeOperation::Overwrite; - SECTION("Ignore") + SECTION("Overwrite") { - operation = faabric::util::SnapshotMergeOperation::Ignore; - - // Scatter some modifications through the updated data, to make sure - // none are picked up - updatedData[0] = 1; - updatedData[sizeof(int32_t) - 2] = 1; - updatedData[sizeof(int32_t) + 10] = 1; - - expectedNumDiffs = 0; + operation = faabric::util::SnapshotMergeOperation::Overwrite; } } @@ -417,25 +397,22 @@ TEST_CASE_METHOD(SnapshotMergeTestFixture, std::vector actualDiffs = snap.getChangeDiffs(sharedMem, sharedMemSize); - // Check number of diffs - REQUIRE(actualDiffs.size() == expectedNumDiffs); - - if (expectedNumDiffs == 1) { - std::vector expectedDiffs = { { dataType, - operation, - offset, - expectedData.data(), - expectedData.size() } }; + // Check diff + REQUIRE(actualDiffs.size() == 1); + std::vector expectedDiffs = { { dataType, + operation, + offset, + expectedData.data(), + expectedData.size() } }; - checkDiffs(actualDiffs, expectedDiffs); - } + checkDiffs(actualDiffs, expectedDiffs); deallocatePages(snap.data, snapPages); } TEST_CASE_METHOD(SnapshotMergeTestFixture, "Test invalid snapshot merges", - "[util]") + "[snapshot][util]") { std::string snapKey = "foobar123"; int snapPages = 3; @@ -500,84 +477,9 @@ TEST_CASE_METHOD(SnapshotMergeTestFixture, deallocatePages(snap.data, snapPages); } -TEST_CASE_METHOD(SnapshotMergeTestFixture, "Test cross-page ignores", "[util]") -{ - int snapPages = 6; - size_t sharedMemSize = snapPages * HOST_PAGE_SIZE; - uint8_t* sharedMem = setUpSnapshot(snapPages); - - // Add ignore regions that cover multiple pages - int ignoreOffsetA = HOST_PAGE_SIZE + 100; - int ignoreLengthA = 2 * HOST_PAGE_SIZE; - snap.addMergeRegion(ignoreOffsetA, - ignoreLengthA, - faabric::util::SnapshotDataType::Raw, - faabric::util::SnapshotMergeOperation::Ignore); - - int ignoreOffsetB = sharedMemSize - (HOST_PAGE_SIZE + 10); - int ignoreLengthB = 30; - snap.addMergeRegion(ignoreOffsetB, - ignoreLengthB, - faabric::util::SnapshotDataType::Raw, - faabric::util::SnapshotMergeOperation::Ignore); - - // Add some modifications that will cause diffs, and some should be ignored - std::vector dataA(10, 1); - std::vector dataB(1, 1); - std::vector dataC(1, 1); - std::vector dataD(3, 1); - - // Not ignored, start of memory - uint32_t offsetA = 0; - std::memcpy(sharedMem + offsetA, dataA.data(), dataA.size()); - - // Not ignored, just before first ignore region - uint32_t offsetB = ignoreOffsetA - 1; - std::memcpy(sharedMem + offsetB, dataB.data(), dataB.size()); - - // Not ignored, just after first ignore region - uint32_t offsetC = ignoreOffsetA + ignoreLengthA; - std::memcpy(sharedMem + offsetC, dataC.data(), dataC.size()); - - // Not ignored, just before second ignore region - uint32_t offsetD = ignoreOffsetB - 4; - std::memcpy(sharedMem + offsetD, dataD.data(), dataD.size()); - - // Ignored, start of first ignore region - sharedMem[ignoreOffsetA] = (uint8_t)1; - - // Ignored, just before page boundary in ignore region - sharedMem[(2 * HOST_PAGE_SIZE) - 1] = (uint8_t)1; - - // Ignored, just after page boundary in ignore region - sharedMem[(2 * HOST_PAGE_SIZE)] = (uint8_t)1; - - // Deliberately don't put any changes after the next page boundary to check - // that it rolls over properly - - // Ignored, just inside second region - sharedMem[ignoreOffsetB + 2] = (uint8_t)1; - - // Ignored, end of second region - sharedMem[ignoreOffsetB + ignoreLengthB - 1] = (uint8_t)1; - - std::vector expectedDiffs = { - { offsetA, dataA.data(), dataA.size() }, - { offsetB, dataB.data(), dataB.size() }, - { offsetC, dataC.data(), dataC.size() }, - { offsetD, dataD.data(), dataD.size() }, - }; - - // Check number of diffs - std::vector actualDiffs = - snap.getChangeDiffs(sharedMem, sharedMemSize); - - checkDiffs(actualDiffs, expectedDiffs); -} - TEST_CASE_METHOD(SnapshotMergeTestFixture, "Test fine-grained byte-wise diffs", - "[util]") + "[snapshot][util]") { int snapPages = 3; size_t sharedMemSize = snapPages * HOST_PAGE_SIZE; @@ -601,12 +503,34 @@ TEST_CASE_METHOD(SnapshotMergeTestFixture, std::memcpy(sharedMem + offsetD, dataD.data(), dataD.size()); std::vector expectedDiffs = { - { offsetA, dataA.data(), dataA.size() }, - { offsetB, dataB.data(), dataB.size() }, - { offsetC, dataC.data(), dataC.size() }, - { offsetD, dataD.data(), dataD.size() }, + { SnapshotDataType::Raw, + SnapshotMergeOperation::Overwrite, + offsetA, + dataA.data(), + dataA.size() }, + { SnapshotDataType::Raw, + SnapshotMergeOperation::Overwrite, + offsetB, + dataB.data(), + dataB.size() }, + { SnapshotDataType::Raw, + SnapshotMergeOperation::Overwrite, + offsetC, + dataC.data(), + dataC.size() }, + { SnapshotDataType::Raw, + SnapshotMergeOperation::Overwrite, + offsetD, + dataD.data(), + dataD.size() }, }; + // Add a single merge region for all the changes + snap.addMergeRegion(0, + offsetD + dataD.size() + 20, + SnapshotDataType::Raw, + SnapshotMergeOperation::Overwrite); + // Check number of diffs std::vector actualDiffs = snap.getChangeDiffs(sharedMem, sharedMemSize); @@ -616,39 +540,40 @@ TEST_CASE_METHOD(SnapshotMergeTestFixture, TEST_CASE_METHOD(SnapshotMergeTestFixture, "Test mix of applicable and non-applicable merge regions", - "[util]") + "[snapshot][util]") { int snapPages = 6; size_t sharedMemSize = snapPages * HOST_PAGE_SIZE; uint8_t* sharedMem = setUpSnapshot(snapPages); // Add a couple of merge regions on each page, which should be skipped as - // they won't cover any changes + // they won't overlap any changes for (int i = 0; i < snapPages; i++) { - // Ignore - int iOff = i * HOST_PAGE_SIZE; - snap.addMergeRegion(iOff, + // Overwrite + int skippedOverwriteOffset = i * HOST_PAGE_SIZE; + snap.addMergeRegion(skippedOverwriteOffset, 10, faabric::util::SnapshotDataType::Raw, - faabric::util::SnapshotMergeOperation::Ignore); + faabric::util::SnapshotMergeOperation::Overwrite); // Sum - int sOff = ((i + 1) * HOST_PAGE_SIZE) - (2 * sizeof(int32_t)); - snap.addMergeRegion(sOff, + int skippedSumOffset = + ((i + 1) * HOST_PAGE_SIZE) - (2 * sizeof(int32_t)); + snap.addMergeRegion(skippedSumOffset, sizeof(int32_t), faabric::util::SnapshotDataType::Int, faabric::util::SnapshotMergeOperation::Sum); } - // Add an ignore region that should take effect, along with a corresponding - // change to be ignored - uint32_t ignoreA = (2 * HOST_PAGE_SIZE) + 2; - snap.addMergeRegion(ignoreA, + // Add an overwrite region that should take effect + uint32_t overwriteAOffset = (2 * HOST_PAGE_SIZE) + 20; + snap.addMergeRegion(overwriteAOffset, 20, faabric::util::SnapshotDataType::Raw, - faabric::util::SnapshotMergeOperation::Ignore); - std::vector dataA(10, 1); - std::memcpy(sharedMem + ignoreA, dataA.data(), dataA.size()); + faabric::util::SnapshotMergeOperation::Overwrite); + std::vector overwriteData(10, 1); + std::memcpy( + sharedMem + overwriteAOffset, overwriteData.data(), overwriteData.size()); // Add a sum region and data that should also take effect uint32_t sumOffset = (4 * HOST_PAGE_SIZE) + 100; @@ -664,6 +589,11 @@ TEST_CASE_METHOD(SnapshotMergeTestFixture, // Check diffs std::vector expectedDiffs = { + { faabric::util::SnapshotDataType::Raw, + faabric::util::SnapshotMergeOperation::Overwrite, + overwriteAOffset, + BYTES(overwriteData.data()), + overwriteData.size() }, { faabric::util::SnapshotDataType::Int, faabric::util::SnapshotMergeOperation::Sum, sumOffset,