Skip to content

Commit

Permalink
Initial version with no-op rate-limiting
Browse files Browse the repository at this point in the history
  • Loading branch information
infinity0 committed Aug 27, 2020
1 parent 386601d commit 49cfa4b
Show file tree
Hide file tree
Showing 10 changed files with 1,239 additions and 38 deletions.
6 changes: 5 additions & 1 deletion Cargo.toml
Expand Up @@ -9,13 +9,17 @@ edition = "2018"
[dependencies]
clap = "2"
futures = "0.3"
futures-lite = "0.1.10"
libc = "0.2.74"
once_cell = "1.4.0"
vec-arena = "0.5"

# bwlim-test-tokio
tokio = { version = "0.2", features = ["full"] }
tokio-util = { version = "0.3", features = ["compat"] }

# bwlim-test-async
async-io = "0.1"
async-io = "0.2"
smol = "0.3"

[[bin]]
Expand Down
23 changes: 19 additions & 4 deletions src/bin/bwlim-test-async.rs
@@ -1,4 +1,6 @@
use bwlim::testing::*;
use bwlim::RLAsync;
use bwlim::util::RW;

use std::net::{SocketAddr, TcpStream, TcpListener};
use std::process::{Command, Stdio};
Expand All @@ -12,7 +14,7 @@ use async_io::Async;
use smol::Task;

async fn async_main() {
let (test_bytes, listen, connect, host) = get_args();
let (test_bytes, listen, connect, host, rate_limit) = get_args();

let listen_addr = SocketAddr::new("127.0.0.1".parse().unwrap(), listen);
let connect_addr = SocketAddr::new("127.0.0.1".parse().unwrap(), connect);
Expand All @@ -39,7 +41,12 @@ async fn async_main() {
futures::select! {
val = listener.accept().fuse() => {
let (socket, _) = val.unwrap();
workers.push(Task::spawn(server_thread(socket)));
let thread = if rate_limit {
Task::spawn(server_thread(RLAsync::new(RW(socket.into_inner().unwrap())).unwrap()))
} else {
Task::spawn(server_thread(socket))
};
workers.push(thread);
},
_ = is_shutdown.next() => {
break;
Expand All @@ -53,12 +60,20 @@ async fn async_main() {
println!("server: shutting down");
});

thread::sleep(Duration::from_millis(1000)); // TODO: get rid of this

// client threads
let clients = [0; 2].iter().map(|_| {
Task::spawn(async move {
let mut stream = Async::<TcpStream>::connect(connect_addr).await.unwrap();
client_thread(&mut stream, test_bytes).await;
stream.close().await.unwrap();
if rate_limit {
let mut stream = RLAsync::new(RW(stream.into_inner().unwrap())).unwrap();
client_thread(&mut stream, test_bytes).await;
stream.close().await.unwrap();
} else {
client_thread(&mut stream, test_bytes).await;
stream.close().await.unwrap();
}
})
}).collect::<Vec<_>>();
future::join_all(clients).await;
Expand Down
2 changes: 1 addition & 1 deletion src/bin/bwlim-test-tokio.rs
Expand Up @@ -13,7 +13,7 @@ use tokio_util::compat::*;

#[tokio::main]
async fn main() {
let (test_bytes, listen, connect, host) = get_args();
let (test_bytes, listen, connect, host, _) = get_args();

let listen_addr = SocketAddr::new("127.0.0.1".parse().unwrap(), listen);
let connect_addr = SocketAddr::new("127.0.0.1".parse().unwrap(), connect);
Expand Down
201 changes: 176 additions & 25 deletions src/lib.rs
@@ -1,30 +1,181 @@
pub mod limit;
pub mod reactor;
pub mod stats;
pub mod testing;
pub mod util;
pub mod sys;

use std::fmt::Debug;
use std::future::Future;
use std::io::{self, IoSlice, IoSliceMut, Read, Write};
use std::pin::Pin;
use std::mem::ManuallyDrop;
use std::net::{Shutdown, TcpStream};
use std::sync::Arc;
use std::task::{Context, Poll};

use futures_lite::io::{AsyncRead, AsyncWrite};
use futures_lite::{future, pin};

use crate::limit::RateLimited;
use crate::util::RorW;
use crate::reactor::{Reactor, Source};
use crate::sys::*;

use std::time::Instant;

#[derive(Debug)]
pub struct BWStats {
start: Instant,
/** cumulative time, cumulative space */
stats: Vec<(u128, u128)>,
}

impl BWStats {
pub fn new() -> BWStats {
let mut stats = Vec::new();
stats.push((0, 0));
BWStats {
start: Instant::now(),
stats: stats,
}
}

pub fn add(self: &mut Self, n: usize) {
let prev = self.stats.last().unwrap().1;
self.stats.push((self.start.elapsed().as_micros(), prev + (n as u128)));
}

pub fn last(self: &Self) -> &(u128, u128) {
self.stats.last().unwrap()
}
pub struct RLAsync<T> {
source: Arc<Source<T>>,
}

impl<T> RLAsync<T> {
pub fn new(io: T) -> io::Result<RLAsync<T>> where T: RorW + Send + Sync + 'static {
Ok(RLAsync {
source: Reactor::get().insert_io(io)?,
})
}
}

impl<T> Drop for RLAsync<T> {
fn drop(&mut self) {
// Deregister and ignore errors because destructors should not panic.
let _ = Reactor::get().remove_io(&self.source);
}
}

// copied from async-io, except:
// self.get_mut() replaced with lock() / RateLimited
// TODO: figure out a way to de-duplicate with them
impl<T> RLAsync<T> {
pub async fn readable(&self) -> io::Result<()> {
self.source.readable().await
}
pub async fn writable(&self) -> io::Result<()> {
self.source.writable().await
}
pub async fn read_with_mut<R>(
&mut self,
op: impl FnMut(&mut RateLimited<T>) -> io::Result<R>,
) -> io::Result<R> {
let mut op = op;
loop {
// If there are no blocked readers, attempt the read operation.
if !self.source.readers_registered() {
let mut inner = self.source.inner.lock().unwrap();
match op(&mut inner) {
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
res => return res,
}
}
// Wait until the I/O handle becomes readable.
optimistic(self.readable()).await?;
}
}
pub async fn write_with_mut<R>(
&mut self,
op: impl FnMut(&mut RateLimited<T>) -> io::Result<R>,
) -> io::Result<R> {
let mut op = op;
loop {
// If there are no blocked readers, attempt the write operation.
if !self.source.writers_registered() {
let mut inner = self.source.inner.lock().unwrap();
match op(&mut inner) {
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
res => return res,
}
}
// Wait until the I/O handle becomes writable.
optimistic(self.writable()).await?;
}
}
}

// copied from async-io
impl<T: Read> AsyncRead for RLAsync<T> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
poll_future(cx, self.read_with_mut(|io| io.read(buf)))
}

fn poll_read_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &mut [IoSliceMut<'_>],
) -> Poll<io::Result<usize>> {
poll_future(cx, self.read_with_mut(|io| io.read_vectored(bufs)))
}
}

// copied from async-io
impl<T: Write> AsyncWrite for RLAsync<T>
where
T: AsRawSource
{
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
poll_future(cx, self.write_with_mut(|io| io.write(buf)))
}

fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<io::Result<usize>> {
poll_future(cx, self.write_with_mut(|io| io.write_vectored(bufs)))
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
poll_future(cx, self.write_with_mut(|io| io.flush()))
}

fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
let inner = self.source.inner.lock().unwrap();
Poll::Ready(shutdown_write(inner.inner.as_raw_source()))
}
}

// copied from async-io
fn poll_future<T>(cx: &mut Context<'_>, fut: impl Future<Output = T>) -> Poll<T> {
pin!(fut);
fut.poll(cx)
}

// copied from async-io
async fn optimistic(fut: impl Future<Output = io::Result<()>>) -> io::Result<()> {
let mut polled = false;
pin!(fut);

future::poll_fn(|cx| {
if !polled {
polled = true;
fut.as_mut().poll(cx)
} else {
Poll::Ready(Ok(()))
}
})
.await
}

// copied from async-io
fn shutdown_write(raw: RawSource) -> io::Result<()> {
// This may not be a TCP stream, but that's okay. All we do is attempt a `shutdown()` on the
// raw descriptor and ignore errors.
let stream = unsafe {
ManuallyDrop::new(
TcpStream::from_raw_source(raw),
)
};

// If the socket is a TCP stream, the only actual error can be ENOTCONN.
match stream.shutdown(Shutdown::Write) {
Err(err) if err.kind() == io::ErrorKind::NotConnected => Err(err),
_ => Ok(()),
}
}

0 comments on commit 49cfa4b

Please sign in to comment.