Skip to content

Commit 27f9ba2

Browse files
committed
net: add V1Transport lock protecting receive state
Rather than relying on the caller to prevent concurrent calls to the various receive-side functions of Transport, introduce a private m_cs_recv inside the implementation to protect the lock state. Of course, this does not remove the need for callers to synchronize calls entirely, as it is a stateful object, and e.g. the order in which Receive(), Complete(), and GetMessage() are called matters. It seems impossible to use a Transport object in a meaningful way in a multi-threaded way without some form of external synchronization, but it still feels safer to make the transport object itself responsible for protecting its internal state.
1 parent 93594e4 commit 27f9ba2

File tree

2 files changed

+43
-24
lines changed

2 files changed

+43
-24
lines changed

src/net.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -719,6 +719,7 @@ bool CNode::ReceiveMsgBytes(Span<const uint8_t> msg_bytes, bool& complete)
719719

720720
int V1Transport::readHeader(Span<const uint8_t> msg_bytes)
721721
{
722+
AssertLockHeld(m_recv_mutex);
722723
// copy data to temporary parsing buffer
723724
unsigned int nRemaining = CMessageHeader::HEADER_SIZE - nHdrPos;
724725
unsigned int nCopy = std::min<unsigned int>(nRemaining, msg_bytes.size());
@@ -759,6 +760,7 @@ int V1Transport::readHeader(Span<const uint8_t> msg_bytes)
759760

760761
int V1Transport::readData(Span<const uint8_t> msg_bytes)
761762
{
763+
AssertLockHeld(m_recv_mutex);
762764
unsigned int nRemaining = hdr.nMessageSize - nDataPos;
763765
unsigned int nCopy = std::min<unsigned int>(nRemaining, msg_bytes.size());
764766

@@ -776,17 +778,20 @@ int V1Transport::readData(Span<const uint8_t> msg_bytes)
776778

777779
const uint256& V1Transport::GetMessageHash() const
778780
{
779-
assert(Complete());
781+
AssertLockHeld(m_recv_mutex);
782+
assert(CompleteInternal());
780783
if (data_hash.IsNull())
781784
hasher.Finalize(data_hash);
782785
return data_hash;
783786
}
784787

785788
CNetMessage V1Transport::GetMessage(const std::chrono::microseconds time, bool& reject_message)
786789
{
790+
AssertLockNotHeld(m_recv_mutex);
787791
// Initialize out parameter
788792
reject_message = false;
789793
// decompose a single CNetMessage from the TransportDeserializer
794+
LOCK(m_recv_mutex);
790795
CNetMessage msg(std::move(vRecv));
791796

792797
// store message type string, time, and sizes

src/net.h

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -259,8 +259,7 @@ class Transport {
259259
virtual ~Transport() {}
260260

261261
// 1. Receiver side functions, for decoding bytes received on the wire into transport protocol
262-
// agnostic CNetMessage (message type & payload) objects. Callers must guarantee that none of
263-
// these functions are called concurrently w.r.t. one another.
262+
// agnostic CNetMessage (message type & payload) objects.
264263

265264
// returns true if the current deserialization is complete
266265
virtual bool Complete() const = 0;
@@ -282,20 +281,22 @@ class V1Transport final : public Transport
282281
private:
283282
const CChainParams& m_chain_params;
284283
const NodeId m_node_id; // Only for logging
285-
mutable CHash256 hasher;
286-
mutable uint256 data_hash;
287-
bool in_data; // parsing header (false) or data (true)
288-
CDataStream hdrbuf; // partially received header
289-
CMessageHeader hdr; // complete header
290-
CDataStream vRecv; // received message data
291-
unsigned int nHdrPos;
292-
unsigned int nDataPos;
293-
294-
const uint256& GetMessageHash() const;
295-
int readHeader(Span<const uint8_t> msg_bytes);
296-
int readData(Span<const uint8_t> msg_bytes);
297-
298-
void Reset() {
284+
mutable Mutex m_recv_mutex; //!< Lock for receive state
285+
mutable CHash256 hasher GUARDED_BY(m_recv_mutex);
286+
mutable uint256 data_hash GUARDED_BY(m_recv_mutex);
287+
bool in_data GUARDED_BY(m_recv_mutex); // parsing header (false) or data (true)
288+
CDataStream hdrbuf GUARDED_BY(m_recv_mutex); // partially received header
289+
CMessageHeader hdr GUARDED_BY(m_recv_mutex); // complete header
290+
CDataStream vRecv GUARDED_BY(m_recv_mutex); // received message data
291+
unsigned int nHdrPos GUARDED_BY(m_recv_mutex);
292+
unsigned int nDataPos GUARDED_BY(m_recv_mutex);
293+
294+
const uint256& GetMessageHash() const EXCLUSIVE_LOCKS_REQUIRED(m_recv_mutex);
295+
int readHeader(Span<const uint8_t> msg_bytes) EXCLUSIVE_LOCKS_REQUIRED(m_recv_mutex);
296+
int readData(Span<const uint8_t> msg_bytes) EXCLUSIVE_LOCKS_REQUIRED(m_recv_mutex);
297+
298+
void Reset() EXCLUSIVE_LOCKS_REQUIRED(m_recv_mutex) {
299+
AssertLockHeld(m_recv_mutex);
299300
vRecv.clear();
300301
hdrbuf.clear();
301302
hdrbuf.resize(24);
@@ -306,29 +307,42 @@ class V1Transport final : public Transport
306307
hasher.Reset();
307308
}
308309

310+
bool CompleteInternal() const noexcept EXCLUSIVE_LOCKS_REQUIRED(m_recv_mutex)
311+
{
312+
AssertLockHeld(m_recv_mutex);
313+
if (!in_data) return false;
314+
return hdr.nMessageSize == nDataPos;
315+
}
316+
309317
public:
310318
V1Transport(const CChainParams& chain_params, const NodeId node_id, int nTypeIn, int nVersionIn)
311319
: m_chain_params(chain_params),
312320
m_node_id(node_id),
313321
hdrbuf(nTypeIn, nVersionIn),
314322
vRecv(nTypeIn, nVersionIn)
315323
{
324+
LOCK(m_recv_mutex);
316325
Reset();
317326
}
318327

319-
bool Complete() const override
328+
bool Complete() const override EXCLUSIVE_LOCKS_REQUIRED(!m_recv_mutex)
320329
{
321-
if (!in_data)
322-
return false;
323-
return (hdr.nMessageSize == nDataPos);
330+
AssertLockNotHeld(m_recv_mutex);
331+
return WITH_LOCK(m_recv_mutex, return CompleteInternal());
324332
}
325-
void SetVersion(int nVersionIn) override
333+
334+
void SetVersion(int nVersionIn) override EXCLUSIVE_LOCKS_REQUIRED(!m_recv_mutex)
326335
{
336+
AssertLockNotHeld(m_recv_mutex);
337+
LOCK(m_recv_mutex);
327338
hdrbuf.SetVersion(nVersionIn);
328339
vRecv.SetVersion(nVersionIn);
329340
}
330-
int Read(Span<const uint8_t>& msg_bytes) override
341+
342+
int Read(Span<const uint8_t>& msg_bytes) override EXCLUSIVE_LOCKS_REQUIRED(!m_recv_mutex)
331343
{
344+
AssertLockNotHeld(m_recv_mutex);
345+
LOCK(m_recv_mutex);
332346
int ret = in_data ? readData(msg_bytes) : readHeader(msg_bytes);
333347
if (ret < 0) {
334348
Reset();
@@ -337,7 +351,7 @@ class V1Transport final : public Transport
337351
}
338352
return ret;
339353
}
340-
CNetMessage GetMessage(std::chrono::microseconds time, bool& reject_message) override;
354+
CNetMessage GetMessage(std::chrono::microseconds time, bool& reject_message) override EXCLUSIVE_LOCKS_REQUIRED(!m_recv_mutex);
341355

342356
void prepareForTransport(CSerializedNetMsg& msg, std::vector<unsigned char>& header) const override;
343357
};

0 commit comments

Comments
 (0)