diff --git a/Cargo.toml b/Cargo.toml index a073486..bdf1458 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,7 @@ appveyor = { repository = "lipanski/mockito", branch = "master", service = "gith assert-json-diff = "2.0" bytes = "1" colored = { version = "2.0", optional = true } -futures-core = "0.3" +futures-util = { version = "0.3", default-features = false } http = "1" http-body = "1" http-body-util = "0.1" diff --git a/src/response.rs b/src/response.rs index d41e012..218cb87 100644 --- a/src/response.rs +++ b/src/response.rs @@ -1,9 +1,8 @@ use crate::error::Error; use crate::Request; use bytes::Bytes; -use futures_core::stream::Stream; +use futures_util::Stream; use http::{HeaderMap, StatusCode}; -use http_body::Frame; use std::fmt; use std::io; use std::sync::Arc; @@ -117,7 +116,7 @@ impl Drop for ChunkedStream { } impl Stream for ChunkedStream { - type Item = io::Result>; + type Item = io::Result; fn poll_next( mut self: std::pin::Pin<&mut Self>, @@ -126,9 +125,9 @@ impl Stream for ChunkedStream { self.receiver .as_mut() .map(move |receiver| { - receiver.poll_recv(cx).map(|received| { - received.map(|result| result.map(|data| Frame::data(Bytes::from(data)))) - }) + receiver + .poll_recv(cx) + .map(|received| received.map(|result| result.map(Into::into))) }) .unwrap_or(Poll::Ready(None)) } diff --git a/src/server.rs b/src/server.rs index 43cf848..4b04409 100644 --- a/src/server.rs +++ b/src/server.rs @@ -3,8 +3,11 @@ use crate::request::Request; use crate::response::{Body as ResponseBody, ChunkedStream}; use crate::ServerGuard; use crate::{Error, ErrorKind, Matcher, Mock}; +use bytes::Bytes; +use futures_util::{TryStream, TryStreamExt}; use http::{Request as HttpRequest, Response, StatusCode}; -use http_body_util::{BodyExt, Empty, Full, StreamBody}; +use http_body::{Body as HttpBody, Frame, SizeHint}; +use http_body_util::{BodyExt, StreamBody}; use hyper::body::Incoming; use hyper::service::service_fn; use hyper_util::rt::{TokioExecutor, TokioIo}; @@ -14,8 +17,10 @@ use std::error::Error as StdError; use std::fmt; use std::net::{IpAddr, SocketAddr}; use std::ops::Drop; +use std::pin::Pin; use std::str::FromStr; use std::sync::{mpsc, Arc, RwLock}; +use std::task::{ready, Context, Poll}; use std::thread; use tokio::net::TcpListener; use tokio::runtime; @@ -446,26 +451,72 @@ impl fmt::Display for Server { } type BoxError = Box; -type BoxBody = http_body_util::combinators::UnsyncBoxBody; -trait IntoBoxBody { - fn into_box_body(self) -> BoxBody; +enum Body { + Once(Option), + Wrap(http_body_util::combinators::UnsyncBoxBody), } -impl IntoBoxBody for B -where - B: http_body::Body + Send + 'static, - B::Error: Into, -{ - fn into_box_body(self) -> BoxBody { - self.map_err(Into::into).boxed_unsync() +impl Body { + fn empty() -> Self { + Self::Once(None) + } + + fn from_data_stream(stream: S) -> Self + where + S: TryStream + Send + 'static, + S::Error: Into, + { + let body = StreamBody::new(stream.map_ok(Frame::data).map_err(Into::into)).boxed_unsync(); + Self::Wrap(body) + } +} + +impl From for Body { + fn from(bytes: Bytes) -> Self { + if bytes.is_empty() { + Self::empty() + } else { + Self::Once(Some(bytes)) + } + } +} + +impl HttpBody for Body { + type Data = Bytes; + type Error = BoxError; + + fn poll_frame( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + match self.as_mut().get_mut() { + Self::Once(val) => Poll::Ready(Ok(val.take().map(Frame::data)).transpose()), + Self::Wrap(body) => Poll::Ready(ready!(Pin::new(body).poll_frame(cx))), + } + } + + fn size_hint(&self) -> SizeHint { + match self { + Self::Once(None) => SizeHint::with_exact(0), + Self::Once(Some(bytes)) => SizeHint::with_exact(bytes.len() as u64), + Self::Wrap(body) => body.size_hint(), + } + } + + fn is_end_stream(&self) -> bool { + match self { + Self::Once(None) => true, + Self::Once(Some(bytes)) => bytes.is_empty(), + Self::Wrap(body) => body.is_end_stream(), + } } } async fn handle_request( hyper_request: HttpRequest, state: Arc>, -) -> Result, Error> { +) -> Result, Error> { let mut request = Request::new(hyper_request); request.read_body().await; log::debug!("Request received: {}", request.formatted()); @@ -498,7 +549,7 @@ async fn handle_request( } } -fn respond_with_mock(request: Request, mock: &RemoteMock) -> Result, Error> { +fn respond_with_mock(request: Request, mock: &RemoteMock) -> Result, Error> { let status: StatusCode = mock.inner.response.status; let mut response = Response::builder().status(status); @@ -512,32 +563,32 @@ fn respond_with_mock(request: Request, mock: &RemoteMock) -> Result { let stream = ChunkedStream::new(Arc::clone(body_fn))?; - StreamBody::new(stream).into_box_body() + Body::from_data_stream(stream) } ResponseBody::FnWithRequest(body_fn) => { let bytes = body_fn(&request); - Full::new(bytes.to_owned()).into_box_body() + Body::from(bytes) } } } else { - Empty::new().into_box_body() + Body::empty() }; - let response: Response = response + let response = response .body(body) .map_err(|err| Error::new_with_context(ErrorKind::ResponseFailure, err))?; Ok(response) } -fn respond_with_mock_not_found() -> Result, Error> { - let response: Response = Response::builder() +fn respond_with_mock_not_found() -> Result, Error> { + let response = Response::builder() .status(StatusCode::NOT_IMPLEMENTED) - .body(Empty::new().into_box_body()) + .body(Body::empty()) .map_err(|err| Error::new_with_context(ErrorKind::ResponseFailure, err))?; Ok(response)