Skip to content

Commit

Permalink
feat: add Stream when getting next message from connection
Browse files Browse the repository at this point in the history
Adds a new method `next_with_stream` to `ConnectionIncoming` in order
for users to respond to a bidirectional stream.
  • Loading branch information
hansonkd authored and Yoga07 committed Apr 25, 2022
1 parent 92e311e commit cd3b255
Showing 1 changed file with 77 additions and 61 deletions.
138 changes: 77 additions & 61 deletions src/connection.rs
Expand Up @@ -12,7 +12,7 @@ use futures::{
};
use std::{fmt, net::SocketAddr, pin::Pin, sync::Arc, task, time::Duration};
use tokio::{
sync::{mpsc, watch},
sync::{mpsc, watch, Mutex},
time::timeout,
};
use tracing::{trace, warn};
Expand Down Expand Up @@ -263,7 +263,7 @@ impl fmt::Debug for RecvStream {
/// The receiving API for a connection.
#[derive(Debug)]
pub struct ConnectionIncoming {
message_rx: mpsc::Receiver<Result<Bytes, RecvError>>,
message_rx: mpsc::Receiver<Result<(Bytes, Option<Arc<Mutex<SendStream>>>), RecvError>>,
_alive_tx: Arc<watch::Sender<()>>,
}

Expand Down Expand Up @@ -297,6 +297,17 @@ impl ConnectionIncoming {

/// Get the next message sent by the peer, over any stream.
pub async fn next(&mut self) -> Result<Option<Bytes>, RecvError> {
if let Some((bytes, _opt)) = self.next_with_stream().await? {
Ok(Some(bytes))
} else {
Ok(None)
}
}

/// Get the next message sent by the peer, over any stream along with the stream to respond with.
pub async fn next_with_stream(
&mut self,
) -> Result<Option<(Bytes, Option<Arc<Mutex<SendStream>>>)>, RecvError> {
self.message_rx.recv().await.transpose()
}
}
Expand All @@ -312,7 +323,7 @@ fn start_message_listeners(
uni_streams: quinn::IncomingUniStreams,
bi_streams: quinn::IncomingBiStreams,
alive_rx: watch::Receiver<()>,
message_tx: mpsc::Sender<Result<Bytes, RecvError>>,
message_tx: mpsc::Sender<Result<(Bytes, Option<Arc<Mutex<SendStream>>>), RecvError>>,
) {
let _ = tokio::spawn(listen_on_uni_streams(
peer_addr,
Expand All @@ -334,7 +345,7 @@ async fn listen_on_uni_streams(
peer_addr: SocketAddr,
uni_streams: FilterBenignClose<quinn::IncomingUniStreams>,
mut alive_rx: watch::Receiver<()>,
message_tx: mpsc::Sender<Result<Bytes, RecvError>>,
message_tx: mpsc::Sender<Result<(Bytes, Option<Arc<Mutex<SendStream>>>), RecvError>>,
) {
trace!(
"Started listener for incoming uni-streams from {}",
Expand Down Expand Up @@ -388,7 +399,7 @@ async fn listen_on_uni_streams(
break_ = true;
}

if message_tx.send(result).await.is_err() {
if message_tx.send(result.map(|b| (b, None))).await.is_err() {
// if we can't send the result, the receiving end is closed so we should stop processing
break_ = true;
}
Expand All @@ -408,78 +419,83 @@ async fn listen_on_bi_streams(
peer_addr: SocketAddr,
bi_streams: FilterBenignClose<quinn::IncomingBiStreams>,
mut alive_rx: watch::Receiver<()>,
message_tx: mpsc::Sender<Result<Bytes, RecvError>>,
message_tx: mpsc::Sender<Result<(Bytes, Option<Arc<Mutex<SendStream>>>), RecvError>>,
) {
trace!(
"Started listener for incoming bi-streams from {}",
peer_addr
);

let streaming =
bi_streams.try_for_each_concurrent(None, |(mut send_stream, mut recv_stream)| {
let endpoint = &endpoint;
let message_tx = &message_tx;
async move {
trace!("Handling incoming bi-stream from {}", peer_addr);

loop {
match WireMsg::read_from_stream(&mut recv_stream).await {
Err(error) => {
let mut break_ = false;

if let RecvError::ConnectionLost(_) = &error {
break_ = true;
}

if let Err(error) = message_tx.send(Err(error)).await {
// if we can't send the result, the receiving end is closed so we should stop
trace!("Receiver gone, dropping error: {:?}", error);
break_ = true;
}

if break_ {
break;
}
let streaming = bi_streams.try_for_each_concurrent(None, |(send_stream, mut recv_stream)| {
let endpoint = &endpoint;
let message_tx = &message_tx;
async move {
trace!("Handling incoming bi-stream from {}", peer_addr);
let arc_mutex = Arc::new(Mutex::new(SendStream::new(send_stream)));

loop {
match WireMsg::read_from_stream(&mut recv_stream).await {
Err(error) => {
let mut break_ = false;

if let RecvError::ConnectionLost(_) = &error {
break_ = true;
}
Ok(None) => {
break;

if let Err(error) = message_tx.send(Err(error)).await {
// if we can't send the result, the receiving end is closed so we should stop
trace!("Receiver gone, dropping error: {:?}", error);
break_ = true;
}
Ok(Some(WireMsg::UserMsg(msg))) => {
if let Err(msg) = message_tx.send(Ok(msg)).await {
// if we can't send the result, the receiving end is closed so we should stop
trace!("Receiver gone, dropping message: {:?}", msg);
break;
}

if break_ {
break;
}
Ok(Some(WireMsg::EndpointEchoReq)) => {
if let Err(error) =
handle_endpoint_echo(&mut send_stream, peer_addr).await
{
// TODO: consider more carefully how to handle this
warn!("Error handling endpoint echo request: {}", error);
}
}
Ok(None) => {
break;
}
Ok(Some(WireMsg::UserMsg(msg))) => {
if let Err(msg) = message_tx.send(Ok((msg, Some(arc_mutex.clone())))).await
{
// if we can't send the result, the receiving end is closed so we should stop
trace!("Receiver gone, dropping message: {:?}", msg);
break;
}
Ok(Some(WireMsg::EndpointVerificationReq(addr))) => {
if let Err(error) =
handle_endpoint_verification(endpoint, &mut send_stream, addr).await
{
// TODO: consider more carefully how to handle this
warn!("Error handling endpoint verification request: {}", error);
}
}
Ok(Some(WireMsg::EndpointEchoReq)) => {
if let Err(error) =
handle_endpoint_echo(&mut arc_mutex.lock().await.inner, peer_addr).await
{
// TODO: consider more carefully how to handle this
warn!("Error handling endpoint echo request: {}", error);
}
Ok(msg) => {
}
Ok(Some(WireMsg::EndpointVerificationReq(addr))) => {
if let Err(error) = handle_endpoint_verification(
endpoint,
&mut arc_mutex.lock().await.inner,
addr,
)
.await
{
// TODO: consider more carefully how to handle this
warn!(
"Error on bi-stream: {}",
SerializationError::unexpected(&msg)
);
warn!("Error handling endpoint verification request: {}", error);
}
}
Ok(msg) => {
// TODO: consider more carefully how to handle this
warn!(
"Error on bi-stream: {}",
SerializationError::unexpected(&msg)
);
}
}

Ok(())
}
});

Ok(())
}
});

// it's a shame to allocate, but there are `Pin` errors otherwise – and we should only be doing
// this once.
Expand Down

0 comments on commit cd3b255

Please sign in to comment.