Skip to content

Commit

Permalink
fix(comms): only permit a single inbound messaging substream per peer (
Browse files Browse the repository at this point in the history
…tari-project#5731)

Description
---
Only permit a single inbound messaging session per peer
Only emit MessageReceived event in tests
Ban peer if they exceed the frame size

Motivation and Context
---
A node may initiate boundless inbound messaging sessions. This PR limits
the number of sessions to one. Tari nodes only initiate a single session
as needed so this shouldn't have any effect on non-malicious nodes.

MessageReceived events would be generated for each message received.
This was only used in tests. Since these events aren't read in non-test
contexts and because it can be very busy, this PR disables them by
default and only enables in tests. This does not affect the message
count in the minotari node status line.

How Has This Been Tested?
---
New unit test that checks sessions are closed as long as the first one
is active.

What process can a PR reviewer use to test or verify this change?
---
Messaging still works as expected (ping-peer, discover-peer etc)
<!-- Checklist -->
<!-- 1. Is the title of your PR in the form that would make nice release
notes? The title, excluding the conventional commit
tag, will be included exactly as is in the CHANGELOG, so please think
about it carefully. -->


Breaking Changes
---

- [x] None
- [ ] Requires data directory on base node to be deleted
- [ ] Requires hard fork
- [ ] Other - Please specify

<!-- Does this include a breaking change? If so, include this line as a
footer -->
<!-- BREAKING CHANGE: Description what the user should do, e.g. delete a
database, resync the chain -->
  • Loading branch information
sdbondi committed Sep 4, 2023
1 parent 660a5c1 commit c91a35f
Show file tree
Hide file tree
Showing 13 changed files with 316 additions and 73 deletions.
4 changes: 2 additions & 2 deletions base_layer/core/tests/tests/node_service.rs
Expand Up @@ -308,13 +308,13 @@ async fn propagate_and_forward_invalid_block_hash() {
let msg_event = event_stream_next(&mut bob_message_events, Duration::from_secs(10))
.await
.unwrap();
unpack_enum!(MessagingEvent::MessageReceived(_a, _b) = &*msg_event);
unpack_enum!(MessagingEvent::MessageReceived(_a, _b) = &msg_event);

// Bob asks Alice for missing transaction
let msg_event = event_stream_next(&mut bob_message_events, Duration::from_secs(10))
.await
.unwrap();
unpack_enum!(MessagingEvent::MessageReceived(node_id, _a) = &*msg_event);
unpack_enum!(MessagingEvent::MessageReceived(node_id, _a) = &msg_event);
assert_eq!(node_id, alice_node.node_identity.node_id());

// Checking a negative: Bob should not have propagated this hash to Carol. If Bob does, this assertion will be
Expand Down
12 changes: 7 additions & 5 deletions base_layer/p2p/src/initialization.rs
Expand Up @@ -198,7 +198,9 @@ pub async fn initialize_local_test_comms<P: AsRef<Path>>(
.build();

let comms = comms
.add_protocol_extension(MessagingProtocolExtension::new(event_sender.clone(), pipeline))
.add_protocol_extension(
MessagingProtocolExtension::new(event_sender.clone(), pipeline).enable_message_received_event(),
)
.spawn_with_transport(MemoryTransport)
.await?;

Expand Down Expand Up @@ -371,10 +373,10 @@ async fn configure_comms_and_dht(
.build();

let (messaging_events_sender, _) = broadcast::channel(1);
comms = comms.add_protocol_extension(MessagingProtocolExtension::new(
messaging_events_sender,
messaging_pipeline,
));
comms = comms.add_protocol_extension(
MessagingProtocolExtension::new(messaging_events_sender, messaging_pipeline)
.with_ban_duration(config.dht.ban_duration_short),
);

Ok((comms, dht))
}
Expand Down
13 changes: 8 additions & 5 deletions comms/core/src/builder/tests.rs
Expand Up @@ -89,16 +89,19 @@ async fn spawn_node(
let (messaging_events_sender, _) = broadcast::channel(100);
let comms_node = comms_node
.add_protocol_extensions(protocols.into())
.add_protocol_extension(MessagingProtocolExtension::new(
messaging_events_sender.clone(),
pipeline::Builder::new()
.add_protocol_extension(
MessagingProtocolExtension::new(
messaging_events_sender.clone(),
pipeline::Builder::new()
// Outbound messages will be forwarded "as is" to outbound messaging
.with_outbound_pipeline(outbound_rx, identity)
.max_concurrent_inbound_tasks(1)
// Inbound messages will be forwarded "as is" to inbound_tx
.with_inbound_pipeline(SinkService::new(inbound_tx))
.build(),
))
)
.enable_message_received_event(),
)
.spawn_with_transport(MemoryTransport)
.await
.unwrap();
Expand Down Expand Up @@ -251,7 +254,7 @@ async fn peer_to_peer_messaging() {

let events = collect_recv!(messaging_events2, take = NUM_MSGS, timeout = Duration::from_secs(10));
events.into_iter().for_each(|m| {
unpack_enum!(MessagingEvent::MessageReceived(_n, _t) = &*m);
unpack_enum!(MessagingEvent::MessageReceived(_n, _t) = &m);
});

// Send NUM_MSGS messages from node 2 to node 1
Expand Down
33 changes: 28 additions & 5 deletions comms/core/src/protocol/messaging/extension.rs
Expand Up @@ -20,7 +20,7 @@
// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

use std::fmt;
use std::{fmt, time::Duration};

use tokio::sync::mpsc;
use tower::Service;
Expand All @@ -31,7 +31,7 @@ use crate::{
message::InboundMessage,
pipeline,
protocol::{
messaging::{protocol::MESSAGING_PROTOCOL, MessagingEventSender},
messaging::{protocol::MESSAGING_PROTOCOL_ID, MessagingEventSender},
ProtocolExtension,
ProtocolExtensionContext,
ProtocolExtensionError,
Expand All @@ -50,11 +50,32 @@ pub const MESSAGING_PROTOCOL_EVENTS_BUFFER_SIZE: usize = 30;
pub struct MessagingProtocolExtension<TInPipe, TOutPipe, TOutReq> {
event_tx: MessagingEventSender,
pipeline: pipeline::Config<TInPipe, TOutPipe, TOutReq>,
enable_message_received_event: bool,
ban_duration: Duration,
}

impl<TInPipe, TOutPipe, TOutReq> MessagingProtocolExtension<TInPipe, TOutPipe, TOutReq> {
pub fn new(event_tx: MessagingEventSender, pipeline: pipeline::Config<TInPipe, TOutPipe, TOutReq>) -> Self {
Self { event_tx, pipeline }
Self {
event_tx,
pipeline,
enable_message_received_event: false,
ban_duration: Duration::from_secs(10 * 60),
}
}

/// Enables the MessageReceived event which is disabled by default. This will enable sending the MessageReceived
/// event per message received. This is typically used in tests. If unused it should be disabled to reduce memory
/// usage (not reading the event from the channel).
pub fn enable_message_received_event(mut self) -> Self {
self.enable_message_received_event = true;
self
}

/// Sets the ban duration for peers that violate protocol. Default is 10 minutes.
pub fn with_ban_duration(mut self, ban_duration: Duration) -> Self {
self.ban_duration = ban_duration;
self
}
}

Expand All @@ -70,7 +91,7 @@ where
{
fn install(mut self: Box<Self>, context: &mut ProtocolExtensionContext) -> Result<(), ProtocolExtensionError> {
let (proto_tx, proto_rx) = mpsc::channel(MESSAGING_PROTOCOL_EVENTS_BUFFER_SIZE);
context.add_protocol(&[MESSAGING_PROTOCOL.clone()], &proto_tx);
context.add_protocol(&[MESSAGING_PROTOCOL_ID.clone()], &proto_tx);

let (inbound_message_tx, inbound_message_rx) = mpsc::channel(INBOUND_MESSAGE_BUFFER_SIZE);

Expand All @@ -82,7 +103,9 @@ where
self.event_tx,
inbound_message_tx,
context.shutdown_signal(),
);
)
.set_message_received_event_enabled(self.enable_message_received_event)
.with_ban_duration(self.ban_duration);

context.register_complete_signal(messaging.complete_signal());

Expand Down
40 changes: 32 additions & 8 deletions comms/core/src/protocol/messaging/inbound.rs
Expand Up @@ -20,7 +20,7 @@
// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

use std::sync::Arc;
use std::io;

use futures::StreamExt;
use log::*;
Expand All @@ -38,19 +38,22 @@ const LOG_TARGET: &str = "comms::protocol::messaging::inbound";
pub struct InboundMessaging {
peer: NodeId,
inbound_message_tx: mpsc::Sender<InboundMessage>,
messaging_events_tx: broadcast::Sender<Arc<MessagingEvent>>,
messaging_events_tx: broadcast::Sender<MessagingEvent>,
enable_message_received_event: bool,
}

impl InboundMessaging {
pub fn new(
peer: NodeId,
inbound_message_tx: mpsc::Sender<InboundMessage>,
messaging_events_tx: broadcast::Sender<Arc<MessagingEvent>>,
messaging_events_tx: broadcast::Sender<MessagingEvent>,
enable_message_received_event: bool,
) -> Self {
Self {
peer,
inbound_message_tx,
messaging_events_tx,
enable_message_received_event,
}
}

Expand Down Expand Up @@ -83,21 +86,39 @@ impl InboundMessaging {
msg_len
);

let event = MessagingEvent::MessageReceived(inbound_msg.source_peer.clone(), inbound_msg.tag);
let message_tag = inbound_msg.tag;

if let Err(err) = self.inbound_message_tx.send(inbound_msg).await {
let tag = err.0.tag;
if self.inbound_message_tx.send(inbound_msg).await.is_err() {
warn!(
target: LOG_TARGET,
"Failed to send InboundMessage {} for peer '{}' because inbound message channel closed",
tag,
message_tag,
peer.short_str(),
);

break;
}

let _result = self.messaging_events_tx.send(Arc::new(event));
if self.enable_message_received_event {
let _result = self
.messaging_events_tx
.send(MessagingEvent::MessageReceived(peer.clone(), message_tag));
}
},
// LengthDelimitedCodec emits a InvalidData io error when the message length exceeds the maximum allowed
Err(err) if err.kind() == io::ErrorKind::InvalidData => {
metrics::error_count(peer).inc();
debug!(
target: LOG_TARGET,
"Failed to receive from peer '{}' because '{}'",
peer.short_str(),
err
);
let _result = self.messaging_events_tx.send(MessagingEvent::ProtocolViolation {
peer_node_id: peer.clone(),
details: err.to_string(),
});
break;
},
Err(err) => {
metrics::error_count(peer).inc();
Expand All @@ -112,6 +133,9 @@ impl InboundMessaging {
}
}

let _ignore = self
.messaging_events_tx
.send(MessagingEvent::InboundProtocolExited(peer.clone()));
metrics::num_sessions().dec();
debug!(
target: LOG_TARGET,
Expand Down
9 changes: 8 additions & 1 deletion comms/core/src/protocol/messaging/mod.rs
Expand Up @@ -37,7 +37,14 @@ mod inbound;
mod metrics;
mod outbound;
mod protocol;
pub use protocol::{MessagingEvent, MessagingEventReceiver, MessagingEventSender, MessagingProtocol, SendFailReason};
pub use protocol::{
MessagingEvent,
MessagingEventReceiver,
MessagingEventSender,
MessagingProtocol,
SendFailReason,
MESSAGING_PROTOCOL_ID,
};

#[cfg(test)]
mod test;
6 changes: 3 additions & 3 deletions comms/core/src/protocol/messaging/outbound.rs
Expand Up @@ -33,7 +33,7 @@ use crate::{
message::OutboundMessage,
multiplexing::Substream,
peer_manager::NodeId,
protocol::messaging::protocol::MESSAGING_PROTOCOL,
protocol::messaging::protocol::MESSAGING_PROTOCOL_ID,
stream_id::StreamId,
};

Expand Down Expand Up @@ -123,7 +123,7 @@ impl OutboundMessaging {
}

metrics::num_sessions().dec();
let _ = messaging_events_tx
let _ignore = messaging_events_tx
.send(MessagingEvent::OutboundProtocolExited(peer_node_id))
.await;
}
Expand Down Expand Up @@ -223,7 +223,7 @@ impl OutboundMessaging {
&mut self,
conn: &mut PeerConnection,
) -> Result<NegotiatedSubstream<Substream>, MessagingProtocolError> {
match conn.open_substream(&MESSAGING_PROTOCOL).await {
match conn.open_substream(&MESSAGING_PROTOCOL_ID).await {
Ok(substream) => Ok(substream),
Err(err) => {
debug!(
Expand Down

0 comments on commit c91a35f

Please sign in to comment.