Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: avoid race condition between pending frames and closing stream #156

Merged
merged 15 commits into from May 24, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 0 additions & 1 deletion test-harness/Cargo.toml
Expand Up @@ -16,4 +16,3 @@ log = "0.4.17"
[dev-dependencies]
env_logger = "0.10"
constrained-connection = "0.1"

1 change: 1 addition & 0 deletions yamux/Cargo.toml
Expand Up @@ -26,6 +26,7 @@ quickcheck = "1.0"
tokio = { version = "1.0", features = ["net", "rt-multi-thread", "macros", "time"] }
tokio-util = { version = "0.7", features = ["compat"] }
constrained-connection = "0.1"
futures_ringbuf = "0.3.1"

[[bench]]
name = "concurrent"
Expand Down
328 changes: 321 additions & 7 deletions yamux/src/connection.rs
Expand Up @@ -102,8 +102,9 @@ use cleanup::Cleanup;
use closing::Closing;
use futures::{channel::mpsc, future::Either, prelude::*, sink::SinkExt, stream::Fuse};
use nohash_hasher::IntMap;
use std::collections::hash_map::Entry;
use std::collections::VecDeque;
use std::task::Context;
use std::task::{Context, Waker};
use std::{fmt, sync::Arc, task::Poll};

pub use stream::{Packet, State, Stream};
Expand Down Expand Up @@ -348,6 +349,8 @@ struct Active<T> {
socket: Fuse<frame::Io<T>>,
next_id: u32,
streams: IntMap<StreamId, Stream>,
/// Stores the "marks" at which we need to notify a waiting flush task of a [`Stream`].
flush_marks: IntMap<StreamId, (u64, Waker)>,
stream_sender: mpsc::Sender<StreamCommand>,
stream_receiver: mpsc::Receiver<StreamCommand>,
dropped_streams: Vec<StreamId>,
Expand All @@ -359,6 +362,13 @@ struct Active<T> {
pub(crate) enum StreamCommand {
/// A new frame should be sent to the remote.
SendFrame(Frame<Either<Data, WindowUpdate>>),
Flush {
id: StreamId,
/// How many frames we've queued for sending at the time the flush was requested.
num_frames: u64,
thomaseizinger marked this conversation as resolved.
Show resolved Hide resolved
/// The waker to wake once the flush is complete.
waker: Waker,
},
/// Close a stream.
CloseStream { id: StreamId, ack: bool },
}
Expand Down Expand Up @@ -416,6 +426,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
config: Arc::new(cfg),
socket,
streams: IntMap::default(),
flush_marks: Default::default(),
stream_sender,
stream_receiver,
next_id: match mode {
Expand Down Expand Up @@ -466,6 +477,14 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
self.on_close_stream(id, ack);
continue;
}
Poll::Ready(Some(StreamCommand::Flush {
id,
num_frames,
waker,
})) => {
self.on_flush_stream(id, num_frames, waker);
continue;
}
Poll::Ready(None) => {
debug_assert!(false, "Only closed during shutdown")
}
Expand Down Expand Up @@ -526,13 +545,36 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
}

fn on_send_frame(&mut self, frame: Frame<Either<Data, WindowUpdate>>) {
log::trace!(
"{}/{}: sending: {}",
self.id,
frame.header().stream_id(),
frame.header()
);
let stream_id = frame.header().stream_id();

log::trace!("{}/{}: sending: {}", self.id, stream_id, frame.header());
self.pending_frames.push_back(frame.into());

if let Some(stream) = self.streams.get(&stream_id) {
let mut shared = stream.shared();

shared.inc_sent();

if let Entry::Occupied(entry) = self.flush_marks.entry(stream_id) {
if shared.num_sent() >= entry.get().0 {
entry.remove().1.wake();
}
}
}
}

fn on_flush_stream(&mut self, id: StreamId, new_flush_mark: u64, waker: Waker) {
if let Some(stream) = self.streams.get(&id) {
let shared = stream.shared();

// Check if we have already reached the requested flush mark:
if shared.num_sent() >= new_flush_mark {
waker.wake();
return;
}

self.flush_marks.insert(id, (new_flush_mark, waker));
}
}

fn on_close_stream(&mut self, id: StreamId, ack: bool) {
Expand Down Expand Up @@ -934,3 +976,275 @@ impl<T> Active<T> {
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use futures::future::BoxFuture;
use futures::stream::FuturesUnordered;
use futures::AsyncReadExt;
use futures_ringbuf::Endpoint;
use std::mem;
use std::pin::Pin;

#[tokio::test]
async fn poll_flush_on_stream_only_returns_ok_if_frame_is_queued_for_sending() {
let (client, server) = Endpoint::pair(1000, 1000);

let client = Client::new(Connection::new(client, Config::default(), Mode::Client));
let server = EchoServer::new(Connection::new(server, Config::default(), Mode::Server));

let ((), processed) = futures::future::try_join(client, server).await.unwrap();

assert_eq!(processed, 1);
}

/// Our testing client.
///
/// This struct will open a single outbound stream, send a message, attempt to flush it and assert the internal state of [`Connection`] after it.
enum Client {
Initial {
connection: Connection<Endpoint>,
},
Testing {
connection: Connection<Endpoint>,
worker_stream: StreamState,
},
Closing {
connection: Connection<Endpoint>,
},
Poisoned,
}

enum StreamState {
Sending(Stream),
Flushing(Stream),
Receiving(Stream),
Closing(Stream),
}

impl Client {
fn new(connection: Connection<Endpoint>) -> Self {
Self::Initial { connection }
}
}

impl Future for Client {
type Output = Result<()>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();

loop {
match mem::replace(this, Client::Poisoned) {
// This state matching is out of order to have the interesting one at the top.
Client::Testing {
worker_stream: StreamState::Flushing(mut stream),
mut connection,
} => {
match Pin::new(&mut stream).poll_flush(cx)? {
Poll::Ready(()) => {
// Here is the actual test:
// If the stream reports that it successfully flushed, we expect the connection to have queued the frames for sending.
// Because we only have a single stream, this means we can simply assert that there are no pending frames in the channel.

let ConnectionState::Active(active) = &mut connection.inner else {
panic!("Connection is not active")
};

active.stream_receiver.try_next().expect_err(
"expected no pending frames in the channel after flushing",
);

*this = Client::Testing {
worker_stream: StreamState::Receiving(stream),
connection,
};
continue;
}
Poll::Pending => {}
}

drive_connection(this, connection, StreamState::Flushing(stream), cx);
return Poll::Pending;
}
Client::Testing {
worker_stream: StreamState::Receiving(mut stream),
connection,
} => {
let mut buffer = [0u8; 5];

match Pin::new(&mut stream).poll_read(cx, &mut buffer)? {
Poll::Ready(num_bytes) => {
assert_eq!(num_bytes, 5);
assert_eq!(&buffer, b"hello");

*this = Client::Testing {
worker_stream: StreamState::Closing(stream),
connection,
};
continue;
}
Poll::Pending => {}
}

drive_connection(this, connection, StreamState::Closing(stream), cx);
return Poll::Pending;
}
Client::Testing {
worker_stream: StreamState::Closing(mut stream),
connection,
} => {
match Pin::new(&mut stream).poll_close(cx)? {
Poll::Ready(()) => {
*this = Client::Closing { connection };
continue;
}
Poll::Pending => {}
}

drive_connection(this, connection, StreamState::Closing(stream), cx);
return Poll::Pending;
}
Client::Initial { mut connection } => {
match connection.poll_new_outbound(cx)? {
Poll::Ready(stream) => {
*this = Client::Testing {
connection,
worker_stream: StreamState::Sending(stream),
};
continue;
}
Poll::Pending => {
*this = Client::Initial { connection };
return Poll::Pending;
}
}
}
Client::Testing {
worker_stream: StreamState::Sending(mut stream),
connection,
} => {
match Pin::new(&mut stream).poll_write(cx, b"hello")? {
Poll::Ready(written) => {
assert_eq!(written, 5);
*this = Client::Testing {
worker_stream: StreamState::Flushing(stream),
connection,
};
continue;
}
Poll::Pending => {}
}

drive_connection(this, connection, StreamState::Flushing(stream), cx);
return Poll::Pending;
}
Client::Closing { mut connection } => match connection.poll_close(cx)? {
Poll::Ready(()) => {
return Poll::Ready(Ok(()));
}
Poll::Pending => {
*this = Client::Closing { connection };
return Poll::Pending;
}
},
Client::Poisoned => {
unreachable!()
}
}
}
}
}

fn drive_connection(
this: &mut Client,
mut connection: Connection<futures_ringbuf::Endpoint>,
state: StreamState,
cx: &mut Context,
) {
match connection.poll_next_inbound(cx) {
Poll::Ready(Some(_)) => {
panic!("Unexpected inbound stream")
}
Poll::Ready(None) => {
panic!("Unexpected connection close")
}
Poll::Pending => {
*this = Client::Testing {
worker_stream: state,
connection,
};
}
}
}

struct EchoServer {
connection: Connection<Endpoint>,
worker_streams: FuturesUnordered<BoxFuture<'static, Result<()>>>,
streams_processed: usize,
connection_closed: bool,
}

impl EchoServer {
fn new(connection: Connection<Endpoint>) -> Self {
Self {
connection,
worker_streams: FuturesUnordered::default(),
streams_processed: 0,
connection_closed: false,
}
}
}

impl Future for EchoServer {
type Output = Result<usize>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();

loop {
match this.worker_streams.poll_next_unpin(cx) {
Poll::Ready(Some(Ok(()))) => {
this.streams_processed += 1;
continue;
}
Poll::Ready(Some(Err(e))) => {
eprintln!("A stream failed: {}", e);
continue;
}
Poll::Ready(None) => {
if this.connection_closed {
return Poll::Ready(Ok(this.streams_processed));
}
}
Poll::Pending => {}
}

match this.connection.poll_next_inbound(cx) {
Poll::Ready(Some(Ok(mut stream))) => {
this.worker_streams.push(
async move {
{
let (mut r, mut w) = AsyncReadExt::split(&mut stream);
futures::io::copy(&mut r, &mut w).await?;
}
stream.close().await?;
Ok(())
}
.boxed(),
);
continue;
}
Poll::Ready(None) | Poll::Ready(Some(Err(_))) => {
this.connection_closed = true;
continue;
}
Poll::Pending => {}
}

return Poll::Pending;
}
}
}
}