From c466644fdf4dc720099217b538d32e8bfb51f23b Mon Sep 17 00:00:00 2001 From: Josh Wilson Date: Wed, 15 Feb 2023 11:52:01 +0100 Subject: [PATCH] feat: limit async during read/write for greater consistency BREAKING CHANGE: consumes receiver on read. --- src/connection.rs | 47 +++++++++++++++++++++--------- src/endpoint.rs | 2 +- src/error.rs | 4 +++ src/wire_msg.rs | 74 +++++++++++++++++++++++++++++------------------ 4 files changed, 84 insertions(+), 43 deletions(-) diff --git a/src/connection.rs b/src/connection.rs index 69247310..74e75054 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -11,7 +11,7 @@ use tokio::sync::mpsc::{Receiver, Sender}; use tracing::{trace, warn}; // TODO: this seems arbitrary - it may need tuned or made configurable. -const INCOMING_MESSAGE_BUFFER_LEN: usize = 10_000; +const INCOMING_MESSAGE_BUFFER_LEN: usize = 1; // Error reason for closing a connection when triggered manually by qp2p apis const QP2P_CLOSED_CONNECTION: &str = "The connection was closed intentionally by qp2p."; @@ -137,13 +137,13 @@ fn build_conn_id(conn: &quinn::Connection) -> String { fn listen_on_uni_streams(connection: quinn::Connection, tx: Sender) { let conn_id = build_conn_id(&connection); - let _ = tokio::spawn(async move { + let _handle = tokio::spawn(async move { trace!("Connection {conn_id}: listening for incoming uni-streams"); loop { // Wait for an incoming stream. let uni = connection.accept_uni().await.map_err(ConnectionError::from); - let mut recv = match uni { + let recv = match uni { Ok(recv) => recv, Err(err) => { // In case of a connection error, there is not much we can do. @@ -160,11 +160,21 @@ fn listen_on_uni_streams(connection: quinn::Connection, tx: Sender) let tx = tx.clone(); // Make sure we are able to process multiple streams in parallel. - let _ = tokio::spawn(async move { - let msg = WireMsg::read_from_stream(&mut recv).await; + let _handle = tokio::spawn(async move { + let reserved_sender = match tx.reserve().await { + Ok(p) => p, + Err(error) => { + tracing::error!( + "Could not reserve sender for new conn msg read: {error:?}" + ); + return; + } + }; + + let msg = WireMsg::read_from_stream(recv).await; // Send away the msg or error - let _ = tx.send(msg.map(|r| (r, None))).await; + reserved_sender.send(msg.map(|r| (r, None))); }); } @@ -176,13 +186,13 @@ fn listen_on_uni_streams(connection: quinn::Connection, tx: Sender) fn listen_on_bi_streams(connection: quinn::Connection, tx: Sender) { let conn_id = build_conn_id(&connection); - let _ = tokio::spawn(async move { + let _handle = tokio::spawn(async move { trace!("Connection {conn_id}: listening for incoming bi-streams"); loop { // Wait for an incoming stream. let bi = connection.accept_bi().await.map_err(ConnectionError::from); - let (send, mut recv) = match bi { + let (send, recv) = match bi { Ok(recv) => recv, Err(err) => { // In case of a connection error, there is not much we can do. @@ -200,13 +210,22 @@ fn listen_on_bi_streams(connection: quinn::Connection, tx: Sender) let conn_id = conn_id.clone(); // Make sure we are able to process multiple streams in parallel. - let _ = tokio::spawn(async move { - let msg = WireMsg::read_from_stream(&mut recv).await; + let _handle = tokio::spawn(async move { + let reserved_sender = match tx.reserve().await { + Ok(p) => p, + Err(error) => { + tracing::error!( + "Could not reserve sender for new conn msg read: {error:?}" + ); + return; + } + }; + let msg = WireMsg::read_from_stream(recv).await; // Pass the stream, so it can be used to respond to the user message. let msg = msg.map(|msg| (msg, Some(SendStream::new(send, conn_id.clone())))); // Send away the msg or error - let _ = tx.send(msg).await; + reserved_sender.send(msg); trace!("Incoming new msg on conn_id={conn_id} sent to user in upper layer"); }); } @@ -347,12 +366,12 @@ impl RecvStream { } /// Parse the message sent by the peer over this stream. - pub async fn read(&mut self) -> Result { + pub async fn read(self) -> Result { self.read_wire_msg().await.map(|v| v.0) } - pub(crate) async fn read_wire_msg(&mut self) -> Result { - WireMsg::read_from_stream(&mut self.inner).await + pub(crate) async fn read_wire_msg(self) -> Result { + WireMsg::read_from_stream(self.inner).await } } diff --git a/src/endpoint.rs b/src/endpoint.rs index a5491dde..2ff81da7 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -150,7 +150,7 @@ pub(super) fn listen_for_incoming_connections( quinn_endpoint: quinn::Endpoint, connection_tx: mpsc::Sender<(Connection, ConnectionIncoming)>, ) { - let _ = tokio::spawn(async move { + let _handle = tokio::spawn(async move { while let Some(quinn_conn) = quinn_endpoint.accept().await { match quinn_conn.await { Ok(connection) => { diff --git a/src/error.rs b/src/error.rs index 97fada29..0e6594a4 100644 --- a/src/error.rs +++ b/src/error.rs @@ -322,6 +322,10 @@ pub enum RecvError { #[error("Connection was lost when trying to receive a message")] ConnectionLost(#[from] ConnectionError), + /// Connection was lost when trying to receive a message. + #[error("Error reading to end of stream")] + ReadToEndError(#[from] quinn::ReadToEndError), + /// Stream was lost when trying to receive a message. #[error("Stream was lost when trying to receive a message")] StreamLost(#[source] StreamError), diff --git a/src/wire_msg.rs b/src/wire_msg.rs index 6f59c414..6a0a8221 100644 --- a/src/wire_msg.rs +++ b/src/wire_msg.rs @@ -12,11 +12,11 @@ use crate::{ utils, }; use bytes::{Bytes, BytesMut}; -use quinn::VarInt; + use serde::{Deserialize, Serialize}; -use std::convert::TryFrom; use std::fmt; - +use std::{convert::TryFrom, time::Instant}; +use tracing::trace; const MSG_HEADER_LEN: usize = 16; const MSG_PROTOCOL_VERSION: u16 = 0x0002; @@ -34,11 +34,28 @@ impl WireMsg { // Read a message's bytes from the provided stream /// # Cancellation safety /// Warning: This method is not cancellation safe! - pub(crate) async fn read_from_stream(recv: &mut quinn::RecvStream) -> Result { + pub(crate) async fn read_from_stream(mut recv: quinn::RecvStream) -> Result { let mut header_bytes = [0; MSG_HEADER_LEN]; recv.read_exact(&mut header_bytes).await?; let msg_header = MsgHeader::from_bytes(header_bytes); + + let start = Instant::now(); + let all_bytes = recv.read_to_end(1024 * 1024 * 100).await?; + + let duration = start.elapsed(); + trace!( + "Incoming new msg. Reading {:?} bytes took: {:?}", + all_bytes.len(), + duration + ); + + if all_bytes.is_empty() { + return Err(RecvError::EmptyMsgPayload); + } + + let mut bytes = Bytes::from(all_bytes); + // https://github.com/rust-lang/rust/issues/70460 for work on a cleaner alternative: #[cfg(not(any(target_pointer_width = "32", target_pointer_width = "64")))] { @@ -49,30 +66,21 @@ impl WireMsg { let dst_length = msg_header.user_dst_len() as usize; let payload_length = msg_header.user_payload_len() as usize; - let mut header_data = BytesMut::with_capacity(header_length); - let mut dst_data = BytesMut::with_capacity(dst_length); - let mut payload_data = BytesMut::with_capacity(payload_length); - // buffer capacity does not actually give us length, so this sets us up - header_data.resize(header_length, 0); - dst_data.resize(dst_length, 0); - payload_data.resize(payload_length, 0); + // Check we have all the data and we weren't cut short, otherwise + // the following would panic... + if bytes.len() != (header_length + dst_length + payload_length) { + return Err(RecvError::NotEnoughBytes); + } - // fill up our data vecs from the stream - recv.read_exact(&mut header_data).await?; - recv.read_exact(&mut dst_data).await?; - recv.read_exact(&mut payload_data).await?; + let header_data = bytes.split_to(header_length); - // let sender know we won't receive any more. - let _ = recv.stop(VarInt::from_u32(0)); + let dst_data = bytes.split_to(dst_length); + let payload_data = bytes; if payload_data.is_empty() { Err(RecvError::EmptyMsgPayload) } else { - Ok(WireMsg(( - header_data.freeze(), - dst_data.freeze(), - payload_data.freeze(), - ))) + Ok(WireMsg((header_data, dst_data, payload_data))) } } @@ -88,13 +96,23 @@ impl WireMsg { let header_bytes = msg_header.to_bytes(); - // Send the header bytes over QUIC - send_stream.write_all(&header_bytes).await?; + let mut all_bytes = BytesMut::with_capacity( + header_bytes.len() + + msg_header.user_header_len() as usize + + msg_header.user_dst_len() as usize + + msg_header.user_payload_len() as usize, + ); + + all_bytes.extend_from_slice(&header_bytes); + all_bytes.extend_from_slice(msg_head); + all_bytes.extend_from_slice(msg_dst); + all_bytes.extend_from_slice(msg_payload); - // Send message bytes over QUIC - send_stream.write_all(msg_head).await?; - send_stream.write_all(msg_dst).await?; - send_stream.write_all(msg_payload).await?; + let start = Instant::now(); + // Send the header bytes over QUIC + send_stream.write_all(&all_bytes).await?; + let duration = start.elapsed(); + trace!("Writing {:?} bytes took: {duration:?}", all_bytes.len()); Ok(()) }