diff --git a/tonic/src/codec/encode.rs b/tonic/src/codec/encode.rs index f335a6140..bd48ee47b 100644 --- a/tonic/src/codec/encode.rs +++ b/tonic/src/codec/encode.rs @@ -11,7 +11,10 @@ use std::{ }; use tokio_stream::{Stream, StreamExt}; +use fuse::Fuse; + pub(super) const BUFFER_SIZE: usize = 8 * 1024; +const YIELD_THRESHOLD: usize = 32 * 1024; pub(crate) fn encode_server( encoder: T, @@ -24,7 +27,7 @@ where T: Encoder, U: Stream>, { - let stream = encode( + let stream = EncodedBytes::new( encoder, source, compression_encoding, @@ -45,7 +48,7 @@ where T: Encoder, U: Stream, { - let stream = encode( + let stream = EncodedBytes::new( encoder, source.map(Ok), compression_encoding, @@ -55,44 +58,115 @@ where EncodeBody::new_client(stream) } -fn encode( - mut encoder: T, - source: U, +/// Combinator for efficient encoding of messages into reasonably sized buffers. +/// EncodedBytes encodes ready messages from its delegate stream into a BytesMut, +/// splitting off and yielding a buffer when either: +/// * The delegate stream polls as not ready, or +/// * The encoded buffer surpasses YIELD_THRESHOLD. +#[pin_project(project = EncodedBytesProj)] +#[derive(Debug)] +pub(crate) struct EncodedBytes +where + T: Encoder, + U: Stream>, +{ + #[pin] + source: Fuse, + encoder: T, compression_encoding: Option, - compression_override: SingleMessageCompressionOverride, max_message_size: Option, -) -> impl Stream> + buf: BytesMut, + uncompression_buf: BytesMut, +} + +impl EncodedBytes where T: Encoder, U: Stream>, { - let mut buf = BytesMut::with_capacity(BUFFER_SIZE); + fn new( + encoder: T, + source: U, + compression_encoding: Option, + compression_override: SingleMessageCompressionOverride, + max_message_size: Option, + ) -> Self { + let buf = BytesMut::with_capacity(BUFFER_SIZE); - let compression_encoding = if compression_override == SingleMessageCompressionOverride::Disable - { - None - } else { - compression_encoding - }; + let compression_encoding = + if compression_override == SingleMessageCompressionOverride::Disable { + None + } else { + compression_encoding + }; - let mut uncompression_buf = if compression_encoding.is_some() { - BytesMut::with_capacity(BUFFER_SIZE) - } else { - BytesMut::new() - }; + let uncompression_buf = if compression_encoding.is_some() { + BytesMut::with_capacity(BUFFER_SIZE) + } else { + BytesMut::new() + }; - source.map(move |result| { - let item = result?; + return EncodedBytes { + source: Fuse::new(source), + encoder, + compression_encoding, + max_message_size, + buf, + uncompression_buf, + }; + } +} - encode_item( - &mut encoder, - &mut buf, - &mut uncompression_buf, +impl Stream for EncodedBytes +where + T: Encoder, + U: Stream>, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let EncodedBytesProj { + mut source, + encoder, compression_encoding, max_message_size, - item, - ) - }) + buf, + uncompression_buf, + } = self.project(); + + loop { + match source.as_mut().poll_next(cx) { + Poll::Pending if buf.is_empty() => { + return Poll::Pending; + } + Poll::Ready(None) if buf.is_empty() => { + return Poll::Ready(None); + } + Poll::Pending | Poll::Ready(None) => { + return Poll::Ready(Some(Ok(buf.split_to(buf.len()).freeze()))); + } + Poll::Ready(Some(Ok(item))) => { + if let Err(status) = encode_item( + encoder, + buf, + uncompression_buf, + *compression_encoding, + *max_message_size, + item, + ) { + return Poll::Ready(Some(Err(status))); + } + + if buf.len() >= YIELD_THRESHOLD { + return Poll::Ready(Some(Ok(buf.split_to(buf.len()).freeze()))); + } + } + Poll::Ready(Some(Err(status))) => { + return Poll::Ready(Some(Err(status))); + } + } + } + } } fn encode_item( @@ -102,10 +176,12 @@ fn encode_item( compression_encoding: Option, max_message_size: Option, item: T::Item, -) -> Result +) -> Result<(), Status> where T: Encoder, { + let offset = buf.len(); + buf.reserve(HEADER_SIZE); unsafe { buf.advance_mut(HEADER_SIZE); @@ -129,14 +205,14 @@ where } // now that we know length, we can write the header - finish_encoding(compression_encoding, max_message_size, buf) + finish_encoding(compression_encoding, max_message_size, &mut buf[offset..]) } fn finish_encoding( compression_encoding: Option, max_message_size: Option, - buf: &mut BytesMut, -) -> Result { + buf: &mut [u8], +) -> Result<(), Status> { let len = buf.len() - HEADER_SIZE; let limit = max_message_size.unwrap_or(DEFAULT_MAX_SEND_MESSAGE_SIZE); if len > limit { @@ -160,7 +236,7 @@ fn finish_encoding( buf.put_u32(len as u32); } - Ok(buf.split_to(len + HEADER_SIZE).freeze()) + Ok(()) } #[derive(Debug)] @@ -269,3 +345,57 @@ where Poll::Ready(self.project().state.trailers()) } } + +mod fuse { + use std::{ + pin::Pin, + task::{ready, Context, Poll}, + }; + + use tokio_stream::Stream; + + /// Stream for the [`fuse`](super::StreamExt::fuse) method. + #[derive(Debug)] + #[pin_project::pin_project] + #[must_use = "streams do nothing unless polled"] + pub(crate) struct Fuse { + #[pin] + stream: St, + done: bool, + } + + impl Fuse { + pub(crate) fn new(stream: St) -> Self { + Self { + stream, + done: false, + } + } + } + + impl Stream for Fuse { + type Item = S::Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + + if *this.done { + return Poll::Ready(None); + } + + let item = ready!(this.stream.poll_next(cx)); + if item.is_none() { + *this.done = true; + } + Poll::Ready(item) + } + + fn size_hint(&self) -> (usize, Option) { + if self.done { + (0, Some(0)) + } else { + self.stream.size_hint() + } + } + } +}