diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 12e225b49..b14a0f1d9 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -49,6 +49,8 @@ jobs: env: RUSTFLAGS: "-D warnings" + # run a lot of quickcheck iterations + QUICKCHECK_TESTS: 1000 steps: - uses: hecrj/setup-rust-action@master diff --git a/tests/integration_tests/Cargo.toml b/tests/integration_tests/Cargo.toml index 37ef59a39..2e6836862 100644 --- a/tests/integration_tests/Cargo.toml +++ b/tests/integration_tests/Cargo.toml @@ -16,6 +16,7 @@ bytes = "1.0" [dev-dependencies] tokio = { version = "1.0", features = ["macros", "rt-multi-thread", "net"] } +tokio-stream = { version = "0.1.5", features = ["net"] } [build-dependencies] tonic-build = { path = "../../tonic-build" } diff --git a/tests/integration_tests/tests/timeout.rs b/tests/integration_tests/tests/timeout.rs new file mode 100644 index 000000000..450a67d21 --- /dev/null +++ b/tests/integration_tests/tests/timeout.rs @@ -0,0 +1,92 @@ +use integration_tests::pb::{test_client, test_server, Input, Output}; +use std::{net::SocketAddr, time::Duration}; +use tokio::net::TcpListener; +use tonic::{transport::Server, Code, Request, Response, Status}; + +#[tokio::test] +async fn cancelation_on_timeout() { + let addr = run_service_in_background(Duration::from_secs(1), Duration::from_secs(100)).await; + + let mut client = test_client::TestClient::connect(format!("http://{}", addr)) + .await + .unwrap(); + + let mut req = Request::new(Input {}); + req.metadata_mut() + // 500 ms + .insert("grpc-timeout", "500m".parse().unwrap()); + + let res = client.unary_call(req).await; + + let err = res.unwrap_err(); + assert!(err.message().contains("Timeout expired")); + assert_eq!(err.code(), Code::Cancelled); +} + +#[tokio::test] +async fn picks_server_timeout_if_thats_sorter() { + let addr = run_service_in_background(Duration::from_secs(1), Duration::from_millis(100)).await; + + let mut client = test_client::TestClient::connect(format!("http://{}", addr)) + .await + .unwrap(); + + let mut req = Request::new(Input {}); + req.metadata_mut() + // 10 hours + .insert("grpc-timeout", "10H".parse().unwrap()); + + let res = client.unary_call(req).await; + let err = res.unwrap_err(); + assert!(err.message().contains("Timeout expired")); + assert_eq!(err.code(), Code::Cancelled); +} + +#[tokio::test] +async fn picks_client_timeout_if_thats_sorter() { + let addr = run_service_in_background(Duration::from_secs(1), Duration::from_secs(100)).await; + + let mut client = test_client::TestClient::connect(format!("http://{}", addr)) + .await + .unwrap(); + + let mut req = Request::new(Input {}); + req.metadata_mut() + // 100 ms + .insert("grpc-timeout", "100m".parse().unwrap()); + + let res = client.unary_call(req).await; + let err = res.unwrap_err(); + assert!(err.message().contains("Timeout expired")); + assert_eq!(err.code(), Code::Cancelled); +} + +async fn run_service_in_background(latency: Duration, server_timeout: Duration) -> SocketAddr { + struct Svc { + latency: Duration, + } + + #[tonic::async_trait] + impl test_server::Test for Svc { + async fn unary_call(&self, _req: Request) -> Result, Status> { + tokio::time::sleep(self.latency).await; + Ok(Response::new(Output {})) + } + } + + let svc = test_server::TestServer::new(Svc { latency }); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + tokio::spawn(async move { + Server::builder() + .timeout(server_timeout) + .add_service(svc) + .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) + .await + .unwrap(); + }); + + addr +} diff --git a/tonic/Cargo.toml b/tonic/Cargo.toml index c92d51106..5e494d5c9 100644 --- a/tonic/Cargo.toml +++ b/tonic/Cargo.toml @@ -31,7 +31,8 @@ transport = [ "tokio", "tower", "tracing-futures", - "tokio/macros" + "tokio/macros", + "tokio/time", ] tls = ["transport", "tokio-rustls"] tls-roots = ["tls", "rustls-native-certs"] @@ -68,7 +69,7 @@ h2 = { version = "0.3", optional = true } hyper = { version = "0.14.2", features = ["full"], optional = true } tokio = { version = "1.0.1", features = ["net"], optional = true } tokio-stream = "0.1" -tower = { version = "0.4.4", features = ["balance", "buffer", "discover", "limit", "load", "make", "timeout", "util"], optional = true } +tower = { version = "0.4.7", features = ["balance", "buffer", "discover", "limit", "load", "make", "timeout", "util"], optional = true } tracing-futures = { version = "0.2", optional = true } # rustls @@ -80,6 +81,8 @@ tokio = { version = "1.0", features = ["rt", "macros"] } static_assertions = "1.0" rand = "0.8" bencher = "0.1.5" +quickcheck = "1.0" +quickcheck_macros = "1.0" [package.metadata.docs.rs] all-features = true diff --git a/tonic/src/metadata/map.rs b/tonic/src/metadata/map.rs index 177bb3c96..8ddccf194 100644 --- a/tonic/src/metadata/map.rs +++ b/tonic/src/metadata/map.rs @@ -194,15 +194,17 @@ pub struct OccupiedEntry<'a, VE: ValueEncoding> { phantom: PhantomData, } +#[cfg(feature = "transport")] +pub(crate) const GRPC_TIMEOUT_HEADER: &str = "grpc-timeout"; + // ===== impl MetadataMap ===== impl MetadataMap { // Headers reserved by the gRPC protocol. - pub(crate) const GRPC_RESERVED_HEADERS: [&'static str; 8] = [ + pub(crate) const GRPC_RESERVED_HEADERS: [&'static str; 7] = [ "te", "user-agent", "content-type", - "grpc-timeout", "grpc-message", "grpc-encoding", "grpc-message-type", diff --git a/tonic/src/metadata/mod.rs b/tonic/src/metadata/mod.rs index 8389681b9..50bfb49e4 100644 --- a/tonic/src/metadata/mod.rs +++ b/tonic/src/metadata/mod.rs @@ -29,6 +29,9 @@ pub use self::value::AsciiMetadataValue; pub use self::value::BinaryMetadataValue; pub use self::value::MetadataValue; +#[cfg(feature = "transport")] +pub(crate) use self::map::GRPC_TIMEOUT_HEADER; + /// The metadata::errors module contains types for errors that can occur /// while handling gRPC custom metadata. pub mod errors { diff --git a/tonic/src/status.rs b/tonic/src/status.rs index dcf381d5b..071e1f9cd 100644 --- a/tonic/src/status.rs +++ b/tonic/src/status.rs @@ -313,7 +313,7 @@ impl Status { Status::try_from_error(err).unwrap_or_else(|| Status::new(Code::Unknown, err.to_string())) } - fn try_from_error(err: &(dyn Error + 'static)) -> Option { + pub(crate) fn try_from_error(err: &(dyn Error + 'static)) -> Option { let mut cause = Some(err); while let Some(err) = cause { @@ -331,6 +331,10 @@ impl Status { if let Some(h2) = err.downcast_ref::() { return Some(Status::from_h2_error(h2)); } + + if let Some(timeout) = err.downcast_ref::() { + return Some(Status::cancelled(timeout.to_string())); + } } cause = err.source(); diff --git a/tonic/src/transport/mod.rs b/tonic/src/transport/mod.rs index b91767aa3..6cc4c2890 100644 --- a/tonic/src/transport/mod.rs +++ b/tonic/src/transport/mod.rs @@ -98,6 +98,8 @@ pub use self::channel::{Channel, Endpoint}; pub use self::error::Error; #[doc(inline)] pub use self::server::{NamedService, Server}; +#[doc(inline)] +pub use self::service::TimeoutExpired; pub use self::tls::{Certificate, Identity}; pub use hyper::{Body, Uri}; diff --git a/tonic/src/transport/server/mod.rs b/tonic/src/transport/server/mod.rs index 6f2729d09..c2ad4d5ef 100644 --- a/tonic/src/transport/server/mod.rs +++ b/tonic/src/transport/server/mod.rs @@ -2,6 +2,7 @@ mod conn; mod incoming; +mod recover_error; #[cfg(feature = "tls")] #[cfg_attr(docsrs, doc(cfg(feature = "tls")))] mod tls; @@ -21,8 +22,9 @@ pub(crate) use tokio_rustls::server::TlsStream; #[cfg(feature = "tls")] use crate::transport::Error; +use self::recover_error::RecoverError; use super::{ - service::{Or, Routes, ServerIo}, + service::{GrpcTimeout, Or, Routes, ServerIo}, BoxFuture, }; use crate::{body::BoxBody, request::ConnectionInfo}; @@ -42,10 +44,7 @@ use std::{ time::Duration, }; use tokio::io::{AsyncRead, AsyncWrite}; -use tower::{ - limit::concurrency::ConcurrencyLimitLayer, timeout::TimeoutLayer, util::Either, Service, - ServiceBuilder, -}; +use tower::{limit::concurrency::ConcurrencyLimitLayer, util::Either, Service, ServiceBuilder}; use tracing_futures::{Instrument, Instrumented}; type BoxService = tower::util::BoxService, Response, crate::Error>; @@ -655,8 +654,9 @@ where Box::pin(async move { let svc = ServiceBuilder::new() + .layer_fn(RecoverError::new) .option_layer(concurrency_limit.map(ConcurrencyLimitLayer::new)) - .option_layer(timeout.map(TimeoutLayer::new)) + .layer_fn(|s| GrpcTimeout::new(s, timeout)) .service(svc); let svc = BoxService::new(Svc { diff --git a/tonic/src/transport/server/recover_error.rs b/tonic/src/transport/server/recover_error.rs new file mode 100644 index 000000000..9b4ff2e67 --- /dev/null +++ b/tonic/src/transport/server/recover_error.rs @@ -0,0 +1,75 @@ +use crate::{body::BoxBody, Status}; +use futures_util::ready; +use http::Response; +use pin_project::pin_project; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; +use tower::Service; + +/// Middleware that attempts to recover from service errors by turning them into a response built +/// from the `Status`. +#[derive(Debug, Clone)] +pub(crate) struct RecoverError { + inner: S, +} + +impl RecoverError { + pub(crate) fn new(inner: S) -> Self { + Self { inner } + } +} + +impl Service for RecoverError +where + S: Service>, + S::Error: Into, +{ + type Response = Response; + type Error = crate::Error; + type Future = ResponseFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx).map_err(Into::into) + } + + fn call(&mut self, req: R) -> Self::Future { + ResponseFuture { + inner: self.inner.call(req), + } + } +} + +#[pin_project] +pub(crate) struct ResponseFuture { + #[pin] + inner: F, +} + +impl Future for ResponseFuture +where + F: Future, E>>, + E: Into, +{ + type Output = Result, crate::Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let result: Result, crate::Error> = + ready!(self.project().inner.poll(cx)).map_err(Into::into); + + match result { + Ok(res) => Poll::Ready(Ok(res)), + Err(err) => { + if let Some(status) = Status::try_from_error(&*err) { + let mut res = Response::new(BoxBody::empty()); + status.add_header(res.headers_mut()).unwrap(); + Poll::Ready(Ok(res)) + } else { + Poll::Ready(Err(err)) + } + } + } + } +} diff --git a/tonic/src/transport/service/connection.rs b/tonic/src/transport/service/connection.rs index db3f17e72..6a365417e 100644 --- a/tonic/src/transport/service/connection.rs +++ b/tonic/src/transport/service/connection.rs @@ -1,5 +1,5 @@ use super::super::BoxFuture; -use super::{reconnect::Reconnect, AddOrigin, UserAgent}; +use super::{grpc_timeout::GrpcTimeout, reconnect::Reconnect, AddOrigin, UserAgent}; use crate::{body::BoxBody, transport::Endpoint}; use http::Uri; use hyper::client::conn::Builder; @@ -14,7 +14,6 @@ use tower::load::Load; use tower::{ layer::Layer, limit::{concurrency::ConcurrencyLimitLayer, rate::RateLimitLayer}, - timeout::TimeoutLayer, util::BoxService, ServiceBuilder, ServiceExt, }; @@ -53,7 +52,7 @@ impl Connection { let stack = ServiceBuilder::new() .layer_fn(|s| AddOrigin::new(s, endpoint.uri.clone())) .layer_fn(|s| UserAgent::new(s, endpoint.user_agent.clone())) - .option_layer(endpoint.timeout.map(TimeoutLayer::new)) + .layer_fn(|s| GrpcTimeout::new(s, endpoint.timeout)) .option_layer(endpoint.concurrency_limit.map(ConcurrencyLimitLayer::new)) .option_layer(endpoint.rate_limit.map(|(l, d)| RateLimitLayer::new(l, d))) .into_inner(); diff --git a/tonic/src/transport/service/grpc_timeout.rs b/tonic/src/transport/service/grpc_timeout.rs new file mode 100644 index 000000000..580addbac --- /dev/null +++ b/tonic/src/transport/service/grpc_timeout.rs @@ -0,0 +1,293 @@ +use crate::metadata::GRPC_TIMEOUT_HEADER; +use http::{HeaderMap, HeaderValue, Request}; +use pin_project::pin_project; +use std::{ + fmt, + future::Future, + pin::Pin, + task::{Context, Poll}, + time::Duration, +}; +use tokio::time::Sleep; +use tower_service::Service; + +#[derive(Debug, Clone)] +pub(crate) struct GrpcTimeout { + inner: S, + server_timeout: Option, +} + +impl GrpcTimeout { + pub(crate) fn new(inner: S, server_timeout: Option) -> Self { + Self { + inner, + server_timeout, + } + } +} + +impl Service> for GrpcTimeout +where + S: Service>, + S::Error: Into, +{ + type Response = S::Response; + type Error = crate::Error; + type Future = ResponseFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx).map_err(Into::into) + } + + fn call(&mut self, req: Request) -> Self::Future { + let client_timeout = try_parse_grpc_timeout(req.headers()).unwrap_or_else(|e| { + tracing::trace!("Error parsing `grpc-timeout` header {:?}", e); + None + }); + + // Use the shorter of the two durations, if either are set + let timeout_duration = match (client_timeout, self.server_timeout) { + (None, None) => None, + (Some(dur), None) => Some(dur), + (None, Some(dur)) => Some(dur), + (Some(header), Some(server)) => { + let shorter_duration = std::cmp::min(header, server); + Some(shorter_duration) + } + }; + + ResponseFuture { + inner: self.inner.call(req), + sleep: timeout_duration + .map(tokio::time::sleep) + .map(OptionPin::Some) + .unwrap_or(OptionPin::None), + } + } +} + +#[pin_project] +pub(crate) struct ResponseFuture { + #[pin] + inner: F, + #[pin] + sleep: OptionPin, +} + +impl Future for ResponseFuture +where + F: Future>, + E: Into, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + + if let Poll::Ready(result) = this.inner.poll(cx) { + return Poll::Ready(result.map_err(Into::into)); + } + + if let OptionPinProj::Some(sleep) = this.sleep.project() { + futures_util::ready!(sleep.poll(cx)); + return Poll::Ready(Err(TimeoutExpired(()).into())); + } + + Poll::Pending + } +} + +#[pin_project(project = OptionPinProj)] +enum OptionPin { + Some(#[pin] T), + None, +} + +const SECONDS_IN_HOUR: u64 = 60 * 60; +const SECONDS_IN_MINUTE: u64 = 60; + +/// Tries to parse the `grpc-timeout` header if it is present. If we fail to parse, returns +/// the value we attempted to parse. +/// +/// Follows the [gRPC over HTTP2 spec](https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md). +fn try_parse_grpc_timeout( + headers: &HeaderMap, +) -> Result, &HeaderValue> { + match headers.get(GRPC_TIMEOUT_HEADER) { + Some(val) => { + let (timeout_value, timeout_unit) = val + .to_str() + .map_err(|_| val) + .and_then(|s| if s.is_empty() { Err(val) } else { Ok(s) })? + // `HeaderValue::to_str` only returns `Ok` if the header contains ASCII so this + // `split_at` will never panic from trying to split in the middle of a character. + // See https://docs.rs/http/0.2.4/http/header/struct.HeaderValue.html#method.to_str + // + // `len - 1` also wont panic since we just checked `s.is_empty`. + .split_at(val.len() - 1); + + // gRPC spec specifies `TimeoutValue` will be at most 8 digits + // Caping this at 8 digits also prevents integer overflow from ever occurring + if timeout_value.len() > 8 { + return Err(val); + } + + let timeout_value: u64 = timeout_value.parse().map_err(|_| val)?; + + let duration = match timeout_unit { + // Hours + "H" => Duration::from_secs(timeout_value * SECONDS_IN_HOUR), + // Minutes + "M" => Duration::from_secs(timeout_value * SECONDS_IN_MINUTE), + // Seconds + "S" => Duration::from_secs(timeout_value), + // Milliseconds + "m" => Duration::from_millis(timeout_value), + // Microseconds + "u" => Duration::from_micros(timeout_value), + // Nanoseconds + "n" => Duration::from_nanos(timeout_value), + _ => return Err(val), + }; + + Ok(Some(duration)) + } + None => Ok(None), + } +} + +/// Error returned if a request didn't complete within the configured timeout. +/// +/// Timeouts can be configured either with [`Endpoint::timeout`], [`Server::timeout`], or by +/// setting the [`grpc-timeout` metadata value][spec]. +/// +/// [`Endpoint::timeout`]: crate::transport::server::Server::timeout +/// [`Server::timeout`]: crate::transport::channel::Endpoint::timeout +/// [spec]: https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md +#[derive(Debug)] +pub struct TimeoutExpired(()); + +impl fmt::Display for TimeoutExpired { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Timeout expired") + } +} + +// std::error::Error only requires a type to impl Debug and Display +impl std::error::Error for TimeoutExpired {} + +#[cfg(test)] +mod tests { + use super::*; + use quickcheck::{Arbitrary, Gen}; + use quickcheck_macros::quickcheck; + + // Helper function to reduce the boiler plate of our test cases + fn setup_map_try_parse(val: Option<&str>) -> Result, HeaderValue> { + let mut hm = HeaderMap::new(); + if let Some(v) = val { + let hv = HeaderValue::from_str(v).unwrap(); + hm.insert(GRPC_TIMEOUT_HEADER, hv); + }; + + try_parse_grpc_timeout(&hm).map_err(|e| e.clone()) + } + + #[test] + fn test_hours() { + let parsed_duration = setup_map_try_parse(Some("3H")).unwrap().unwrap(); + assert_eq!(Duration::from_secs(3 * 60 * 60), parsed_duration); + } + + #[test] + fn test_minutes() { + let parsed_duration = setup_map_try_parse(Some("1M")).unwrap().unwrap(); + assert_eq!(Duration::from_secs(60), parsed_duration); + } + + #[test] + fn test_seconds() { + let parsed_duration = setup_map_try_parse(Some("42S")).unwrap().unwrap(); + assert_eq!(Duration::from_secs(42), parsed_duration); + } + + #[test] + fn test_milliseconds() { + let parsed_duration = setup_map_try_parse(Some("13m")).unwrap().unwrap(); + assert_eq!(Duration::from_millis(13), parsed_duration); + } + + #[test] + fn test_microseconds() { + let parsed_duration = setup_map_try_parse(Some("2u")).unwrap().unwrap(); + assert_eq!(Duration::from_micros(2), parsed_duration); + } + + #[test] + fn test_nanoseconds() { + let parsed_duration = setup_map_try_parse(Some("82n")).unwrap().unwrap(); + assert_eq!(Duration::from_nanos(82), parsed_duration); + } + + #[test] + fn test_header_not_present() { + let parsed_duration = setup_map_try_parse(None).unwrap(); + assert!(parsed_duration.is_none()); + } + + #[test] + #[should_panic(expected = "82f")] + fn test_invalid_unit() { + // "f" is not a valid TimeoutUnit + setup_map_try_parse(Some("82f")).unwrap().unwrap(); + } + + #[test] + #[should_panic(expected = "123456789H")] + fn test_too_many_digits() { + // gRPC spec states TimeoutValue will be at most 8 digits + setup_map_try_parse(Some("123456789H")).unwrap().unwrap(); + } + + #[test] + #[should_panic(expected = "oneH")] + fn test_invalid_digits() { + // gRPC spec states TimeoutValue will be at most 8 digits + setup_map_try_parse(Some("oneH")).unwrap().unwrap(); + } + + #[quickcheck] + fn fuzz(header_value: HeaderValueGen) -> bool { + let header_value = header_value.0; + + // this just shouldn't panic + let _ = setup_map_try_parse(Some(&header_value)); + + true + } + + /// Newtype to implement `Arbitrary` for generating `String`s that are valid `HeaderValue`s. + #[derive(Clone, Debug)] + struct HeaderValueGen(String); + + impl Arbitrary for HeaderValueGen { + fn arbitrary(g: &mut Gen) -> Self { + let max = g.choose(&(1..70).collect::>()).copied().unwrap(); + Self(gen_string(g, 0, max)) + } + } + + // copied from https://github.com/hyperium/http/blob/master/tests/header_map_fuzz.rs + fn gen_string(g: &mut Gen, min: usize, max: usize) -> String { + let bytes: Vec<_> = (min..max) + .map(|_| { + // Chars to pick from + g.choose(b"ABCDEFGHIJKLMNOPQRSTUVabcdefghilpqrstuvwxyz----") + .copied() + .unwrap() + }) + .collect(); + + String::from_utf8(bytes).unwrap() + } +} diff --git a/tonic/src/transport/service/mod.rs b/tonic/src/transport/service/mod.rs index e39883076..4e1d89c0c 100644 --- a/tonic/src/transport/service/mod.rs +++ b/tonic/src/transport/service/mod.rs @@ -2,6 +2,7 @@ mod add_origin; mod connection; mod connector; mod discover; +mod grpc_timeout; mod io; mod reconnect; mod router; @@ -13,8 +14,11 @@ pub(crate) use self::add_origin::AddOrigin; pub(crate) use self::connection::Connection; pub(crate) use self::connector::connector; pub(crate) use self::discover::DynamicServiceStream; +pub(crate) use self::grpc_timeout::GrpcTimeout; pub(crate) use self::io::ServerIo; pub(crate) use self::router::{Or, Routes}; #[cfg(feature = "tls")] pub(crate) use self::tls::{TlsAcceptor, TlsConnector}; pub(crate) use self::user_agent::UserAgent; + +pub use self::grpc_timeout::TimeoutExpired;