From 38a3074f67d8a0d64c127662dbc6868c58b880aa Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Wed, 14 Apr 2021 10:44:40 +0200 Subject: [PATCH 01/11] transport: Support timeouts with "grpc-timeout" header --- tests/integration_tests/Cargo.toml | 1 + tests/integration_tests/tests/timeout.rs | 94 ++++++++ tonic/Cargo.toml | 3 +- tonic/src/metadata/map.rs | 3 +- tonic/src/transport/service/connection.rs | 7 +- tonic/src/transport/service/mod.rs | 1 + tonic/src/transport/service/timeout.rs | 251 ++++++++++++++++++++++ 7 files changed, 354 insertions(+), 6 deletions(-) create mode 100644 tests/integration_tests/tests/timeout.rs create mode 100644 tonic/src/transport/service/timeout.rs diff --git a/tests/integration_tests/Cargo.toml b/tests/integration_tests/Cargo.toml index 37ef59a39..2e6836862 100644 --- a/tests/integration_tests/Cargo.toml +++ b/tests/integration_tests/Cargo.toml @@ -16,6 +16,7 @@ bytes = "1.0" [dev-dependencies] tokio = { version = "1.0", features = ["macros", "rt-multi-thread", "net"] } +tokio-stream = { version = "0.1.5", features = ["net"] } [build-dependencies] tonic-build = { path = "../../tonic-build" } diff --git a/tests/integration_tests/tests/timeout.rs b/tests/integration_tests/tests/timeout.rs new file mode 100644 index 000000000..83e457dc7 --- /dev/null +++ b/tests/integration_tests/tests/timeout.rs @@ -0,0 +1,94 @@ +use integration_tests::pb::{test_client, test_server, Input, Output}; +use std::time::Duration; +use tokio::net::TcpListener; +use tonic::{transport::Server, Code, Request, Response, Status}; + +#[tokio::test] +#[ignore] +async fn cancelation_on_timeout() { + struct Svc; + + #[tonic::async_trait] + impl test_server::Test for Svc { + async fn unary_call(&self, _req: Request) -> Result, Status> { + // Wait for a time longer than the timeout + tokio::time::sleep(Duration::from_millis(1_000)).await; + Ok(Response::new(Output {})) + } + } + + let svc = test_server::TestServer::new(Svc); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + tokio::spawn(async move { + Server::builder() + .add_service(svc) + .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) + .await + .unwrap(); + }); + + let mut client = test_client::TestClient::connect(format!("http://{}", addr)) + .await + .unwrap(); + + let mut req = Request::new(Input {}); + req.metadata_mut() + .insert("grpc-timeout", "500m".parse().unwrap()); + + let res = client.unary_call(req).await; + dbg!(&res); + + let err = res.unwrap_err(); + assert!(err.message().contains("Timeout expired")); + + // TODO(david): make this work. Will require mapping `TimeoutExpired` errors into + // `Code::Cancelled` but can't quite figure out how to do that. + assert_eq!(err.code(), Code::Cancelled); +} + +#[tokio::test] +async fn picks_the_shortest_timeout() { + struct Svc; + + #[tonic::async_trait] + impl test_server::Test for Svc { + async fn unary_call(&self, _req: Request) -> Result, Status> { + // Wait for a time longer than the timeout + tokio::time::sleep(Duration::from_secs(1)).await; + Ok(Response::new(Output {})) + } + } + + let svc = test_server::TestServer::new(Svc); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + tokio::spawn(async move { + Server::builder() + .timeout(Duration::from_millis(100)) + .add_service(svc) + .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) + .await + .unwrap(); + }); + + let mut client = test_client::TestClient::connect(format!("http://{}", addr)) + .await + .unwrap(); + + let mut req = Request::new(Input {}); + req.metadata_mut() + // 10 hours + .insert("grpc-timeout", "10H".parse().unwrap()); + + // TODO(david): for some reason this fails with "h2 protocol error: protocol error: unexpected + // internal error encountered". Seems to be happening on `master` as well. Bug? + let res = client.unary_call(req).await; + dbg!(&res); + let err = res.unwrap_err(); + assert!(err.message().contains("Timeout expired")); +} diff --git a/tonic/Cargo.toml b/tonic/Cargo.toml index c92d51106..86733e062 100644 --- a/tonic/Cargo.toml +++ b/tonic/Cargo.toml @@ -31,7 +31,8 @@ transport = [ "tokio", "tower", "tracing-futures", - "tokio/macros" + "tokio/macros", + "tokio/time", ] tls = ["transport", "tokio-rustls"] tls-roots = ["tls", "rustls-native-certs"] diff --git a/tonic/src/metadata/map.rs b/tonic/src/metadata/map.rs index 177bb3c96..b1886173d 100644 --- a/tonic/src/metadata/map.rs +++ b/tonic/src/metadata/map.rs @@ -198,11 +198,10 @@ pub struct OccupiedEntry<'a, VE: ValueEncoding> { impl MetadataMap { // Headers reserved by the gRPC protocol. - pub(crate) const GRPC_RESERVED_HEADERS: [&'static str; 8] = [ + pub(crate) const GRPC_RESERVED_HEADERS: [&'static str; 7] = [ "te", "user-agent", "content-type", - "grpc-timeout", "grpc-message", "grpc-encoding", "grpc-message-type", diff --git a/tonic/src/transport/service/connection.rs b/tonic/src/transport/service/connection.rs index ed23aac45..5b19f34ca 100644 --- a/tonic/src/transport/service/connection.rs +++ b/tonic/src/transport/service/connection.rs @@ -1,5 +1,7 @@ use super::super::BoxFuture; -use super::{layer::ServiceBuilderExt, reconnect::Reconnect, AddOrigin, UserAgent}; +use super::{ + layer::ServiceBuilderExt, reconnect::Reconnect, timeout::Timeout, AddOrigin, UserAgent, +}; use crate::{body::BoxBody, transport::Endpoint}; use http::Uri; use hyper::client::conn::Builder; @@ -14,7 +16,6 @@ use tower::load::Load; use tower::{ layer::Layer, limit::{concurrency::ConcurrencyLimitLayer, rate::RateLimitLayer}, - timeout::TimeoutLayer, util::BoxService, ServiceBuilder, ServiceExt, }; @@ -53,7 +54,7 @@ impl Connection { let stack = ServiceBuilder::new() .layer_fn(|s| AddOrigin::new(s, endpoint.uri.clone())) .layer_fn(|s| UserAgent::new(s, endpoint.user_agent.clone())) - .optional_layer(endpoint.timeout.map(TimeoutLayer::new)) + .layer_fn(|s| Timeout::new(s, endpoint.timeout)) .optional_layer(endpoint.concurrency_limit.map(ConcurrencyLimitLayer::new)) .optional_layer(endpoint.rate_limit.map(|(l, d)| RateLimitLayer::new(l, d))) .into_inner(); diff --git a/tonic/src/transport/service/mod.rs b/tonic/src/transport/service/mod.rs index eab3b40ef..bfb005684 100644 --- a/tonic/src/transport/service/mod.rs +++ b/tonic/src/transport/service/mod.rs @@ -6,6 +6,7 @@ mod io; mod layer; mod reconnect; mod router; +mod timeout; #[cfg(feature = "tls")] mod tls; mod user_agent; diff --git a/tonic/src/transport/service/timeout.rs b/tonic/src/transport/service/timeout.rs new file mode 100644 index 000000000..871c4474b --- /dev/null +++ b/tonic/src/transport/service/timeout.rs @@ -0,0 +1,251 @@ +use http::Request; +use http::{HeaderMap, HeaderValue}; +use pin_project::pin_project; +use std::{fmt, pin::Pin}; +use std::{ + future::Future, + task::{Context, Poll}, + time::Duration, +}; +use tokio::time::Sleep; +use tower_service::Service; + +#[derive(Debug, Clone)] +pub(crate) struct Timeout { + inner: S, + server_timeout: Option, +} + +impl Timeout { + pub(crate) fn new(inner: S, server_timeout: Option) -> Self { + Self { + inner, + server_timeout, + } + } +} + +impl Service> for Timeout +where + S: Service>, + S::Error: Into, +{ + type Response = S::Response; + type Error = crate::Error; + type Future = ResponseFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx).map_err(Into::into) + } + + fn call(&mut self, req: Request) -> Self::Future { + let client_timeout = try_parse_grpc_timeout(req.headers()).unwrap_or_else(|e| { + tracing::trace!("Error parsing `grpc-timeout` header {:?}", e); + None + }); + + // Use the shorter of the two durations, if either are set + let timeout_duration = match (client_timeout, self.server_timeout) { + (None, None) => None, + (Some(dur), None) => Some(dur), + (None, Some(dur)) => Some(dur), + (Some(header), Some(server)) => { + let shorter_duration = std::cmp::min(header, server); + Some(shorter_duration) + } + }; + + ResponseFuture { + inner: self.inner.call(req), + sleep: timeout_duration + .map(tokio::time::sleep) + .map(OptionPin::Some) + .unwrap_or(OptionPin::None), + } + } +} + +#[pin_project] +pub(crate) struct ResponseFuture { + #[pin] + inner: F, + #[pin] + sleep: OptionPin, +} + +impl Future for ResponseFuture +where + F: Future>, + E: Into, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + + if let Poll::Ready(result) = this.inner.poll(cx) { + return Poll::Ready(result.map_err(Into::into)); + } + + if let OptionPinProj::Some(sleep) = this.sleep.project() { + futures_util::ready!(sleep.poll(cx)); + return Poll::Ready(Err(TimeoutExpired.into())); + } + + Poll::Pending + } +} + +#[pin_project(project = OptionPinProj)] +enum OptionPin { + Some(#[pin] T), + None, +} + +const GRPC_TIMEOUT_HEADER: &str = "grpc-timeout"; + +const SECONDS_IN_HOUR: u64 = 60 * 60; +const SECONDS_IN_MINUTE: u64 = 60; + +/// Tries to parse the `grpc-timeout` header if it is present. If we fail to parse, returns +/// the value we attempted to parse. +/// +/// Follows the [gRPC over HTTP2 spec](https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md). +fn try_parse_grpc_timeout( + headers: &HeaderMap, +) -> Result, &HeaderValue> { + match headers.get(GRPC_TIMEOUT_HEADER) { + Some(val) => { + let (timeout_value, timeout_unit) = val + .to_str() + .map_err(|_| val) + .and_then(|s| if s.is_empty() { Err(val) } else { Ok(s) })? + // `HeaderValue::to_str` only returns `Ok` if the header contains ASCII so this + // `split_at` will never panic from trying to split in the middle of a character. + // See https://docs.rs/http/0.2.4/http/header/struct.HeaderValue.html#method.to_str + // + // `len - 1` also wont panic since we just checked `s.is_empty`. + .split_at(val.len() - 1); + + // gRPC spec specifies `TimeoutValue` will be at most 8 digits + // Caping this at 8 digits also prevents integer overflow from ever occurring + if timeout_value.len() > 8 { + return Err(val); + } + + let timeout_value: u64 = timeout_value.parse().map_err(|_| val)?; + + let duration = match timeout_unit { + // Hours + "H" => Duration::from_secs(timeout_value * SECONDS_IN_HOUR), + // Minutes + "M" => Duration::from_secs(timeout_value * SECONDS_IN_MINUTE), + // Seconds + "S" => Duration::from_secs(timeout_value), + // Milliseconds + "m" => Duration::from_millis(timeout_value), + // Microseconds + "u" => Duration::from_micros(timeout_value), + // Nanoseconds + "n" => Duration::from_nanos(timeout_value), + _ => return Err(val), + }; + + Ok(Some(duration)) + } + None => Ok(None), + } +} + +// Note: The wrapped Duration should only be used for logging purposes. It is **not** the +// actual duration that elapsed, resulting in a timeout, instead it is a close approximation +#[derive(Debug)] +struct TimeoutExpired; + +impl fmt::Display for TimeoutExpired { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Timeout expired") + } +} + +// std::error::Error only requires a type to impl Debug and Display +impl std::error::Error for TimeoutExpired {} + +#[cfg(test)] +mod tests { + use super::*; + + // Helper function to reduce the boiler plate of our test cases + fn setup_map_try_parse(val: Option<&'static str>) -> Result, HeaderValue> { + let mut hm = HeaderMap::new(); + if let Some(v) = val { + let hv = HeaderValue::from_static(v); + hm.insert(GRPC_TIMEOUT_HEADER, hv); + }; + + try_parse_grpc_timeout(&hm).map_err(|e| e.clone()) + } + + #[test] + fn test_hours() { + let parsed_duration = setup_map_try_parse(Some("3H")).unwrap().unwrap(); + assert_eq!(Duration::from_secs(3 * 60 * 60), parsed_duration); + } + + #[test] + fn test_minutes() { + let parsed_duration = setup_map_try_parse(Some("1M")).unwrap().unwrap(); + assert_eq!(Duration::from_secs(60), parsed_duration); + } + + #[test] + fn test_seconds() { + let parsed_duration = setup_map_try_parse(Some("42S")).unwrap().unwrap(); + assert_eq!(Duration::from_secs(42), parsed_duration); + } + + #[test] + fn test_milliseconds() { + let parsed_duration = setup_map_try_parse(Some("13m")).unwrap().unwrap(); + assert_eq!(Duration::from_millis(13), parsed_duration); + } + + #[test] + fn test_microseconds() { + let parsed_duration = setup_map_try_parse(Some("2u")).unwrap().unwrap(); + assert_eq!(Duration::from_micros(2), parsed_duration); + } + + #[test] + fn test_nanoseconds() { + let parsed_duration = setup_map_try_parse(Some("82n")).unwrap().unwrap(); + assert_eq!(Duration::from_nanos(82), parsed_duration); + } + + #[test] + fn test_header_not_present() { + let parsed_duration = setup_map_try_parse(None).unwrap(); + assert!(parsed_duration.is_none()); + } + + #[test] + #[should_panic(expected = "82f")] + fn test_invalid_unit() { + // "f" is not a valid TimeoutUnit + setup_map_try_parse(Some("82f")).unwrap().unwrap(); + } + + #[test] + #[should_panic(expected = "123456789H")] + fn test_too_many_digits() { + // gRPC spec states TimeoutValue will be at most 8 digits + setup_map_try_parse(Some("123456789H")).unwrap().unwrap(); + } + + #[test] + #[should_panic(expected = "oneH")] + fn test_invalid_digits() { + // gRPC spec states TimeoutValue will be at most 8 digits + setup_map_try_parse(Some("oneH")).unwrap().unwrap(); + } +} From 83098f93b8cb0143c8fcffea9847f55eb942eb87 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Tue, 27 Apr 2021 20:47:18 +0200 Subject: [PATCH 02/11] Apply suggestions from code review Co-authored-by: Lucio Franco --- tonic/src/transport/service/timeout.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tonic/src/transport/service/timeout.rs b/tonic/src/transport/service/timeout.rs index 871c4474b..4e5000b35 100644 --- a/tonic/src/transport/service/timeout.rs +++ b/tonic/src/transport/service/timeout.rs @@ -11,7 +11,7 @@ use tokio::time::Sleep; use tower_service::Service; #[derive(Debug, Clone)] -pub(crate) struct Timeout { +pub(crate) struct GrpcTimeout { inner: S, server_timeout: Option, } @@ -40,7 +40,7 @@ where fn call(&mut self, req: Request) -> Self::Future { let client_timeout = try_parse_grpc_timeout(req.headers()).unwrap_or_else(|e| { - tracing::trace!("Error parsing `grpc-timeout` header {:?}", e); + tracing::trace!("Error parsing `grpc-timeout` header {}", e); None }); @@ -160,7 +160,7 @@ fn try_parse_grpc_timeout( // Note: The wrapped Duration should only be used for logging purposes. It is **not** the // actual duration that elapsed, resulting in a timeout, instead it is a close approximation #[derive(Debug)] -struct TimeoutExpired; +struct TimeoutExpired(()); impl fmt::Display for TimeoutExpired { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { From bc177d3eae3564a712def5af634ca90c3221ed2e Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Tue, 27 Apr 2021 21:05:09 +0200 Subject: [PATCH 03/11] Timeout -> GrpcTimeout and export TimeoutExpired --- tonic/src/transport/mod.rs | 2 ++ tonic/src/transport/service/connection.rs | 4 ++-- .../service/{timeout.rs => grpc_timeout.rs} | 20 ++++++++++++------- tonic/src/transport/service/mod.rs | 4 +++- 4 files changed, 20 insertions(+), 10 deletions(-) rename tonic/src/transport/service/{timeout.rs => grpc_timeout.rs} (92%) diff --git a/tonic/src/transport/mod.rs b/tonic/src/transport/mod.rs index b91767aa3..6cc4c2890 100644 --- a/tonic/src/transport/mod.rs +++ b/tonic/src/transport/mod.rs @@ -98,6 +98,8 @@ pub use self::channel::{Channel, Endpoint}; pub use self::error::Error; #[doc(inline)] pub use self::server::{NamedService, Server}; +#[doc(inline)] +pub use self::service::TimeoutExpired; pub use self::tls::{Certificate, Identity}; pub use hyper::{Body, Uri}; diff --git a/tonic/src/transport/service/connection.rs b/tonic/src/transport/service/connection.rs index 5b19f34ca..7a3790b02 100644 --- a/tonic/src/transport/service/connection.rs +++ b/tonic/src/transport/service/connection.rs @@ -1,6 +1,6 @@ use super::super::BoxFuture; use super::{ - layer::ServiceBuilderExt, reconnect::Reconnect, timeout::Timeout, AddOrigin, UserAgent, + layer::ServiceBuilderExt, reconnect::Reconnect, grpc_timeout::GrpcTimeout, AddOrigin, UserAgent, }; use crate::{body::BoxBody, transport::Endpoint}; use http::Uri; @@ -54,7 +54,7 @@ impl Connection { let stack = ServiceBuilder::new() .layer_fn(|s| AddOrigin::new(s, endpoint.uri.clone())) .layer_fn(|s| UserAgent::new(s, endpoint.user_agent.clone())) - .layer_fn(|s| Timeout::new(s, endpoint.timeout)) + .layer_fn(|s| GrpcTimeout::new(s, endpoint.timeout)) .optional_layer(endpoint.concurrency_limit.map(ConcurrencyLimitLayer::new)) .optional_layer(endpoint.rate_limit.map(|(l, d)| RateLimitLayer::new(l, d))) .into_inner(); diff --git a/tonic/src/transport/service/timeout.rs b/tonic/src/transport/service/grpc_timeout.rs similarity index 92% rename from tonic/src/transport/service/timeout.rs rename to tonic/src/transport/service/grpc_timeout.rs index 4e5000b35..009372f92 100644 --- a/tonic/src/transport/service/timeout.rs +++ b/tonic/src/transport/service/grpc_timeout.rs @@ -16,7 +16,7 @@ pub(crate) struct GrpcTimeout { server_timeout: Option, } -impl Timeout { +impl GrpcTimeout { pub(crate) fn new(inner: S, server_timeout: Option) -> Self { Self { inner, @@ -25,7 +25,7 @@ impl Timeout { } } -impl Service> for Timeout +impl Service> for GrpcTimeout where S: Service>, S::Error: Into, @@ -40,7 +40,7 @@ where fn call(&mut self, req: Request) -> Self::Future { let client_timeout = try_parse_grpc_timeout(req.headers()).unwrap_or_else(|e| { - tracing::trace!("Error parsing `grpc-timeout` header {}", e); + tracing::trace!("Error parsing `grpc-timeout` header {:?}", e); None }); @@ -89,7 +89,7 @@ where if let OptionPinProj::Some(sleep) = this.sleep.project() { futures_util::ready!(sleep.poll(cx)); - return Poll::Ready(Err(TimeoutExpired.into())); + return Poll::Ready(Err(TimeoutExpired(()).into())); } Poll::Pending @@ -157,10 +157,16 @@ fn try_parse_grpc_timeout( } } -// Note: The wrapped Duration should only be used for logging purposes. It is **not** the -// actual duration that elapsed, resulting in a timeout, instead it is a close approximation +/// Error returned if a request didn't complete within the configured timeout. +/// +/// Timeouts can be configured either with [`Endpoint::timeout`], [`Server::timeout`], or by +/// setting the [`grpc-timeout` metadata value][spec]. +/// +/// [`Endpoint::timeout`]: crate::transport::server::Server::timeout +/// [`Server::timeout`]: crate::transport::channel::Endpoint::timeout +/// [spec]: https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md #[derive(Debug)] -struct TimeoutExpired(()); +pub struct TimeoutExpired(()); impl fmt::Display for TimeoutExpired { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { diff --git a/tonic/src/transport/service/mod.rs b/tonic/src/transport/service/mod.rs index bfb005684..aa2be006e 100644 --- a/tonic/src/transport/service/mod.rs +++ b/tonic/src/transport/service/mod.rs @@ -2,11 +2,11 @@ mod add_origin; mod connection; mod connector; mod discover; +mod grpc_timeout; mod io; mod layer; mod reconnect; mod router; -mod timeout; #[cfg(feature = "tls")] mod tls; mod user_agent; @@ -21,3 +21,5 @@ pub(crate) use self::router::{Or, Routes}; #[cfg(feature = "tls")] pub(crate) use self::tls::{TlsAcceptor, TlsConnector}; pub(crate) use self::user_agent::UserAgent; + +pub use self::grpc_timeout::TimeoutExpired; From 1d893161d60c03101de9e44ada8450b864e0805d Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Tue, 27 Apr 2021 21:13:04 +0200 Subject: [PATCH 04/11] Clean up imports --- tonic/src/transport/service/grpc_timeout.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tonic/src/transport/service/grpc_timeout.rs b/tonic/src/transport/service/grpc_timeout.rs index 009372f92..12842b34d 100644 --- a/tonic/src/transport/service/grpc_timeout.rs +++ b/tonic/src/transport/service/grpc_timeout.rs @@ -1,9 +1,9 @@ -use http::Request; -use http::{HeaderMap, HeaderValue}; +use http::{HeaderMap, HeaderValue, Request}; use pin_project::pin_project; -use std::{fmt, pin::Pin}; use std::{ + fmt, future::Future, + pin::Pin, task::{Context, Poll}, time::Duration, }; From 5602072a77340bdc784638c9861f3ec81233ad24 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Tue, 27 Apr 2021 21:15:04 +0200 Subject: [PATCH 05/11] Give header name a more proper home --- tonic/src/metadata/map.rs | 2 ++ tonic/src/metadata/mod.rs | 2 ++ tonic/src/transport/service/connection.rs | 2 +- tonic/src/transport/service/grpc_timeout.rs | 3 +-- 4 files changed, 6 insertions(+), 3 deletions(-) diff --git a/tonic/src/metadata/map.rs b/tonic/src/metadata/map.rs index b1886173d..cbef3e101 100644 --- a/tonic/src/metadata/map.rs +++ b/tonic/src/metadata/map.rs @@ -194,6 +194,8 @@ pub struct OccupiedEntry<'a, VE: ValueEncoding> { phantom: PhantomData, } +pub(crate) const GRPC_TIMEOUT_HEADER: &str = "grpc-timeout"; + // ===== impl MetadataMap ===== impl MetadataMap { diff --git a/tonic/src/metadata/mod.rs b/tonic/src/metadata/mod.rs index 8389681b9..4e796748f 100644 --- a/tonic/src/metadata/mod.rs +++ b/tonic/src/metadata/mod.rs @@ -29,6 +29,8 @@ pub use self::value::AsciiMetadataValue; pub use self::value::BinaryMetadataValue; pub use self::value::MetadataValue; +pub(crate) use self::map::GRPC_TIMEOUT_HEADER; + /// The metadata::errors module contains types for errors that can occur /// while handling gRPC custom metadata. pub mod errors { diff --git a/tonic/src/transport/service/connection.rs b/tonic/src/transport/service/connection.rs index 7a3790b02..5ea711f4f 100644 --- a/tonic/src/transport/service/connection.rs +++ b/tonic/src/transport/service/connection.rs @@ -1,6 +1,6 @@ use super::super::BoxFuture; use super::{ - layer::ServiceBuilderExt, reconnect::Reconnect, grpc_timeout::GrpcTimeout, AddOrigin, UserAgent, + grpc_timeout::GrpcTimeout, layer::ServiceBuilderExt, reconnect::Reconnect, AddOrigin, UserAgent, }; use crate::{body::BoxBody, transport::Endpoint}; use http::Uri; diff --git a/tonic/src/transport/service/grpc_timeout.rs b/tonic/src/transport/service/grpc_timeout.rs index 12842b34d..f48c191b8 100644 --- a/tonic/src/transport/service/grpc_timeout.rs +++ b/tonic/src/transport/service/grpc_timeout.rs @@ -1,3 +1,4 @@ +use crate::metadata::GRPC_TIMEOUT_HEADER; use http::{HeaderMap, HeaderValue, Request}; use pin_project::pin_project; use std::{ @@ -102,8 +103,6 @@ enum OptionPin { None, } -const GRPC_TIMEOUT_HEADER: &str = "grpc-timeout"; - const SECONDS_IN_HOUR: u64 = 60 * 60; const SECONDS_IN_MINUTE: u64 = 60; From 3452d3077653f9b8c635476848dd3a0ea8f07292 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Tue, 27 Apr 2021 21:45:10 +0200 Subject: [PATCH 06/11] Add fuzz tests for parsing header value into `grpc-timeout` --- .github/workflows/CI.yml | 2 + tonic/Cargo.toml | 2 + tonic/src/transport/service/grpc_timeout.rs | 41 ++++++++++++++++++++- 3 files changed, 43 insertions(+), 2 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 12e225b49..b14a0f1d9 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -49,6 +49,8 @@ jobs: env: RUSTFLAGS: "-D warnings" + # run a lot of quickcheck iterations + QUICKCHECK_TESTS: 1000 steps: - uses: hecrj/setup-rust-action@master diff --git a/tonic/Cargo.toml b/tonic/Cargo.toml index 86733e062..0d417e1e6 100644 --- a/tonic/Cargo.toml +++ b/tonic/Cargo.toml @@ -81,6 +81,8 @@ tokio = { version = "1.0", features = ["rt", "macros"] } static_assertions = "1.0" rand = "0.8" bencher = "0.1.5" +quickcheck = "1.0" +quickcheck_macros = "1.0" [package.metadata.docs.rs] all-features = true diff --git a/tonic/src/transport/service/grpc_timeout.rs b/tonic/src/transport/service/grpc_timeout.rs index f48c191b8..580addbac 100644 --- a/tonic/src/transport/service/grpc_timeout.rs +++ b/tonic/src/transport/service/grpc_timeout.rs @@ -179,12 +179,14 @@ impl std::error::Error for TimeoutExpired {} #[cfg(test)] mod tests { use super::*; + use quickcheck::{Arbitrary, Gen}; + use quickcheck_macros::quickcheck; // Helper function to reduce the boiler plate of our test cases - fn setup_map_try_parse(val: Option<&'static str>) -> Result, HeaderValue> { + fn setup_map_try_parse(val: Option<&str>) -> Result, HeaderValue> { let mut hm = HeaderMap::new(); if let Some(v) = val { - let hv = HeaderValue::from_static(v); + let hv = HeaderValue::from_str(v).unwrap(); hm.insert(GRPC_TIMEOUT_HEADER, hv); }; @@ -253,4 +255,39 @@ mod tests { // gRPC spec states TimeoutValue will be at most 8 digits setup_map_try_parse(Some("oneH")).unwrap().unwrap(); } + + #[quickcheck] + fn fuzz(header_value: HeaderValueGen) -> bool { + let header_value = header_value.0; + + // this just shouldn't panic + let _ = setup_map_try_parse(Some(&header_value)); + + true + } + + /// Newtype to implement `Arbitrary` for generating `String`s that are valid `HeaderValue`s. + #[derive(Clone, Debug)] + struct HeaderValueGen(String); + + impl Arbitrary for HeaderValueGen { + fn arbitrary(g: &mut Gen) -> Self { + let max = g.choose(&(1..70).collect::>()).copied().unwrap(); + Self(gen_string(g, 0, max)) + } + } + + // copied from https://github.com/hyperium/http/blob/master/tests/header_map_fuzz.rs + fn gen_string(g: &mut Gen, min: usize, max: usize) -> String { + let bytes: Vec<_> = (min..max) + .map(|_| { + // Chars to pick from + g.choose(b"ABCDEFGHIJKLMNOPQRSTUVabcdefghilpqrstuvwxyz----") + .copied() + .unwrap() + }) + .collect(); + + String::from_utf8(bytes).unwrap() + } } From 3c5d8c57926ba47d9845c814a06fdad55267cd60 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Wed, 28 Apr 2021 09:19:18 +0200 Subject: [PATCH 07/11] Map `TimeoutExpired` to `cancelled` status --- tests/integration_tests/tests/timeout.rs | 5 ----- tonic/src/status.rs | 4 ++++ 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/integration_tests/tests/timeout.rs b/tests/integration_tests/tests/timeout.rs index 83e457dc7..2ceca4abd 100644 --- a/tests/integration_tests/tests/timeout.rs +++ b/tests/integration_tests/tests/timeout.rs @@ -4,7 +4,6 @@ use tokio::net::TcpListener; use tonic::{transport::Server, Code, Request, Response, Status}; #[tokio::test] -#[ignore] async fn cancelation_on_timeout() { struct Svc; @@ -39,13 +38,9 @@ async fn cancelation_on_timeout() { .insert("grpc-timeout", "500m".parse().unwrap()); let res = client.unary_call(req).await; - dbg!(&res); let err = res.unwrap_err(); assert!(err.message().contains("Timeout expired")); - - // TODO(david): make this work. Will require mapping `TimeoutExpired` errors into - // `Code::Cancelled` but can't quite figure out how to do that. assert_eq!(err.code(), Code::Cancelled); } diff --git a/tonic/src/status.rs b/tonic/src/status.rs index dcf381d5b..f1d735c9d 100644 --- a/tonic/src/status.rs +++ b/tonic/src/status.rs @@ -331,6 +331,10 @@ impl Status { if let Some(h2) = err.downcast_ref::() { return Some(Status::from_h2_error(h2)); } + + if let Some(timeout) = err.downcast_ref::() { + return Some(Status::cancelled(timeout.to_string())); + } } cause = err.source(); From a7e8f61e10deb82d3888f404c8b5c1f513b2606c Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Wed, 28 Apr 2021 11:37:20 +0200 Subject: [PATCH 08/11] Recover from timeout errors in the service --- tests/integration_tests/tests/timeout.rs | 48 +++++++++++-- tonic/Cargo.toml | 2 +- tonic/src/status.rs | 2 +- tonic/src/transport/server/mod.rs | 12 ++-- tonic/src/transport/server/recover_error.rs | 75 +++++++++++++++++++++ tonic/src/transport/service/mod.rs | 1 + 6 files changed, 128 insertions(+), 12 deletions(-) create mode 100644 tonic/src/transport/server/recover_error.rs diff --git a/tests/integration_tests/tests/timeout.rs b/tests/integration_tests/tests/timeout.rs index 2ceca4abd..7a4b94e6b 100644 --- a/tests/integration_tests/tests/timeout.rs +++ b/tests/integration_tests/tests/timeout.rs @@ -45,7 +45,7 @@ async fn cancelation_on_timeout() { } #[tokio::test] -async fn picks_the_shortest_timeout() { +async fn picks_server_timeout_if_thats_sorter() { struct Svc; #[tonic::async_trait] @@ -80,10 +80,50 @@ async fn picks_the_shortest_timeout() { // 10 hours .insert("grpc-timeout", "10H".parse().unwrap()); - // TODO(david): for some reason this fails with "h2 protocol error: protocol error: unexpected - // internal error encountered". Seems to be happening on `master` as well. Bug? let res = client.unary_call(req).await; - dbg!(&res); let err = res.unwrap_err(); assert!(err.message().contains("Timeout expired")); + assert_eq!(err.code(), Code::Cancelled); +} + +#[tokio::test] +async fn picks_client_timeout_if_thats_sorter() { + struct Svc; + + #[tonic::async_trait] + impl test_server::Test for Svc { + async fn unary_call(&self, _req: Request) -> Result, Status> { + // Wait for a time longer than the timeout + tokio::time::sleep(Duration::from_secs(1)).await; + Ok(Response::new(Output {})) + } + } + + let svc = test_server::TestServer::new(Svc); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + tokio::spawn(async move { + Server::builder() + .timeout(Duration::from_secs(9001)) + .add_service(svc) + .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) + .await + .unwrap(); + }); + + let mut client = test_client::TestClient::connect(format!("http://{}", addr)) + .await + .unwrap(); + + let mut req = Request::new(Input {}); + req.metadata_mut() + // 100 ms + .insert("grpc-timeout", "100m".parse().unwrap()); + + let res = client.unary_call(req).await; + let err = res.unwrap_err(); + assert!(err.message().contains("Timeout expired")); + assert_eq!(err.code(), Code::Cancelled); } diff --git a/tonic/Cargo.toml b/tonic/Cargo.toml index 0d417e1e6..5e494d5c9 100644 --- a/tonic/Cargo.toml +++ b/tonic/Cargo.toml @@ -69,7 +69,7 @@ h2 = { version = "0.3", optional = true } hyper = { version = "0.14.2", features = ["full"], optional = true } tokio = { version = "1.0.1", features = ["net"], optional = true } tokio-stream = "0.1" -tower = { version = "0.4.4", features = ["balance", "buffer", "discover", "limit", "load", "make", "timeout", "util"], optional = true } +tower = { version = "0.4.7", features = ["balance", "buffer", "discover", "limit", "load", "make", "timeout", "util"], optional = true } tracing-futures = { version = "0.2", optional = true } # rustls diff --git a/tonic/src/status.rs b/tonic/src/status.rs index f1d735c9d..071e1f9cd 100644 --- a/tonic/src/status.rs +++ b/tonic/src/status.rs @@ -313,7 +313,7 @@ impl Status { Status::try_from_error(err).unwrap_or_else(|| Status::new(Code::Unknown, err.to_string())) } - fn try_from_error(err: &(dyn Error + 'static)) -> Option { + pub(crate) fn try_from_error(err: &(dyn Error + 'static)) -> Option { let mut cause = Some(err); while let Some(err) = cause { diff --git a/tonic/src/transport/server/mod.rs b/tonic/src/transport/server/mod.rs index 6f2729d09..c2ad4d5ef 100644 --- a/tonic/src/transport/server/mod.rs +++ b/tonic/src/transport/server/mod.rs @@ -2,6 +2,7 @@ mod conn; mod incoming; +mod recover_error; #[cfg(feature = "tls")] #[cfg_attr(docsrs, doc(cfg(feature = "tls")))] mod tls; @@ -21,8 +22,9 @@ pub(crate) use tokio_rustls::server::TlsStream; #[cfg(feature = "tls")] use crate::transport::Error; +use self::recover_error::RecoverError; use super::{ - service::{Or, Routes, ServerIo}, + service::{GrpcTimeout, Or, Routes, ServerIo}, BoxFuture, }; use crate::{body::BoxBody, request::ConnectionInfo}; @@ -42,10 +44,7 @@ use std::{ time::Duration, }; use tokio::io::{AsyncRead, AsyncWrite}; -use tower::{ - limit::concurrency::ConcurrencyLimitLayer, timeout::TimeoutLayer, util::Either, Service, - ServiceBuilder, -}; +use tower::{limit::concurrency::ConcurrencyLimitLayer, util::Either, Service, ServiceBuilder}; use tracing_futures::{Instrument, Instrumented}; type BoxService = tower::util::BoxService, Response, crate::Error>; @@ -655,8 +654,9 @@ where Box::pin(async move { let svc = ServiceBuilder::new() + .layer_fn(RecoverError::new) .option_layer(concurrency_limit.map(ConcurrencyLimitLayer::new)) - .option_layer(timeout.map(TimeoutLayer::new)) + .layer_fn(|s| GrpcTimeout::new(s, timeout)) .service(svc); let svc = BoxService::new(Svc { diff --git a/tonic/src/transport/server/recover_error.rs b/tonic/src/transport/server/recover_error.rs new file mode 100644 index 000000000..9b4ff2e67 --- /dev/null +++ b/tonic/src/transport/server/recover_error.rs @@ -0,0 +1,75 @@ +use crate::{body::BoxBody, Status}; +use futures_util::ready; +use http::Response; +use pin_project::pin_project; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; +use tower::Service; + +/// Middleware that attempts to recover from service errors by turning them into a response built +/// from the `Status`. +#[derive(Debug, Clone)] +pub(crate) struct RecoverError { + inner: S, +} + +impl RecoverError { + pub(crate) fn new(inner: S) -> Self { + Self { inner } + } +} + +impl Service for RecoverError +where + S: Service>, + S::Error: Into, +{ + type Response = Response; + type Error = crate::Error; + type Future = ResponseFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx).map_err(Into::into) + } + + fn call(&mut self, req: R) -> Self::Future { + ResponseFuture { + inner: self.inner.call(req), + } + } +} + +#[pin_project] +pub(crate) struct ResponseFuture { + #[pin] + inner: F, +} + +impl Future for ResponseFuture +where + F: Future, E>>, + E: Into, +{ + type Output = Result, crate::Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let result: Result, crate::Error> = + ready!(self.project().inner.poll(cx)).map_err(Into::into); + + match result { + Ok(res) => Poll::Ready(Ok(res)), + Err(err) => { + if let Some(status) = Status::try_from_error(&*err) { + let mut res = Response::new(BoxBody::empty()); + status.add_header(res.headers_mut()).unwrap(); + Poll::Ready(Ok(res)) + } else { + Poll::Ready(Err(err)) + } + } + } + } +} diff --git a/tonic/src/transport/service/mod.rs b/tonic/src/transport/service/mod.rs index b8386561f..4e1d89c0c 100644 --- a/tonic/src/transport/service/mod.rs +++ b/tonic/src/transport/service/mod.rs @@ -14,6 +14,7 @@ pub(crate) use self::add_origin::AddOrigin; pub(crate) use self::connection::Connection; pub(crate) use self::connector::connector; pub(crate) use self::discover::DynamicServiceStream; +pub(crate) use self::grpc_timeout::GrpcTimeout; pub(crate) use self::io::ServerIo; pub(crate) use self::router::{Or, Routes}; #[cfg(feature = "tls")] From 19132a276988287f6a909bf1222bb5f76c7ac740 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Wed, 28 Apr 2021 11:44:06 +0200 Subject: [PATCH 09/11] Refactor tests --- tests/integration_tests/tests/timeout.rs | 95 ++++++++---------------- 1 file changed, 29 insertions(+), 66 deletions(-) diff --git a/tests/integration_tests/tests/timeout.rs b/tests/integration_tests/tests/timeout.rs index 7a4b94e6b..450a67d21 100644 --- a/tests/integration_tests/tests/timeout.rs +++ b/tests/integration_tests/tests/timeout.rs @@ -1,33 +1,11 @@ use integration_tests::pb::{test_client, test_server, Input, Output}; -use std::time::Duration; +use std::{net::SocketAddr, time::Duration}; use tokio::net::TcpListener; use tonic::{transport::Server, Code, Request, Response, Status}; #[tokio::test] async fn cancelation_on_timeout() { - struct Svc; - - #[tonic::async_trait] - impl test_server::Test for Svc { - async fn unary_call(&self, _req: Request) -> Result, Status> { - // Wait for a time longer than the timeout - tokio::time::sleep(Duration::from_millis(1_000)).await; - Ok(Response::new(Output {})) - } - } - - let svc = test_server::TestServer::new(Svc); - - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); - - tokio::spawn(async move { - Server::builder() - .add_service(svc) - .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) - .await - .unwrap(); - }); + let addr = run_service_in_background(Duration::from_secs(1), Duration::from_secs(100)).await; let mut client = test_client::TestClient::connect(format!("http://{}", addr)) .await @@ -35,6 +13,7 @@ async fn cancelation_on_timeout() { let mut req = Request::new(Input {}); req.metadata_mut() + // 500 ms .insert("grpc-timeout", "500m".parse().unwrap()); let res = client.unary_call(req).await; @@ -46,30 +25,7 @@ async fn cancelation_on_timeout() { #[tokio::test] async fn picks_server_timeout_if_thats_sorter() { - struct Svc; - - #[tonic::async_trait] - impl test_server::Test for Svc { - async fn unary_call(&self, _req: Request) -> Result, Status> { - // Wait for a time longer than the timeout - tokio::time::sleep(Duration::from_secs(1)).await; - Ok(Response::new(Output {})) - } - } - - let svc = test_server::TestServer::new(Svc); - - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); - - tokio::spawn(async move { - Server::builder() - .timeout(Duration::from_millis(100)) - .add_service(svc) - .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) - .await - .unwrap(); - }); + let addr = run_service_in_background(Duration::from_secs(1), Duration::from_millis(100)).await; let mut client = test_client::TestClient::connect(format!("http://{}", addr)) .await @@ -88,42 +44,49 @@ async fn picks_server_timeout_if_thats_sorter() { #[tokio::test] async fn picks_client_timeout_if_thats_sorter() { - struct Svc; + let addr = run_service_in_background(Duration::from_secs(1), Duration::from_secs(100)).await; + + let mut client = test_client::TestClient::connect(format!("http://{}", addr)) + .await + .unwrap(); + + let mut req = Request::new(Input {}); + req.metadata_mut() + // 100 ms + .insert("grpc-timeout", "100m".parse().unwrap()); + + let res = client.unary_call(req).await; + let err = res.unwrap_err(); + assert!(err.message().contains("Timeout expired")); + assert_eq!(err.code(), Code::Cancelled); +} + +async fn run_service_in_background(latency: Duration, server_timeout: Duration) -> SocketAddr { + struct Svc { + latency: Duration, + } #[tonic::async_trait] impl test_server::Test for Svc { async fn unary_call(&self, _req: Request) -> Result, Status> { - // Wait for a time longer than the timeout - tokio::time::sleep(Duration::from_secs(1)).await; + tokio::time::sleep(self.latency).await; Ok(Response::new(Output {})) } } - let svc = test_server::TestServer::new(Svc); + let svc = test_server::TestServer::new(Svc { latency }); let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); tokio::spawn(async move { Server::builder() - .timeout(Duration::from_secs(9001)) + .timeout(server_timeout) .add_service(svc) .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) .await .unwrap(); }); - let mut client = test_client::TestClient::connect(format!("http://{}", addr)) - .await - .unwrap(); - - let mut req = Request::new(Input {}); - req.metadata_mut() - // 100 ms - .insert("grpc-timeout", "100m".parse().unwrap()); - - let res = client.unary_call(req).await; - let err = res.unwrap_err(); - assert!(err.message().contains("Timeout expired")); - assert_eq!(err.code(), Code::Cancelled); + addr } From d1e153b001e02fed20d120157ce210551eaf190e Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Thu, 29 Apr 2021 09:25:23 +0200 Subject: [PATCH 10/11] Fix CI --- tonic/src/metadata/mod.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/tonic/src/metadata/mod.rs b/tonic/src/metadata/mod.rs index 4e796748f..50bfb49e4 100644 --- a/tonic/src/metadata/mod.rs +++ b/tonic/src/metadata/mod.rs @@ -29,6 +29,7 @@ pub use self::value::AsciiMetadataValue; pub use self::value::BinaryMetadataValue; pub use self::value::MetadataValue; +#[cfg(feature = "transport")] pub(crate) use self::map::GRPC_TIMEOUT_HEADER; /// The metadata::errors module contains types for errors that can occur From 08cb752b60298f7e9e2096c21d0d2aebd4d2a40c Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Thu, 29 Apr 2021 09:54:21 +0200 Subject: [PATCH 11/11] Fix CI, again --- tonic/src/metadata/map.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/tonic/src/metadata/map.rs b/tonic/src/metadata/map.rs index cbef3e101..8ddccf194 100644 --- a/tonic/src/metadata/map.rs +++ b/tonic/src/metadata/map.rs @@ -194,6 +194,7 @@ pub struct OccupiedEntry<'a, VE: ValueEncoding> { phantom: PhantomData, } +#[cfg(feature = "transport")] pub(crate) const GRPC_TIMEOUT_HEADER: &str = "grpc-timeout"; // ===== impl MetadataMap =====