Skip to content

Commit

Permalink
Add async-byte-channel and use it to avoid needing Unix sockets in te…
Browse files Browse the repository at this point in the history
…sts.
  • Loading branch information
dwrensha committed Oct 6, 2020
1 parent 3ab5512 commit 92093a3
Show file tree
Hide file tree
Showing 7 changed files with 303 additions and 166 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ members = [
"capnp-rpc",

# testing and examples
"async-byte-channel",
"benchmark",
"capnpc/test",
"capnp-futures/test",
Expand Down
14 changes: 14 additions & 0 deletions async-byte-channel/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[package]
name = "async-byte-channel"
version = "0.0.1"
license = "MIT"
description = "Helper library for writing async tests"

repository = "https://github.com/dwrensha/capnproto-rust"
edition = "2018"

[dependencies.futures]
version = "0.3.0"
default-features = false
features = ["std", "executor"]

158 changes: 158 additions & 0 deletions async-byte-channel/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
// Simple in-memory byte stream.

use std::pin::Pin;
use std::sync::{Arc, Mutex};

use futures::{AsyncRead, AsyncWrite};
use std::task::{Poll, Waker};

#[derive(Debug)]
struct Inner {
buffer: Vec<u8>,
write_cursor: usize,
read_cursor: usize,
write_end_closed: bool,
read_waker: Option<Waker>,
write_waker: Option<Waker>,
}

impl Inner {
fn new() -> Inner {
Inner {
buffer: vec![0; 8096],
write_cursor: 0,
read_cursor: 0,
write_end_closed: false,
read_waker: None,
write_waker: None,
}
}
}

pub struct Sender {
inner: Arc<Mutex<Inner>>,
}

impl Drop for Sender {
fn drop(&mut self) {
let mut inner = self.inner.lock().unwrap();
inner.write_end_closed = true;
if let Some(read_waker) = inner.read_waker.take() {
read_waker.wake();
}
}
}

pub struct Receiver {
inner: Arc<Mutex<Inner>>,
}

pub fn channel() -> (Sender, Receiver) {
let inner = Arc::new(Mutex::new(Inner::new()));
let sender = Sender { inner: inner.clone() };
let receiver = Receiver { inner: inner };
(sender, receiver)
}

impl AsyncRead for Receiver {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut futures::task::Context,
buf: &mut [u8]) -> futures::task::Poll<Result<usize, futures::io::Error>>
{
let mut inner = self.inner.lock().unwrap();
if inner.read_cursor == inner.write_cursor {
if inner.write_end_closed {
Poll::Ready(Ok(0))
} else {
inner.read_waker = Some(cx.waker().clone());
Poll::Pending
}
} else {
assert!(inner.read_cursor < inner.write_cursor);
let copy_len = std::cmp::min(buf.len(), inner.write_cursor - inner.read_cursor);
(&mut buf[0..copy_len]).copy_from_slice(&inner.buffer[inner.read_cursor .. inner.read_cursor + copy_len]);
inner.read_cursor += copy_len;
if let Some(write_waker) = inner.write_waker.take() {
write_waker.wake();
}
Poll::Ready(Ok(copy_len))
}
}
}

impl AsyncWrite for Sender {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut futures::task::Context,
buf: &[u8]) -> futures::task::Poll<Result<usize, futures::io::Error>>
{
let mut inner = self.inner.lock().unwrap();
if inner.write_cursor == inner.buffer.len() {
if inner.read_cursor == inner.buffer.len() {
inner.write_cursor = 0;
inner.read_cursor = 0;
} else {
inner.write_waker = Some(cx.waker().clone());
return Poll::Pending
}
}

assert!(inner.write_cursor < inner.buffer.len());

let copy_len = std::cmp::min(buf.len(), inner.buffer.len() - inner.write_cursor);
let dest_range = inner.write_cursor..inner.write_cursor + copy_len;
(&mut inner.buffer[dest_range]).copy_from_slice(&buf[0..copy_len]);
inner.write_cursor += copy_len;
if let Some(read_waker) = inner.read_waker.take() {
read_waker.wake();
}
Poll::Ready(Ok(copy_len))
}

fn poll_flush(
self: Pin<&mut Self>,
_cx: &mut futures::task::Context)
-> Poll<Result<(), futures::io::Error>>
{
Poll::Ready(Ok(()))
}

fn poll_close(
self: Pin<&mut Self>,
_cx: &mut futures::task::Context)
-> Poll<Result<(), futures::io::Error>>
{
let mut inner = self.inner.lock().unwrap();
inner.write_end_closed = true;
if let Some(read_waker) = inner.read_waker.take() {
read_waker.wake();
}
Poll::Ready(Ok(()))
}
}

#[cfg(test)]
pub mod test {
use futures::{AsyncReadExt, AsyncWriteExt};
use futures::task::{LocalSpawnExt};

#[test]
fn basic() {
let (mut sender, mut receiver) = crate::channel();
let buf: Vec<u8> = vec![1,2,3,4,5].into_iter().cycle().take(20000).collect();
let mut pool = futures::executor::LocalPool::new();

let buf2 = buf.clone();
pool.spawner().spawn_local(async move {
sender.write_all(&buf2).await.unwrap();
()
}).unwrap();

let mut buf3 = vec![];
pool.run_until(receiver.read_to_end(&mut buf3)).unwrap();

assert_eq!(buf.len(), buf3.len());
}
}

3 changes: 1 addition & 2 deletions capnp-futures/test/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,4 @@ capnpc = { path = "../../capnpc" }
capnp-futures = {path = "./../"}
capnp = { path = "../../capnp" }
futures = "0.3.0"
tokio = { version = "0.2.6", features = ["net", "rt-util", "time", "uds"]}
tokio-util = { version = "0.3.0", features = ["compat"] }
async-byte-channel = {path = "./../../async-byte-channel"}
122 changes: 51 additions & 71 deletions capnp-futures/test/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ pub mod addressbook_capnp {

#[cfg(test)]
mod tests {
use tokio_util::compat::{Tokio02AsyncReadCompatExt, Tokio02AsyncWriteCompatExt};

use futures::task::{LocalSpawnExt};
use capnp::message;
use capnp_futures::serialize;
use crate::addressbook_capnp::{address_book, person};

fn populate_address_book(address_book: address_book::Builder) {
Expand Down Expand Up @@ -83,44 +84,39 @@ mod tests {
use std::cell::Cell;
use std::rc::Rc;

let mut rt = tokio::runtime::Runtime::new().unwrap();
tokio::task::LocalSet::new().block_on(&mut rt, async move {

let (s1, s2) = tokio::net::UnixStream::pair().expect("socket pair");

let (mut sender, write_queue) = capnp_futures::write_queue(s1.compat_write());

let read_stream = capnp_futures::ReadStream::new(s2.compat(), Default::default());

let messages_read = Rc::new(Cell::new(0u32));
let messages_read1 = messages_read.clone();

let done_reading = read_stream.for_each(|m| {
match m {
Err(e) => panic!("read error: {:?}", e),
Ok(msg) => {
let address_book = msg.get_root::<address_book::Reader>().unwrap();
read_address_book(address_book);
messages_read.set(messages_read.get() + 1);
futures::future::ready(())
}
let mut pool = futures::executor::LocalPool::new();
let spawner = pool.spawner();

let (writer, reader) = async_byte_channel::channel();
let (mut sender, write_queue) = capnp_futures::write_queue(writer);
let read_stream = capnp_futures::ReadStream::new(reader, Default::default());
let messages_read = Rc::new(Cell::new(0u32));
let messages_read1 = messages_read.clone();

let done_reading = read_stream.for_each(|m| {
match m {
Err(e) => panic!("read error: {:?}", e),
Ok(msg) => {
let address_book = msg.get_root::<address_book::Reader>().unwrap();
read_address_book(address_book);
messages_read.set(messages_read.get() + 1);
futures::future::ready(())
}
});
}
});

let io = futures::future::join(done_reading, write_queue.map(|_| ()));
let io = futures::future::join(done_reading, write_queue.map(|_| ()));

let mut m = capnp::message::Builder::new_default();
populate_address_book(m.init_root());
let mut m = capnp::message::Builder::new_default();
populate_address_book(m.init_root());

tokio::task::spawn_local(Box::pin(sender.send(m).map(|_|())));
drop(sender);
io.await;
assert_eq!(messages_read1.get(), 1);
})
spawner.spawn_local(sender.send(m).map(|_|())).unwrap();
drop(sender);
pool.run_until(io);
assert_eq!(messages_read1.get(), 1);
}

fn fill_and_send_message(mut message: capnp::message::Builder<capnp::message::HeapAllocator>) {
use capnp_futures::serialize;
use futures::{FutureExt, TryFutureExt};

{
Expand All @@ -129,29 +125,27 @@ mod tests {
read_address_book(address_book.reborrow_as_reader());
}

let mut rt = tokio::runtime::Runtime::new().unwrap();
tokio::task::LocalSet::new().block_on(&mut rt, async move {
let (stream0, stream1) = tokio::net::UnixStream::pair().expect("socket pair");

let f0 = serialize::write_message(stream0.compat_write(), message)
let mut pool = futures::executor::LocalPool::new();
let (stream0, stream1) = async_byte_channel::channel();
let f0 = serialize::write_message(stream0, message)
.map_err(|e| panic!("write error {:?}", e)).map(|_|());
let f1 =
serialize::read_message(stream1.compat(), capnp::message::ReaderOptions::new()).and_then(|maybe_message_reader| {
match maybe_message_reader {
None => panic!("did not get message"),
Some(m) => {
let address_book = m.get_root::<address_book::Reader>().unwrap();
read_address_book(address_book);
futures::future::ready(Ok::<(),capnp::Error>(()))
}
let f1 =
serialize::read_message(stream1, capnp::message::ReaderOptions::new()).and_then(|maybe_message_reader| {
match maybe_message_reader {
None => panic!("did not get message"),
Some(m) => {
let address_book = m.get_root::<address_book::Reader>().unwrap();
read_address_book(address_book);
futures::future::ready(Ok::<(),capnp::Error>(()))
}
});
}
});

tokio::task::spawn_local(Box::pin(f0));
f1.await
}).expect("fill_and_send_message");
pool.spawner().spawn_local(f0).unwrap();
pool.run_until(f1).unwrap();
}


#[test]
fn single_segment() {
fill_and_send_message(capnp::message::Builder::new_default());
Expand All @@ -166,29 +160,15 @@ mod tests {

#[test]
fn static_lifetime_not_required_funcs() {
use capnp::message;
use capnp_futures::serialize;

let mut rt = tokio::runtime::Runtime::new().unwrap();
tokio::task::LocalSet::new().block_on(&mut rt, async move {
let (write, read) = tokio::net::UnixStream::pair().expect("socket pair");

let _ = serialize::read_message(&mut read.compat(), message::ReaderOptions::default());
let _ = serialize::write_message(&mut write.compat_write(), message::Builder::new_default());
});
let (mut write, mut read) = async_byte_channel::channel();
let _ = serialize::read_message(&mut read, message::ReaderOptions::default());
let _ = serialize::write_message(&mut write, message::Builder::new_default());
}

#[test]
fn static_lifetime_not_required_on_highlevel() {
use capnp::message;
use capnp_futures;

let mut rt = tokio::runtime::Runtime::new().unwrap();
tokio::task::LocalSet::new().block_on(&mut rt, async move {
let (write, read) = tokio::net::UnixStream::pair().expect("socket pair");
let _ = capnp_futures::ReadStream::new(&mut read.compat(), message::ReaderOptions::default());
let _ = capnp_futures::write_queue::<_, message::Builder<message::HeapAllocator>>(&mut write.compat_write());
});
let (mut write, mut read) = async_byte_channel::channel();
let _ = capnp_futures::ReadStream::new(&mut read, message::ReaderOptions::default());
let _ = capnp_futures::write_queue::<_, message::Builder<message::HeapAllocator>>(&mut write);
}

}
4 changes: 1 addition & 3 deletions capnp-rpc/test/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,4 @@ path = "../"
[dependencies]
capnp = { path = "../../capnp" }
futures = "0.3.0"
tokio = { version = "0.2.6", features = ["net", "rt-util", "time", "uds"]}
tokio-util = { version = "0.3.0", features = ["compat"] }

async-byte-channel = {path = "./../../async-byte-channel"}
Loading

0 comments on commit 92093a3

Please sign in to comment.