From 4054d61e14b9794a72b48de1a051c26129ec36b1 Mon Sep 17 00:00:00 2001 From: Lucio Franco Date: Tue, 19 Oct 2021 11:52:59 -0400 Subject: [PATCH] fix(tonic): Status code to set correct source on unkown error (#799) --- tests/integration_tests/tests/status.rs | 74 +++++++++++++++++++++++++ tonic/src/status.rs | 7 ++- 2 files changed, 79 insertions(+), 2 deletions(-) diff --git a/tests/integration_tests/tests/status.rs b/tests/integration_tests/tests/status.rs index ef57d5086..4f395e68c 100644 --- a/tests/integration_tests/tests/status.rs +++ b/tests/integration_tests/tests/status.rs @@ -1,12 +1,16 @@ use bytes::Bytes; use futures_util::FutureExt; +use http::Uri; use integration_tests::pb::{ test_client, test_server, test_stream_client, test_stream_server, Input, InputStream, Output, OutputStream, }; +use std::convert::TryFrom; +use std::error::Error; use std::time::Duration; use tokio::sync::oneshot; use tonic::metadata::{MetadataMap, MetadataValue}; +use tonic::transport::Endpoint; use tonic::{transport::Server, Code, Request, Response, Status}; #[tokio::test] @@ -173,8 +177,78 @@ async fn status_from_server_stream() { assert_eq!(stream.message().await.unwrap(), None); } +#[tokio::test] +async fn status_from_server_stream_with_source() { + trace_init(); + + let channel = Endpoint::try_from("http://[::]:50051") + .unwrap() + .connect_with_connector_lazy(tower::service_fn(move |_: Uri| async move { + Err::(std::io::Error::new(std::io::ErrorKind::Other, "WTF")) + })) + .unwrap(); + + let mut client = test_stream_client::TestStreamClient::new(channel); + + let error = client.stream_call(InputStream {}).await.unwrap_err(); + + let source = error.source().unwrap(); + source.downcast_ref::().unwrap(); +} + fn trace_init() { let _ = tracing_subscriber::FmtSubscriber::builder() .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) .try_init(); } + +mod mock { + use std::{ + pin::Pin, + task::{Context, Poll}, + }; + + use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + use tonic::transport::server::Connected; + + #[derive(Debug)] + pub struct MockStream(pub tokio::io::DuplexStream); + + impl Connected for MockStream { + type ConnectInfo = (); + + /// Create type holding information about the connection. + fn connect_info(&self) -> Self::ConnectInfo {} + } + + impl AsyncRead for MockStream { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.0).poll_read(cx, buf) + } + } + + impl AsyncWrite for MockStream { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.0).poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_flush(cx) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut self.0).poll_shutdown(cx) + } + } +} diff --git a/tonic/src/status.rs b/tonic/src/status.rs index c44f88a19..e4e1b32a2 100644 --- a/tonic/src/status.rs +++ b/tonic/src/status.rs @@ -305,8 +305,11 @@ impl Status { #[cfg_attr(not(feature = "transport"), allow(dead_code))] pub(crate) fn from_error(err: Box) -> Status { - Status::try_from_error(err) - .unwrap_or_else(|err| Status::new(Code::Unknown, err.to_string())) + Status::try_from_error(err).unwrap_or_else(|err| { + let mut status = Status::new(Code::Unknown, err.to_string()); + status.source = Some(err); + status + }) } pub(crate) fn try_from_error(