diff --git a/base_layer/core/tests/tests/node_service.rs b/base_layer/core/tests/tests/node_service.rs index 002553c4f4..17008373f2 100644 --- a/base_layer/core/tests/tests/node_service.rs +++ b/base_layer/core/tests/tests/node_service.rs @@ -308,13 +308,13 @@ async fn propagate_and_forward_invalid_block_hash() { let msg_event = event_stream_next(&mut bob_message_events, Duration::from_secs(10)) .await .unwrap(); - unpack_enum!(MessagingEvent::MessageReceived(_a, _b) = &*msg_event); + unpack_enum!(MessagingEvent::MessageReceived(_a, _b) = &msg_event); // Bob asks Alice for missing transaction let msg_event = event_stream_next(&mut bob_message_events, Duration::from_secs(10)) .await .unwrap(); - unpack_enum!(MessagingEvent::MessageReceived(node_id, _a) = &*msg_event); + unpack_enum!(MessagingEvent::MessageReceived(node_id, _a) = &msg_event); assert_eq!(node_id, alice_node.node_identity.node_id()); // Checking a negative: Bob should not have propagated this hash to Carol. If Bob does, this assertion will be diff --git a/base_layer/p2p/src/initialization.rs b/base_layer/p2p/src/initialization.rs index 5f22b7b422..c3ce05ccb9 100644 --- a/base_layer/p2p/src/initialization.rs +++ b/base_layer/p2p/src/initialization.rs @@ -198,7 +198,9 @@ pub async fn initialize_local_test_comms>( .build(); let comms = comms - .add_protocol_extension(MessagingProtocolExtension::new(event_sender.clone(), pipeline)) + .add_protocol_extension( + MessagingProtocolExtension::new(event_sender.clone(), pipeline).enable_message_received_event(), + ) .spawn_with_transport(MemoryTransport) .await?; @@ -371,10 +373,10 @@ async fn configure_comms_and_dht( .build(); let (messaging_events_sender, _) = broadcast::channel(1); - comms = comms.add_protocol_extension(MessagingProtocolExtension::new( - messaging_events_sender, - messaging_pipeline, - )); + comms = comms.add_protocol_extension( + MessagingProtocolExtension::new(messaging_events_sender, messaging_pipeline) + .with_ban_duration(config.dht.ban_duration_short), + ); Ok((comms, dht)) } diff --git a/comms/core/src/builder/tests.rs b/comms/core/src/builder/tests.rs index a2f3eb657f..527d3fd0ec 100644 --- a/comms/core/src/builder/tests.rs +++ b/comms/core/src/builder/tests.rs @@ -89,16 +89,19 @@ async fn spawn_node( let (messaging_events_sender, _) = broadcast::channel(100); let comms_node = comms_node .add_protocol_extensions(protocols.into()) - .add_protocol_extension(MessagingProtocolExtension::new( - messaging_events_sender.clone(), - pipeline::Builder::new() + .add_protocol_extension( + MessagingProtocolExtension::new( + messaging_events_sender.clone(), + pipeline::Builder::new() // Outbound messages will be forwarded "as is" to outbound messaging .with_outbound_pipeline(outbound_rx, identity) .max_concurrent_inbound_tasks(1) // Inbound messages will be forwarded "as is" to inbound_tx .with_inbound_pipeline(SinkService::new(inbound_tx)) .build(), - )) + ) + .enable_message_received_event(), + ) .spawn_with_transport(MemoryTransport) .await .unwrap(); @@ -251,7 +254,7 @@ async fn peer_to_peer_messaging() { let events = collect_recv!(messaging_events2, take = NUM_MSGS, timeout = Duration::from_secs(10)); events.into_iter().for_each(|m| { - unpack_enum!(MessagingEvent::MessageReceived(_n, _t) = &*m); + unpack_enum!(MessagingEvent::MessageReceived(_n, _t) = &m); }); // Send NUM_MSGS messages from node 2 to node 1 diff --git a/comms/core/src/protocol/messaging/extension.rs b/comms/core/src/protocol/messaging/extension.rs index 4f31841c7c..cb3f7ecd48 100644 --- a/comms/core/src/protocol/messaging/extension.rs +++ b/comms/core/src/protocol/messaging/extension.rs @@ -20,7 +20,7 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use std::fmt; +use std::{fmt, time::Duration}; use tokio::sync::mpsc; use tower::Service; @@ -31,7 +31,7 @@ use crate::{ message::InboundMessage, pipeline, protocol::{ - messaging::{protocol::MESSAGING_PROTOCOL, MessagingEventSender}, + messaging::{protocol::MESSAGING_PROTOCOL_ID, MessagingEventSender}, ProtocolExtension, ProtocolExtensionContext, ProtocolExtensionError, @@ -50,11 +50,32 @@ pub const MESSAGING_PROTOCOL_EVENTS_BUFFER_SIZE: usize = 30; pub struct MessagingProtocolExtension { event_tx: MessagingEventSender, pipeline: pipeline::Config, + enable_message_received_event: bool, + ban_duration: Duration, } impl MessagingProtocolExtension { pub fn new(event_tx: MessagingEventSender, pipeline: pipeline::Config) -> Self { - Self { event_tx, pipeline } + Self { + event_tx, + pipeline, + enable_message_received_event: false, + ban_duration: Duration::from_secs(10 * 60), + } + } + + /// Enables the MessageReceived event which is disabled by default. This will enable sending the MessageReceived + /// event per message received. This is typically used in tests. If unused it should be disabled to reduce memory + /// usage (not reading the event from the channel). + pub fn enable_message_received_event(mut self) -> Self { + self.enable_message_received_event = true; + self + } + + /// Sets the ban duration for peers that violate protocol. Default is 10 minutes. + pub fn with_ban_duration(mut self, ban_duration: Duration) -> Self { + self.ban_duration = ban_duration; + self } } @@ -70,7 +91,7 @@ where { fn install(mut self: Box, context: &mut ProtocolExtensionContext) -> Result<(), ProtocolExtensionError> { let (proto_tx, proto_rx) = mpsc::channel(MESSAGING_PROTOCOL_EVENTS_BUFFER_SIZE); - context.add_protocol(&[MESSAGING_PROTOCOL.clone()], &proto_tx); + context.add_protocol(&[MESSAGING_PROTOCOL_ID.clone()], &proto_tx); let (inbound_message_tx, inbound_message_rx) = mpsc::channel(INBOUND_MESSAGE_BUFFER_SIZE); @@ -82,7 +103,9 @@ where self.event_tx, inbound_message_tx, context.shutdown_signal(), - ); + ) + .set_message_received_event_enabled(self.enable_message_received_event) + .with_ban_duration(self.ban_duration); context.register_complete_signal(messaging.complete_signal()); diff --git a/comms/core/src/protocol/messaging/inbound.rs b/comms/core/src/protocol/messaging/inbound.rs index d62c415d76..99ed90d218 100644 --- a/comms/core/src/protocol/messaging/inbound.rs +++ b/comms/core/src/protocol/messaging/inbound.rs @@ -20,7 +20,7 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use std::sync::Arc; +use std::io; use futures::StreamExt; use log::*; @@ -38,19 +38,22 @@ const LOG_TARGET: &str = "comms::protocol::messaging::inbound"; pub struct InboundMessaging { peer: NodeId, inbound_message_tx: mpsc::Sender, - messaging_events_tx: broadcast::Sender>, + messaging_events_tx: broadcast::Sender, + enable_message_received_event: bool, } impl InboundMessaging { pub fn new( peer: NodeId, inbound_message_tx: mpsc::Sender, - messaging_events_tx: broadcast::Sender>, + messaging_events_tx: broadcast::Sender, + enable_message_received_event: bool, ) -> Self { Self { peer, inbound_message_tx, messaging_events_tx, + enable_message_received_event, } } @@ -83,21 +86,39 @@ impl InboundMessaging { msg_len ); - let event = MessagingEvent::MessageReceived(inbound_msg.source_peer.clone(), inbound_msg.tag); + let message_tag = inbound_msg.tag; - if let Err(err) = self.inbound_message_tx.send(inbound_msg).await { - let tag = err.0.tag; + if self.inbound_message_tx.send(inbound_msg).await.is_err() { warn!( target: LOG_TARGET, "Failed to send InboundMessage {} for peer '{}' because inbound message channel closed", - tag, + message_tag, peer.short_str(), ); break; } - let _result = self.messaging_events_tx.send(Arc::new(event)); + if self.enable_message_received_event { + let _result = self + .messaging_events_tx + .send(MessagingEvent::MessageReceived(peer.clone(), message_tag)); + } + }, + // LengthDelimitedCodec emits a InvalidData io error when the message length exceeds the maximum allowed + Err(err) if err.kind() == io::ErrorKind::InvalidData => { + metrics::error_count(peer).inc(); + debug!( + target: LOG_TARGET, + "Failed to receive from peer '{}' because '{}'", + peer.short_str(), + err + ); + let _result = self.messaging_events_tx.send(MessagingEvent::ProtocolViolation { + peer_node_id: peer.clone(), + details: err.to_string(), + }); + break; }, Err(err) => { metrics::error_count(peer).inc(); @@ -112,6 +133,9 @@ impl InboundMessaging { } } + let _ignore = self + .messaging_events_tx + .send(MessagingEvent::InboundProtocolExited(peer.clone())); metrics::num_sessions().dec(); debug!( target: LOG_TARGET, diff --git a/comms/core/src/protocol/messaging/mod.rs b/comms/core/src/protocol/messaging/mod.rs index 9b45008474..3410c2b60d 100644 --- a/comms/core/src/protocol/messaging/mod.rs +++ b/comms/core/src/protocol/messaging/mod.rs @@ -37,7 +37,14 @@ mod inbound; mod metrics; mod outbound; mod protocol; -pub use protocol::{MessagingEvent, MessagingEventReceiver, MessagingEventSender, MessagingProtocol, SendFailReason}; +pub use protocol::{ + MessagingEvent, + MessagingEventReceiver, + MessagingEventSender, + MessagingProtocol, + SendFailReason, + MESSAGING_PROTOCOL_ID, +}; #[cfg(test)] mod test; diff --git a/comms/core/src/protocol/messaging/outbound.rs b/comms/core/src/protocol/messaging/outbound.rs index 9f8ff2831e..7160ad5c21 100644 --- a/comms/core/src/protocol/messaging/outbound.rs +++ b/comms/core/src/protocol/messaging/outbound.rs @@ -33,7 +33,7 @@ use crate::{ message::OutboundMessage, multiplexing::Substream, peer_manager::NodeId, - protocol::messaging::protocol::MESSAGING_PROTOCOL, + protocol::messaging::protocol::MESSAGING_PROTOCOL_ID, stream_id::StreamId, }; @@ -123,7 +123,7 @@ impl OutboundMessaging { } metrics::num_sessions().dec(); - let _ = messaging_events_tx + let _ignore = messaging_events_tx .send(MessagingEvent::OutboundProtocolExited(peer_node_id)) .await; } @@ -223,7 +223,7 @@ impl OutboundMessaging { &mut self, conn: &mut PeerConnection, ) -> Result, MessagingProtocolError> { - match conn.open_substream(&MESSAGING_PROTOCOL).await { + match conn.open_substream(&MESSAGING_PROTOCOL_ID).await { Ok(substream) => Ok(substream), Err(err) => { debug!( diff --git a/comms/core/src/protocol/messaging/protocol.rs b/comms/core/src/protocol/messaging/protocol.rs index a11d6e57e1..cc6078242b 100644 --- a/comms/core/src/protocol/messaging/protocol.rs +++ b/comms/core/src/protocol/messaging/protocol.rs @@ -23,7 +23,8 @@ use std::{ collections::{hash_map::Entry, HashMap}, fmt, - sync::Arc, + fmt::Display, + time::Duration, }; use bytes::Bytes; @@ -33,6 +34,7 @@ use thiserror::Error; use tokio::{ io::{AsyncRead, AsyncWrite}, sync::{broadcast, mpsc}, + task::JoinHandle, }; use tokio_util::codec::{Framed, LengthDelimitedCodec}; @@ -51,13 +53,13 @@ use crate::{ }; const LOG_TARGET: &str = "comms::protocol::messaging"; -pub(super) static MESSAGING_PROTOCOL: Bytes = Bytes::from_static(b"t/msg/0.1"); +pub static MESSAGING_PROTOCOL_ID: Bytes = Bytes::from_static(b"t/msg/0.1"); const INTERNAL_MESSAGING_EVENT_CHANNEL_SIZE: usize = 10; const MAX_FRAME_LENGTH: usize = 8 * 1_024 * 1_024; -pub type MessagingEventSender = broadcast::Sender>; -pub type MessagingEventReceiver = broadcast::Receiver>; +pub type MessagingEventSender = broadcast::Sender; +pub type MessagingEventReceiver = broadcast::Receiver; /// The reason for dial failure. This enum should contain simple variants which describe the kind of failure that /// occurred @@ -76,20 +78,24 @@ pub enum SendFailReason { } /// Events emitted by the messaging protocol. -#[derive(Debug)] +#[derive(Debug, Clone)] pub enum MessagingEvent { MessageReceived(NodeId, MessageTag), - InvalidMessageReceived(NodeId), OutboundProtocolExited(NodeId), + InboundProtocolExited(NodeId), + ProtocolViolation { peer_node_id: NodeId, details: String }, } impl fmt::Display for MessagingEvent { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - use MessagingEvent::{InvalidMessageReceived, MessageReceived, OutboundProtocolExited}; + use MessagingEvent::*; match self { - MessageReceived(node_id, tag) => write!(f, "MessageReceived({}, {})", node_id.short_str(), tag), - InvalidMessageReceived(node_id) => write!(f, "InvalidMessageReceived({})", node_id.short_str()), + MessageReceived(node_id, tag) => write!(f, "MessageReceived({}, {})", node_id, tag), OutboundProtocolExited(node_id) => write!(f, "OutboundProtocolExited({})", node_id), + InboundProtocolExited(node_id) => write!(f, "InboundProtocolExited({})", node_id), + ProtocolViolation { peer_node_id, details } => { + write!(f, "ProtocolViolation({}, {})", peer_node_id, details) + }, } } } @@ -99,8 +105,11 @@ pub struct MessagingProtocol { connectivity: ConnectivityRequester, proto_notification: mpsc::Receiver>, active_queues: HashMap>, + active_inbound: HashMap>, outbound_message_rx: mpsc::UnboundedReceiver, messaging_events_tx: MessagingEventSender, + enable_message_received_event: bool, + ban_duration: Option, inbound_message_tx: mpsc::Sender, internal_messaging_event_tx: mpsc::Sender, internal_messaging_event_rx: mpsc::Receiver, @@ -128,10 +137,13 @@ impl MessagingProtocol { connectivity, proto_notification, outbound_message_rx, + active_inbound: Default::default(), active_queues: Default::default(), messaging_events_tx, + enable_message_received_event: false, internal_messaging_event_rx, internal_messaging_event_tx, + ban_duration: None, retry_queue_tx, retry_queue_rx, inbound_message_tx, @@ -140,6 +152,19 @@ impl MessagingProtocol { } } + /// Set to true to enable emitting the MessageReceived event for each message received. Typically only useful in + /// tests. + pub fn set_message_received_event_enabled(mut self, enabled: bool) -> Self { + self.enable_message_received_event = enabled; + self + } + + /// Sets a custom ban duration. Banning is disabled by default. + pub fn with_ban_duration(mut self, ban_duration: Duration) -> Self { + self.ban_duration = Some(ban_duration); + self + } + /// Returns a signal that resolves when this actor exits. pub fn complete_signal(&self) -> ShutdownSignal { self.complete_trigger.to_signal() @@ -152,7 +177,7 @@ impl MessagingProtocol { loop { tokio::select! { Some(event) = self.internal_messaging_event_rx.recv() => { - self.handle_internal_messaging_event(event); + self.handle_internal_messaging_event(event).await; }, Some(msg) = self.retry_queue_rx.recv() => { @@ -193,17 +218,17 @@ impl MessagingProtocol { framing::canonical(socket, MAX_FRAME_LENGTH) } - fn handle_internal_messaging_event(&mut self, event: MessagingEvent) { - use MessagingEvent::OutboundProtocolExited; + async fn handle_internal_messaging_event(&mut self, event: MessagingEvent) { + use MessagingEvent::*; trace!(target: LOG_TARGET, "Internal messaging event '{}'", event); - match event { + match &event { OutboundProtocolExited(node_id) => { debug!( target: LOG_TARGET, "Outbound protocol handler exited for peer `{}`", node_id.short_str() ); - if self.active_queues.remove(&node_id).is_none() { + if self.active_queues.remove(node_id).is_none() { debug!( target: LOG_TARGET, "OutboundProtocolExited event, but MessagingProtocol has no record of the outbound protocol \ @@ -211,13 +236,30 @@ impl MessagingProtocol { node_id.short_str() ); } - let _result = self.messaging_events_tx.send(Arc::new(OutboundProtocolExited(node_id))); }, - evt => { - // Forward the event - let _result = self.messaging_events_tx.send(Arc::new(evt)); + InboundProtocolExited(node_id) => { + debug!( + target: LOG_TARGET, + "Inbound protocol handler exited for peer `{}`", + node_id.short_str() + ); + if self.active_inbound.remove(node_id).is_none() { + debug!( + target: LOG_TARGET, + "InboundProtocolExited event, but MessagingProtocol has no record of the inbound protocol \ + for peer `{}`", + node_id.short_str() + ); + } + }, + ProtocolViolation { peer_node_id, details } => { + self.ban_peer(peer_node_id.clone(), details.to_string()).await; }, + _ => {}, } + + // Forward the event + let _result = self.messaging_events_tx.send(event); } fn handle_retry_queue_messages(&mut self, msg: OutboundMessage) -> Result<(), MessagingProtocolError> { @@ -281,10 +323,27 @@ impl MessagingProtocol { } fn spawn_inbound_handler(&mut self, peer: NodeId, substream: Substream) { + if let Some(handle) = self.active_inbound.get(&peer) { + if handle.is_finished() { + self.active_inbound.remove(&peer); + } else { + debug!( + target: LOG_TARGET, + "InboundMessaging for peer '{}' already exists", peer.short_str() + ); + return; + } + } let messaging_events_tx = self.messaging_events_tx.clone(); let inbound_message_tx = self.inbound_message_tx.clone(); - let inbound_messaging = InboundMessaging::new(peer, inbound_message_tx, messaging_events_tx); - tokio::spawn(inbound_messaging.run(substream)); + let inbound_messaging = InboundMessaging::new( + peer.clone(), + inbound_message_tx, + messaging_events_tx, + self.enable_message_received_event, + ); + let handle = tokio::spawn(inbound_messaging.run(substream)); + self.active_inbound.insert(peer, handle); } fn handle_protocol_notification(&mut self, notification: ProtocolNotification) { @@ -301,4 +360,35 @@ impl MessagingProtocol { }, } } + + async fn ban_peer(&mut self, peer_node_id: NodeId, reason: T) { + warn!( + target: LOG_TARGET, + "Banning peer '{}' because it violated the messaging protocol: {}", peer_node_id.short_str(), reason + ); + if let Some(handle) = self.active_inbound.remove(&peer_node_id) { + handle.abort(); + } + drop(self.active_queues.remove(&peer_node_id)); + match self.ban_duration { + Some(ban_duration) => { + if let Err(err) = self + .connectivity + .ban_peer_until(peer_node_id.clone(), ban_duration, reason.to_string()) + .await + { + error!( + target: LOG_TARGET, + "Failed to ban peer '{}' because '{:?}'", peer_node_id.short_str(), err + ); + } + }, + None => { + warn!( + target: LOG_TARGET, + "Banning disabled in MessagingProtocol, so peer '{peer_node_id}' will not be banned (reason: {reason})", + ); + }, + } + } } diff --git a/comms/core/src/protocol/messaging/test.rs b/comms/core/src/protocol/messaging/test.rs index 50deff8ae7..be35886aed 100644 --- a/comms/core/src/protocol/messaging/test.rs +++ b/comms/core/src/protocol/messaging/test.rs @@ -33,7 +33,7 @@ use tokio::{ time, }; -use super::protocol::{MessagingEvent, MessagingEventReceiver, MessagingProtocol, MESSAGING_PROTOCOL}; +use super::protocol::{MessagingEvent, MessagingEventReceiver, MessagingProtocol, MESSAGING_PROTOCOL_ID}; use crate::{ message::{InboundMessage, MessageTag, MessagingReplyRx, OutboundMessage}, multiplexing::Substream, @@ -81,7 +81,8 @@ async fn spawn_messaging_protocol() -> ( events_tx, inbound_msg_tx, shutdown.to_signal(), - ); + ) + .set_message_received_event_enabled(true); tokio::spawn(msg_proto.run()); ( @@ -123,7 +124,7 @@ async fn new_inbound_substream_handling() { let stream_ours = muxer_ours.get_yamux_control().open_stream().await.unwrap(); proto_tx .send(ProtocolNotification::new( - MESSAGING_PROTOCOL.clone(), + MESSAGING_PROTOCOL_ID.clone(), ProtocolEvent::NewInboundSubstream(expected_node_id.clone(), stream_ours), )) .await @@ -146,7 +147,7 @@ async fn new_inbound_substream_handling() { .await .unwrap() .unwrap(); - unpack_enum!(MessagingEvent::MessageReceived(node_id, tag) = &*event); + unpack_enum!(MessagingEvent::MessageReceived(node_id, tag) = &event); assert_eq!(tag, &expected_tag); assert_eq!(*node_id, expected_node_id); } @@ -188,7 +189,7 @@ async fn send_message_dial_failed() { request_tx.send(out_msg).unwrap(); let event = event_tx.recv().await.unwrap(); - unpack_enum!(MessagingEvent::OutboundProtocolExited(_node_id) = &*event); + unpack_enum!(MessagingEvent::OutboundProtocolExited(_node_id) = &event); let reply = reply_rx.await.unwrap().unwrap_err(); unpack_enum!(SendFailReason::PeerDialFailed = reply); @@ -260,7 +261,7 @@ async fn send_message_substream_bulk_failure() { .await .unwrap() .unwrap(); - unpack_enum!(MessagingEvent::OutboundProtocolExited(node_id) = &*event); + unpack_enum!(MessagingEvent::OutboundProtocolExited(node_id) = &event); assert_eq!(node_id, peer_node_id); } @@ -339,3 +340,93 @@ async fn many_concurrent_send_message_requests_that_fail() { let results = unordered.collect::>().await; assert!(results.into_iter().map(|r| r.unwrap()).all(|r| r.is_err())); } + +#[tokio::test] +async fn new_inbound_substream_only_single_session_permitted() { + let (peer_manager, _, _, proto_tx, _, mut inbound_msg_rx, _, _shutdown) = spawn_messaging_protocol().await; + + let expected_node_id = node_id::random(); + let (_, pk) = CommsPublicKey::random_keypair(&mut OsRng); + peer_manager + .add_peer(Peer::new( + pk.clone(), + expected_node_id.clone(), + MultiaddressesWithStats::default(), + PeerFlags::empty(), + PeerFeatures::COMMUNICATION_CLIENT, + Default::default(), + Default::default(), + )) + .await + .unwrap(); + + // Create connected memory sockets - we use each end of the connection as if they exist on different nodes + let (_, muxer_ours, mut muxer_theirs) = transport::build_multiplexed_connections().await; + + // Notify the messaging protocol that a new substream has been established that wants to talk the messaging. + let stream_ours = muxer_ours.get_yamux_control().open_stream().await.unwrap(); + proto_tx + .send(ProtocolNotification::new( + MESSAGING_PROTOCOL_ID.clone(), + ProtocolEvent::NewInboundSubstream(expected_node_id.clone(), stream_ours), + )) + .await + .unwrap(); + + // First stream is open + let stream_theirs = muxer_theirs.incoming_mut().next().await.unwrap(); + + // Open another one for messaging + let stream_ours2 = muxer_ours.get_yamux_control().open_stream().await.unwrap(); + proto_tx + .send(ProtocolNotification::new( + MESSAGING_PROTOCOL_ID.clone(), + ProtocolEvent::NewInboundSubstream(expected_node_id.clone(), stream_ours2), + )) + .await + .unwrap(); + + // Check the second stream closes immediately + let stream_theirs2 = muxer_theirs.incoming_mut().next().await.unwrap(); + let mut framed_ours2 = MessagingProtocol::framed(stream_theirs2); + let next = framed_ours2.next().await; + // The stream is closed + assert!(next.is_none()); + + // The first stream is still active + let mut framed_theirs = MessagingProtocol::framed(stream_theirs); + + framed_theirs.send(TEST_MSG1.clone()).await.unwrap(); + + let in_msg = time::timeout(Duration::from_secs(5), inbound_msg_rx.recv()) + .await + .unwrap() + .unwrap(); + assert_eq!(in_msg.source_peer, expected_node_id); + assert_eq!(in_msg.body, TEST_MSG1); + + // Close the first + framed_theirs.close().await.unwrap(); + + // Open another one for messaging + let stream_ours2 = muxer_ours.get_yamux_control().open_stream().await.unwrap(); + proto_tx + .send(ProtocolNotification::new( + MESSAGING_PROTOCOL_ID.clone(), + ProtocolEvent::NewInboundSubstream(expected_node_id.clone(), stream_ours2), + )) + .await + .unwrap(); + + let stream_theirs = muxer_theirs.incoming_mut().next().await.unwrap(); + let mut framed_theirs = MessagingProtocol::framed(stream_theirs); + framed_theirs.send(TEST_MSG1.clone()).await.unwrap(); + + // The second message comes through + let in_msg = time::timeout(Duration::from_secs(5), inbound_msg_rx.recv()) + .await + .unwrap() + .unwrap(); + assert_eq!(in_msg.source_peer, expected_node_id); + assert_eq!(in_msg.body, TEST_MSG1); +} diff --git a/comms/dht/examples/memory_net/utilities.rs b/comms/dht/examples/memory_net/utilities.rs index dc5dc6fbe7..2eb5550acc 100644 --- a/comms/dht/examples/memory_net/utilities.rs +++ b/comms/dht/examples/memory_net/utilities.rs @@ -465,7 +465,7 @@ pub async fn do_store_and_forward_message_propagation( let msg = time::timeout(Duration::from_secs(2), s.recv()).await; match msg { Ok(Ok(evt)) => { - if let MessagingEvent::MessageReceived(_, tag) = &*evt { + if let MessagingEvent::MessageReceived(_, tag) = &evt { println!("{} received propagated SAF message ({})", neighbour, tag); } }, @@ -745,11 +745,9 @@ impl TestNode { loop { let event = messaging_events.recv().await; use MessagingEvent::MessageReceived; - match event.as_deref() { + match event { Ok(MessageReceived(peer_node_id, _)) => { - messaging_events_tx - .send((Clone::clone(peer_node_id), node_id.clone())) - .unwrap(); + messaging_events_tx.send((peer_node_id, node_id.clone())).unwrap(); }, Err(broadcast::error::RecvError::Closed) => { break; @@ -970,7 +968,9 @@ async fn setup_comms_dht( let (messaging_events_tx, _) = broadcast::channel(100); let comms = comms .add_rpc_server(RpcServer::new().add_service(dht.rpc_service())) - .add_protocol_extension(MessagingProtocolExtension::new(messaging_events_tx.clone(), pipeline)) + .add_protocol_extension( + MessagingProtocolExtension::new(messaging_events_tx.clone(), pipeline).enable_message_received_event(), + ) .spawn_with_transport(MemoryTransport) .await .unwrap(); diff --git a/comms/dht/src/peer_validator.rs b/comms/dht/src/peer_validator.rs index 9c4b37a958..a0252ca86e 100644 --- a/comms/dht/src/peer_validator.rs +++ b/comms/dht/src/peer_validator.rs @@ -26,6 +26,7 @@ use tari_comms::{ peer_manager::{NodeId, Peer, PeerFlags}, peer_validator, peer_validator::{find_most_recent_claim, PeerValidatorError}, + protocol::messaging::MESSAGING_PROTOCOL_ID, }; use crate::{rpc::UnvalidatedPeerInfo, DhtConfig}; @@ -89,13 +90,16 @@ impl<'a> PeerValidator<'a> { let node_id = NodeId::from_public_key(&new_peer.public_key); let mut peer = existing_peer.unwrap_or_else(|| { + // All peer speak messaging protocol, as an optimisation we'll include it here so that we can do optimistic + // protocol negotiation. + let default_protocols = vec![MESSAGING_PROTOCOL_ID.clone()]; Peer::new( new_peer.public_key.clone(), node_id, MultiaddressesWithStats::default(), PeerFlags::default(), most_recent_claim.features, - vec![], + default_protocols, String::new(), ) }); diff --git a/comms/dht/tests/dht.rs b/comms/dht/tests/dht.rs index bbcbd6edb4..2eb654af18 100644 --- a/comms/dht/tests/dht.rs +++ b/comms/dht/tests/dht.rs @@ -21,7 +21,7 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. mod harness; -use std::{sync::Arc, time::Duration}; +use std::time::Duration; use harness::*; use tari_comms::{ @@ -244,7 +244,7 @@ async fn test_dht_store_forward() { .unwrap(); // Wait for node C to and receive a response from the SAF request let event = collect_try_recv!(node_C_msg_events, take = 1, timeout = Duration::from_secs(20)); - unpack_enum!(MessagingEvent::MessageReceived(_node_id, _msg) = event.get(0).unwrap().as_ref()); + unpack_enum!(MessagingEvent::MessageReceived(_node_id, _msg) = event.get(0).unwrap()); let msg = node_C.next_inbound_message(Duration::from_secs(5)).await.unwrap(); assert_eq!( @@ -908,21 +908,18 @@ async fn test_dht_header_not_malleable() { node_C.shutdown().await; } -fn filter_received(events: Vec>) -> Vec> { +fn filter_received(events: Vec) -> Vec { events .into_iter() - .filter(|e| match &**e { - MessagingEvent::MessageReceived(_, _) => true, - _ => unreachable!(), - }) + .filter(|e| matches!(e, MessagingEvent::MessageReceived(_, _))) .collect() } -fn count_messages_received(events: &[Arc], node_ids: &[&NodeId]) -> usize { +fn count_messages_received(events: &[MessagingEvent], node_ids: &[&NodeId]) -> usize { events .iter() .filter(|event| { - unpack_enum!(MessagingEvent::MessageReceived(recv_node_id, _tag) = &***event); + unpack_enum!(MessagingEvent::MessageReceived(recv_node_id, _tag) = &**event); node_ids.iter().any(|n| recv_node_id == *n) }) .count() diff --git a/comms/dht/tests/harness.rs b/comms/dht/tests/harness.rs index bc9abbc8d5..b9baea0f60 100644 --- a/comms/dht/tests/harness.rs +++ b/comms/dht/tests/harness.rs @@ -52,7 +52,7 @@ pub struct TestNode { pub comms: CommsNode, pub dht: Dht, pub inbound_messages: mpsc::Receiver, - pub messaging_events: broadcast::Sender>, + pub messaging_events: broadcast::Sender, pub shutdown: Shutdown, } @@ -197,7 +197,9 @@ pub async fn setup_comms_dht( let (event_tx, _) = broadcast::channel(100); let comms = comms - .add_protocol_extension(MessagingProtocolExtension::new(event_tx.clone(), pipeline)) + .add_protocol_extension( + MessagingProtocolExtension::new(event_tx.clone(), pipeline).enable_message_received_event(), + ) .spawn_with_transport(MemoryTransport) .await .unwrap();