diff --git a/Cargo.toml b/Cargo.toml index 141ed02..ec6904f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,13 +9,17 @@ edition = "2018" [dependencies] clap = "2" futures = "0.3" +futures-lite = "0.1.10" +libc = "0.2.74" +once_cell = "1.4.0" +vec-arena = "0.5" # bwlim-test-tokio tokio = { version = "0.2", features = ["full"] } tokio-util = { version = "0.3", features = ["compat"] } # bwlim-test-async -async-io = "0.1" +async-io = "0.2" smol = "0.3" [[bin]] diff --git a/src/bin/bwlim-test-async.rs b/src/bin/bwlim-test-async.rs index 3a98110..c88c53e 100644 --- a/src/bin/bwlim-test-async.rs +++ b/src/bin/bwlim-test-async.rs @@ -1,4 +1,6 @@ use bwlim::testing::*; +use bwlim::RLAsync; +use bwlim::util::RW; use std::net::{SocketAddr, TcpStream, TcpListener}; use std::process::{Command, Stdio}; @@ -12,7 +14,7 @@ use async_io::Async; use smol::Task; async fn async_main() { - let (test_bytes, listen, connect, host) = get_args(); + let (test_bytes, listen, connect, host, rate_limit) = get_args(); let listen_addr = SocketAddr::new("127.0.0.1".parse().unwrap(), listen); let connect_addr = SocketAddr::new("127.0.0.1".parse().unwrap(), connect); @@ -39,7 +41,12 @@ async fn async_main() { futures::select! { val = listener.accept().fuse() => { let (socket, _) = val.unwrap(); - workers.push(Task::spawn(server_thread(socket))); + let thread = if rate_limit { + Task::spawn(server_thread(RLAsync::new(RW(socket.into_inner().unwrap())).unwrap())) + } else { + Task::spawn(server_thread(socket)) + }; + workers.push(thread); }, _ = is_shutdown.next() => { break; @@ -53,12 +60,20 @@ async fn async_main() { println!("server: shutting down"); }); + thread::sleep(Duration::from_millis(1000)); // TODO: get rid of this + // client threads let clients = [0; 2].iter().map(|_| { Task::spawn(async move { let mut stream = Async::::connect(connect_addr).await.unwrap(); - client_thread(&mut stream, test_bytes).await; - stream.close().await.unwrap(); + if rate_limit { + let mut stream = RLAsync::new(RW(stream.into_inner().unwrap())).unwrap(); + client_thread(&mut stream, test_bytes).await; + stream.close().await.unwrap(); + } else { + client_thread(&mut stream, test_bytes).await; + stream.close().await.unwrap(); + } }) }).collect::>(); future::join_all(clients).await; diff --git a/src/bin/bwlim-test-tokio.rs b/src/bin/bwlim-test-tokio.rs index d512ab0..932ccad 100644 --- a/src/bin/bwlim-test-tokio.rs +++ b/src/bin/bwlim-test-tokio.rs @@ -13,7 +13,7 @@ use tokio_util::compat::*; #[tokio::main] async fn main() { - let (test_bytes, listen, connect, host) = get_args(); + let (test_bytes, listen, connect, host, _) = get_args(); let listen_addr = SocketAddr::new("127.0.0.1".parse().unwrap(), listen); let connect_addr = SocketAddr::new("127.0.0.1".parse().unwrap(), connect); diff --git a/src/lib.rs b/src/lib.rs index d5818d3..2c35469 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,30 +1,181 @@ +pub mod limit; +pub mod reactor; +pub mod stats; pub mod testing; +pub mod util; +pub mod sys; + +use std::fmt::Debug; +use std::future::Future; +use std::io::{self, IoSlice, IoSliceMut, Read, Write}; +use std::pin::Pin; +use std::mem::ManuallyDrop; +use std::net::{Shutdown, TcpStream}; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use futures_lite::io::{AsyncRead, AsyncWrite}; +use futures_lite::{future, pin}; + +use crate::limit::RateLimited; +use crate::util::RorW; +use crate::reactor::{Reactor, Source}; +use crate::sys::*; -use std::time::Instant; #[derive(Debug)] -pub struct BWStats { - start: Instant, - /** cumulative time, cumulative space */ - stats: Vec<(u128, u128)>, -} - -impl BWStats { - pub fn new() -> BWStats { - let mut stats = Vec::new(); - stats.push((0, 0)); - BWStats { - start: Instant::now(), - stats: stats, - } - } - - pub fn add(self: &mut Self, n: usize) { - let prev = self.stats.last().unwrap().1; - self.stats.push((self.start.elapsed().as_micros(), prev + (n as u128))); - } - - pub fn last(self: &Self) -> &(u128, u128) { - self.stats.last().unwrap() - } +pub struct RLAsync { + source: Arc>, +} + +impl RLAsync { + pub fn new(io: T) -> io::Result> where T: RorW + Send + Sync + 'static { + Ok(RLAsync { + source: Reactor::get().insert_io(io)?, + }) + } +} + +impl Drop for RLAsync { + fn drop(&mut self) { + // Deregister and ignore errors because destructors should not panic. + let _ = Reactor::get().remove_io(&self.source); + } +} + +// copied from async-io, except: +// self.get_mut() replaced with lock() / RateLimited +// TODO: figure out a way to de-duplicate with them +impl RLAsync { + pub async fn readable(&self) -> io::Result<()> { + self.source.readable().await + } + pub async fn writable(&self) -> io::Result<()> { + self.source.writable().await + } + pub async fn read_with_mut( + &mut self, + op: impl FnMut(&mut RateLimited) -> io::Result, + ) -> io::Result { + let mut op = op; + loop { + // If there are no blocked readers, attempt the read operation. + if !self.source.readers_registered() { + let mut inner = self.source.inner.lock().unwrap(); + match op(&mut inner) { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + res => return res, + } + } + // Wait until the I/O handle becomes readable. + optimistic(self.readable()).await?; + } + } + pub async fn write_with_mut( + &mut self, + op: impl FnMut(&mut RateLimited) -> io::Result, + ) -> io::Result { + let mut op = op; + loop { + // If there are no blocked readers, attempt the write operation. + if !self.source.writers_registered() { + let mut inner = self.source.inner.lock().unwrap(); + match op(&mut inner) { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + res => return res, + } + } + // Wait until the I/O handle becomes writable. + optimistic(self.writable()).await?; + } + } +} + +// copied from async-io +impl AsyncRead for RLAsync { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + poll_future(cx, self.read_with_mut(|io| io.read(buf))) + } + + fn poll_read_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &mut [IoSliceMut<'_>], + ) -> Poll> { + poll_future(cx, self.read_with_mut(|io| io.read_vectored(bufs))) + } +} + +// copied from async-io +impl AsyncWrite for RLAsync +where + T: AsRawSource +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + poll_future(cx, self.write_with_mut(|io| io.write(buf))) + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + poll_future(cx, self.write_with_mut(|io| io.write_vectored(bufs))) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + poll_future(cx, self.write_with_mut(|io| io.flush())) + } + + fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + let inner = self.source.inner.lock().unwrap(); + Poll::Ready(shutdown_write(inner.inner.as_raw_source())) + } +} + +// copied from async-io +fn poll_future(cx: &mut Context<'_>, fut: impl Future) -> Poll { + pin!(fut); + fut.poll(cx) +} + +// copied from async-io +async fn optimistic(fut: impl Future>) -> io::Result<()> { + let mut polled = false; + pin!(fut); + + future::poll_fn(|cx| { + if !polled { + polled = true; + fut.as_mut().poll(cx) + } else { + Poll::Ready(Ok(())) + } + }) + .await +} + +// copied from async-io +fn shutdown_write(raw: RawSource) -> io::Result<()> { + // This may not be a TCP stream, but that's okay. All we do is attempt a `shutdown()` on the + // raw descriptor and ignore errors. + let stream = unsafe { + ManuallyDrop::new( + TcpStream::from_raw_source(raw), + ) + }; + + // If the socket is a TCP stream, the only actual error can be ENOTCONN. + match stream.shutdown(Shutdown::Write) { + Err(err) if err.kind() == io::ErrorKind::NotConnected => Err(err), + _ => Ok(()), + } } diff --git a/src/limit.rs b/src/limit.rs new file mode 100644 index 0000000..4977ce5 --- /dev/null +++ b/src/limit.rs @@ -0,0 +1,494 @@ +//! Data structures to help perform rate limiting. + +use std::fmt::Debug; +use std::result::Result; +use std::cmp; +use std::io::{self, Read, Write, ErrorKind}; +use std::collections::VecDeque; + +use crate::util::RorW; +use self::Status::*; + +/// Generic buffer for rate-limiting, both reading and writing. +#[derive(Debug)] +pub struct RLBuf { + /// Buffer to help determine demand, for rate-limiting. + buf: VecDeque, + /// Index into `buf`, of the first data not allowed to be used. Everything + /// before it will be used upon request. + /// + /// "Used" means `read` by a higher layer, or `write` by a lower layer. + allowance: usize, + /// Amount of data read out since last call to `reset_usage`. + last_used: usize, +} + +impl RLBuf { + /** Create a new `RLBuf` with the given lower bound on the initial capacity. + + The actual capacity can be got later with `get_demand_cap`. + */ + pub fn new_lb(init: usize) -> RLBuf { + RLBuf { + buf: VecDeque::with_capacity(init), + allowance: 0, + last_used: 0, + } + } + + /** Get the current demand. + + For higher-level rate-limiting logic, to determine how to rate-limit. + */ + pub fn get_demand(&self) -> usize { + self.buf.len() + } + + /** Get the current buffer capacity, i.e. allocated memory. + + For higher-level rate-limiting logic, to monitor resource usage, to help it + analyse how efficient it is. + */ + pub fn get_demand_cap(&self) -> usize { + self.buf.capacity() + } + + pub fn get_demand_remaining(&self) -> usize { + self.get_demand_cap() - self.get_demand() + } + + /** Add the allowance, which must not be greater than the demand. + + For higher-level rate-limiting logic, as it performs the rate-limiting. + */ + pub fn add_allowance(&mut self, allowance: usize) { + if self.allowance + allowance > self.get_demand() { + panic!("allowance > demand"); + } + self.allowance += allowance + } + + /** Return the latest usage figures & reset them back to zero. + + The first number is the number of allowed bytes that were unused. + The second number is the number of allowed bytes that were used. + + For higher-level rate-limiting logic, before rate-limiting is performed, to + detect consumers that consumed even more slowly than the rate limit in the + previous cycle. In response to this, the higher-level logic should give less + allowance for this consumer, to avoid waste. + */ + pub fn reset_usage(&mut self) -> (usize, usize) { + let wasted = self.allowance; + let used = self.last_used; + self.allowance = 0; + self.last_used = 0; + (wasted, used) + } + + fn record_demand(&mut self, buf: &[u8]) { + for &i in buf { + self.buf.push_back(i); + } + } + + fn take_allowance(&mut self, taken: usize) { + if taken > self.allowance { + panic!("taken > allowance"); + } + self.allowance -= taken; + self.last_used += taken; + } + + fn consume_read(&mut self, buf: &mut [u8]) -> usize { + let to_drain = cmp::min(buf.len(), self.allowance); + let bb = self.buf.drain(..to_drain).collect::>(); + for (i, b) in bb.into_iter().enumerate() { + buf[i] = b; + } + self.take_allowance(to_drain); + to_drain + } + + fn consume_write(&mut self, sz: usize, mut write: F) -> Result + where F: FnMut (&[u8]) -> Result { + let mut used = 0; + let mut res = Ok(()); + let (a, b) = self.buf.as_slices(); + let to_drain = cmp::min(a.len(), sz); + match write(&a[..to_drain]) { + Ok(n) => { + used += n; + if n == a.len() { + let to_drain = cmp::min(b.len(), sz - used); + match write(&b[..to_drain]) { + Ok(n) => { + used += n; + }, + Err(e) => { + res = Err(e); + } + } + } + }, + Err(e) => { + res = Err(e); + }, + } + self.buf.drain(..used); + self.take_allowance(used); + match res { + Ok(()) => Ok(used), + Err(e) => Err(e), + } + } +} + +fn unwrap_err_or(r: Result, de: E) -> E { + match r { + Ok(_) => de, + Err(e) => e, + } +} + +#[derive(Debug, PartialEq, Eq)] +enum Status { + SOpen, + SOk, // eof + SErr +} + +/** Rate-limited asynchronous analogue of `std::io::BufReader` + `std::io::BufWriter`. + +You **must** call `flush()` before dropping this (which closes the stream). +This is even more important than doing so on `BufWriter` - if not, you may lose +data. See https://internals.rust-lang.org/t/asynchronous-destructors/11127/49 +for an in-depth explanation. +*/ +#[derive(Debug)] +pub struct RateLimited where T: ?Sized { + rstatus: Status, + pub(crate) rbuf: RLBuf, + wstatus: Status, + pub(crate) wbuf: RLBuf, + pub(crate) inner: T, +} + +impl RateLimited { + /** Create a new `RateLimited` with the given initial capacity. + + The inner stream must already be in non-blocking mode. + */ + pub fn new_lb(inner: T, init: usize) -> RateLimited { + RateLimited { + inner: inner, + rstatus: SOpen, + rbuf: RLBuf::new_lb(init), + wstatus: SOpen, + wbuf: RLBuf::new_lb(init), + } + } +} + +impl RateLimited where T: RorW + ?Sized { + /** Do a pre-read. + + That is, do a non-blocking read from the underlying handle, filling up the + remaining part of `rbuf`. + + This is to be used by higher-level code, before it performs the rate-limiting. + */ + pub fn pre_read(&mut self) { + match self.rstatus { + SOpen => { + // TODO: if allowance is 0, then automatically grow the buffer capacity + let remain = self.rbuf.get_demand_remaining(); + let mut buf = [0].repeat(remain); // TODO: optimise with uninit + match self.inner.read(&mut buf) { // TODO: assert non-blocking + Ok(0) => { + self.rstatus = SOk; + }, + Ok(n) => { + self.rbuf.record_demand(&buf[..n]); + }, + Err(e) => match e.kind() { + ErrorKind::WouldBlock => (), + ErrorKind::Interrupted => (), + _ => { + // println!("pre_read: {:?}", e); + self.rstatus = SErr; + } + }, + } + }, + _ => (), // already finished + } + } + + pub fn is_readable(&self) -> bool { + self.rstatus != SOpen || self.rbuf.allowance > 0 + } + + /** Do a post-write. + + That is, do a non-blocking write to the underlying handle, up to the current + allowance of `wbuf`. + + This is to be used by higher-level code, after it performs the rate-limiting. + */ + pub fn post_write(&mut self) -> bool { + match self.post_write_exact(self.wbuf.allowance) { + None => false, + Some(n) => n > 0, + } + } + + pub fn is_writable(&self) -> bool { + self.wstatus == SOpen && self.wbuf.get_demand_remaining() > 0 + } + + // extra param is exposed for testing only + fn post_write_exact(&mut self, sz: usize) -> Option { + match self.wbuf.get_demand() { + 0 => None, + _ => match self.wbuf.allowance { + 0 => None, + _ => { + let w = &mut self.inner; + match self.wbuf.consume_write(sz, |b| w.write(b)) { + Ok(n) => Some(n), + Err(_) => { + self.wstatus = SErr; + None + } + } + } + } + } + } +} + +impl Read for RateLimited where T: Read { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + match self.rbuf.get_demand() { + 0 => match self.rstatus { + SOpen => Err(io::Error::new(ErrorKind::WouldBlock, "")), + SOk => Ok(0), + SErr => Err(unwrap_err_or(self.inner.read(&mut []), io::Error::new(ErrorKind::Other, "Ok after Err"))), + }, + _ => match self.rbuf.allowance { + 0 => Err(io::Error::new(ErrorKind::WouldBlock, "")), + _ => Ok(self.rbuf.consume_read(buf)), + } + } + } +} + +impl Write for RateLimited where T: Write { + fn write(&mut self, buf: &[u8]) -> io::Result { + match self.wstatus { + SOpen => { + // TODO: if allowance is 0, then automatically grow the buffer capacity + let remain = self.wbuf.get_demand_remaining(); + match remain { + 0 => Err(io::Error::new(ErrorKind::WouldBlock, "")), + _ => { + let n = cmp::min(buf.len(), remain); + self.wbuf.record_demand(&buf[..n]); + Ok(n) + } + } + }, + SOk => Ok(0), + SErr => Err(unwrap_err_or(self.inner.write(&mut []), io::Error::new(ErrorKind::Other, "Ok after Err"))), + } + } + + fn flush(&mut self) -> io::Result<()> { + match self.wstatus { + SErr => + // if there was an error, wbuf might not have been consumed, so output error even if wbuf is non-empty + Err(unwrap_err_or(self.inner.write(&mut []), io::Error::new(ErrorKind::Other, "Ok after Err"))), + _ => match self.wbuf.get_demand() { + 0 => Ok(()), + _ => Err(io::Error::new(ErrorKind::WouldBlock, "")), // something else is responsible for calling post_write + } + } + } +} + +#[cfg(test)] +mod tests { + use std::fs::*; + use std::fmt::Debug; + use std::io; + use std::io::*; + use std::assert; + + use crate::sys::*; + use crate::util::*; + + use super::*; + + fn assert_would_block(res: io::Result) where T: Debug { + match res { + Err(e) => assert_eq!(e.kind(), ErrorKind::WouldBlock), + x => { + println!("{:?}", x); + assert!(false); + }, + } + } + + fn assert_error(res: io::Result) where T: Debug { + match res { + Err(e) => match e.kind() { + ErrorKind::WouldBlock => assert!(false), + ErrorKind::Interrupted => assert!(false), + _ => (), + }, + x => { + println!("{:?}", x); + assert!(false); + }, + } + } + + fn assert_num_bytes(res: io::Result, s: usize) { + match res { + Ok(n) => assert_eq!(n, s), + x => { + println!("{:?}", x); + assert!(false); + }, + } + } + + // TODO: /dev/null etc is not a RawSocket in windows + + #[test] + fn read_eof_ok() -> io::Result<()> { + let file = File::open("/dev/null")?; + set_non_blocking(file.as_raw_source())?; + let mut bf = RateLimited::new_lb(RO(file), 1); + let mut buf = [0].repeat(1); + assert_would_block(bf.read(&mut buf)); + bf.pre_read(); + assert_num_bytes(bf.read(&mut buf), 0); // eof + Ok(()) + } + + #[test] + fn read_zero_err() -> io::Result<()> { + let file = File::open("/dev/zero")?; + set_non_blocking(file.as_raw_source())?; + let unsafe_f = unsafe { File::from_raw_source(file.as_raw_source()) }; + + let sd = 4095; // in case VecDeque changes implementation, this needs to be changed + let sx = 1024; + let sy = 1024; + let mut bf = RateLimited::new_lb(RO(file), sd); + assert_eq!(sd, bf.rbuf.get_demand_cap()); + assert_eq!(0, bf.rbuf.get_demand()); + let mut buf = [0].repeat(sx); + + assert_would_block(bf.read(&mut buf)); + bf.pre_read(); + assert_eq!(sd, bf.rbuf.get_demand()); + assert_would_block(bf.read(&mut buf)); + + bf.rbuf.add_allowance(sx); + assert_num_bytes(bf.read(&mut buf), sx); + assert_eq!(sd - sx, bf.rbuf.get_demand()); + + bf.rbuf.add_allowance(sx + sy); + assert_num_bytes(bf.read(&mut buf), sx); + assert_eq!(sd - sx - sx, bf.rbuf.get_demand()); + + assert_eq!(bf.rbuf.reset_usage(), (sy, sx + sy)); + // sy bytes of allowance were wasted + assert_would_block(bf.read(&mut buf)); + + assert_eq!(bf.rbuf.reset_usage(), (0, 0)); + assert_eq!(sd - sx - sx, bf.rbuf.get_demand()); + assert_eq!(SOpen, bf.rstatus); + + drop(unsafe_f); // close f, to force an error on the underlying stream + bf.pre_read(); + assert_eq!(sd - sx - sx, bf.rbuf.get_demand()); + assert_eq!(SErr, bf.rstatus); + bf.rbuf.add_allowance(sd - sx - sx); + assert_num_bytes(bf.read(&mut buf), sx); + assert!(sd - sx - sx - sx <= sx); // otherwise next step fails + assert_num_bytes(bf.read(&mut buf), sd - sx - sx - sx); + assert_error(bf.read(&mut buf)); + assert_error(bf.read(&mut buf)); + assert_error(bf.read(&mut buf)); + + Ok(()) + } + + #[test] + fn write_eof_err() -> io::Result<()> { + let file = File::open("/dev/zero")?; + set_non_blocking(file.as_raw_source())?; + let mut bf = RateLimited::new_lb(WO(file), 1); + let buf = [0].repeat(1); + assert_num_bytes(bf.write(&buf), 1); + bf.post_write(); + assert_eq!(bf.wstatus, SOpen); + bf.wbuf.add_allowance(1); + bf.post_write(); + assert_eq!(bf.wstatus, SErr); + assert_error(bf.flush()); + assert_error(bf.flush()); + assert_error(bf.flush()); + Ok(()) + } + + #[test] + fn write_null_ok() -> io::Result<()> { + let file = OpenOptions::new().write(true).open("/dev/null")?; + set_non_blocking(file.as_raw_source())?; + + let sd = 4095; // in case VecDeque changes implementation, this needs to be changed + let sx = 1024; + let sy = 1024; + let mut bf = RateLimited::new_lb(WO(file), sd); + assert_eq!(sd, bf.wbuf.get_demand_cap()); + assert_eq!(0, bf.wbuf.get_demand()); + let buf = [0].repeat(sd + sx); + + bf.flush()?; + assert_num_bytes(bf.write(&buf), sd); + assert_eq!(sd, bf.wbuf.get_demand()); + assert_would_block(bf.write(&buf[sd..])); + + bf.wbuf.add_allowance(sx); + bf.post_write(); + assert_eq!(sd - sx, bf.wbuf.get_demand()); + + bf.wbuf.add_allowance(sx + sy); + bf.post_write_exact(sx); + assert_eq!(sd - sx - sx, bf.wbuf.get_demand()); + + assert_eq!(bf.wbuf.reset_usage(), (sy, sx + sy)); + // sy bytes of allowance were wasted + assert_eq!(bf.post_write_exact(0), None); + + assert_eq!(bf.wbuf.reset_usage(), (0, 0)); + assert_eq!(sd - sx - sx, bf.wbuf.get_demand()); + assert_eq!(SOpen, bf.wstatus); + + assert_num_bytes(bf.write(&buf), sx + sx); + assert_eq!(sd, bf.wbuf.get_demand()); + assert_eq!(SOpen, bf.wstatus); + bf.wbuf.add_allowance(sd); + assert_would_block(bf.flush()); + assert_would_block(bf.flush()); + assert_would_block(bf.flush()); + bf.post_write(); + assert_eq!(0, bf.wbuf.get_demand()); + bf.flush() + } +} diff --git a/src/reactor.rs b/src/reactor.rs new file mode 100644 index 0000000..6522c76 --- /dev/null +++ b/src/reactor.rs @@ -0,0 +1,296 @@ +/*! Reactor for rate-limited streams. + +This is only a partial reactor to add rate-limiting to byte streams; it does +not cover other types of events like connect/listen. For that, use a "normal" +reactor like the one in async-io. + +Currently this reactor runs as an asynchronous task inside the full reactor of +async-io. This could in theory be changed, and we could run the main loop as a +standalone thread with blocking sleeps. It's not clear that this would give a +great benefit however, so keeping the current solution works OK for now. + +*/ + +use std::panic; +use std::sync::{atomic::*, Arc, Mutex}; +use std::task::{Poll, Waker}; +use std::time::{Duration, Instant}; +use std::io; +use once_cell::sync::Lazy; +use vec_arena::Arena; + +use futures_lite::*; +use async_io::Timer; +use smol::Task; + +use crate::limit::RateLimited; +use crate::util::RorW; + + +#[derive(Debug)] +pub(crate) struct Reactor { + /// Last tick that we rate-limited on. + last_tick: Mutex, + + /// Ticker bumped before polling. + ticker: AtomicUsize, + + /// Registered sources. + sources: Mutex>>>, +} + +impl Reactor { + pub(crate) fn get() -> &'static Reactor { + static REACTOR: Lazy<(Reactor, Task<()>)> = Lazy::new(|| { + let reactor = Reactor { + last_tick: Mutex::new(Instant::now()), + ticker: AtomicUsize::new(0), + sources: Mutex::new(Arena::new()), + }; + + let task = Task::spawn(async { + Reactor::get().main_loop_async().await + }); + + (reactor, task) + }); + &(REACTOR.0) + } + + /// Registers an I/O source in the reactor. + pub(crate) fn insert_io( + &self, + inner: T, + ) -> io::Result>> + where T: RorW + Send + Sync + 'static + { + let mut sources = self.sources.lock().unwrap(); + let key = sources.next_vacant(); + let source = Arc::new(Source { + inner: Mutex::new(RateLimited::new_lb(inner, 65536)), + key, + wakers: Mutex::new(Wakers { + tick_readable: 0, + tick_writable: 0, + readers: Vec::new(), + writers: Vec::new(), + }), + wakers_registered: AtomicU8::new(0), + }); + sources.insert(source.clone()); + Ok(source) + } + + /// Deregisters an I/O source from the reactor. + pub(crate) fn remove_io(&self, source: &Source) -> io::Result<()> { + let mut sources = self.sources.lock().unwrap(); + sources.remove(source.key); + Ok(()) + } + + pub(crate) async fn main_loop_async(&self) { + loop { + let mut wakers = Vec::new(); + + let tick_length = Duration::from_millis(1); // rate-limit every 1 ms + let target = *self.last_tick.lock().unwrap() + tick_length; + let now = Instant::now(); + if target > now { + Timer::after(target - now).await; + } else { + println!("rwlim reactor running slow: {:?} {:?} {:?}", tick_length, now, target); + } + let mut last_tick = self.last_tick.lock().unwrap(); + if Instant::now() >= target { + let tick = self + .ticker + .fetch_add(1, Ordering::SeqCst) + .wrapping_add(1); + + for (key, source) in self.sources.lock().unwrap().iter_mut() { + let rl = &mut *source.inner.lock().unwrap(); + if rl.inner.can_read() { + rl.pre_read(); + rl.rbuf.reset_usage(); + // TODO: actually perform rate-limiting. the current code ought not + // to be (but is) much slower than the async-io version. + rl.rbuf.add_allowance(rl.rbuf.get_demand()); + if rl.is_readable() { + self.react_evt(&mut wakers, &**source, true, false, tick); + } + } + if rl.inner.can_write() { + rl.wbuf.reset_usage(); + // TODO: actually perform rate-limiting. the current code ought not + // to be (but is) much slower than the async-io version. + rl.wbuf.add_allowance(rl.wbuf.get_demand()); + rl.post_write(); + if rl.is_writable() { + self.react_evt(&mut wakers, &**source, false, true, tick); + } + } + } + *last_tick = Instant::now(); + } + drop(last_tick); + + // Wake up ready tasks. + for waker in wakers { + // Don't let a panicking waker blow everything up. + let _ = panic::catch_unwind(|| waker.wake()); + } + } + } + + // copied from async-io Reactor.react, except references to poller removed + fn react_evt(&self, wakers: &mut Vec, source: &Source, ev_readable: bool, ev_writable: bool, tick: usize) { + let mut w = source.wakers.lock().unwrap(); + + // Wake readers if a readability event was emitted. + if ev_readable { + w.tick_readable = tick; + wakers.append(&mut w.readers); + source + .wakers_registered + .fetch_and(!READERS_REGISTERED, Ordering::SeqCst); + } + + // Wake writers if a writability event was emitted. + if ev_writable { + w.tick_writable = tick; + wakers.append(&mut w.writers); + source + .wakers_registered + .fetch_and(!WRITERS_REGISTERED, Ordering::SeqCst); + } + } +} + +// copied from async-io, except inner field +#[derive(Debug)] +pub struct Source where T: ?Sized { + /// The key of this source obtained during registration. + key: usize, + + /// Tasks interested in events on this source. + wakers: Mutex, + + /// Whether there are wakers interrested in events on this source. + wakers_registered: AtomicU8, + + pub(crate) inner: Mutex>, +} + +// copied from async-io. TODO: figure out a way to deduplicate +/// Tasks interested in events on a source. +#[derive(Debug)] +struct Wakers { + /// Last reactor tick that delivered a readability event. + tick_readable: usize, + + /// Last reactor tick that delivered a writability event. + tick_writable: usize, + + /// Tasks waiting for the next readability event. + readers: Vec, + + /// Tasks waiting for the next writability event. + writers: Vec, +} + +const READERS_REGISTERED: u8 = 1 << 0; +const WRITERS_REGISTERED: u8 = 1 << 1; + +// copied from async-io, except references to reactor.poller removed +// TODO: figure out a way to deduplicate +impl Source { + /// Waits until the I/O source is readable. + pub(crate) async fn readable(&self) -> io::Result<()> { + let mut ticks = None; + + future::poll_fn(|cx| { + let mut w = self.wakers.lock().unwrap(); + + // Check if the reactor has delivered a readability event. + if let Some((a, b)) = ticks { + // If `tick_readable` has changed to a value other than the old reactor tick, that + // means a newer reactor tick has delivered a readability event. + if w.tick_readable != a && w.tick_readable != b { + return Poll::Ready(Ok(())); + } + } + + // If there are no other readers, re-register in the reactor. + if w.readers.is_empty() { + self.wakers_registered + .fetch_or(READERS_REGISTERED, Ordering::SeqCst); + } + + // Register the current task's waker if not present already. + if w.readers.iter().all(|w| !w.will_wake(cx.waker())) { + w.readers.push(cx.waker().clone()); + } + + // Remember the current ticks. + if ticks.is_none() { + ticks = Some(( + Reactor::get().ticker.load(Ordering::SeqCst), + w.tick_readable, + )); + } + + Poll::Pending + }) + .await + } + + pub(crate) fn readers_registered(&self) -> bool { + self.wakers_registered.load(Ordering::SeqCst) & READERS_REGISTERED + == READERS_REGISTERED + } + + /// Waits until the I/O source is writable. + pub(crate) async fn writable(&self) -> io::Result<()> { + let mut ticks = None; + + future::poll_fn(|cx| { + let mut w = self.wakers.lock().unwrap(); + + // Check if the reactor has delivered a writability event. + if let Some((a, b)) = ticks { + // If `tick_writable` has changed to a value other than the old reactor tick, that + // means a newer reactor tick has delivered a writability event. + if w.tick_writable != a && w.tick_writable != b { + return Poll::Ready(Ok(())); + } + } + + // If there are no other writers, re-register in the reactor. + if w.writers.is_empty() { + self.wakers_registered + .fetch_or(WRITERS_REGISTERED, Ordering::SeqCst); + } + + // Register the current task's waker if not present already. + if w.writers.iter().all(|w| !w.will_wake(cx.waker())) { + w.writers.push(cx.waker().clone()); + } + + // Remember the current ticks. + if ticks.is_none() { + ticks = Some(( + Reactor::get().ticker.load(Ordering::SeqCst), + w.tick_writable, + )); + } + + Poll::Pending + }) + .await + } + + pub(crate) fn writers_registered(&self) -> bool { + self.wakers_registered.load(Ordering::SeqCst) & WRITERS_REGISTERED + == WRITERS_REGISTERED + } +} diff --git a/src/stats.rs b/src/stats.rs new file mode 100644 index 0000000..8c50e9e --- /dev/null +++ b/src/stats.rs @@ -0,0 +1,30 @@ +//! Data structures to measure bandwidth stats. + +use std::time::Instant; + +#[derive(Debug)] +pub struct BWStats { + start: Instant, + /** cumulative time, cumulative space */ + stats: Vec<(u128, u128)>, +} + +impl BWStats { + pub fn new() -> BWStats { + let mut stats = Vec::new(); + stats.push((0, 0)); + BWStats { + start: Instant::now(), + stats: stats, + } + } + + pub fn add(self: &mut Self, n: usize) { + let prev = self.stats.last().unwrap().1; + self.stats.push((self.start.elapsed().as_micros(), prev + (n as u128))); + } + + pub fn last(self: &Self) -> &(u128, u128) { + self.stats.last().unwrap() + } +} diff --git a/src/sys.rs b/src/sys.rs new file mode 100644 index 0000000..ac3dd8f --- /dev/null +++ b/src/sys.rs @@ -0,0 +1,96 @@ +//! Cross-platform type and trait aliases. + +pub(crate) use self::sys::*; +use std::io; + +/// Cross-platform alias to `AsRawFd` (Unix) or `AsRawSocket` (Windows). +/// +/// Note: this is a slight hack around the rust type system. You should not +/// implement this trait directly, e.g. for a wrapper type, that will not work. +/// Instead you have to implement both `AsRawFd` and `AsRawSocket` separately. +pub trait AsRawSource { + /// Cross-platform alias to `AsRawFd::as_raw_fd` (Unix) or `AsRawSocket` (Windows). + fn as_raw_source(&self) -> RawSource; +} + +/// Cross-platform alias to `FromRawFd` (Unix) or `FromRawSocket` (Windows). +pub trait FromRawSource { + /// Cross-platform alias to `FromRawFd::from_raw_fd` (Unix) or `FromRawSocket::from_raw_socket` (Windows). + unsafe fn from_raw_source(h: RawSource) -> Self; +} + +#[cfg(unix)] +mod sys { + use super::*; + use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; + + pub(crate) type RawSource = RawFd; + + impl AsRawSource for T where T: AsRawFd { + fn as_raw_source(&self) -> RawSource { + self.as_raw_fd() + } + } + + impl FromRawSource for T where T: FromRawFd { + unsafe fn from_raw_source(h: RawSource) -> Self { + Self::from_raw_fd(h) + } + } + + /// Calls a libc function and results in `io::Result`. + macro_rules! syscall { + ($fn:ident $args:tt) => {{ + let res = unsafe { libc::$fn $args }; + if res == -1 { + Err(std::io::Error::last_os_error()) + } else { + Ok(res) + } + }}; + } + + pub fn set_non_blocking(fd: RawSource) -> io::Result<()> { + // Put the file descriptor in non-blocking mode. + let flags = syscall!(fcntl(fd, libc::F_GETFL))?; + syscall!(fcntl(fd, libc::F_SETFL, flags | libc::O_NONBLOCK))?; + Ok(()) + } +} + +#[cfg(windows)] +mod sys { + use super::*; + use std::os::windows::io::{AsRawSocket, FromRawSocket, RawSocket}; + use winapi::um::winsock2; + + pub(crate) type RawSource = RawSocket; + + impl AsRawSource for T where T: AsRawSocket { + fn as_raw_source(&self) -> RawSource { + self.as_raw_socket() + } + } + + impl FromRawSource for T where T: FromRawSocket { + unsafe fn from_raw_source(h: RawSource) -> Self { + Self::from_raw_socket(h) + } + } + + pub fn set_non_blocking(sock: RawSource) -> io::Result<()> { + unsafe { + let mut nonblocking = true as libc::c_ulong; + let res = winsock2::ioctlsocket( + sock as winsock2::SOCKET, + winsock2::FIONBIO, + &mut nonblocking, + ); + if res != 0 { + return Err(io::Error::last_os_error()); + } + } + + Ok(()) + } +} diff --git a/src/testing.rs b/src/testing.rs index cf39224..2d1dc32 100644 --- a/src/testing.rs +++ b/src/testing.rs @@ -1,6 +1,6 @@ // TODO: measure latency between client and corresponding server-worker -use crate::BWStats; +use crate::stats::BWStats; use futures::io::{AsyncReadExt, AsyncWriteExt}; use clap::{App, Arg}; @@ -37,6 +37,7 @@ pub async fn client_thread(stream: &mut W, len: usize) while remain > 0 { let bytes_to_write = if remain >= TEST_CHUNK.len() { TEST_CHUNK.len() } else { remain }; stream.write_all(&TEST_CHUNK[0..bytes_to_write]).await.unwrap(); + stream.flush().await.unwrap(); //println!("client: wrote to stream; success={:?}", result.is_ok()); remain -= bytes_to_write; stats.add(bytes_to_write); @@ -47,7 +48,7 @@ pub async fn client_thread(stream: &mut W, len: usize) println!("client: {} B sent in {} us: ~{} MBps", s, t, if *t == 0 { 0 } else { s / t }); } -pub fn get_args() -> (usize, u16, u16, Option) { +pub fn get_args() -> (usize, u16, u16, Option, bool) { let m = App::new("Bandwidth limit tester") .version("0.0") .author("Ximin Luo ") @@ -56,14 +57,14 @@ pub fn get_args() -> (usize, u16, u16, Option) { .short("p") .long("port") .value_name("PORT") - .help("Local port for listening") - .default_value("6397")) + .default_value("6397") + .help("Local port for listening")) .arg(Arg::with_name("bytes") .short("b") .long("bytes") .value_name("NUM") - .help("Number of bytes to send per client") - .default_value("16777216")) + .default_value("16777216") + .help("Number of bytes to send per client")) .arg(Arg::with_name("connect") .short("c") .long("connect") @@ -75,6 +76,12 @@ pub fn get_args() -> (usize, u16, u16, Option) { .value_name("HOST") .help("Remote ssh host for setting up an ssh loopback proxy. \ Use ssh_config if you need to configure more things.")) + .arg(Arg::with_name("rate_limit") + .short("r") + .long("rate-limit") + .value_name("BOOL") + .default_value("true") + .help("Whether to perform rate-limiting")) .get_matches(); let test_bytes = m.value_of("bytes").unwrap().parse::().unwrap(); @@ -87,7 +94,7 @@ pub fn get_args() -> (usize, u16, u16, Option) { }, _ => panic!("--host and --connect must be both set or unset"), }; - (test_bytes, listen, connect, host.map(str::to_string)) + (test_bytes, listen, connect, host.map(str::to_string), m.value_of("rate_limit").unwrap().parse::().unwrap()) } pub fn get_ssh_args(host: String, connect: u16, listen: u16) -> Vec { @@ -99,4 +106,5 @@ pub fn get_ssh_args(host: String, connect: u16, listen: u16) -> Vec { format!("-Rlocalhost:{}:localhost:{}", remote, listen), "cat".to_string() // so it responds properly to EOF on stdin; -N ignores it*/ ] + // FIXME: for some reason ssh fails to connect with --bytes < ~2.8MB, for both tokio/asyncio } diff --git a/src/util.rs b/src/util.rs new file mode 100644 index 0000000..31ea98b --- /dev/null +++ b/src/util.rs @@ -0,0 +1,107 @@ +//! Various utils. + +use std::fmt::Debug; +use std::io::{self, Read, Write}; +use crate::sys::{AsRawSource, RawSource}; + +/// Trait that unifies `Read` + `Write` to make code easier to write. +pub trait RorW: Read + Write + Debug { + fn can_read(&self) -> bool; + fn can_write(&self) -> bool; +} + +#[derive(Debug)] +#[repr(transparent)] +pub struct RW(pub T); + +impl AsRawSource for RW where T: AsRawSource { + fn as_raw_source(&self) -> RawSource { + self.0.as_raw_source() + } +} + +pub fn as_rw_ref(x: &T) -> &RW { + unsafe { + std::mem::transmute(x) + } +} + +impl Read for RW where T: Read { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.0.read(buf) + } +} + +impl Write for RW where T: Write { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.0.write(buf) + } + fn flush(&mut self) -> io::Result<()> { + self.0.flush() + } +} + +impl RorW for RW where T: Read + Write + Debug { + fn can_read(&self) -> bool { + true + } + fn can_write(&self) -> bool { + true + } +} + +#[derive(Debug)] +#[repr(transparent)] +pub struct RO(pub T); + +impl Read for RO where T: Read { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.0.read(buf) + } +} + +impl Write for RO where T: Write { + fn write(&mut self, _: &[u8]) -> io::Result { + panic!("tried to write a RO") + } + fn flush(&mut self) -> io::Result<()> { + panic!("tried to flush a RO") + } +} + +impl RorW for RO where T: Read + Write + Debug { + fn can_read(&self) -> bool { + true + } + fn can_write(&self) -> bool { + false + } +} + +#[derive(Debug)] +#[repr(transparent)] +pub struct WO(pub T); + +impl Read for WO where T: Read { + fn read(&mut self, _: &mut [u8]) -> io::Result { + panic!("tried to read a WO") + } +} + +impl Write for WO where T: Write { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.0.write(buf) + } + fn flush(&mut self) -> io::Result<()> { + self.0.flush() + } +} + +impl RorW for WO where T: Read + Write + Debug { + fn can_read(&self) -> bool { + false + } + fn can_write(&self) -> bool { + true + } +}