Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: avoid race condition between pending frames and closing stream #156

Merged
merged 15 commits into from May 24, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 0 additions & 1 deletion test-harness/Cargo.toml
Expand Up @@ -16,4 +16,3 @@ log = "0.4.17"
[dev-dependencies]
env_logger = "0.10"
constrained-connection = "0.1"

1 change: 1 addition & 0 deletions yamux/Cargo.toml
Expand Up @@ -26,6 +26,7 @@ quickcheck = "1.0"
tokio = { version = "1.0", features = ["net", "rt-multi-thread", "macros", "time"] }
tokio-util = { version = "0.7", features = ["compat"] }
constrained-connection = "0.1"
futures_ringbuf = "0.3.1"

[[bench]]
name = "concurrent"
Expand Down
263 changes: 145 additions & 118 deletions yamux/src/connection.rs
Expand Up @@ -96,14 +96,16 @@ use crate::{
error::ConnectionError,
frame::header::{self, Data, GoAway, Header, Ping, StreamId, Tag, WindowUpdate, CONNECTION_ID},
frame::{self, Frame},
Config, WindowUpdateMode, DEFAULT_CREDIT, MAX_COMMAND_BACKLOG,
Config, WindowUpdateMode, DEFAULT_CREDIT,
};
use cleanup::Cleanup;
use closing::Closing;
use futures::stream::SelectAll;
use futures::{channel::mpsc, future::Either, prelude::*, sink::SinkExt, stream::Fuse};
use nohash_hasher::IntMap;
use std::collections::VecDeque;
use std::task::Context;
use std::iter::FromIterator;
use std::task::{Context, Waker};
use std::{fmt, sync::Arc, task::Poll};

pub use stream::{Packet, State, Stream};
Expand Down Expand Up @@ -347,10 +349,11 @@ struct Active<T> {
config: Arc<Config>,
socket: Fuse<frame::Io<T>>,
next_id: u32,

streams: IntMap<StreamId, Stream>,
stream_sender: mpsc::Sender<StreamCommand>,
stream_receiver: mpsc::Receiver<StreamCommand>,
dropped_streams: Vec<StreamId>,
stream_receivers: Vec<(StreamId, mpsc::Receiver<StreamCommand>)>,
no_streams_waker: Option<Waker>,

pending_frames: VecDeque<Frame<()>>,
}

Expand Down Expand Up @@ -408,28 +411,34 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
fn new(socket: T, cfg: Config, mode: Mode) -> Self {
let id = Id::random();
log::debug!("new connection: {} ({:?})", id, mode);
let (stream_sender, stream_receiver) = mpsc::channel(MAX_COMMAND_BACKLOG);
let socket = frame::Io::new(id, socket, cfg.max_buffer_size).fuse();
Active {
id,
mode,
config: Arc::new(cfg),
socket,
streams: IntMap::default(),
stream_sender,
stream_receiver,
stream_receivers: Vec::default(),
no_streams_waker: None,
next_id: match mode {
Mode::Client => 1,
Mode::Server => 2,
},
dropped_streams: Vec::new(),
pending_frames: VecDeque::default(),
}
}

/// Gracefully close the connection to the remote.
fn close(self) -> Closing<T> {
Closing::new(self.stream_receiver, self.pending_frames, self.socket)
Closing::new(
SelectAll::from_iter(
self.stream_receivers
.into_iter()
.map(|(_, receiver)| receiver),
),
self.pending_frames,
self.socket,
)
}

/// Cleanup all our resources.
Expand All @@ -438,13 +447,18 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
fn cleanup(mut self, error: ConnectionError) -> Cleanup {
self.drop_all_streams();

Cleanup::new(self.stream_receiver, error)
Cleanup::new(
SelectAll::from_iter(
self.stream_receivers
.into_iter()
.map(|(_, receiver)| receiver),
),
error,
)
}

fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Result<Stream>> {
loop {
self.garbage_collect();

if self.socket.poll_ready_unpin(cx).is_ready() {
if let Some(frame) = self.pending_frames.pop_front() {
self.socket.start_send_unpin(frame)?;
Expand All @@ -457,18 +471,15 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
Poll::Pending => {}
}

match self.stream_receiver.poll_next_unpin(cx) {
Poll::Ready(Some(StreamCommand::SendFrame(frame))) => {
self.on_send_frame(frame);
match self.poll_stream_receivers(cx) {
Poll::Ready(StreamCommand::SendFrame(frame)) => {
self.on_send_frame(frame.into());
continue;
}
Poll::Ready(Some(StreamCommand::CloseStream { id, ack })) => {
Poll::Ready(StreamCommand::CloseStream { id, ack }) => {
self.on_close_stream(id, ack);
continue;
}
Poll::Ready(None) => {
debug_assert!(false, "Only closed during shutdown")
}
Poll::Pending => {}
}

Expand All @@ -490,6 +501,32 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
}
}

fn poll_stream_receivers(&mut self, cx: &mut Context) -> Poll<StreamCommand> {
for i in (0..self.stream_receivers.len()).rev() {
let (id, mut receiver) = self.stream_receivers.swap_remove(i);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Polling each stream Receiver even though only one might be Poll::Ready seems wasteful. Would SelectAll not be an option? One would wrap each Receiver such that it returns the StreamId when the Receiver returns Poll::Ready(None). With that StreamId we can then call self.on_drop_stream(id).

(On consecutive polls the wrapper can return Poll::Ready(None) and thus it would be cleaned up by the SelectAll.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had a SelectAll based design before but dropped it because I couldn't detect None from the outside.

Detecting that requires another data structure for translating between the stream commands and the actual frames in connection (the wrapping you mentioned).

This also needs to work for the Cleanup and Close case.

I tried hiding it in a custom collection object but we need access to the stream IntMap upon drop so that needs even more refactoring but would be a clean design.

I can invest the time if you want but I don't think it is a quick 30min refactoring.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wrong, it was actually a super quick refactoring, now that I've seen how tokio's StreamMap works. They recommend an additional wrapper around Stream. The trick is to wrap the output of the Stream again such that you get a tuple of (key, Option<Item>) which allows you to detect closing of the stream.


match receiver.poll_next_unpin(cx) {
Poll::Ready(Some(command)) => {
self.stream_receivers.push((id, receiver));
return Poll::Ready(command);
}
Poll::Ready(None) => {
self.on_drop_stream(id);
}
Poll::Pending => {
self.stream_receivers.push((id, receiver));
}
}
}

if self.stream_receivers.is_empty() {
self.no_streams_waker = Some(cx.waker().clone());
return Poll::Pending;
}

Poll::Pending
}

fn new_outbound(&mut self) -> Result<Stream> {
if self.streams.len() >= self.config.max_num_streams {
log::error!("{}: maximum number of streams reached", self.id);
Expand All @@ -508,16 +545,11 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
self.pending_frames.push_back(frame.into());
}

let stream = {
let config = self.config.clone();
let sender = self.stream_sender.clone();
let window = self.config.receive_window;
let mut stream = Stream::new(id, self.id, config, window, DEFAULT_CREDIT, sender);
if extra_credit == 0 {
stream.set_flag(stream::Flag::Syn)
}
stream
};
let mut stream = self.make_new_stream(id, self.config.receive_window, DEFAULT_CREDIT);

if extra_credit == 0 {
stream.set_flag(stream::Flag::Syn)
}

log::debug!("{}: new outbound {} of {}", self.id, stream, self);
self.streams.insert(id, stream.clone());
Expand All @@ -541,6 +573,71 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
.push_back(Frame::close_stream(id, ack).into());
}

fn on_drop_stream(&mut self, id: StreamId) {
let stream = self.streams.remove(&id).expect("stream not found");

log::trace!("{}: removing dropped {}", self.id, stream);
let stream_id = stream.id();
let frame = {
let mut shared = stream.shared();
let frame = match shared.update_state(self.id, stream_id, State::Closed) {
// The stream was dropped without calling `poll_close`.
// We reset the stream to inform the remote of the closure.
State::Open => {
let mut header = Header::data(stream_id, 0);
header.rst();
Some(Frame::new(header))
}
// The stream was dropped without calling `poll_close`.
// We have already received a FIN from remote and send one
// back which closes the stream for good.
State::RecvClosed => {
let mut header = Header::data(stream_id, 0);
header.fin();
Some(Frame::new(header))
}
// The stream was properly closed. We either already have
// or will at some later point send our FIN frame.
thomaseizinger marked this conversation as resolved.
Show resolved Hide resolved
// The remote may be out of credit though and blocked on
// writing more data. We may need to reset the stream.
State::SendClosed => {
if self.config.window_update_mode == WindowUpdateMode::OnRead
&& shared.window == 0
{
// The remote may be waiting for a window update
// which we will never send, so reset the stream now.
let mut header = Header::data(stream_id, 0);
header.rst();
Some(Frame::new(header))
} else {
// The remote has either still credit or will be given more
// (due to an enqueued window update or because the update
// mode is `OnReceive`) or we already have inbound frames in
// the socket buffer which will be processed later. In any
// case we will reply with an RST in `Connection::on_data`
// because the stream will no longer be known.
None
}
}
// The stream was properly closed. We either already have
// or will at some later point send our FIN frame. The
// remote end has already done so in the past.
State::Closed => None,
};
if let Some(w) = shared.reader.take() {
w.wake()
}
if let Some(w) = shared.writer.take() {
w.wake()
}
frame
};
if let Some(f) = frame {
log::trace!("{}/{}: sending: {}", self.id, stream_id, f.header());
self.pending_frames.push_back(f.into());
}
}

/// Process the result of reading from the socket.
///
/// Unless `frame` is `Ok(Some(_))` we will assume the connection got closed
Expand Down Expand Up @@ -628,12 +725,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
log::error!("{}: maximum number of streams reached", self.id);
return Action::Terminate(Frame::internal_error());
}
let mut stream = {
let config = self.config.clone();
let credit = DEFAULT_CREDIT;
let sender = self.stream_sender.clone();
Stream::new(stream_id, self.id, config, credit, credit, sender)
};
let mut stream = self.make_new_stream(stream_id, DEFAULT_CREDIT, DEFAULT_CREDIT);
let mut window_update = None;
{
let mut shared = stream.shared();
Expand Down Expand Up @@ -748,15 +840,11 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
log::error!("{}: maximum number of streams reached", self.id);
return Action::Terminate(Frame::protocol_error());
}
let stream = {
let credit = frame.header().credit() + DEFAULT_CREDIT;
let config = self.config.clone();
let sender = self.stream_sender.clone();
let mut stream =
Stream::new(stream_id, self.id, config, DEFAULT_CREDIT, credit, sender);
stream.set_flag(stream::Flag::Ack);
stream
};

let credit = frame.header().credit() + DEFAULT_CREDIT;
let mut stream = self.make_new_stream(stream_id, DEFAULT_CREDIT, credit);
stream.set_flag(stream::Flag::Ack);

if is_finish {
stream
.shared()
Expand Down Expand Up @@ -821,6 +909,18 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
Action::None
}

fn make_new_stream(&mut self, id: StreamId, window: u32, credit: u32) -> Stream {
let config = self.config.clone();

let (sender, receiver) = mpsc::channel(10);
thomaseizinger marked this conversation as resolved.
Show resolved Hide resolved
self.stream_receivers.push((id, receiver));
if let Some(waker) = self.no_streams_waker.take() {
waker.wake();
}

Stream::new(id, self.id, config, window, credit, sender)
}

fn next_stream_id(&mut self) -> Result<StreamId> {
let proposed = StreamId::new(self.next_id);
self.next_id = self
Expand All @@ -844,79 +944,6 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
Mode::Server => id.is_client(),
}
}

/// Remove stale streams and create necessary messages to be sent to the remote.
fn garbage_collect(&mut self) {
let conn_id = self.id;
let win_update_mode = self.config.window_update_mode;
for stream in self.streams.values_mut() {
if stream.strong_count() > 1 {
continue;
}
log::trace!("{}: removing dropped {}", conn_id, stream);
let stream_id = stream.id();
let frame = {
let mut shared = stream.shared();
let frame = match shared.update_state(conn_id, stream_id, State::Closed) {
// The stream was dropped without calling `poll_close`.
// We reset the stream to inform the remote of the closure.
State::Open => {
let mut header = Header::data(stream_id, 0);
header.rst();
Some(Frame::new(header))
}
// The stream was dropped without calling `poll_close`.
// We have already received a FIN from remote and send one
// back which closes the stream for good.
State::RecvClosed => {
let mut header = Header::data(stream_id, 0);
header.fin();
Some(Frame::new(header))
}
// The stream was properly closed. We either already have
// or will at some later point send our FIN frame.
// The remote may be out of credit though and blocked on
// writing more data. We may need to reset the stream.
State::SendClosed => {
if win_update_mode == WindowUpdateMode::OnRead && shared.window == 0 {
// The remote may be waiting for a window update
// which we will never send, so reset the stream now.
let mut header = Header::data(stream_id, 0);
header.rst();
Some(Frame::new(header))
} else {
// The remote has either still credit or will be given more
// (due to an enqueued window update or because the update
// mode is `OnReceive`) or we already have inbound frames in
// the socket buffer which will be processed later. In any
// case we will reply with an RST in `Connection::on_data`
// because the stream will no longer be known.
None
}
}
// The stream was properly closed. We either already have
// or will at some later point send our FIN frame. The
// remote end has already done so in the past.
State::Closed => None,
};
if let Some(w) = shared.reader.take() {
w.wake()
}
if let Some(w) = shared.writer.take() {
w.wake()
}
frame
};
if let Some(f) = frame {
log::trace!("{}/{}: sending: {}", self.id, stream_id, f.header());
self.pending_frames.push_back(f.into());
}
self.dropped_streams.push(stream_id)
}
for id in self.dropped_streams.drain(..) {
self.streams.remove(&id);
}
}
}

impl<T> Active<T> {
Expand Down