Skip to content

Commit

Permalink
Support streaming body
Browse files Browse the repository at this point in the history
  • Loading branch information
kornelski committed Sep 13, 2023
1 parent 8b7575c commit 339bffc
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 42 deletions.
85 changes: 51 additions & 34 deletions src/response.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use crate::error::Error;
use crate::Request;
use core::task::Poll;
use futures::stream::Stream;
use hyper::StatusCode;
use std::fmt;
use std::io;
use std::mem;
use std::sync::Arc;
use std::task::Poll;
use std::thread;
use tokio::sync::mpsc;

#[derive(Clone, Debug, PartialEq)]
pub(crate) struct Response {
Expand Down Expand Up @@ -61,50 +63,65 @@ impl Default for Response {
}
}

pub(crate) struct Chunked {
buffer: Vec<u8>,
finished: bool,
struct ChunkedStreamWriter {
sender: mpsc::Sender<io::Result<Box<[u8]>>>,
}

impl Chunked {
pub fn new() -> Self {
Self {
buffer: vec![],
finished: false,
}
impl io::Write for ChunkedStreamWriter {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.sender
.blocking_send(Ok(buf.into()))
.map_err(|_| io::ErrorKind::BrokenPipe)?;
Ok(buf.len())
}

pub fn finish(&mut self) {
self.finished = true;
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}

impl Stream for Chunked {
type Item = Result<Vec<u8>, String>;
pub(crate) struct ChunkedStream {
receiver: Option<mpsc::Receiver<io::Result<Box<[u8]>>>>,
thread: Option<thread::JoinHandle<()>>,
}

fn poll_next(
mut self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
if !self.buffer.is_empty() {
let data = mem::take(&mut self.buffer);
Poll::Ready(Some(Ok(data)))
} else if !self.finished {
Poll::Pending
} else {
Poll::Ready(None)
}
impl ChunkedStream {
pub fn new(body_fn: Arc<BodyFnWithWriter>) -> Result<Self, Error> {
let (sender, receiver) = mpsc::channel(1);
let join = thread::Builder::new()
.name(format!("mockito::body_fn_{:p}", body_fn))
.spawn(move || {
let mut writer = ChunkedStreamWriter { sender };
if let Err(e) = body_fn(&mut writer) {
let _ = writer.sender.blocking_send(Err(e));
}
})
.map_err(|e| Error::new_with_context(crate::ErrorKind::ResponseFailure, e))?;
Ok(Self {
receiver: Some(receiver),
thread: Some(join),
})
}
}

impl io::Write for Chunked {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.buffer.append(&mut buf.to_vec());
Ok(buf.len())
impl Drop for ChunkedStream {
fn drop(&mut self) {
// must close the channel first
let _ = self.receiver.take();
let _ = self.thread.take().map(|t| t.join());
}
}

fn flush(&mut self) -> io::Result<()> {
self.finished = true;
Ok(())
impl Stream for ChunkedStream {
type Item = io::Result<Box<[u8]>>;

fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
self.receiver
.as_mut()
.map(move |r| r.poll_recv(cx))
.unwrap_or(Poll::Ready(None))
}
}
11 changes: 3 additions & 8 deletions src/server.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::mock::InnerMock;
use crate::request::Request;
use crate::response::{Body as ResponseBody, Chunked as ResponseChunked};
use crate::response::{Body as ResponseBody, ChunkedStream};
use crate::ServerGuard;
use crate::{Error, ErrorKind, Matcher, Mock};
use hyper::server::conn::Http;
Expand Down Expand Up @@ -401,13 +401,8 @@ fn respond_with_mock(request: Request, mock: &RemoteMock) -> Result<Response<Bod
Body::from(bytes.clone())
}
ResponseBody::FnWithWriter(body_fn) => {
let mut chunked = ResponseChunked::new();
body_fn(&mut chunked)
.map_err(|_| Error::new(ErrorKind::ResponseBodyFailure))
.unwrap();
chunked.finish();

Body::wrap_stream(chunked)
let stream = ChunkedStream::new(Arc::clone(body_fn))?;
Body::wrap_stream(stream)
}
ResponseBody::FnWithRequest(body_fn) => {
let bytes = body_fn(&request);
Expand Down
14 changes: 14 additions & 0 deletions tests/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,20 @@ fn test_mock_with_fn_body() {
assert_eq!("hello", body);
}

#[test]
fn test_mock_with_fn_body_streamed_forever() {
let mut s = Server::new();
s.mock("GET", "/")
.with_chunked_body(|w| loop {
w.write_all(b"spam")?
})
.create();

let stream = request_stream("1.1", s.host_with_port(), "GET /", "", "");
let (status_line, _, _) = parse_stream(stream, true);
assert_eq!("HTTP/1.1 200 OK\r\n", status_line);
}

#[test]
fn test_mock_with_body_from_request() {
let mut s = Server::new();
Expand Down

0 comments on commit 339bffc

Please sign in to comment.