From c91a35f82557afd39c9b83f643876630bb4275c5 Mon Sep 17 00:00:00 2001 From: Stan Bondi Date: Mon, 4 Sep 2023 16:07:18 +0400 Subject: [PATCH] fix(comms): only permit a single inbound messaging substream per peer (#5731) Description --- Only permit a single inbound messaging session per peer Only emit MessageReceived event in tests Ban peer if they exceed the frame size Motivation and Context --- A node may initiate boundless inbound messaging sessions. This PR limits the number of sessions to one. Tari nodes only initiate a single session as needed so this shouldn't have any effect on non-malicious nodes. MessageReceived events would be generated for each message received. This was only used in tests. Since these events aren't read in non-test contexts and because it can be very busy, this PR disables them by default and only enables in tests. This does not affect the message count in the minotari node status line. How Has This Been Tested? --- New unit test that checks sessions are closed as long as the first one is active. What process can a PR reviewer use to test or verify this change? --- Messaging still works as expected (ping-peer, discover-peer etc) Breaking Changes --- - [x] None - [ ] Requires data directory on base node to be deleted - [ ] Requires hard fork - [ ] Other - Please specify --- base_layer/core/tests/tests/node_service.rs | 4 +- base_layer/p2p/src/initialization.rs | 12 +- comms/core/src/builder/tests.rs | 13 +- .../core/src/protocol/messaging/extension.rs | 33 ++++- comms/core/src/protocol/messaging/inbound.rs | 40 ++++-- comms/core/src/protocol/messaging/mod.rs | 9 +- comms/core/src/protocol/messaging/outbound.rs | 6 +- comms/core/src/protocol/messaging/protocol.rs | 130 +++++++++++++++--- comms/core/src/protocol/messaging/test.rs | 103 +++++++++++++- comms/dht/examples/memory_net/utilities.rs | 12 +- comms/dht/src/peer_validator.rs | 6 +- comms/dht/tests/dht.rs | 15 +- comms/dht/tests/harness.rs | 6 +- 13 files changed, 316 insertions(+), 73 deletions(-) diff --git a/base_layer/core/tests/tests/node_service.rs b/base_layer/core/tests/tests/node_service.rs index 002553c4f4..17008373f2 100644 --- a/base_layer/core/tests/tests/node_service.rs +++ b/base_layer/core/tests/tests/node_service.rs @@ -308,13 +308,13 @@ async fn propagate_and_forward_invalid_block_hash() { let msg_event = event_stream_next(&mut bob_message_events, Duration::from_secs(10)) .await .unwrap(); - unpack_enum!(MessagingEvent::MessageReceived(_a, _b) = &*msg_event); + unpack_enum!(MessagingEvent::MessageReceived(_a, _b) = &msg_event); // Bob asks Alice for missing transaction let msg_event = event_stream_next(&mut bob_message_events, Duration::from_secs(10)) .await .unwrap(); - unpack_enum!(MessagingEvent::MessageReceived(node_id, _a) = &*msg_event); + unpack_enum!(MessagingEvent::MessageReceived(node_id, _a) = &msg_event); assert_eq!(node_id, alice_node.node_identity.node_id()); // Checking a negative: Bob should not have propagated this hash to Carol. If Bob does, this assertion will be diff --git a/base_layer/p2p/src/initialization.rs b/base_layer/p2p/src/initialization.rs index 5f22b7b422..c3ce05ccb9 100644 --- a/base_layer/p2p/src/initialization.rs +++ b/base_layer/p2p/src/initialization.rs @@ -198,7 +198,9 @@ pub async fn initialize_local_test_comms>( .build(); let comms = comms - .add_protocol_extension(MessagingProtocolExtension::new(event_sender.clone(), pipeline)) + .add_protocol_extension( + MessagingProtocolExtension::new(event_sender.clone(), pipeline).enable_message_received_event(), + ) .spawn_with_transport(MemoryTransport) .await?; @@ -371,10 +373,10 @@ async fn configure_comms_and_dht( .build(); let (messaging_events_sender, _) = broadcast::channel(1); - comms = comms.add_protocol_extension(MessagingProtocolExtension::new( - messaging_events_sender, - messaging_pipeline, - )); + comms = comms.add_protocol_extension( + MessagingProtocolExtension::new(messaging_events_sender, messaging_pipeline) + .with_ban_duration(config.dht.ban_duration_short), + ); Ok((comms, dht)) } diff --git a/comms/core/src/builder/tests.rs b/comms/core/src/builder/tests.rs index a2f3eb657f..527d3fd0ec 100644 --- a/comms/core/src/builder/tests.rs +++ b/comms/core/src/builder/tests.rs @@ -89,16 +89,19 @@ async fn spawn_node( let (messaging_events_sender, _) = broadcast::channel(100); let comms_node = comms_node .add_protocol_extensions(protocols.into()) - .add_protocol_extension(MessagingProtocolExtension::new( - messaging_events_sender.clone(), - pipeline::Builder::new() + .add_protocol_extension( + MessagingProtocolExtension::new( + messaging_events_sender.clone(), + pipeline::Builder::new() // Outbound messages will be forwarded "as is" to outbound messaging .with_outbound_pipeline(outbound_rx, identity) .max_concurrent_inbound_tasks(1) // Inbound messages will be forwarded "as is" to inbound_tx .with_inbound_pipeline(SinkService::new(inbound_tx)) .build(), - )) + ) + .enable_message_received_event(), + ) .spawn_with_transport(MemoryTransport) .await .unwrap(); @@ -251,7 +254,7 @@ async fn peer_to_peer_messaging() { let events = collect_recv!(messaging_events2, take = NUM_MSGS, timeout = Duration::from_secs(10)); events.into_iter().for_each(|m| { - unpack_enum!(MessagingEvent::MessageReceived(_n, _t) = &*m); + unpack_enum!(MessagingEvent::MessageReceived(_n, _t) = &m); }); // Send NUM_MSGS messages from node 2 to node 1 diff --git a/comms/core/src/protocol/messaging/extension.rs b/comms/core/src/protocol/messaging/extension.rs index 4f31841c7c..cb3f7ecd48 100644 --- a/comms/core/src/protocol/messaging/extension.rs +++ b/comms/core/src/protocol/messaging/extension.rs @@ -20,7 +20,7 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use std::fmt; +use std::{fmt, time::Duration}; use tokio::sync::mpsc; use tower::Service; @@ -31,7 +31,7 @@ use crate::{ message::InboundMessage, pipeline, protocol::{ - messaging::{protocol::MESSAGING_PROTOCOL, MessagingEventSender}, + messaging::{protocol::MESSAGING_PROTOCOL_ID, MessagingEventSender}, ProtocolExtension, ProtocolExtensionContext, ProtocolExtensionError, @@ -50,11 +50,32 @@ pub const MESSAGING_PROTOCOL_EVENTS_BUFFER_SIZE: usize = 30; pub struct MessagingProtocolExtension { event_tx: MessagingEventSender, pipeline: pipeline::Config, + enable_message_received_event: bool, + ban_duration: Duration, } impl MessagingProtocolExtension { pub fn new(event_tx: MessagingEventSender, pipeline: pipeline::Config) -> Self { - Self { event_tx, pipeline } + Self { + event_tx, + pipeline, + enable_message_received_event: false, + ban_duration: Duration::from_secs(10 * 60), + } + } + + /// Enables the MessageReceived event which is disabled by default. This will enable sending the MessageReceived + /// event per message received. This is typically used in tests. If unused it should be disabled to reduce memory + /// usage (not reading the event from the channel). + pub fn enable_message_received_event(mut self) -> Self { + self.enable_message_received_event = true; + self + } + + /// Sets the ban duration for peers that violate protocol. Default is 10 minutes. + pub fn with_ban_duration(mut self, ban_duration: Duration) -> Self { + self.ban_duration = ban_duration; + self } } @@ -70,7 +91,7 @@ where { fn install(mut self: Box, context: &mut ProtocolExtensionContext) -> Result<(), ProtocolExtensionError> { let (proto_tx, proto_rx) = mpsc::channel(MESSAGING_PROTOCOL_EVENTS_BUFFER_SIZE); - context.add_protocol(&[MESSAGING_PROTOCOL.clone()], &proto_tx); + context.add_protocol(&[MESSAGING_PROTOCOL_ID.clone()], &proto_tx); let (inbound_message_tx, inbound_message_rx) = mpsc::channel(INBOUND_MESSAGE_BUFFER_SIZE); @@ -82,7 +103,9 @@ where self.event_tx, inbound_message_tx, context.shutdown_signal(), - ); + ) + .set_message_received_event_enabled(self.enable_message_received_event) + .with_ban_duration(self.ban_duration); context.register_complete_signal(messaging.complete_signal()); diff --git a/comms/core/src/protocol/messaging/inbound.rs b/comms/core/src/protocol/messaging/inbound.rs index d62c415d76..99ed90d218 100644 --- a/comms/core/src/protocol/messaging/inbound.rs +++ b/comms/core/src/protocol/messaging/inbound.rs @@ -20,7 +20,7 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use std::sync::Arc; +use std::io; use futures::StreamExt; use log::*; @@ -38,19 +38,22 @@ const LOG_TARGET: &str = "comms::protocol::messaging::inbound"; pub struct InboundMessaging { peer: NodeId, inbound_message_tx: mpsc::Sender, - messaging_events_tx: broadcast::Sender>, + messaging_events_tx: broadcast::Sender, + enable_message_received_event: bool, } impl InboundMessaging { pub fn new( peer: NodeId, inbound_message_tx: mpsc::Sender, - messaging_events_tx: broadcast::Sender>, + messaging_events_tx: broadcast::Sender, + enable_message_received_event: bool, ) -> Self { Self { peer, inbound_message_tx, messaging_events_tx, + enable_message_received_event, } } @@ -83,21 +86,39 @@ impl InboundMessaging { msg_len ); - let event = MessagingEvent::MessageReceived(inbound_msg.source_peer.clone(), inbound_msg.tag); + let message_tag = inbound_msg.tag; - if let Err(err) = self.inbound_message_tx.send(inbound_msg).await { - let tag = err.0.tag; + if self.inbound_message_tx.send(inbound_msg).await.is_err() { warn!( target: LOG_TARGET, "Failed to send InboundMessage {} for peer '{}' because inbound message channel closed", - tag, + message_tag, peer.short_str(), ); break; } - let _result = self.messaging_events_tx.send(Arc::new(event)); + if self.enable_message_received_event { + let _result = self + .messaging_events_tx + .send(MessagingEvent::MessageReceived(peer.clone(), message_tag)); + } + }, + // LengthDelimitedCodec emits a InvalidData io error when the message length exceeds the maximum allowed + Err(err) if err.kind() == io::ErrorKind::InvalidData => { + metrics::error_count(peer).inc(); + debug!( + target: LOG_TARGET, + "Failed to receive from peer '{}' because '{}'", + peer.short_str(), + err + ); + let _result = self.messaging_events_tx.send(MessagingEvent::ProtocolViolation { + peer_node_id: peer.clone(), + details: err.to_string(), + }); + break; }, Err(err) => { metrics::error_count(peer).inc(); @@ -112,6 +133,9 @@ impl InboundMessaging { } } + let _ignore = self + .messaging_events_tx + .send(MessagingEvent::InboundProtocolExited(peer.clone())); metrics::num_sessions().dec(); debug!( target: LOG_TARGET, diff --git a/comms/core/src/protocol/messaging/mod.rs b/comms/core/src/protocol/messaging/mod.rs index 9b45008474..3410c2b60d 100644 --- a/comms/core/src/protocol/messaging/mod.rs +++ b/comms/core/src/protocol/messaging/mod.rs @@ -37,7 +37,14 @@ mod inbound; mod metrics; mod outbound; mod protocol; -pub use protocol::{MessagingEvent, MessagingEventReceiver, MessagingEventSender, MessagingProtocol, SendFailReason}; +pub use protocol::{ + MessagingEvent, + MessagingEventReceiver, + MessagingEventSender, + MessagingProtocol, + SendFailReason, + MESSAGING_PROTOCOL_ID, +}; #[cfg(test)] mod test; diff --git a/comms/core/src/protocol/messaging/outbound.rs b/comms/core/src/protocol/messaging/outbound.rs index 9f8ff2831e..7160ad5c21 100644 --- a/comms/core/src/protocol/messaging/outbound.rs +++ b/comms/core/src/protocol/messaging/outbound.rs @@ -33,7 +33,7 @@ use crate::{ message::OutboundMessage, multiplexing::Substream, peer_manager::NodeId, - protocol::messaging::protocol::MESSAGING_PROTOCOL, + protocol::messaging::protocol::MESSAGING_PROTOCOL_ID, stream_id::StreamId, }; @@ -123,7 +123,7 @@ impl OutboundMessaging { } metrics::num_sessions().dec(); - let _ = messaging_events_tx + let _ignore = messaging_events_tx .send(MessagingEvent::OutboundProtocolExited(peer_node_id)) .await; } @@ -223,7 +223,7 @@ impl OutboundMessaging { &mut self, conn: &mut PeerConnection, ) -> Result, MessagingProtocolError> { - match conn.open_substream(&MESSAGING_PROTOCOL).await { + match conn.open_substream(&MESSAGING_PROTOCOL_ID).await { Ok(substream) => Ok(substream), Err(err) => { debug!( diff --git a/comms/core/src/protocol/messaging/protocol.rs b/comms/core/src/protocol/messaging/protocol.rs index a11d6e57e1..cc6078242b 100644 --- a/comms/core/src/protocol/messaging/protocol.rs +++ b/comms/core/src/protocol/messaging/protocol.rs @@ -23,7 +23,8 @@ use std::{ collections::{hash_map::Entry, HashMap}, fmt, - sync::Arc, + fmt::Display, + time::Duration, }; use bytes::Bytes; @@ -33,6 +34,7 @@ use thiserror::Error; use tokio::{ io::{AsyncRead, AsyncWrite}, sync::{broadcast, mpsc}, + task::JoinHandle, }; use tokio_util::codec::{Framed, LengthDelimitedCodec}; @@ -51,13 +53,13 @@ use crate::{ }; const LOG_TARGET: &str = "comms::protocol::messaging"; -pub(super) static MESSAGING_PROTOCOL: Bytes = Bytes::from_static(b"t/msg/0.1"); +pub static MESSAGING_PROTOCOL_ID: Bytes = Bytes::from_static(b"t/msg/0.1"); const INTERNAL_MESSAGING_EVENT_CHANNEL_SIZE: usize = 10; const MAX_FRAME_LENGTH: usize = 8 * 1_024 * 1_024; -pub type MessagingEventSender = broadcast::Sender>; -pub type MessagingEventReceiver = broadcast::Receiver>; +pub type MessagingEventSender = broadcast::Sender; +pub type MessagingEventReceiver = broadcast::Receiver; /// The reason for dial failure. This enum should contain simple variants which describe the kind of failure that /// occurred @@ -76,20 +78,24 @@ pub enum SendFailReason { } /// Events emitted by the messaging protocol. -#[derive(Debug)] +#[derive(Debug, Clone)] pub enum MessagingEvent { MessageReceived(NodeId, MessageTag), - InvalidMessageReceived(NodeId), OutboundProtocolExited(NodeId), + InboundProtocolExited(NodeId), + ProtocolViolation { peer_node_id: NodeId, details: String }, } impl fmt::Display for MessagingEvent { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - use MessagingEvent::{InvalidMessageReceived, MessageReceived, OutboundProtocolExited}; + use MessagingEvent::*; match self { - MessageReceived(node_id, tag) => write!(f, "MessageReceived({}, {})", node_id.short_str(), tag), - InvalidMessageReceived(node_id) => write!(f, "InvalidMessageReceived({})", node_id.short_str()), + MessageReceived(node_id, tag) => write!(f, "MessageReceived({}, {})", node_id, tag), OutboundProtocolExited(node_id) => write!(f, "OutboundProtocolExited({})", node_id), + InboundProtocolExited(node_id) => write!(f, "InboundProtocolExited({})", node_id), + ProtocolViolation { peer_node_id, details } => { + write!(f, "ProtocolViolation({}, {})", peer_node_id, details) + }, } } } @@ -99,8 +105,11 @@ pub struct MessagingProtocol { connectivity: ConnectivityRequester, proto_notification: mpsc::Receiver>, active_queues: HashMap>, + active_inbound: HashMap>, outbound_message_rx: mpsc::UnboundedReceiver, messaging_events_tx: MessagingEventSender, + enable_message_received_event: bool, + ban_duration: Option, inbound_message_tx: mpsc::Sender, internal_messaging_event_tx: mpsc::Sender, internal_messaging_event_rx: mpsc::Receiver, @@ -128,10 +137,13 @@ impl MessagingProtocol { connectivity, proto_notification, outbound_message_rx, + active_inbound: Default::default(), active_queues: Default::default(), messaging_events_tx, + enable_message_received_event: false, internal_messaging_event_rx, internal_messaging_event_tx, + ban_duration: None, retry_queue_tx, retry_queue_rx, inbound_message_tx, @@ -140,6 +152,19 @@ impl MessagingProtocol { } } + /// Set to true to enable emitting the MessageReceived event for each message received. Typically only useful in + /// tests. + pub fn set_message_received_event_enabled(mut self, enabled: bool) -> Self { + self.enable_message_received_event = enabled; + self + } + + /// Sets a custom ban duration. Banning is disabled by default. + pub fn with_ban_duration(mut self, ban_duration: Duration) -> Self { + self.ban_duration = Some(ban_duration); + self + } + /// Returns a signal that resolves when this actor exits. pub fn complete_signal(&self) -> ShutdownSignal { self.complete_trigger.to_signal() @@ -152,7 +177,7 @@ impl MessagingProtocol { loop { tokio::select! { Some(event) = self.internal_messaging_event_rx.recv() => { - self.handle_internal_messaging_event(event); + self.handle_internal_messaging_event(event).await; }, Some(msg) = self.retry_queue_rx.recv() => { @@ -193,17 +218,17 @@ impl MessagingProtocol { framing::canonical(socket, MAX_FRAME_LENGTH) } - fn handle_internal_messaging_event(&mut self, event: MessagingEvent) { - use MessagingEvent::OutboundProtocolExited; + async fn handle_internal_messaging_event(&mut self, event: MessagingEvent) { + use MessagingEvent::*; trace!(target: LOG_TARGET, "Internal messaging event '{}'", event); - match event { + match &event { OutboundProtocolExited(node_id) => { debug!( target: LOG_TARGET, "Outbound protocol handler exited for peer `{}`", node_id.short_str() ); - if self.active_queues.remove(&node_id).is_none() { + if self.active_queues.remove(node_id).is_none() { debug!( target: LOG_TARGET, "OutboundProtocolExited event, but MessagingProtocol has no record of the outbound protocol \ @@ -211,13 +236,30 @@ impl MessagingProtocol { node_id.short_str() ); } - let _result = self.messaging_events_tx.send(Arc::new(OutboundProtocolExited(node_id))); }, - evt => { - // Forward the event - let _result = self.messaging_events_tx.send(Arc::new(evt)); + InboundProtocolExited(node_id) => { + debug!( + target: LOG_TARGET, + "Inbound protocol handler exited for peer `{}`", + node_id.short_str() + ); + if self.active_inbound.remove(node_id).is_none() { + debug!( + target: LOG_TARGET, + "InboundProtocolExited event, but MessagingProtocol has no record of the inbound protocol \ + for peer `{}`", + node_id.short_str() + ); + } + }, + ProtocolViolation { peer_node_id, details } => { + self.ban_peer(peer_node_id.clone(), details.to_string()).await; }, + _ => {}, } + + // Forward the event + let _result = self.messaging_events_tx.send(event); } fn handle_retry_queue_messages(&mut self, msg: OutboundMessage) -> Result<(), MessagingProtocolError> { @@ -281,10 +323,27 @@ impl MessagingProtocol { } fn spawn_inbound_handler(&mut self, peer: NodeId, substream: Substream) { + if let Some(handle) = self.active_inbound.get(&peer) { + if handle.is_finished() { + self.active_inbound.remove(&peer); + } else { + debug!( + target: LOG_TARGET, + "InboundMessaging for peer '{}' already exists", peer.short_str() + ); + return; + } + } let messaging_events_tx = self.messaging_events_tx.clone(); let inbound_message_tx = self.inbound_message_tx.clone(); - let inbound_messaging = InboundMessaging::new(peer, inbound_message_tx, messaging_events_tx); - tokio::spawn(inbound_messaging.run(substream)); + let inbound_messaging = InboundMessaging::new( + peer.clone(), + inbound_message_tx, + messaging_events_tx, + self.enable_message_received_event, + ); + let handle = tokio::spawn(inbound_messaging.run(substream)); + self.active_inbound.insert(peer, handle); } fn handle_protocol_notification(&mut self, notification: ProtocolNotification) { @@ -301,4 +360,35 @@ impl MessagingProtocol { }, } } + + async fn ban_peer(&mut self, peer_node_id: NodeId, reason: T) { + warn!( + target: LOG_TARGET, + "Banning peer '{}' because it violated the messaging protocol: {}", peer_node_id.short_str(), reason + ); + if let Some(handle) = self.active_inbound.remove(&peer_node_id) { + handle.abort(); + } + drop(self.active_queues.remove(&peer_node_id)); + match self.ban_duration { + Some(ban_duration) => { + if let Err(err) = self + .connectivity + .ban_peer_until(peer_node_id.clone(), ban_duration, reason.to_string()) + .await + { + error!( + target: LOG_TARGET, + "Failed to ban peer '{}' because '{:?}'", peer_node_id.short_str(), err + ); + } + }, + None => { + warn!( + target: LOG_TARGET, + "Banning disabled in MessagingProtocol, so peer '{peer_node_id}' will not be banned (reason: {reason})", + ); + }, + } + } } diff --git a/comms/core/src/protocol/messaging/test.rs b/comms/core/src/protocol/messaging/test.rs index 50deff8ae7..be35886aed 100644 --- a/comms/core/src/protocol/messaging/test.rs +++ b/comms/core/src/protocol/messaging/test.rs @@ -33,7 +33,7 @@ use tokio::{ time, }; -use super::protocol::{MessagingEvent, MessagingEventReceiver, MessagingProtocol, MESSAGING_PROTOCOL}; +use super::protocol::{MessagingEvent, MessagingEventReceiver, MessagingProtocol, MESSAGING_PROTOCOL_ID}; use crate::{ message::{InboundMessage, MessageTag, MessagingReplyRx, OutboundMessage}, multiplexing::Substream, @@ -81,7 +81,8 @@ async fn spawn_messaging_protocol() -> ( events_tx, inbound_msg_tx, shutdown.to_signal(), - ); + ) + .set_message_received_event_enabled(true); tokio::spawn(msg_proto.run()); ( @@ -123,7 +124,7 @@ async fn new_inbound_substream_handling() { let stream_ours = muxer_ours.get_yamux_control().open_stream().await.unwrap(); proto_tx .send(ProtocolNotification::new( - MESSAGING_PROTOCOL.clone(), + MESSAGING_PROTOCOL_ID.clone(), ProtocolEvent::NewInboundSubstream(expected_node_id.clone(), stream_ours), )) .await @@ -146,7 +147,7 @@ async fn new_inbound_substream_handling() { .await .unwrap() .unwrap(); - unpack_enum!(MessagingEvent::MessageReceived(node_id, tag) = &*event); + unpack_enum!(MessagingEvent::MessageReceived(node_id, tag) = &event); assert_eq!(tag, &expected_tag); assert_eq!(*node_id, expected_node_id); } @@ -188,7 +189,7 @@ async fn send_message_dial_failed() { request_tx.send(out_msg).unwrap(); let event = event_tx.recv().await.unwrap(); - unpack_enum!(MessagingEvent::OutboundProtocolExited(_node_id) = &*event); + unpack_enum!(MessagingEvent::OutboundProtocolExited(_node_id) = &event); let reply = reply_rx.await.unwrap().unwrap_err(); unpack_enum!(SendFailReason::PeerDialFailed = reply); @@ -260,7 +261,7 @@ async fn send_message_substream_bulk_failure() { .await .unwrap() .unwrap(); - unpack_enum!(MessagingEvent::OutboundProtocolExited(node_id) = &*event); + unpack_enum!(MessagingEvent::OutboundProtocolExited(node_id) = &event); assert_eq!(node_id, peer_node_id); } @@ -339,3 +340,93 @@ async fn many_concurrent_send_message_requests_that_fail() { let results = unordered.collect::>().await; assert!(results.into_iter().map(|r| r.unwrap()).all(|r| r.is_err())); } + +#[tokio::test] +async fn new_inbound_substream_only_single_session_permitted() { + let (peer_manager, _, _, proto_tx, _, mut inbound_msg_rx, _, _shutdown) = spawn_messaging_protocol().await; + + let expected_node_id = node_id::random(); + let (_, pk) = CommsPublicKey::random_keypair(&mut OsRng); + peer_manager + .add_peer(Peer::new( + pk.clone(), + expected_node_id.clone(), + MultiaddressesWithStats::default(), + PeerFlags::empty(), + PeerFeatures::COMMUNICATION_CLIENT, + Default::default(), + Default::default(), + )) + .await + .unwrap(); + + // Create connected memory sockets - we use each end of the connection as if they exist on different nodes + let (_, muxer_ours, mut muxer_theirs) = transport::build_multiplexed_connections().await; + + // Notify the messaging protocol that a new substream has been established that wants to talk the messaging. + let stream_ours = muxer_ours.get_yamux_control().open_stream().await.unwrap(); + proto_tx + .send(ProtocolNotification::new( + MESSAGING_PROTOCOL_ID.clone(), + ProtocolEvent::NewInboundSubstream(expected_node_id.clone(), stream_ours), + )) + .await + .unwrap(); + + // First stream is open + let stream_theirs = muxer_theirs.incoming_mut().next().await.unwrap(); + + // Open another one for messaging + let stream_ours2 = muxer_ours.get_yamux_control().open_stream().await.unwrap(); + proto_tx + .send(ProtocolNotification::new( + MESSAGING_PROTOCOL_ID.clone(), + ProtocolEvent::NewInboundSubstream(expected_node_id.clone(), stream_ours2), + )) + .await + .unwrap(); + + // Check the second stream closes immediately + let stream_theirs2 = muxer_theirs.incoming_mut().next().await.unwrap(); + let mut framed_ours2 = MessagingProtocol::framed(stream_theirs2); + let next = framed_ours2.next().await; + // The stream is closed + assert!(next.is_none()); + + // The first stream is still active + let mut framed_theirs = MessagingProtocol::framed(stream_theirs); + + framed_theirs.send(TEST_MSG1.clone()).await.unwrap(); + + let in_msg = time::timeout(Duration::from_secs(5), inbound_msg_rx.recv()) + .await + .unwrap() + .unwrap(); + assert_eq!(in_msg.source_peer, expected_node_id); + assert_eq!(in_msg.body, TEST_MSG1); + + // Close the first + framed_theirs.close().await.unwrap(); + + // Open another one for messaging + let stream_ours2 = muxer_ours.get_yamux_control().open_stream().await.unwrap(); + proto_tx + .send(ProtocolNotification::new( + MESSAGING_PROTOCOL_ID.clone(), + ProtocolEvent::NewInboundSubstream(expected_node_id.clone(), stream_ours2), + )) + .await + .unwrap(); + + let stream_theirs = muxer_theirs.incoming_mut().next().await.unwrap(); + let mut framed_theirs = MessagingProtocol::framed(stream_theirs); + framed_theirs.send(TEST_MSG1.clone()).await.unwrap(); + + // The second message comes through + let in_msg = time::timeout(Duration::from_secs(5), inbound_msg_rx.recv()) + .await + .unwrap() + .unwrap(); + assert_eq!(in_msg.source_peer, expected_node_id); + assert_eq!(in_msg.body, TEST_MSG1); +} diff --git a/comms/dht/examples/memory_net/utilities.rs b/comms/dht/examples/memory_net/utilities.rs index dc5dc6fbe7..2eb5550acc 100644 --- a/comms/dht/examples/memory_net/utilities.rs +++ b/comms/dht/examples/memory_net/utilities.rs @@ -465,7 +465,7 @@ pub async fn do_store_and_forward_message_propagation( let msg = time::timeout(Duration::from_secs(2), s.recv()).await; match msg { Ok(Ok(evt)) => { - if let MessagingEvent::MessageReceived(_, tag) = &*evt { + if let MessagingEvent::MessageReceived(_, tag) = &evt { println!("{} received propagated SAF message ({})", neighbour, tag); } }, @@ -745,11 +745,9 @@ impl TestNode { loop { let event = messaging_events.recv().await; use MessagingEvent::MessageReceived; - match event.as_deref() { + match event { Ok(MessageReceived(peer_node_id, _)) => { - messaging_events_tx - .send((Clone::clone(peer_node_id), node_id.clone())) - .unwrap(); + messaging_events_tx.send((peer_node_id, node_id.clone())).unwrap(); }, Err(broadcast::error::RecvError::Closed) => { break; @@ -970,7 +968,9 @@ async fn setup_comms_dht( let (messaging_events_tx, _) = broadcast::channel(100); let comms = comms .add_rpc_server(RpcServer::new().add_service(dht.rpc_service())) - .add_protocol_extension(MessagingProtocolExtension::new(messaging_events_tx.clone(), pipeline)) + .add_protocol_extension( + MessagingProtocolExtension::new(messaging_events_tx.clone(), pipeline).enable_message_received_event(), + ) .spawn_with_transport(MemoryTransport) .await .unwrap(); diff --git a/comms/dht/src/peer_validator.rs b/comms/dht/src/peer_validator.rs index 9c4b37a958..a0252ca86e 100644 --- a/comms/dht/src/peer_validator.rs +++ b/comms/dht/src/peer_validator.rs @@ -26,6 +26,7 @@ use tari_comms::{ peer_manager::{NodeId, Peer, PeerFlags}, peer_validator, peer_validator::{find_most_recent_claim, PeerValidatorError}, + protocol::messaging::MESSAGING_PROTOCOL_ID, }; use crate::{rpc::UnvalidatedPeerInfo, DhtConfig}; @@ -89,13 +90,16 @@ impl<'a> PeerValidator<'a> { let node_id = NodeId::from_public_key(&new_peer.public_key); let mut peer = existing_peer.unwrap_or_else(|| { + // All peer speak messaging protocol, as an optimisation we'll include it here so that we can do optimistic + // protocol negotiation. + let default_protocols = vec![MESSAGING_PROTOCOL_ID.clone()]; Peer::new( new_peer.public_key.clone(), node_id, MultiaddressesWithStats::default(), PeerFlags::default(), most_recent_claim.features, - vec![], + default_protocols, String::new(), ) }); diff --git a/comms/dht/tests/dht.rs b/comms/dht/tests/dht.rs index bbcbd6edb4..2eb654af18 100644 --- a/comms/dht/tests/dht.rs +++ b/comms/dht/tests/dht.rs @@ -21,7 +21,7 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. mod harness; -use std::{sync::Arc, time::Duration}; +use std::time::Duration; use harness::*; use tari_comms::{ @@ -244,7 +244,7 @@ async fn test_dht_store_forward() { .unwrap(); // Wait for node C to and receive a response from the SAF request let event = collect_try_recv!(node_C_msg_events, take = 1, timeout = Duration::from_secs(20)); - unpack_enum!(MessagingEvent::MessageReceived(_node_id, _msg) = event.get(0).unwrap().as_ref()); + unpack_enum!(MessagingEvent::MessageReceived(_node_id, _msg) = event.get(0).unwrap()); let msg = node_C.next_inbound_message(Duration::from_secs(5)).await.unwrap(); assert_eq!( @@ -908,21 +908,18 @@ async fn test_dht_header_not_malleable() { node_C.shutdown().await; } -fn filter_received(events: Vec>) -> Vec> { +fn filter_received(events: Vec) -> Vec { events .into_iter() - .filter(|e| match &**e { - MessagingEvent::MessageReceived(_, _) => true, - _ => unreachable!(), - }) + .filter(|e| matches!(e, MessagingEvent::MessageReceived(_, _))) .collect() } -fn count_messages_received(events: &[Arc], node_ids: &[&NodeId]) -> usize { +fn count_messages_received(events: &[MessagingEvent], node_ids: &[&NodeId]) -> usize { events .iter() .filter(|event| { - unpack_enum!(MessagingEvent::MessageReceived(recv_node_id, _tag) = &***event); + unpack_enum!(MessagingEvent::MessageReceived(recv_node_id, _tag) = &**event); node_ids.iter().any(|n| recv_node_id == *n) }) .count() diff --git a/comms/dht/tests/harness.rs b/comms/dht/tests/harness.rs index bc9abbc8d5..b9baea0f60 100644 --- a/comms/dht/tests/harness.rs +++ b/comms/dht/tests/harness.rs @@ -52,7 +52,7 @@ pub struct TestNode { pub comms: CommsNode, pub dht: Dht, pub inbound_messages: mpsc::Receiver, - pub messaging_events: broadcast::Sender>, + pub messaging_events: broadcast::Sender, pub shutdown: Shutdown, } @@ -197,7 +197,9 @@ pub async fn setup_comms_dht( let (event_tx, _) = broadcast::channel(100); let comms = comms - .add_protocol_extension(MessagingProtocolExtension::new(event_tx.clone(), pipeline)) + .add_protocol_extension( + MessagingProtocolExtension::new(event_tx.clone(), pipeline).enable_message_received_event(), + ) .spawn_with_transport(MemoryTransport) .await .unwrap();