Skip to content

Commit

Permalink
Merge pull request #244 from George-Miao/feat/concurrent-dispatcher
Browse files Browse the repository at this point in the history
feat(dispatcher): Make concurrency
  • Loading branch information
George-Miao committed Apr 30, 2024
2 parents c307b71 + 3828594 commit 1f8497a
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 74 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Expand Up @@ -3,3 +3,5 @@
/.vscode
/.cargo
/.idea
/.direnv
.envrc
1 change: 1 addition & 0 deletions Cargo.toml
Expand Up @@ -35,6 +35,7 @@ compio-dispatcher = { path = "./compio-dispatcher", version = "0.2.0" }
compio-log = { path = "./compio-log", version = "0.1.0" }
compio-tls = { path = "./compio-tls", version = "0.2.0", default-features = false }

flume = "0.11.0"
cfg-if = "1.0.0"
criterion = "0.5.1"
crossbeam-channel = "0.5.8"
Expand Down
4 changes: 2 additions & 2 deletions compio-dispatcher/Cargo.toml
Expand Up @@ -13,9 +13,9 @@ repository = { workspace = true }
[dependencies]
# Workspace dependencies
compio-driver = { workspace = true }
compio-runtime = { workspace = true, features = ["event"] }
compio-runtime = { workspace = true, features = ["event", "time"] }

crossbeam-channel = { workspace = true }
flume = { workspace = true }
futures-util = { workspace = true }

[dev-dependencies]
Expand Down
180 changes: 118 additions & 62 deletions compio-dispatcher/src/lib.rs
Expand Up @@ -3,25 +3,23 @@
#![warn(missing_docs)]

use std::{
any::Any,
future::Future,
io,
num::NonZeroUsize,
panic::{resume_unwind, UnwindSafe},
panic::resume_unwind,
pin::Pin,
sync::{Arc, Mutex},
thread::{available_parallelism, JoinHandle},
};

use compio_driver::{AsyncifyPool, ProactorBuilder};
use compio_runtime::{
event::{Event, EventHandle},
Runtime,
};
use crossbeam_channel::{unbounded, Sender};
use futures_util::{future::LocalBoxFuture, FutureExt};
use compio_runtime::{event::Event, Runtime};
use flume::{unbounded, SendError, Sender};

/// The dispatcher. It manages the threads and dispatches the tasks.
pub struct Dispatcher {
sender: Sender<DispatcherClosure>,
sender: Sender<Box<Closure>>,
threads: Vec<JoinHandle<()>>,
pool: AsyncifyPool,
}
Expand All @@ -32,13 +30,12 @@ impl Dispatcher {
let mut proactor_builder = builder.proactor_builder;
proactor_builder.force_reuse_thread_pool();
let pool = proactor_builder.create_or_get_thread_pool();
let (sender, receiver) = unbounded::<Box<Closure>>();

let (sender, receiver) = unbounded::<DispatcherClosure>();
let threads = (0..builder.nthreads)
.map({
|index| {
let proactor_builder = proactor_builder.clone();

let receiver = receiver.clone();

let thread_builder = std::thread::Builder::new();
Expand All @@ -54,17 +51,21 @@ impl Dispatcher {
};

thread_builder.spawn(move || {
let runtime = Runtime::builder()
Runtime::builder()
.with_proactor(proactor_builder)
.build()
.expect("cannot create compio runtime");
let _guard = runtime.enter();
while let Ok(f) = receiver.recv() {
*f.result.lock().unwrap() = Some(std::panic::catch_unwind(|| {
Runtime::current().block_on((f.func)());
}));
f.handle.notify();
}
.expect("cannot create compio runtime")
.block_on(async move {
let rt = Runtime::current();
while let Ok(f) = receiver.recv_async().await {
let fut = (f)();
if builder.concurrent {
rt.spawn(fut).detach()
} else {
fut.await
}
}
})
})
}
})
Expand All @@ -86,30 +87,75 @@ impl Dispatcher {
DispatcherBuilder::default()
}

/// Dispatch a task to the threads.
fn prepare<Fut, Fn, R>(&self, f: Fn) -> (Executing<R>, Box<Closure>)
where
Fn: (FnOnce() -> Fut) + Send + 'static,
Fut: Future<Output = R> + 'static,
R: Any + Send + 'static,
{
let event = Event::new();
let handle = event.handle();
let res = Arc::new(Mutex::new(None));
let dispatched = Executing {
event,
result: res.clone(),
};
let closure = Box::new(|| {
Box::pin(async move {
*res.lock().unwrap() = Some(f().await);
handle.notify();
}) as BoxFuture<()>
});
(dispatched, closure)
}

/// Spawn a boxed closure to the threads.
///
/// If all threads have panicked, this method will return an error with the
/// sent closure.
pub fn spawn(&self, closure: Box<Closure>) -> Result<(), SendError<Box<Closure>>> {
self.sender.send(closure)
}

/// Dispatch a task to the threads
///
/// The provided `f` should be [`Send`] because it will be send to another
/// thread before calling. The return [`Future`] need not to be [`Send`]
/// because it will be executed on only one thread.
pub fn dispatch<
F: Future<Output = ()> + 'static,
Fn: (FnOnce() -> F) + Send + UnwindSafe + 'static,
>(
&self,
f: Fn,
) -> io::Result<DispatcherJoinHandle> {
let event = Event::new();
let handle = event.handle();
let join_handle = DispatcherJoinHandle::new(event);
let closure = DispatcherClosure {
handle,
result: join_handle.result.clone(),
func: Box::new(|| f().boxed_local()),
};
self.sender
.send(closure)
.expect("the channel should not be disconnected");
Ok(join_handle)
///
/// # Error
///
/// If all threads have panicked, this method will return an error with the
/// sent closure. Notice that the returned closure is not the same as the
/// argument and cannot be simply transmuted back to `Fn`.
pub fn dispatch<Fut, Fn>(&self, f: Fn) -> Result<(), SendError<Box<Closure>>>
where
Fn: (FnOnce() -> Fut) + Send + 'static,
Fut: Future<Output = ()> + 'static,
{
self.spawn(Box::new(|| Box::pin(f()) as BoxFuture<()>))
}

/// Execute a task on the threads and retrieve its returned value.
///
/// The provided `f` should be [`Send`] because it will be send to another
/// thread before calling. The return [`Future`] need not to be [`Send`]
/// because it will be executed on only one thread.
///
/// # Error
///
/// If all threads have panicked, this method will return an error with the
/// sent closure. Notice that the returned closure is not the same as the
/// argument and cannot be simply transmuted back to `Fn`.
pub fn execute<Fut, Fn, R>(&self, f: Fn) -> Result<Executing<R>, SendError<Box<Closure>>>
where
Fn: (FnOnce() -> Fut) + Send + 'static,
Fut: Future<Output = R> + 'static,
R: Any + Send + 'static,
{
let (dispatched, closure) = self.prepare(f);
self.spawn(closure)?;
Ok(dispatched)
}

/// Stop the dispatcher and wait for the threads to complete. If there is a
Expand All @@ -135,7 +181,6 @@ impl Dispatcher {
event.wait().await;
let mut guard = results.lock().unwrap();
for res in std::mem::take::<Vec<std::thread::Result<()>>>(guard.as_mut()) {
// The thread should not panic.
res.unwrap_or_else(|e| resume_unwind(e));
}
Ok(())
Expand All @@ -145,6 +190,7 @@ impl Dispatcher {
/// A builder for [`Dispatcher`].
pub struct DispatcherBuilder {
nthreads: usize,
concurrent: bool,
stack_size: Option<usize>,
names: Option<Box<dyn FnMut(usize) -> String>>,
proactor_builder: ProactorBuilder,
Expand All @@ -155,12 +201,22 @@ impl DispatcherBuilder {
pub fn new() -> Self {
Self {
nthreads: available_parallelism().map(|n| n.get()).unwrap_or(1),
concurrent: true,
stack_size: None,
names: None,
proactor_builder: ProactorBuilder::new(),
}
}

/// If execute tasks concurrently. Default to be `true`.
///
/// When set to `false`, tasks are executed sequentially without any
/// concurrency within the thread.
pub fn concurrent(mut self, concurrent: bool) -> Self {
self.concurrent = concurrent;
self
}

/// Set the number of worker threads of the dispatcher. The default value is
/// the CPU number. If the CPU number could not be retrieved, the
/// default value is 1.
Expand Down Expand Up @@ -199,36 +255,36 @@ impl Default for DispatcherBuilder {
}
}

type Closure<'a> = dyn (FnOnce() -> LocalBoxFuture<'a, ()>) + Send + UnwindSafe;
type BoxFuture<T> = Pin<Box<dyn Future<Output = T>>>;
type Closure = dyn (FnOnce() -> BoxFuture<()>) + Send;

struct DispatcherClosure {
handle: EventHandle,
result: Arc<Mutex<Option<std::thread::Result<()>>>>,
func: Box<Closure<'static>>,
}

/// The join handle for dispatched task.
pub struct DispatcherJoinHandle {
/// The join handle for an executing task. It can be used to wait for the
/// task's returned value.
pub struct Executing<R> {
event: Event,
result: Arc<Mutex<Option<std::thread::Result<()>>>>,
result: Arc<Mutex<Option<R>>>,
}

impl DispatcherJoinHandle {
pub(crate) fn new(event: Event) -> Self {
Self {
event,
result: Arc::new(Mutex::new(None)),
impl<R: 'static> Executing<R> {
fn take(val: &Mutex<Option<R>>) -> R {
val.lock()
.unwrap()
.take()
.expect("the result should be set")
}

/// Try to wait for the task to complete without blocking.
pub fn try_join(self) -> Result<R, Self> {
if self.event.notified() {
Ok(Self::take(&self.result))
} else {
Err(self)
}
}

/// Wait for the task to complete.
pub async fn join(self) -> io::Result<std::thread::Result<()>> {
pub async fn join(self) -> R {
self.event.wait().await;
Ok(self
.result
.lock()
.unwrap()
.take()
.expect("the result should be set"))
Self::take(&self.result)
}
}
8 changes: 3 additions & 5 deletions compio-dispatcher/tests/listener.rs
@@ -1,4 +1,4 @@
use std::{num::NonZeroUsize, panic::resume_unwind};
use std::num::NonZeroUsize;

use compio_buf::arrayvec::ArrayVec;
use compio_dispatcher::Dispatcher;
Expand Down Expand Up @@ -29,16 +29,14 @@ async fn listener_dispatch() {
for _i in 0..CLIENT_NUM {
let (mut srv, _) = listener.accept().await.unwrap();
let handle = dispatcher
.dispatch(move || async move {
.execute(move || async move {
let (_, buf) = srv.read_exact(ArrayVec::<u8, 12>::new()).await.unwrap();
assert_eq!(buf.as_slice(), b"Hello world!");
})
.unwrap();
handles.push(handle.join());
}
while let Some(res) = handles.next().await {
res.unwrap().unwrap_or_else(|e| resume_unwind(e));
}
while handles.next().await.is_some() {}
let (_, results) = futures_util::join!(task, dispatcher.join());
results.unwrap();
}
8 changes: 3 additions & 5 deletions compio/examples/dispatcher.rs
@@ -1,4 +1,4 @@
use std::{num::NonZeroUsize, panic::resume_unwind};
use std::num::NonZeroUsize;

use compio::{
dispatcher::Dispatcher,
Expand Down Expand Up @@ -35,16 +35,14 @@ async fn main() {
for _i in 0..CLIENT_NUM {
let (mut srv, _) = listener.accept().await.unwrap();
let handle = dispatcher
.dispatch(move || async move {
.execute(move || async move {
let BufResult(res, buf) = srv.read(Vec::with_capacity(20)).await;
res.unwrap();
println!("{}", std::str::from_utf8(&buf).unwrap());
})
.unwrap();
handles.push(handle.join());
}
while let Some(res) = handles.next().await {
res.unwrap().unwrap_or_else(|e| resume_unwind(e));
}
while handles.next().await.is_some() {}
dispatcher.join().await.unwrap();
}

0 comments on commit 1f8497a

Please sign in to comment.