From 1b2ca47a58bb609879869909a6436eeecba4e8b3 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Fri, 10 Mar 2023 21:16:02 -0500 Subject: [PATCH] Support body health detection --- Cargo.toml | 3 + src/proto/h1/dispatch.rs | 20 +++++- src/proto/h2/mod.rs | 29 ++++---- tests/server.rs | 145 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 182 insertions(+), 15 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 16d8e585be..5ecad99cae 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -219,3 +219,6 @@ required-features = ["full"] name = "server" path = "tests/server.rs" required-features = ["full"] + +[patch.crates-io] +http-body = { git = "https://github.com/sfackler/http-body", branch = "body-poll-alive" } diff --git a/src/proto/h1/dispatch.rs b/src/proto/h1/dispatch.rs index cd494581b9..75e6cbf02c 100644 --- a/src/proto/h1/dispatch.rs +++ b/src/proto/h1/dispatch.rs @@ -28,7 +28,8 @@ pub(crate) trait Dispatch { self: Pin<&mut Self>, cx: &mut task::Context<'_>, ) -> Poll>>; - fn recv_msg(&mut self, msg: crate::Result<(Self::RecvItem, IncomingBody)>) -> crate::Result<()>; + fn recv_msg(&mut self, msg: crate::Result<(Self::RecvItem, IncomingBody)>) + -> crate::Result<()>; fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll>; fn should_poll(&self) -> bool; } @@ -249,7 +250,8 @@ where let body = match body_len { DecodedLength::ZERO => IncomingBody::empty(), other => { - let (tx, rx) = IncomingBody::new_channel(other, wants.contains(Wants::EXPECT)); + let (tx, rx) = + IncomingBody::new_channel(other, wants.contains(Wants::EXPECT)); self.body_tx = Some(tx); rx } @@ -317,7 +319,19 @@ where return Poll::Ready(Ok(())); } } else if !self.conn.can_buffer_body() { - ready!(self.poll_flush(cx))?; + if self.poll_flush(cx)?.is_pending() { + // If we're not able to make progress, check the body health + if let (Some(body), clear_body) = + OptGuard::new(self.body_rx.as_mut()).guard_mut() + { + body.poll_healthy(cx).map_err(|e| { + *clear_body = true; + crate::Error::new_user_body(e) + })?; + } + + return Poll::Pending; + } } else { // A new scope is needed :( if let (Some(mut body), clear_body) = diff --git a/src/proto/h2/mod.rs b/src/proto/h2/mod.rs index a1cbd25813..8b869180d7 100644 --- a/src/proto/h2/mod.rs +++ b/src/proto/h2/mod.rs @@ -126,13 +126,13 @@ where if me.body_tx.capacity() == 0 { loop { - match ready!(me.body_tx.poll_capacity(cx)) { - Some(Ok(0)) => {} - Some(Ok(_)) => break, - Some(Err(e)) => { + match me.body_tx.poll_capacity(cx) { + Poll::Ready(Some(Ok(0))) => {} + Poll::Ready(Some(Ok(_))) => break, + Poll::Ready(Some(Err(e))) => { return Poll::Ready(Err(crate::Error::new_body_write(e))) } - None => { + Poll::Ready(None) => { // None means the stream is no longer in a // streaming state, we either finished it // somehow, or the remote reset us. @@ -140,6 +140,15 @@ where "send stream capacity unexpectedly closed", ))); } + Poll::Pending => { + // If we're not able to make progress, check if the body is healthy + me.stream + .as_mut() + .poll_healthy(cx) + .map_err(|e| me.body_tx.on_user_err(e))?; + + return Poll::Pending; + } } } } else if let Poll::Ready(reason) = me @@ -148,9 +157,7 @@ where .map_err(crate::Error::new_body_write)? { debug!("stream received RST_STREAM: {:?}", reason); - return Poll::Ready(Err(crate::Error::new_body_write(::h2::Error::from( - reason, - )))); + return Poll::Ready(Err(crate::Error::new_body_write(::h2::Error::from(reason)))); } match ready!(me.stream.as_mut().poll_frame(cx)) { @@ -365,14 +372,12 @@ where cx: &mut Context<'_>, ) -> Poll> { if self.send_stream.write(&[], true).is_ok() { - return Poll::Ready(Ok(())) + return Poll::Ready(Ok(())); } Poll::Ready(Err(h2_to_io_error( match ready!(self.send_stream.poll_reset(cx)) { - Ok(Reason::NO_ERROR) => { - return Poll::Ready(Ok(())) - } + Ok(Reason::NO_ERROR) => return Poll::Ready(Ok(())), Ok(Reason::CANCEL) | Ok(Reason::STREAM_CLOSED) => { return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into())) } diff --git a/tests/server.rs b/tests/server.rs index 632ce4839a..c37d3c1d4d 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -1737,6 +1737,151 @@ async fn http_connect_new() { assert_eq!(s(&vec), "bar=foo"); } +struct UnhealthyBody { + rx: oneshot::Receiver<()>, + tx: Option>, +} + +impl Body for UnhealthyBody { + type Data = Bytes; + + type Error = &'static str; + + fn poll_frame( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + Poll::Ready(Some(Ok(http_body::Frame::data(Bytes::from_static( + &[0; 1024], + ))))) + } + + fn poll_healthy(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Result<(), Self::Error> { + if Pin::new(&mut self.rx).poll(cx).is_pending() { + return Ok(()); + } + + let _ = self.tx.take().unwrap().send(()); + Err("blammo") + } +} + +#[tokio::test] +async fn h1_unhealthy_body() { + let (listener, addr) = setup_tcp_listener(); + let (unhealthy_tx, unhealthy_rx) = oneshot::channel(); + let (read_body_tx, read_body_rx) = oneshot::channel(); + + let client = tokio::spawn(async move { + let mut tcp = connect_async(addr).await; + tcp.write_all( + b"\ + GET / HTTP/1.1\r\n\ + \r\n\ + Host: localhost\r\n\ + \r\n + ", + ) + .await + .expect("write 1"); + + let mut buf = [0; 1024]; + loop { + let nread = tcp.read(&mut buf).await.expect("read 1"); + if buf[..nread].contains(&0) { + break; + } + } + + read_body_tx.send(()).unwrap(); + unhealthy_rx.await.expect("rx"); + + while tcp.read(&mut buf).await.expect("read") > 0 {} + }); + + let mut read_body_rx = Some(read_body_rx); + let mut unhealthy_tx = Some(unhealthy_tx); + let svc = service_fn(move |_: Request| { + future::ok::<_, &'static str>( + Response::builder() + .status(200) + .body(UnhealthyBody { + rx: read_body_rx.take().unwrap(), + tx: unhealthy_tx.take(), + }) + .unwrap(), + ) + }); + + let (socket, _) = listener.accept().await.unwrap(); + let err = http1::Builder::new() + .serve_connection(socket, svc) + .await + .err() + .unwrap(); + assert!(err.to_string().contains("blammo")); + + client.await.unwrap(); +} + +#[tokio::test] +async fn h2_unhealthy_body() { + let (listener, addr) = setup_tcp_listener(); + let (unhealthy_tx, unhealthy_rx) = oneshot::channel(); + let (read_body_tx, read_body_rx) = oneshot::channel(); + + let client = tokio::spawn(async move { + let tcp = connect_async(addr).await; + let (h2, connection) = h2::client::handshake(tcp).await.unwrap(); + tokio::spawn(async move { + connection.await.unwrap(); + }); + let mut h2 = h2.ready().await.unwrap(); + + let request = Request::get("/").body(()).unwrap(); + let (response, _) = h2.send_request(request, true).unwrap(); + + let mut body = response.await.unwrap().into_body(); + + let bytes = body.data().await.unwrap().unwrap(); + let _ = body.flow_control().release_capacity(bytes.len()); + + read_body_tx.send(()).unwrap(); + unhealthy_rx.await.unwrap(); + + loop { + let bytes = match body.data().await.transpose() { + Ok(Some(bytes)) => bytes, + Ok(None) => panic!(), + Err(_) => break, + }; + let _ = body.flow_control().release_capacity(bytes.len()); + } + }); + + let mut read_body_rx = Some(read_body_rx); + let mut unhealthy_tx = Some(unhealthy_tx); + let svc = service_fn(move |_: Request| { + future::ok::<_, &'static str>( + Response::builder() + .status(200) + .body(UnhealthyBody { + rx: read_body_rx.take().unwrap(), + tx: unhealthy_tx.take(), + }) + .unwrap(), + ) + }); + + let (socket, _) = listener.accept().await.unwrap(); + http2::Builder::new(TokioExecutor) + .serve_connection(socket, svc) + .await + .unwrap(); + + client.await.unwrap(); +} + #[tokio::test] async fn h2_connect() { let (listener, addr) = setup_tcp_listener();