diff --git a/src/lib.rs b/src/lib.rs index b63924f..9ffc8ef 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,6 +15,7 @@ mod empty; mod full; +mod limited; mod next; mod size_hint; @@ -22,6 +23,7 @@ pub mod combinators; pub use self::empty::Empty; pub use self::full::Full; +pub use self::limited::{LengthLimitError, Limited}; pub use self::next::{Data, Trailers}; pub use self::size_hint::SizeHint; diff --git a/src/limited.rs b/src/limited.rs new file mode 100644 index 0000000..e5303fe --- /dev/null +++ b/src/limited.rs @@ -0,0 +1,209 @@ +//! Body types. + +use crate::{Body, SizeHint}; +use bytes::Buf; +use pin_project_lite::pin_project; +use std::{ + fmt, + pin::Pin, + task::{Context, Poll}, +}; + +type BoxError = Box; + +pin_project! { + /// Body wrapper that returns error when limit is exceeded. + #[derive(Clone, Copy, Debug)] + pub struct Limited { + #[pin] + inner: B, + remaining: usize, + } +} + +impl Limited { + /// Crate a new [`Limited`]. + pub fn new(inner: B, limit: usize) -> Self { + Self { + inner, + remaining: limit, + } + } +} + +impl Body for Limited +where + B: Body, + B::Error: Into, +{ + type Data = B::Data; + type Error = BoxError; + + fn poll_data( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + let this = self.project(); + + let res = match this.inner.poll_data(cx) { + Poll::Ready(Some(Ok(data))) => { + if data.remaining() > *this.remaining { + *this.remaining = 0; + Some(Err(LengthLimitError::new().into())) + } else { + *this.remaining -= data.remaining(); + Some(Ok(data)) + } + } + Poll::Ready(Some(Err(e))) => Some(Err(e.into())), + Poll::Ready(None) => None, + Poll::Pending => return Poll::Pending, + }; + + Poll::Ready(res) + } + + fn poll_trailers( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>> { + self.project().inner.poll_trailers(cx).map_err(Into::into) + } + + fn is_end_stream(&self) -> bool { + self.inner.is_end_stream() + } + + fn size_hint(&self) -> SizeHint { + use std::convert::TryFrom; + match u64::try_from(self.remaining) { + Ok(n) => { + let mut hint = self.inner.size_hint(); + if hint.lower() >= n { + hint.set_exact(n) + } else if let Some(max) = hint.upper() { + hint.set_upper(n.min(max)) + } else { + hint.set_upper(n) + } + hint + } + Err(_) => self.inner.size_hint(), + } + } +} + +/// An error returned when reading from a [`Limited`] body. +#[derive(Debug)] +pub struct LengthLimitError {} + +impl LengthLimitError { + pub(crate) fn new() -> Self { + Self {} + } +} + +impl fmt::Display for LengthLimitError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str("length limit exceeded") + } +} + +impl std::error::Error for LengthLimitError {} + +#[cfg(test)] +mod tests { + use super::Limited; + use crate::{Body, Full, SizeHint}; + use bytes::{BufMut, Bytes, BytesMut}; + use std::{ + convert::Infallible, + pin::Pin, + task::{Context, Poll}, + }; + + #[tokio::test] + async fn over_limit() { + let body = Full::new(Bytes::from(vec![0u8; 4096])); + let limited_body = Limited::new(body, 2048); + + assert!(to_bytes(limited_body).await.is_err()); + } + + #[tokio::test] + async fn under_limit() { + let body = Full::new(Bytes::from(vec![0u8; 4096])); + let limited_body = Limited::new(body, 8192); + + assert!(to_bytes(limited_body).await.is_ok()); + } + + #[tokio::test] + async fn size_hint() { + const CHUNK: [u8; 8] = [0u8; 8]; + + enum TestBody { + Empty, + Half, + Full, + } + + impl Body for TestBody { + type Data = Bytes; + type Error = Infallible; + + fn poll_data( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll>> { + match *self { + Self::Empty => self.set(Self::Half), + Self::Half => self.set(Self::Full), + Self::Full => return Poll::Ready(None), + } + + Poll::Ready(Some(Ok(CHUNK.to_vec().into()))) + } + + fn poll_trailers( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll, Self::Error>> { + unimplemented!() + } + + fn size_hint(&self) -> SizeHint { + match self { + Self::Empty => SizeHint::with_exact((CHUNK.len() * 2) as u64), + Self::Half => SizeHint::with_exact(CHUNK.len() as u64), + Self::Full => SizeHint::with_exact(0), + } + } + } + + let mut body = TestBody::Empty; + + assert_eq!(body.size_hint().upper().unwrap(), (CHUNK.len() * 2) as u64); + + let data = body.data().await.unwrap().unwrap(); + assert_eq!(data, CHUNK.to_vec()); + + assert_eq!(body.size_hint().upper().unwrap(), CHUNK.len() as u64); + + let data = body.data().await.unwrap().unwrap(); + assert_eq!(data, CHUNK.to_vec()); + + assert_eq!(body.size_hint().upper().unwrap(), 0); + } + + async fn to_bytes(body: B) -> Result { + tokio::pin!(body); + + let mut bytes = BytesMut::new(); + while let Some(result) = body.data().await { + bytes.put(result?); + } + + Ok(bytes.freeze()) + } +}