Skip to content

Commit

Permalink
feat: limit async during read/write for greater consistency
Browse files Browse the repository at this point in the history
BREAKING CHANGE: consumes receiver on read.
  • Loading branch information
joshuef authored and bochaco committed Feb 20, 2023
1 parent c45c63a commit c466644
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 43 deletions.
47 changes: 33 additions & 14 deletions src/connection.rs
Expand Up @@ -11,7 +11,7 @@ use tokio::sync::mpsc::{Receiver, Sender};
use tracing::{trace, warn};

// TODO: this seems arbitrary - it may need tuned or made configurable.
const INCOMING_MESSAGE_BUFFER_LEN: usize = 10_000;
const INCOMING_MESSAGE_BUFFER_LEN: usize = 1;

// Error reason for closing a connection when triggered manually by qp2p apis
const QP2P_CLOSED_CONNECTION: &str = "The connection was closed intentionally by qp2p.";
Expand Down Expand Up @@ -137,13 +137,13 @@ fn build_conn_id(conn: &quinn::Connection) -> String {
fn listen_on_uni_streams(connection: quinn::Connection, tx: Sender<IncomingMsg>) {
let conn_id = build_conn_id(&connection);

let _ = tokio::spawn(async move {
let _handle = tokio::spawn(async move {
trace!("Connection {conn_id}: listening for incoming uni-streams");

loop {
// Wait for an incoming stream.
let uni = connection.accept_uni().await.map_err(ConnectionError::from);
let mut recv = match uni {
let recv = match uni {
Ok(recv) => recv,
Err(err) => {
// In case of a connection error, there is not much we can do.
Expand All @@ -160,11 +160,21 @@ fn listen_on_uni_streams(connection: quinn::Connection, tx: Sender<IncomingMsg>)
let tx = tx.clone();

// Make sure we are able to process multiple streams in parallel.
let _ = tokio::spawn(async move {
let msg = WireMsg::read_from_stream(&mut recv).await;
let _handle = tokio::spawn(async move {
let reserved_sender = match tx.reserve().await {
Ok(p) => p,
Err(error) => {
tracing::error!(
"Could not reserve sender for new conn msg read: {error:?}"
);
return;
}
};

let msg = WireMsg::read_from_stream(recv).await;

// Send away the msg or error
let _ = tx.send(msg.map(|r| (r, None))).await;
reserved_sender.send(msg.map(|r| (r, None)));
});
}

Expand All @@ -176,13 +186,13 @@ fn listen_on_uni_streams(connection: quinn::Connection, tx: Sender<IncomingMsg>)
fn listen_on_bi_streams(connection: quinn::Connection, tx: Sender<IncomingMsg>) {
let conn_id = build_conn_id(&connection);

let _ = tokio::spawn(async move {
let _handle = tokio::spawn(async move {
trace!("Connection {conn_id}: listening for incoming bi-streams");

loop {
// Wait for an incoming stream.
let bi = connection.accept_bi().await.map_err(ConnectionError::from);
let (send, mut recv) = match bi {
let (send, recv) = match bi {
Ok(recv) => recv,
Err(err) => {
// In case of a connection error, there is not much we can do.
Expand All @@ -200,13 +210,22 @@ fn listen_on_bi_streams(connection: quinn::Connection, tx: Sender<IncomingMsg>)
let conn_id = conn_id.clone();

// Make sure we are able to process multiple streams in parallel.
let _ = tokio::spawn(async move {
let msg = WireMsg::read_from_stream(&mut recv).await;
let _handle = tokio::spawn(async move {
let reserved_sender = match tx.reserve().await {
Ok(p) => p,
Err(error) => {
tracing::error!(
"Could not reserve sender for new conn msg read: {error:?}"
);
return;
}
};
let msg = WireMsg::read_from_stream(recv).await;

// Pass the stream, so it can be used to respond to the user message.
let msg = msg.map(|msg| (msg, Some(SendStream::new(send, conn_id.clone()))));
// Send away the msg or error
let _ = tx.send(msg).await;
reserved_sender.send(msg);
trace!("Incoming new msg on conn_id={conn_id} sent to user in upper layer");
});
}
Expand Down Expand Up @@ -347,12 +366,12 @@ impl RecvStream {
}

/// Parse the message sent by the peer over this stream.
pub async fn read(&mut self) -> Result<UsrMsgBytes, RecvError> {
pub async fn read(self) -> Result<UsrMsgBytes, RecvError> {
self.read_wire_msg().await.map(|v| v.0)
}

pub(crate) async fn read_wire_msg(&mut self) -> Result<WireMsg, RecvError> {
WireMsg::read_from_stream(&mut self.inner).await
pub(crate) async fn read_wire_msg(self) -> Result<WireMsg, RecvError> {
WireMsg::read_from_stream(self.inner).await
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/endpoint.rs
Expand Up @@ -150,7 +150,7 @@ pub(super) fn listen_for_incoming_connections(
quinn_endpoint: quinn::Endpoint,
connection_tx: mpsc::Sender<(Connection, ConnectionIncoming)>,
) {
let _ = tokio::spawn(async move {
let _handle = tokio::spawn(async move {
while let Some(quinn_conn) = quinn_endpoint.accept().await {
match quinn_conn.await {
Ok(connection) => {
Expand Down
4 changes: 4 additions & 0 deletions src/error.rs
Expand Up @@ -322,6 +322,10 @@ pub enum RecvError {
#[error("Connection was lost when trying to receive a message")]
ConnectionLost(#[from] ConnectionError),

/// Connection was lost when trying to receive a message.
#[error("Error reading to end of stream")]
ReadToEndError(#[from] quinn::ReadToEndError),

/// Stream was lost when trying to receive a message.
#[error("Stream was lost when trying to receive a message")]
StreamLost(#[source] StreamError),
Expand Down
74 changes: 46 additions & 28 deletions src/wire_msg.rs
Expand Up @@ -12,11 +12,11 @@ use crate::{
utils,
};
use bytes::{Bytes, BytesMut};
use quinn::VarInt;

use serde::{Deserialize, Serialize};
use std::convert::TryFrom;
use std::fmt;

use std::{convert::TryFrom, time::Instant};
use tracing::trace;
const MSG_HEADER_LEN: usize = 16;
const MSG_PROTOCOL_VERSION: u16 = 0x0002;

Expand All @@ -34,11 +34,28 @@ impl WireMsg {
// Read a message's bytes from the provided stream
/// # Cancellation safety
/// Warning: This method is not cancellation safe!
pub(crate) async fn read_from_stream(recv: &mut quinn::RecvStream) -> Result<Self, RecvError> {
pub(crate) async fn read_from_stream(mut recv: quinn::RecvStream) -> Result<Self, RecvError> {
let mut header_bytes = [0; MSG_HEADER_LEN];
recv.read_exact(&mut header_bytes).await?;

let msg_header = MsgHeader::from_bytes(header_bytes);

let start = Instant::now();
let all_bytes = recv.read_to_end(1024 * 1024 * 100).await?;

let duration = start.elapsed();
trace!(
"Incoming new msg. Reading {:?} bytes took: {:?}",
all_bytes.len(),
duration
);

if all_bytes.is_empty() {
return Err(RecvError::EmptyMsgPayload);
}

let mut bytes = Bytes::from(all_bytes);

// https://github.com/rust-lang/rust/issues/70460 for work on a cleaner alternative:
#[cfg(not(any(target_pointer_width = "32", target_pointer_width = "64")))]
{
Expand All @@ -49,30 +66,21 @@ impl WireMsg {
let dst_length = msg_header.user_dst_len() as usize;
let payload_length = msg_header.user_payload_len() as usize;

let mut header_data = BytesMut::with_capacity(header_length);
let mut dst_data = BytesMut::with_capacity(dst_length);
let mut payload_data = BytesMut::with_capacity(payload_length);
// buffer capacity does not actually give us length, so this sets us up
header_data.resize(header_length, 0);
dst_data.resize(dst_length, 0);
payload_data.resize(payload_length, 0);
// Check we have all the data and we weren't cut short, otherwise
// the following would panic...
if bytes.len() != (header_length + dst_length + payload_length) {
return Err(RecvError::NotEnoughBytes);
}

// fill up our data vecs from the stream
recv.read_exact(&mut header_data).await?;
recv.read_exact(&mut dst_data).await?;
recv.read_exact(&mut payload_data).await?;
let header_data = bytes.split_to(header_length);

// let sender know we won't receive any more.
let _ = recv.stop(VarInt::from_u32(0));
let dst_data = bytes.split_to(dst_length);
let payload_data = bytes;

if payload_data.is_empty() {
Err(RecvError::EmptyMsgPayload)
} else {
Ok(WireMsg((
header_data.freeze(),
dst_data.freeze(),
payload_data.freeze(),
)))
Ok(WireMsg((header_data, dst_data, payload_data)))
}
}

Expand All @@ -88,13 +96,23 @@ impl WireMsg {

let header_bytes = msg_header.to_bytes();

// Send the header bytes over QUIC
send_stream.write_all(&header_bytes).await?;
let mut all_bytes = BytesMut::with_capacity(
header_bytes.len()
+ msg_header.user_header_len() as usize
+ msg_header.user_dst_len() as usize
+ msg_header.user_payload_len() as usize,
);

all_bytes.extend_from_slice(&header_bytes);
all_bytes.extend_from_slice(msg_head);
all_bytes.extend_from_slice(msg_dst);
all_bytes.extend_from_slice(msg_payload);

// Send message bytes over QUIC
send_stream.write_all(msg_head).await?;
send_stream.write_all(msg_dst).await?;
send_stream.write_all(msg_payload).await?;
let start = Instant::now();
// Send the header bytes over QUIC
send_stream.write_all(&all_bytes).await?;
let duration = start.elapsed();
trace!("Writing {:?} bytes took: {duration:?}", all_bytes.len());

Ok(())
}
Expand Down

0 comments on commit c466644

Please sign in to comment.