Skip to content

Commit

Permalink
fix(http2): Fix race condition in client dispatcher (#3041)
Browse files Browse the repository at this point in the history
There exists a race condition in ClientTask::poll() when the request
that is sent via h2::client::send_request() is pending open. A task will
be spawned to wait for send capacity on the sendstream. Because this
same stream is also stored in the pending member of
h2::client::SendRequest the next iteration of the poll() loop can call
poll_ready() and call wait_send() on the same stream passed into the
spawned task.

Fix this by always calling poll_ready() after send_request(). If this
call to poll_ready() returns Pending save the necessary context in
ClientTask and only spawn the task that will eventually resolve to the
response after poll_ready() returns Ok.
  • Loading branch information
jfourie1 authored and seanmonstar committed Nov 7, 2022
1 parent 75aac9f commit f202230
Showing 1 changed file with 143 additions and 81 deletions.
224 changes: 143 additions & 81 deletions src/proto/h2/client.rs
Expand Up @@ -6,20 +6,23 @@ use futures_channel::{mpsc, oneshot};
use futures_util::future::{self, Either, FutureExt as _, TryFutureExt as _};
use futures_util::stream::StreamExt as _;
use h2::client::{Builder, SendRequest};
use h2::SendStream;
use http::{Method, StatusCode};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{debug, trace, warn};

use super::{ping, H2Upgraded, PipeToSendStream, SendBuf};
use crate::body::{Body, Incoming as IncomingBody};
use crate::common::time::Time;
use crate::client::dispatch::Callback;
use crate::common::{exec::Exec, task, Future, Never, Pin, Poll};
use crate::ext::Protocol;
use crate::headers;
use crate::proto::h2::UpgradedSendStream;
use crate::proto::Dispatched;
use crate::upgrade::Upgraded;
use crate::{Request, Response};
use h2::client::ResponseFuture;

type ClientRx<B> = crate::client::dispatch::Receiver<Request<B>, Response<IncomingBody>>;

Expand Down Expand Up @@ -161,6 +164,7 @@ where
executor: exec,
h2_tx,
req_rx,
fut_ctx: None,
})
}

Expand All @@ -184,6 +188,20 @@ where
}
}

struct FutCtx<B>
where
B: Body,
{
is_connect: bool,
eos: bool,
fut: ResponseFuture,
body_tx: SendStream<SendBuf<B::Data>>,
body: B,
cb: Callback<Request<B>, Response<IncomingBody>>,
}

impl<B: Body> Unpin for FutCtx<B> {}

pub(crate) struct ClientTask<B>
where
B: Body,
Expand All @@ -194,6 +212,7 @@ where
executor: Exec,
h2_tx: SendRequest<SendBuf<B::Data>>,
req_rx: ClientRx<B>,
fut_ctx: Option<FutCtx<B>>,
}

impl<B> ClientTask<B>
Expand All @@ -205,6 +224,99 @@ where
}
}

impl<B> ClientTask<B>
where
B: Body + Send + 'static,
B::Data: Send,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
{
fn poll_pipe(&mut self, f: FutCtx<B>, cx: &mut task::Context<'_>) {
let ping = self.ping.clone();
let send_stream = if !f.is_connect {
if !f.eos {
let mut pipe = Box::pin(PipeToSendStream::new(f.body, f.body_tx)).map(|res| {
if let Err(e) = res {
debug!("client request body error: {}", e);
}
});

// eagerly see if the body pipe is ready and
// can thus skip allocating in the executor
match Pin::new(&mut pipe).poll(cx) {
Poll::Ready(_) => (),
Poll::Pending => {
let conn_drop_ref = self.conn_drop_ref.clone();
// keep the ping recorder's knowledge of an
// "open stream" alive while this body is
// still sending...
let ping = ping.clone();
let pipe = pipe.map(move |x| {
drop(conn_drop_ref);
drop(ping);
x
});
// Clear send task
self.executor.execute(pipe);
}
}
}

None
} else {
Some(f.body_tx)
};

let fut = f.fut.map(move |result| match result {
Ok(res) => {
// record that we got the response headers
ping.record_non_data();

let content_length = headers::content_length_parse_all(res.headers());
if let (Some(mut send_stream), StatusCode::OK) = (send_stream, res.status()) {
if content_length.map_or(false, |len| len != 0) {
warn!("h2 connect response with non-zero body not supported");

send_stream.send_reset(h2::Reason::INTERNAL_ERROR);
return Err((
crate::Error::new_h2(h2::Reason::INTERNAL_ERROR.into()),
None,
));
}
let (parts, recv_stream) = res.into_parts();
let mut res = Response::from_parts(parts, IncomingBody::empty());

let (pending, on_upgrade) = crate::upgrade::pending();
let io = H2Upgraded {
ping,
send_stream: unsafe { UpgradedSendStream::new(send_stream) },
recv_stream,
buf: Bytes::new(),
};
let upgraded = Upgraded::new(io, Bytes::new());

pending.fulfill(upgraded);
res.extensions_mut().insert(on_upgrade);

Ok(res)
} else {
let res = res.map(|stream| {
let ping = ping.for_stream(&stream);
IncomingBody::h2(stream, content_length.into(), ping)
});
Ok(res)
}
}
Err(err) => {
ping.ensure_not_timed_out().map_err(|e| (e, None))?;

debug!("client response error: {}", err);
Err((crate::Error::new_h2(err), None))
}
});
self.executor.execute(f.cb.send_when(fut));
}
}

impl<B> Future for ClientTask<B>
where
B: Body + Send + 'static,
Expand All @@ -228,6 +340,16 @@ where
}
};

match self.fut_ctx.take() {
// If we were waiting on pending open
// continue where we left off.
Some(f) => {
self.poll_pipe(f, cx);
continue;
}
None => (),
}

match self.req_rx.poll_recv(cx) {
Poll::Ready(Some((req, cb))) => {
// check that future hasn't been canceled already
Expand All @@ -246,7 +368,6 @@ where

let is_connect = req.method() == Method::CONNECT;
let eos = body.is_end_stream();
let ping = self.ping.clone();

if is_connect {
if headers::content_length_parse_all(req.headers())
Expand Down Expand Up @@ -274,90 +395,31 @@ where
}
};

let send_stream = if !is_connect {
if !eos {
let mut pipe =
Box::pin(PipeToSendStream::new(body, body_tx)).map(|res| {
if let Err(e) = res {
debug!("client request body error: {}", e);
}
});

// eagerly see if the body pipe is ready and
// can thus skip allocating in the executor
match Pin::new(&mut pipe).poll(cx) {
Poll::Ready(_) => (),
Poll::Pending => {
let conn_drop_ref = self.conn_drop_ref.clone();
// keep the ping recorder's knowledge of an
// "open stream" alive while this body is
// still sending...
let ping = ping.clone();
let pipe = pipe.map(move |x| {
drop(conn_drop_ref);
drop(ping);
x
});
self.executor.execute(pipe);
}
}
}

None
} else {
Some(body_tx)
let f = FutCtx {
is_connect,
eos,
fut,
body_tx,
body,
cb,
};

let fut = fut.map(move |result| match result {
Ok(res) => {
// record that we got the response headers
ping.record_non_data();

let content_length = headers::content_length_parse_all(res.headers());
if let (Some(mut send_stream), StatusCode::OK) =
(send_stream, res.status())
{
if content_length.map_or(false, |len| len != 0) {
warn!("h2 connect response with non-zero body not supported");

send_stream.send_reset(h2::Reason::INTERNAL_ERROR);
return Err((
crate::Error::new_h2(h2::Reason::INTERNAL_ERROR.into()),
None,
));
}
let (parts, recv_stream) = res.into_parts();
let mut res = Response::from_parts(parts, IncomingBody::empty());

let (pending, on_upgrade) = crate::upgrade::pending();
let io = H2Upgraded {
ping,
send_stream: unsafe { UpgradedSendStream::new(send_stream) },
recv_stream,
buf: Bytes::new(),
};
let upgraded = Upgraded::new(io, Bytes::new());

pending.fulfill(upgraded);
res.extensions_mut().insert(on_upgrade);

Ok(res)
} else {
let res = res.map(|stream| {
let ping = ping.for_stream(&stream);
IncomingBody::h2(stream, content_length.into(), ping)
});
Ok(res)
}
// Check poll_ready() again.
// If the call to send_request() resulted in the new stream being pending open
// we have to wait for the open to complete before accepting new requests.
match self.h2_tx.poll_ready(cx) {
Poll::Pending => {
// Save Context
self.fut_ctx = Some(f);
return Poll::Pending;
}
Err(err) => {
ping.ensure_not_timed_out().map_err(|e| (e, None))?;

debug!("client response error: {}", err);
Err((crate::Error::new_h2(err), None))
Poll::Ready(Ok(())) => (),
Poll::Ready(Err(err)) => {
f.cb.send(Err((crate::Error::new_h2(err), None)));
continue;
}
});
self.executor.execute(cb.send_when(fut));
}
self.poll_pipe(f, cx);
continue;
}

Expand Down

0 comments on commit f202230

Please sign in to comment.