From 9f586903229bd5446c0e9e1c493e840cff00b673 Mon Sep 17 00:00:00 2001 From: Alexander Block Date: Sun, 17 Feb 2019 12:38:56 +0100 Subject: [PATCH] Avoid using ordered maps in LLMQ signing code (#2708) * Implement and use SigShareMap instead of ordered map with helper methods The old implementation was relying on the maps being ordered, which allowed us to grab all sig shares for the same signHash by doing range queries on the map. This has the disadvantage of being unnecessarily slow when the maps get larger. Using an unordered map would be the naive solution, but then it's not possible to query by range anymore. The solution now is to have a specialized map "SigShareMap" which is indexed by "SigShareKey". It's internally just an unordered map, indexed by the sign hash and another unordered map for the value, indexed by the quorum member index. * Only use unordered maps/sets in CSigSharesManager These are faster when maps/sets get larger. * Use unorderes sets/maps in CSigningManager --- src/llmq/quorums_signing.cpp | 13 +- src/llmq/quorums_signing.h | 4 +- src/llmq/quorums_signing_shares.cpp | 236 +++++++++++++--------------- src/llmq/quorums_signing_shares.h | 199 +++++++++++++++++++++-- 4 files changed, 298 insertions(+), 154 deletions(-) diff --git a/src/llmq/quorums_signing.cpp b/src/llmq/quorums_signing.cpp index 9530c7955f97c..b8a364c54edda 100644 --- a/src/llmq/quorums_signing.cpp +++ b/src/llmq/quorums_signing.cpp @@ -17,6 +17,7 @@ #include #include +#include namespace llmq { @@ -357,8 +358,8 @@ bool CSigningManager::PreVerifyRecoveredSig(NodeId nodeId, const CRecoveredSig& void CSigningManager::CollectPendingRecoveredSigsToVerify( size_t maxUniqueSessions, - std::map>& retSigShares, - std::map, CQuorumCPtr>& retQuorums) + std::unordered_map>& retSigShares, + std::unordered_map, CQuorumCPtr>& retQuorums) { { LOCK(cs); @@ -366,7 +367,7 @@ void CSigningManager::CollectPendingRecoveredSigsToVerify( return; } - std::set> uniqueSignHashes; + std::unordered_set> uniqueSignHashes; CLLMQUtils::IterateNodesRandom(pendingRecoveredSigs, [&]() { return uniqueSignHashes.size() < maxUniqueSessions; }, [&](NodeId nodeId, std::list& ns) { @@ -423,8 +424,8 @@ void CSigningManager::CollectPendingRecoveredSigsToVerify( bool CSigningManager::ProcessPendingRecoveredSigs(CConnman& connman) { - std::map> recSigsByNode; - std::map, CQuorumCPtr> quorums; + std::unordered_map> recSigsByNode; + std::unordered_map, CQuorumCPtr> quorums; CollectPendingRecoveredSigsToVerify(32, recSigsByNode, quorums); if (recSigsByNode.empty()) { @@ -453,7 +454,7 @@ bool CSigningManager::ProcessPendingRecoveredSigs(CConnman& connman) LogPrint("llmq", "CSigningManager::%s -- verified recovered sig(s). count=%d, vt=%d, nodes=%d\n", __func__, verifyCount, verifyTimer.count(), recSigsByNode.size()); - std::set processed; + std::unordered_set processed; for (auto& p : recSigsByNode) { NodeId nodeId = p.first; auto& v = p.second; diff --git a/src/llmq/quorums_signing.h b/src/llmq/quorums_signing.h index ac40bd43978d1..03bfb2ad520dc 100644 --- a/src/llmq/quorums_signing.h +++ b/src/llmq/quorums_signing.h @@ -135,7 +135,7 @@ class CSigningManager CRecoveredSigsDb db; // Incoming and not verified yet - std::map> pendingRecoveredSigs; + std::unordered_map> pendingRecoveredSigs; // must be protected by cs FastRandomContext rnd; @@ -156,7 +156,7 @@ class CSigningManager void ProcessMessageRecoveredSig(CNode* pfrom, const CRecoveredSig& recoveredSig, CConnman& connman); bool PreVerifyRecoveredSig(NodeId nodeId, const CRecoveredSig& recoveredSig, bool& retBan); - void CollectPendingRecoveredSigsToVerify(size_t maxUniqueSessions, std::map>& retSigShares, std::map, CQuorumCPtr>& retQuorums); + void CollectPendingRecoveredSigsToVerify(size_t maxUniqueSessions, std::unordered_map>& retSigShares, std::unordered_map, CQuorumCPtr>& retQuorums); bool ProcessPendingRecoveredSigs(CConnman& connman); // called from the worker thread of CSigSharesManager void ProcessRecoveredSig(NodeId nodeId, const CRecoveredSig& recoveredSig, const CQuorumCPtr& quorum, CConnman& connman); void Cleanup(); // called from the worker thread of CSigSharesManager diff --git a/src/llmq/quorums_signing_shares.cpp b/src/llmq/quorums_signing_shares.cpp index b316f7440ab3b..cf5a916fcdb44 100644 --- a/src/llmq/quorums_signing_shares.cpp +++ b/src/llmq/quorums_signing_shares.cpp @@ -20,33 +20,6 @@ namespace llmq CSigSharesManager* quorumSigSharesManager = nullptr; -template -static std::pair FindBySignHash(const M& m, const uint256& signHash) -{ - return std::make_pair( - m.lower_bound(std::make_pair(signHash, (uint16_t)0)), - m.upper_bound(std::make_pair(signHash, std::numeric_limits::max())) - ); -} -template -static size_t CountBySignHash(const M& m, const uint256& signHash) -{ - auto itPair = FindBySignHash(m, signHash); - size_t count = 0; - while (itPair.first != itPair.second) { - count++; - ++itPair.first; - } - return count; -} - -template -static void RemoveBySignHash(M& m, const uint256& signHash) -{ - auto itPair = FindBySignHash(m, signHash); - m.erase(itPair.first, itPair.second); -} - void CSigShare::UpdateKey() { key.first = CLLMQUtils::BuildSignHash(*this); @@ -157,9 +130,8 @@ void CSigSharesNodeState::MarkKnows(Consensus::LLMQType llmqType, const uint256& void CSigSharesNodeState::RemoveSession(const uint256& signHash) { sessions.erase(signHash); - //pendingIncomingRecSigs.erase(signHash); - RemoveBySignHash(requestedSigShares, signHash); - RemoveBySignHash(pendingIncomingSigShares, signHash); + requestedSigShares.EraseAllForSignHash(signHash); + pendingIncomingSigShares.EraseAllForSignHash(signHash); } CSigSharesInv CBatchedSigShares::ToInv() const @@ -304,13 +276,13 @@ void CSigSharesManager::ProcessMessageBatchedSigShares(CNode* pfrom, const CBatc for (size_t i = 0; i < batchedSigShares.sigShares.size(); i++) { CSigShare sigShare = batchedSigShares.RebuildSigShare(i); - nodeState.requestedSigShares.erase(sigShare.GetKey()); + nodeState.requestedSigShares.Erase(sigShare.GetKey()); // TODO track invalid sig shares received for PoSe? // It's important to only skip seen *valid* sig shares here. If a node sends us a // batch of mostly valid sig shares with a single invalid one and thus batched // verification fails, we'd skip the valid ones in the future if received from other nodes - if (this->sigShares.count(sigShare.GetKey())) { + if (this->sigShares.Has(sigShare.GetKey())) { continue; } @@ -333,7 +305,7 @@ void CSigSharesManager::ProcessMessageBatchedSigShares(CNode* pfrom, const CBatc LOCK(cs); auto& nodeState = nodeStates[pfrom->id]; for (auto& s : sigShares) { - nodeState.pendingIncomingSigShares.emplace(s.GetKey(), s); + nodeState.pendingIncomingSigShares.Add(s.GetKey(), s); } } @@ -369,7 +341,7 @@ bool CSigSharesManager::PreVerifyBatchedSigShares(NodeId nodeId, const CBatchedS return false; } - std::set dupMembers; + std::unordered_set dupMembers; for (size_t i = 0; i < batchedSigShares.sigShares.size(); i++) { auto quorumMember = batchedSigShares.sigShares[i].first; @@ -394,8 +366,8 @@ bool CSigSharesManager::PreVerifyBatchedSigShares(NodeId nodeId, const CBatchedS void CSigSharesManager::CollectPendingSigSharesToVerify( size_t maxUniqueSessions, - std::map>& retSigShares, - std::map, CQuorumCPtr>& retQuorums) + std::unordered_map>& retSigShares, + std::unordered_map, CQuorumCPtr>& retQuorums) { { LOCK(cs); @@ -409,22 +381,22 @@ void CSigSharesManager::CollectPendingSigSharesToVerify( // invalid, making batch verification fail and revert to per-share verification, which in turn would slow down // the whole verification process - std::set> uniqueSignHashes; + std::unordered_set> uniqueSignHashes; CLLMQUtils::IterateNodesRandom(nodeStates, [&]() { return uniqueSignHashes.size() < maxUniqueSessions; }, [&](NodeId nodeId, CSigSharesNodeState& ns) { - if (ns.pendingIncomingSigShares.empty()) { + if (ns.pendingIncomingSigShares.Empty()) { return false; } - auto& sigShare = ns.pendingIncomingSigShares.begin()->second; + auto& sigShare = *ns.pendingIncomingSigShares.GetFirst(); - bool alreadyHave = this->sigShares.count(sigShare.GetKey()) != 0; + bool alreadyHave = this->sigShares.Has(sigShare.GetKey()); if (!alreadyHave) { uniqueSignHashes.emplace(nodeId, sigShare.GetSignHash()); retSigShares[nodeId].emplace_back(sigShare); } - ns.pendingIncomingSigShares.erase(ns.pendingIncomingSigShares.begin()); - return !ns.pendingIncomingSigShares.empty(); + ns.pendingIncomingSigShares.Erase(sigShare.GetKey()); + return !ns.pendingIncomingSigShares.Empty(); }, rnd); if (retSigShares.empty()) { @@ -456,8 +428,8 @@ void CSigSharesManager::CollectPendingSigSharesToVerify( bool CSigSharesManager::ProcessPendingSigShares(CConnman& connman) { - std::map> sigSharesByNodes; - std::map, CQuorumCPtr> quorums; + std::unordered_map> sigSharesByNodes; + std::unordered_map, CQuorumCPtr> quorums; CollectPendingSigSharesToVerify(32, sigSharesByNodes, quorums); if (sigSharesByNodes.empty()) { @@ -526,7 +498,7 @@ bool CSigSharesManager::ProcessPendingSigShares(CConnman& connman) } // It's ensured that no duplicates are passed to this method -void CSigSharesManager::ProcessPendingSigSharesFromNode(NodeId nodeId, const std::vector& sigShares, const std::map, CQuorumCPtr>& quorums, CConnman& connman) +void CSigSharesManager::ProcessPendingSigSharesFromNode(NodeId nodeId, const std::vector& sigShares, const std::unordered_map, CQuorumCPtr>& quorums, CConnman& connman) { auto& nodeState = nodeStates[nodeId]; @@ -570,12 +542,12 @@ void CSigSharesManager::ProcessSigShare(NodeId nodeId, const CSigShare& sigShare { LOCK(cs); - - if (!sigShares.emplace(sigShare.GetKey(), sigShare).second) { + + if (!sigShares.Add(sigShare.GetKey(), sigShare)) { return; } - sigSharesToAnnounce.emplace(sigShare.GetKey()); + sigSharesToAnnounce.Add(sigShare.GetKey(), true); firstSeenForSessions.emplace(sigShare.GetSignHash(), GetTimeMillis()); if (!quorumNodes.empty()) { @@ -591,7 +563,7 @@ void CSigSharesManager::ProcessSigShare(NodeId nodeId, const CSigShare& sigShare } } - size_t sigShareCount = CountBySignHash(sigShares, sigShare.GetSignHash()); + size_t sigShareCount = sigShares.CountForSignHash(sigShare.GetSignHash()); if (sigShareCount >= quorum->params.threshold) { canTryRecovery = true; } @@ -616,11 +588,14 @@ void CSigSharesManager::TryRecoverSig(const CQuorumCPtr& quorum, const uint256& auto k = std::make_pair(quorum->params.type, id); auto signHash = CLLMQUtils::BuildSignHash(quorum->params.type, quorum->quorumHash, id, msgHash); - auto itPair = FindBySignHash(sigShares, signHash); + auto sigShares = this->sigShares.GetAllForSignHash(signHash); + if (!sigShares) { + return; + } sigSharesForRecovery.reserve((size_t) quorum->params.threshold); idsForRecovery.reserve((size_t) quorum->params.threshold); - for (auto it = itPair.first; it != itPair.second && sigSharesForRecovery.size() < quorum->params.threshold; ++it) { + for (auto it = sigShares->begin(); it != sigShares->end() && sigSharesForRecovery.size() < quorum->params.threshold; ++it) { auto& sigShare = it->second; sigSharesForRecovery.emplace_back(sigShare.sigShare.GetSig()); idsForRecovery.emplace_back(CBLSId::FromHash(quorum->members[sigShare.quorumMember]->proTxHash)); @@ -665,10 +640,10 @@ void CSigSharesManager::TryRecoverSig(const CQuorumCPtr& quorum, const uint256& } // cs must be held -void CSigSharesManager::CollectSigSharesToRequest(std::map>& sigSharesToRequest) +void CSigSharesManager::CollectSigSharesToRequest(std::unordered_map>& sigSharesToRequest) { int64_t now = GetTimeMillis(); - std::map> nodesBySigShares; + std::unordered_map> nodesBySigShares; const size_t maxRequestsForNode = 32; @@ -690,18 +665,17 @@ void CSigSharesManager::CollectSigSharesToRequest(std::mapsecond >= SIG_SHARE_REQUEST_TIMEOUT) { + nodeState.requestedSigShares.EraseIf([&](const SigShareKey& k, int64_t t) { + if (now - t >= SIG_SHARE_REQUEST_TIMEOUT) { // timeout while waiting for this one, so retry it with another node - LogPrint("llmq", "CSigSharesManager::%s -- timeout while waiting for %s-%d, node=%d\n", __func__, - it->first.first.ToString(), it->first.second, nodeId); - it = nodeState.requestedSigShares.erase(it); - } else { - ++it; + LogPrint("llmq", "CSigSharesManager::CollectSigSharesToRequest -- timeout while waiting for %s-%d, node=%d\n", + k.first.ToString(), k.second, nodeId); + return true; } - } + return false; + }); - std::map* invMap = nullptr; + std::unordered_map* invMap = nullptr; for (auto& p2 : nodeState.sessions) { auto& signHash = p2.first; @@ -716,21 +690,21 @@ void CSigSharesManager::CollectSigSharesToRequest(std::map= maxRequestsForNode) { + if (nodeState.requestedSigShares.Size() >= maxRequestsForNode) { // too many pending requests for this node break; } - auto it = sigSharesRequested.find(k); - if (it != sigSharesRequested.end()) { - if (now - it->second.second >= SIG_SHARE_REQUEST_TIMEOUT && nodeId != it->second.first) { + auto p = sigSharesRequested.Get(k); + if (p) { + if (now - p->second >= SIG_SHARE_REQUEST_TIMEOUT && nodeId != p->first) { // other node timed out, re-request from this node LogPrint("llmq", "CSigSharesManager::%s -- other node timeout while waiting for %s-%d, re-request from=%d, node=%d\n", __func__, - it->first.first.ToString(), it->first.second, nodeId, it->second.first); + k.first.ToString(), k.second, nodeId, p->first); } else { continue; } @@ -738,10 +712,10 @@ void CSigSharesManager::CollectSigSharesToRequest(std::map>& sigSharesToSend) +void CSigSharesManager::CollectSigSharesToSend(std::unordered_map>& sigSharesToSend) { for (auto& p : nodeStates) { auto nodeId = p.first; @@ -772,7 +746,7 @@ void CSigSharesManager::CollectSigSharesToSend(std::map* sigSharesToSend2 = nullptr; + std::unordered_map* sigSharesToSend2 = nullptr; for (auto& p2 : nodeState.sessions) { auto& signHash = p2.first; @@ -791,21 +765,20 @@ void CSigSharesManager::CollectSigSharesToSend(std::mapsecond; if (batchedSigShares.sigShares.empty()) { - batchedSigShares.llmqType = sigShare.llmqType; - batchedSigShares.quorumHash = sigShare.quorumHash; - batchedSigShares.id = sigShare.id; - batchedSigShares.msgHash = sigShare.msgHash; + batchedSigShares.llmqType = sigShare->llmqType; + batchedSigShares.quorumHash = sigShare->quorumHash; + batchedSigShares.id = sigShare->id; + batchedSigShares.msgHash = sigShare->msgHash; } - batchedSigShares.sigShares.emplace_back((uint16_t)i, sigShare.sigShare); + batchedSigShares.sigShares.emplace_back((uint16_t)i, sigShare->sigShare); } if (!batchedSigShares.sigShares.empty()) { @@ -820,20 +793,19 @@ void CSigSharesManager::CollectSigSharesToSend(std::map>& sigSharesToAnnounce) +void CSigSharesManager::CollectSigSharesToAnnounce(std::unordered_map>& sigSharesToAnnounce) { - std::set> quorumNodesPrepared; + std::unordered_set> quorumNodesPrepared; - for (auto& sigShareKey : this->sigSharesToAnnounce) { + this->sigSharesToAnnounce.ForEach([&](const SigShareKey& sigShareKey, bool) { auto& signHash = sigShareKey.first; auto quorumMember = sigShareKey.second; - auto sigShareIt = sigShares.find(sigShareKey); - if (sigShareIt == sigShares.end()) { - continue; + const CSigShare* sigShare = sigShares.Get(sigShareKey); + if (!sigShare) { + return; } - auto& sigShare = sigShareIt->second; - auto quorumKey = std::make_pair((Consensus::LLMQType)sigShare.llmqType, sigShare.quorumHash); + auto quorumKey = std::make_pair((Consensus::LLMQType)sigShare->llmqType, sigShare->quorumHash); if (quorumNodesPrepared.emplace(quorumKey).second) { // make sure we announce to at least the nodes which we know through the inter-quorum-communication system auto nodeIds = g_connman->GetMasternodeQuorumNodes(quorumKey.first, quorumKey.second); @@ -860,7 +832,7 @@ void CSigSharesManager::CollectSigSharesToAnnounce(std::mapllmqType, signHash); if (session.knows.inv[quorumMember]) { // he already knows that one @@ -869,18 +841,18 @@ void CSigSharesManager::CollectSigSharesToAnnounce(std::mapllmqType, signHash); } inv.inv[quorumMember] = true; session.knows.inv[quorumMember] = true; } - } + }); // don't announce these anymore // nodes which did not send us a valid sig share before were left out now, but this is ok as it only results in slower // propagation for the first signing session of a fresh quorum. The sig shares should still arrive on all nodes due to // the deterministic inter-quorum-communication system - this->sigSharesToAnnounce.clear(); + this->sigSharesToAnnounce.Clear(); } bool CSigSharesManager::SendMessages() @@ -890,9 +862,9 @@ bool CSigSharesManager::SendMessages() nodesByAddress.emplace(pnode->addr, pnode->id); }); - std::map> sigSharesToRequest; - std::map> sigSharesToSend; - std::map> sigSharesToAnnounce; + std::unordered_map> sigSharesToRequest; + std::unordered_map> sigSharesToSend; + std::unordered_map> sigSharesToAnnounce; { LOCK(cs); @@ -956,13 +928,13 @@ void CSigSharesManager::Cleanup() return; } - std::set> quorumsToCheck; + std::unordered_set> quorumsToCheck; { LOCK(cs); // Remove sessions which timed out - std::set timeoutSessions; + std::unordered_set timeoutSessions; for (auto& p : firstSeenForSessions) { auto& signHash = p.first; int64_t time = p.second; @@ -972,13 +944,15 @@ void CSigSharesManager::Cleanup() } } for (auto& signHash : timeoutSessions) { - size_t count = CountBySignHash(sigShares, signHash); + size_t count = sigShares.CountForSignHash(signHash); if (count > 0) { - auto itPair = FindBySignHash(sigShares, signHash); - auto& firstSigShare = itPair.first->second; + auto m = sigShares.GetAllForSignHash(signHash); + assert(m); + + auto& oneSigShare = m->begin()->second; LogPrintf("CSigSharesManager::%s -- signing session timed out. signHash=%s, id=%s, msgHash=%s, sigShareCount=%d\n", __func__, - signHash.ToString(), firstSigShare.id.ToString(), firstSigShare.msgHash.ToString(), count); + signHash.ToString(), oneSigShare.id.ToString(), oneSigShare.msgHash.ToString(), count); } else { LogPrintf("CSigSharesManager::%s -- signing session timed out. signHash=%s, sigShareCount=%d\n", __func__, signHash.ToString(), count); @@ -987,22 +961,22 @@ void CSigSharesManager::Cleanup() } // Remove sessions which were succesfully recovered - std::set doneSessions; - for (auto& p : sigShares) { - if (doneSessions.count(p.second.GetSignHash())) { - continue; + std::unordered_set doneSessions; + sigShares.ForEach([&](const SigShareKey& k, const CSigShare& sigShare) { + if (doneSessions.count(sigShare.GetSignHash())) { + return; } - if (quorumSigningManager->HasRecoveredSigForSession(p.second.GetSignHash())) { - doneSessions.emplace(p.second.GetSignHash()); + if (quorumSigningManager->HasRecoveredSigForSession(sigShare.GetSignHash())) { + doneSessions.emplace(sigShare.GetSignHash()); } - } + }); for (auto& signHash : doneSessions) { RemoveSigSharesForSession(signHash); } - for (auto& p : sigShares) { - quorumsToCheck.emplace((Consensus::LLMQType)p.second.llmqType, p.second.quorumHash); - } + sigShares.ForEach([&](const SigShareKey& k, const CSigShare& sigShare) { + quorumsToCheck.emplace((Consensus::LLMQType) sigShare.llmqType, sigShare.quorumHash); + }); } // Find quorums which became inactive @@ -1017,19 +991,19 @@ void CSigSharesManager::Cleanup() { // Now delete sessions which are for inactive quorums LOCK(cs); - std::set inactiveQuorumSessions; - for (auto& p : sigShares) { - if (quorumsToCheck.count(std::make_pair((Consensus::LLMQType)p.second.llmqType, p.second.quorumHash))) { - inactiveQuorumSessions.emplace(p.second.GetSignHash()); + std::unordered_set inactiveQuorumSessions; + sigShares.ForEach([&](const SigShareKey& k, const CSigShare& sigShare) { + if (quorumsToCheck.count(std::make_pair((Consensus::LLMQType)sigShare.llmqType, sigShare.quorumHash))) { + inactiveQuorumSessions.emplace(sigShare.GetSignHash()); } - } + }); for (auto& signHash : inactiveQuorumSessions) { RemoveSigSharesForSession(signHash); } } // Find node states for peers that disappeared from CConnman - std::set nodeStatesToDelete; + std::unordered_set nodeStatesToDelete; for (auto& p : nodeStates) { nodeStatesToDelete.emplace(p.first); } @@ -1042,9 +1016,9 @@ void CSigSharesManager::Cleanup() for (auto nodeId : nodeStatesToDelete) { auto& nodeState = nodeStates[nodeId]; // remove global requested state to force a re-request from another node - for (auto& p : nodeState.requestedSigShares) { - sigSharesRequested.erase(p.first); - } + nodeState.requestedSigShares.ForEach([&](const SigShareKey& k, bool) { + sigSharesRequested.Erase(k); + }); nodeStates.erase(nodeId); } @@ -1058,9 +1032,9 @@ void CSigSharesManager::RemoveSigSharesForSession(const uint256& signHash) ns.RemoveSession(signHash); } - RemoveBySignHash(sigSharesRequested, signHash); - RemoveBySignHash(sigSharesToAnnounce, signHash); - RemoveBySignHash(sigShares, signHash); + sigSharesRequested.EraseAllForSignHash(signHash); + sigSharesToAnnounce.EraseAllForSignHash(signHash); + sigShares.EraseAllForSignHash(signHash); firstSeenForSessions.erase(signHash); } @@ -1069,13 +1043,13 @@ void CSigSharesManager::RemoveBannedNodeStates() // Called regularly to cleanup local node states for banned nodes LOCK2(cs_main, cs); - std::set toRemove; + std::unordered_set toRemove; for (auto it = nodeStates.begin(); it != nodeStates.end();) { if (IsBanned(it->first)) { // re-request sigshares from other nodes - for (auto& p : it->second.requestedSigShares) { - sigSharesRequested.erase(p.first); - } + it->second.requestedSigShares.ForEach([&](const SigShareKey& k, int64_t) { + sigSharesRequested.Erase(k); + }); it = nodeStates.erase(it); } else { ++it; @@ -1102,10 +1076,10 @@ void CSigSharesManager::BanNode(NodeId nodeId) auto& nodeState = it->second; // Whatever we requested from him, let's request it from someone else now - for (auto& p : nodeState.requestedSigShares) { - sigSharesRequested.erase(p.first); - } - nodeState.requestedSigShares.clear(); + nodeState.requestedSigShares.ForEach([&](const SigShareKey& k, int64_t) { + sigSharesRequested.Erase(k); + }); + nodeState.requestedSigShares.Clear(); nodeState.banned = true; } diff --git a/src/llmq/quorums_signing_shares.h b/src/llmq/quorums_signing_shares.h index fb25f52f4b09f..e4e5e0994e446 100644 --- a/src/llmq/quorums_signing_shares.h +++ b/src/llmq/quorums_signing_shares.h @@ -18,15 +18,39 @@ #include #include +#include +#include class CEvoDB; class CScheduler; namespace llmq { - // typedef std::pair SigShareKey; +} + +namespace std { + template <> + struct hash + { + std::size_t operator()(const llmq::SigShareKey& k) const + { + return (std::size_t)((k.second + 1) * k.first.GetCheapHash()); + } + }; + template <> + struct hash> + { + std::size_t operator()(const std::pair& k) const + { + return (std::size_t)((k.first + 1) * k.second.GetCheapHash()); + } + }; +} + +namespace llmq +{ // this one does not get transmitted over the wire as it is batched inside CBatchedSigShares class CSigShare @@ -130,6 +154,151 @@ class CBatchedSigShares CSigSharesInv ToInv() const; }; +template +class SigShareMap +{ +private: + std::unordered_map> internalMap; + +public: + bool Add(const SigShareKey& k, const T& v) + { + auto& m = internalMap[k.first]; + return m.emplace(k.second, v).second; + } + + void Erase(const SigShareKey& k) + { + auto it = internalMap.find(k.first); + if (it == internalMap.end()) { + return; + } + it->second.erase(k.second); + if (it->second.empty()) { + internalMap.erase(it); + } + } + + void Clear() + { + internalMap.clear(); + } + + bool Has(const SigShareKey& k) const + { + auto it = internalMap.find(k.first); + if (it == internalMap.end()) { + return false; + } + return it->second.count(k.second) != 0; + } + + T* Get(const SigShareKey& k) + { + auto it = internalMap.find(k.first); + if (it == internalMap.end()) { + return nullptr; + } + + auto jt = it->second.find(k.second); + if (jt == it->second.end()) { + return nullptr; + } + + return &jt->second; + } + + T& GetOrAdd(const SigShareKey& k) + { + T* v = Get(k); + if (!v) { + Add(k, T()); + v = Get(k); + } + return *v; + } + + const T* GetFirst() const + { + if (internalMap.empty()) { + return nullptr; + } + return &internalMap.begin()->second.begin()->second; + } + + size_t Size() const + { + size_t s = 0; + for (auto& p : internalMap) { + s += p.second.size(); + } + return s; + } + + size_t CountForSignHash(const uint256& signHash) const + { + auto it = internalMap.find(signHash); + if (it == internalMap.end()) { + return 0; + } + return it->second.size(); + } + + bool Empty() const + { + return internalMap.empty(); + } + + const std::unordered_map* GetAllForSignHash(const uint256& signHash) + { + auto it = internalMap.find(signHash); + if (it == internalMap.end()) { + return nullptr; + } + return &it->second; + } + + void EraseAllForSignHash(const uint256& signHash) + { + internalMap.erase(signHash); + } + + template + void EraseIf(F&& f) + { + for (auto it = internalMap.begin(); it != internalMap.end(); ) { + SigShareKey k; + k.first = it->first; + for (auto jt = it->second.begin(); jt != it->second.end(); ) { + k.second = jt->first; + if (f(k, jt->second)) { + jt = it->second.erase(jt); + } else { + ++jt; + } + } + if (it->second.empty()) { + it = internalMap.erase(it); + } else { + ++it; + } + } + } + + template + void ForEach(F&& f) + { + for (auto& p : internalMap) { + SigShareKey k; + k.first = p.first; + for (auto& p2 : p.second) { + k.second = p2.first; + f(k, p2.second); + } + } + } +}; + class CSigSharesNodeState { public: @@ -139,14 +308,14 @@ class CSigSharesNodeState CSigSharesInv knows; }; // TODO limit number of sessions per node - std::map sessions; + std::unordered_map sessions; - std::map pendingIncomingSigShares; - std::map requestedSigShares; + SigShareMap pendingIncomingSigShares; + SigShareMap requestedSigShares; // elements are added whenever we receive a valid sig share from this node // this triggers us to send inventory items to him as he seems to be interested in these - std::set> interestedIn; + std::unordered_set> interestedIn; bool banned{false}; @@ -174,12 +343,12 @@ class CSigSharesManager std::thread workThread; CThreadInterrupt workInterrupt; - std::map sigShares; - std::map firstSeenForSessions; + SigShareMap sigShares; + std::unordered_map firstSeenForSessions; - std::map nodeStates; - std::map> sigSharesRequested; - std::set sigSharesToAnnounce; + std::unordered_map nodeStates; + SigShareMap> sigSharesRequested; + SigShareMap sigSharesToAnnounce; std::vector> pendingSigns; @@ -211,10 +380,10 @@ class CSigSharesManager bool VerifySigSharesInv(NodeId from, const CSigSharesInv& inv); bool PreVerifyBatchedSigShares(NodeId nodeId, const CBatchedSigShares& batchedSigShares, bool& retBan); - void CollectPendingSigSharesToVerify(size_t maxUniqueSessions, std::map>& retSigShares, std::map, CQuorumCPtr>& retQuorums); + void CollectPendingSigSharesToVerify(size_t maxUniqueSessions, std::unordered_map>& retSigShares, std::unordered_map, CQuorumCPtr>& retQuorums); bool ProcessPendingSigShares(CConnman& connman); - void ProcessPendingSigSharesFromNode(NodeId nodeId, const std::vector& sigShares, const std::map, CQuorumCPtr>& quorums, CConnman& connman); + void ProcessPendingSigSharesFromNode(NodeId nodeId, const std::vector& sigShares, const std::unordered_map, CQuorumCPtr>& quorums, CConnman& connman); void ProcessSigShare(NodeId nodeId, const CSigShare& sigShare, CConnman& connman, const CQuorumCPtr& quorum); void TryRecoverSig(const CQuorumCPtr& quorum, const uint256& id, const uint256& msgHash, CConnman& connman); @@ -227,9 +396,9 @@ class CSigSharesManager void BanNode(NodeId nodeId); bool SendMessages(); - void CollectSigSharesToRequest(std::map>& sigSharesToRequest); - void CollectSigSharesToSend(std::map>& sigSharesToSend); - void CollectSigSharesToAnnounce(std::map>& sigSharesToAnnounce); + void CollectSigSharesToRequest(std::unordered_map>& sigSharesToRequest); + void CollectSigSharesToSend(std::unordered_map>& sigSharesToSend); + void CollectSigSharesToAnnounce(std::unordered_map>& sigSharesToAnnounce); bool SignPendingSigShares(); void WorkThreadMain(); };