diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml
index 12e225b49..b14a0f1d9 100644
--- a/.github/workflows/CI.yml
+++ b/.github/workflows/CI.yml
@@ -49,6 +49,8 @@ jobs:
env:
RUSTFLAGS: "-D warnings"
+ # run a lot of quickcheck iterations
+ QUICKCHECK_TESTS: 1000
steps:
- uses: hecrj/setup-rust-action@master
diff --git a/tests/integration_tests/Cargo.toml b/tests/integration_tests/Cargo.toml
index 37ef59a39..2e6836862 100644
--- a/tests/integration_tests/Cargo.toml
+++ b/tests/integration_tests/Cargo.toml
@@ -16,6 +16,7 @@ bytes = "1.0"
[dev-dependencies]
tokio = { version = "1.0", features = ["macros", "rt-multi-thread", "net"] }
+tokio-stream = { version = "0.1.5", features = ["net"] }
[build-dependencies]
tonic-build = { path = "../../tonic-build" }
diff --git a/tests/integration_tests/tests/timeout.rs b/tests/integration_tests/tests/timeout.rs
new file mode 100644
index 000000000..450a67d21
--- /dev/null
+++ b/tests/integration_tests/tests/timeout.rs
@@ -0,0 +1,92 @@
+use integration_tests::pb::{test_client, test_server, Input, Output};
+use std::{net::SocketAddr, time::Duration};
+use tokio::net::TcpListener;
+use tonic::{transport::Server, Code, Request, Response, Status};
+
+#[tokio::test]
+async fn cancelation_on_timeout() {
+ let addr = run_service_in_background(Duration::from_secs(1), Duration::from_secs(100)).await;
+
+ let mut client = test_client::TestClient::connect(format!("http://{}", addr))
+ .await
+ .unwrap();
+
+ let mut req = Request::new(Input {});
+ req.metadata_mut()
+ // 500 ms
+ .insert("grpc-timeout", "500m".parse().unwrap());
+
+ let res = client.unary_call(req).await;
+
+ let err = res.unwrap_err();
+ assert!(err.message().contains("Timeout expired"));
+ assert_eq!(err.code(), Code::Cancelled);
+}
+
+#[tokio::test]
+async fn picks_server_timeout_if_thats_sorter() {
+ let addr = run_service_in_background(Duration::from_secs(1), Duration::from_millis(100)).await;
+
+ let mut client = test_client::TestClient::connect(format!("http://{}", addr))
+ .await
+ .unwrap();
+
+ let mut req = Request::new(Input {});
+ req.metadata_mut()
+ // 10 hours
+ .insert("grpc-timeout", "10H".parse().unwrap());
+
+ let res = client.unary_call(req).await;
+ let err = res.unwrap_err();
+ assert!(err.message().contains("Timeout expired"));
+ assert_eq!(err.code(), Code::Cancelled);
+}
+
+#[tokio::test]
+async fn picks_client_timeout_if_thats_sorter() {
+ let addr = run_service_in_background(Duration::from_secs(1), Duration::from_secs(100)).await;
+
+ let mut client = test_client::TestClient::connect(format!("http://{}", addr))
+ .await
+ .unwrap();
+
+ let mut req = Request::new(Input {});
+ req.metadata_mut()
+ // 100 ms
+ .insert("grpc-timeout", "100m".parse().unwrap());
+
+ let res = client.unary_call(req).await;
+ let err = res.unwrap_err();
+ assert!(err.message().contains("Timeout expired"));
+ assert_eq!(err.code(), Code::Cancelled);
+}
+
+async fn run_service_in_background(latency: Duration, server_timeout: Duration) -> SocketAddr {
+ struct Svc {
+ latency: Duration,
+ }
+
+ #[tonic::async_trait]
+ impl test_server::Test for Svc {
+ async fn unary_call(&self, _req: Request) -> Result, Status> {
+ tokio::time::sleep(self.latency).await;
+ Ok(Response::new(Output {}))
+ }
+ }
+
+ let svc = test_server::TestServer::new(Svc { latency });
+
+ let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
+ let addr = listener.local_addr().unwrap();
+
+ tokio::spawn(async move {
+ Server::builder()
+ .timeout(server_timeout)
+ .add_service(svc)
+ .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener))
+ .await
+ .unwrap();
+ });
+
+ addr
+}
diff --git a/tonic/Cargo.toml b/tonic/Cargo.toml
index c92d51106..5e494d5c9 100644
--- a/tonic/Cargo.toml
+++ b/tonic/Cargo.toml
@@ -31,7 +31,8 @@ transport = [
"tokio",
"tower",
"tracing-futures",
- "tokio/macros"
+ "tokio/macros",
+ "tokio/time",
]
tls = ["transport", "tokio-rustls"]
tls-roots = ["tls", "rustls-native-certs"]
@@ -68,7 +69,7 @@ h2 = { version = "0.3", optional = true }
hyper = { version = "0.14.2", features = ["full"], optional = true }
tokio = { version = "1.0.1", features = ["net"], optional = true }
tokio-stream = "0.1"
-tower = { version = "0.4.4", features = ["balance", "buffer", "discover", "limit", "load", "make", "timeout", "util"], optional = true }
+tower = { version = "0.4.7", features = ["balance", "buffer", "discover", "limit", "load", "make", "timeout", "util"], optional = true }
tracing-futures = { version = "0.2", optional = true }
# rustls
@@ -80,6 +81,8 @@ tokio = { version = "1.0", features = ["rt", "macros"] }
static_assertions = "1.0"
rand = "0.8"
bencher = "0.1.5"
+quickcheck = "1.0"
+quickcheck_macros = "1.0"
[package.metadata.docs.rs]
all-features = true
diff --git a/tonic/src/metadata/map.rs b/tonic/src/metadata/map.rs
index 177bb3c96..8ddccf194 100644
--- a/tonic/src/metadata/map.rs
+++ b/tonic/src/metadata/map.rs
@@ -194,15 +194,17 @@ pub struct OccupiedEntry<'a, VE: ValueEncoding> {
phantom: PhantomData,
}
+#[cfg(feature = "transport")]
+pub(crate) const GRPC_TIMEOUT_HEADER: &str = "grpc-timeout";
+
// ===== impl MetadataMap =====
impl MetadataMap {
// Headers reserved by the gRPC protocol.
- pub(crate) const GRPC_RESERVED_HEADERS: [&'static str; 8] = [
+ pub(crate) const GRPC_RESERVED_HEADERS: [&'static str; 7] = [
"te",
"user-agent",
"content-type",
- "grpc-timeout",
"grpc-message",
"grpc-encoding",
"grpc-message-type",
diff --git a/tonic/src/metadata/mod.rs b/tonic/src/metadata/mod.rs
index 8389681b9..50bfb49e4 100644
--- a/tonic/src/metadata/mod.rs
+++ b/tonic/src/metadata/mod.rs
@@ -29,6 +29,9 @@ pub use self::value::AsciiMetadataValue;
pub use self::value::BinaryMetadataValue;
pub use self::value::MetadataValue;
+#[cfg(feature = "transport")]
+pub(crate) use self::map::GRPC_TIMEOUT_HEADER;
+
/// The metadata::errors module contains types for errors that can occur
/// while handling gRPC custom metadata.
pub mod errors {
diff --git a/tonic/src/status.rs b/tonic/src/status.rs
index dcf381d5b..071e1f9cd 100644
--- a/tonic/src/status.rs
+++ b/tonic/src/status.rs
@@ -313,7 +313,7 @@ impl Status {
Status::try_from_error(err).unwrap_or_else(|| Status::new(Code::Unknown, err.to_string()))
}
- fn try_from_error(err: &(dyn Error + 'static)) -> Option {
+ pub(crate) fn try_from_error(err: &(dyn Error + 'static)) -> Option {
let mut cause = Some(err);
while let Some(err) = cause {
@@ -331,6 +331,10 @@ impl Status {
if let Some(h2) = err.downcast_ref::() {
return Some(Status::from_h2_error(h2));
}
+
+ if let Some(timeout) = err.downcast_ref::() {
+ return Some(Status::cancelled(timeout.to_string()));
+ }
}
cause = err.source();
diff --git a/tonic/src/transport/mod.rs b/tonic/src/transport/mod.rs
index b91767aa3..6cc4c2890 100644
--- a/tonic/src/transport/mod.rs
+++ b/tonic/src/transport/mod.rs
@@ -98,6 +98,8 @@ pub use self::channel::{Channel, Endpoint};
pub use self::error::Error;
#[doc(inline)]
pub use self::server::{NamedService, Server};
+#[doc(inline)]
+pub use self::service::TimeoutExpired;
pub use self::tls::{Certificate, Identity};
pub use hyper::{Body, Uri};
diff --git a/tonic/src/transport/server/mod.rs b/tonic/src/transport/server/mod.rs
index 6f2729d09..c2ad4d5ef 100644
--- a/tonic/src/transport/server/mod.rs
+++ b/tonic/src/transport/server/mod.rs
@@ -2,6 +2,7 @@
mod conn;
mod incoming;
+mod recover_error;
#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
mod tls;
@@ -21,8 +22,9 @@ pub(crate) use tokio_rustls::server::TlsStream;
#[cfg(feature = "tls")]
use crate::transport::Error;
+use self::recover_error::RecoverError;
use super::{
- service::{Or, Routes, ServerIo},
+ service::{GrpcTimeout, Or, Routes, ServerIo},
BoxFuture,
};
use crate::{body::BoxBody, request::ConnectionInfo};
@@ -42,10 +44,7 @@ use std::{
time::Duration,
};
use tokio::io::{AsyncRead, AsyncWrite};
-use tower::{
- limit::concurrency::ConcurrencyLimitLayer, timeout::TimeoutLayer, util::Either, Service,
- ServiceBuilder,
-};
+use tower::{limit::concurrency::ConcurrencyLimitLayer, util::Either, Service, ServiceBuilder};
use tracing_futures::{Instrument, Instrumented};
type BoxService = tower::util::BoxService, Response, crate::Error>;
@@ -655,8 +654,9 @@ where
Box::pin(async move {
let svc = ServiceBuilder::new()
+ .layer_fn(RecoverError::new)
.option_layer(concurrency_limit.map(ConcurrencyLimitLayer::new))
- .option_layer(timeout.map(TimeoutLayer::new))
+ .layer_fn(|s| GrpcTimeout::new(s, timeout))
.service(svc);
let svc = BoxService::new(Svc {
diff --git a/tonic/src/transport/server/recover_error.rs b/tonic/src/transport/server/recover_error.rs
new file mode 100644
index 000000000..9b4ff2e67
--- /dev/null
+++ b/tonic/src/transport/server/recover_error.rs
@@ -0,0 +1,75 @@
+use crate::{body::BoxBody, Status};
+use futures_util::ready;
+use http::Response;
+use pin_project::pin_project;
+use std::{
+ future::Future,
+ pin::Pin,
+ task::{Context, Poll},
+};
+use tower::Service;
+
+/// Middleware that attempts to recover from service errors by turning them into a response built
+/// from the `Status`.
+#[derive(Debug, Clone)]
+pub(crate) struct RecoverError {
+ inner: S,
+}
+
+impl RecoverError {
+ pub(crate) fn new(inner: S) -> Self {
+ Self { inner }
+ }
+}
+
+impl Service for RecoverError
+where
+ S: Service>,
+ S::Error: Into,
+{
+ type Response = Response;
+ type Error = crate::Error;
+ type Future = ResponseFuture;
+
+ fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> {
+ self.inner.poll_ready(cx).map_err(Into::into)
+ }
+
+ fn call(&mut self, req: R) -> Self::Future {
+ ResponseFuture {
+ inner: self.inner.call(req),
+ }
+ }
+}
+
+#[pin_project]
+pub(crate) struct ResponseFuture {
+ #[pin]
+ inner: F,
+}
+
+impl Future for ResponseFuture
+where
+ F: Future