diff --git a/.gitignore b/.gitignore index c01c8216..d9f80abc 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,5 @@ /.vscode /.cargo /.idea +/.direnv +.envrc \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 7bbc3f9a..0c62fe4b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,6 +35,7 @@ compio-dispatcher = { path = "./compio-dispatcher", version = "0.2.0" } compio-log = { path = "./compio-log", version = "0.1.0" } compio-tls = { path = "./compio-tls", version = "0.2.0", default-features = false } +flume = "0.11.0" cfg-if = "1.0.0" criterion = "0.5.1" crossbeam-channel = "0.5.8" diff --git a/compio-dispatcher/Cargo.toml b/compio-dispatcher/Cargo.toml index 881d9b6b..9679fabe 100644 --- a/compio-dispatcher/Cargo.toml +++ b/compio-dispatcher/Cargo.toml @@ -13,9 +13,9 @@ repository = { workspace = true } [dependencies] # Workspace dependencies compio-driver = { workspace = true } -compio-runtime = { workspace = true, features = ["event"] } +compio-runtime = { workspace = true, features = ["event", "time"] } -crossbeam-channel = { workspace = true } +flume = { workspace = true } futures-util = { workspace = true } [dev-dependencies] diff --git a/compio-dispatcher/src/lib.rs b/compio-dispatcher/src/lib.rs index cd395caf..acf2e383 100644 --- a/compio-dispatcher/src/lib.rs +++ b/compio-dispatcher/src/lib.rs @@ -3,25 +3,23 @@ #![warn(missing_docs)] use std::{ + any::Any, future::Future, io, num::NonZeroUsize, - panic::{resume_unwind, UnwindSafe}, + panic::resume_unwind, + pin::Pin, sync::{Arc, Mutex}, thread::{available_parallelism, JoinHandle}, }; use compio_driver::{AsyncifyPool, ProactorBuilder}; -use compio_runtime::{ - event::{Event, EventHandle}, - Runtime, -}; -use crossbeam_channel::{unbounded, Sender}; -use futures_util::{future::LocalBoxFuture, FutureExt}; +use compio_runtime::{event::Event, Runtime}; +use flume::{unbounded, SendError, Sender}; /// The dispatcher. It manages the threads and dispatches the tasks. pub struct Dispatcher { - sender: Sender, + sender: Sender>, threads: Vec>, pool: AsyncifyPool, } @@ -32,13 +30,12 @@ impl Dispatcher { let mut proactor_builder = builder.proactor_builder; proactor_builder.force_reuse_thread_pool(); let pool = proactor_builder.create_or_get_thread_pool(); + let (sender, receiver) = unbounded::>(); - let (sender, receiver) = unbounded::(); let threads = (0..builder.nthreads) .map({ |index| { let proactor_builder = proactor_builder.clone(); - let receiver = receiver.clone(); let thread_builder = std::thread::Builder::new(); @@ -54,17 +51,21 @@ impl Dispatcher { }; thread_builder.spawn(move || { - let runtime = Runtime::builder() + Runtime::builder() .with_proactor(proactor_builder) .build() - .expect("cannot create compio runtime"); - let _guard = runtime.enter(); - while let Ok(f) = receiver.recv() { - *f.result.lock().unwrap() = Some(std::panic::catch_unwind(|| { - Runtime::current().block_on((f.func)()); - })); - f.handle.notify(); - } + .expect("cannot create compio runtime") + .block_on(async move { + let rt = Runtime::current(); + while let Ok(f) = receiver.recv_async().await { + let fut = (f)(); + if builder.concurrent { + rt.spawn(fut).detach() + } else { + fut.await + } + } + }) }) } }) @@ -86,30 +87,75 @@ impl Dispatcher { DispatcherBuilder::default() } - /// Dispatch a task to the threads. + fn prepare(&self, f: Fn) -> (Executing, Box) + where + Fn: (FnOnce() -> Fut) + Send + 'static, + Fut: Future + 'static, + R: Any + Send + 'static, + { + let event = Event::new(); + let handle = event.handle(); + let res = Arc::new(Mutex::new(None)); + let dispatched = Executing { + event, + result: res.clone(), + }; + let closure = Box::new(|| { + Box::pin(async move { + *res.lock().unwrap() = Some(f().await); + handle.notify(); + }) as BoxFuture<()> + }); + (dispatched, closure) + } + + /// Spawn a boxed closure to the threads. + /// + /// If all threads have panicked, this method will return an error with the + /// sent closure. + pub fn spawn(&self, closure: Box) -> Result<(), SendError>> { + self.sender.send(closure) + } + + /// Dispatch a task to the threads /// /// The provided `f` should be [`Send`] because it will be send to another /// thread before calling. The return [`Future`] need not to be [`Send`] /// because it will be executed on only one thread. - pub fn dispatch< - F: Future + 'static, - Fn: (FnOnce() -> F) + Send + UnwindSafe + 'static, - >( - &self, - f: Fn, - ) -> io::Result { - let event = Event::new(); - let handle = event.handle(); - let join_handle = DispatcherJoinHandle::new(event); - let closure = DispatcherClosure { - handle, - result: join_handle.result.clone(), - func: Box::new(|| f().boxed_local()), - }; - self.sender - .send(closure) - .expect("the channel should not be disconnected"); - Ok(join_handle) + /// + /// # Error + /// + /// If all threads have panicked, this method will return an error with the + /// sent closure. Notice that the returned closure is not the same as the + /// argument and cannot be simply transmuted back to `Fn`. + pub fn dispatch(&self, f: Fn) -> Result<(), SendError>> + where + Fn: (FnOnce() -> Fut) + Send + 'static, + Fut: Future + 'static, + { + self.spawn(Box::new(|| Box::pin(f()) as BoxFuture<()>)) + } + + /// Execute a task on the threads and retrieve its returned value. + /// + /// The provided `f` should be [`Send`] because it will be send to another + /// thread before calling. The return [`Future`] need not to be [`Send`] + /// because it will be executed on only one thread. + /// + /// # Error + /// + /// If all threads have panicked, this method will return an error with the + /// sent closure. Notice that the returned closure is not the same as the + /// argument and cannot be simply transmuted back to `Fn`. + pub fn execute(&self, f: Fn) -> Result, SendError>> + where + Fn: (FnOnce() -> Fut) + Send + 'static, + Fut: Future + 'static, + R: Any + Send + 'static, + { + let (dispatched, closure) = self.prepare(f); + self.spawn(closure)?; + Ok(dispatched) } /// Stop the dispatcher and wait for the threads to complete. If there is a @@ -135,7 +181,6 @@ impl Dispatcher { event.wait().await; let mut guard = results.lock().unwrap(); for res in std::mem::take::>>(guard.as_mut()) { - // The thread should not panic. res.unwrap_or_else(|e| resume_unwind(e)); } Ok(()) @@ -145,6 +190,7 @@ impl Dispatcher { /// A builder for [`Dispatcher`]. pub struct DispatcherBuilder { nthreads: usize, + concurrent: bool, stack_size: Option, names: Option String>>, proactor_builder: ProactorBuilder, @@ -155,12 +201,22 @@ impl DispatcherBuilder { pub fn new() -> Self { Self { nthreads: available_parallelism().map(|n| n.get()).unwrap_or(1), + concurrent: true, stack_size: None, names: None, proactor_builder: ProactorBuilder::new(), } } + /// If execute tasks concurrently. Default to be `true`. + /// + /// When set to `false`, tasks are executed sequentially without any + /// concurrency within the thread. + pub fn concurrent(mut self, concurrent: bool) -> Self { + self.concurrent = concurrent; + self + } + /// Set the number of worker threads of the dispatcher. The default value is /// the CPU number. If the CPU number could not be retrieved, the /// default value is 1. @@ -199,36 +255,36 @@ impl Default for DispatcherBuilder { } } -type Closure<'a> = dyn (FnOnce() -> LocalBoxFuture<'a, ()>) + Send + UnwindSafe; +type BoxFuture = Pin>>; +type Closure = dyn (FnOnce() -> BoxFuture<()>) + Send; -struct DispatcherClosure { - handle: EventHandle, - result: Arc>>>, - func: Box>, -} - -/// The join handle for dispatched task. -pub struct DispatcherJoinHandle { +/// The join handle for an executing task. It can be used to wait for the +/// task's returned value. +pub struct Executing { event: Event, - result: Arc>>>, + result: Arc>>, } -impl DispatcherJoinHandle { - pub(crate) fn new(event: Event) -> Self { - Self { - event, - result: Arc::new(Mutex::new(None)), +impl Executing { + fn take(val: &Mutex>) -> R { + val.lock() + .unwrap() + .take() + .expect("the result should be set") + } + + /// Try to wait for the task to complete without blocking. + pub fn try_join(self) -> Result { + if self.event.notified() { + Ok(Self::take(&self.result)) + } else { + Err(self) } } /// Wait for the task to complete. - pub async fn join(self) -> io::Result> { + pub async fn join(self) -> R { self.event.wait().await; - Ok(self - .result - .lock() - .unwrap() - .take() - .expect("the result should be set")) + Self::take(&self.result) } } diff --git a/compio-dispatcher/tests/listener.rs b/compio-dispatcher/tests/listener.rs index 4be7cdb0..7745b5d9 100644 --- a/compio-dispatcher/tests/listener.rs +++ b/compio-dispatcher/tests/listener.rs @@ -1,4 +1,4 @@ -use std::{num::NonZeroUsize, panic::resume_unwind}; +use std::num::NonZeroUsize; use compio_buf::arrayvec::ArrayVec; use compio_dispatcher::Dispatcher; @@ -29,16 +29,14 @@ async fn listener_dispatch() { for _i in 0..CLIENT_NUM { let (mut srv, _) = listener.accept().await.unwrap(); let handle = dispatcher - .dispatch(move || async move { + .execute(move || async move { let (_, buf) = srv.read_exact(ArrayVec::::new()).await.unwrap(); assert_eq!(buf.as_slice(), b"Hello world!"); }) .unwrap(); handles.push(handle.join()); } - while let Some(res) = handles.next().await { - res.unwrap().unwrap_or_else(|e| resume_unwind(e)); - } + while handles.next().await.is_some() {} let (_, results) = futures_util::join!(task, dispatcher.join()); results.unwrap(); } diff --git a/compio/examples/dispatcher.rs b/compio/examples/dispatcher.rs index dc5c6903..bc0af42e 100644 --- a/compio/examples/dispatcher.rs +++ b/compio/examples/dispatcher.rs @@ -1,4 +1,4 @@ -use std::{num::NonZeroUsize, panic::resume_unwind}; +use std::num::NonZeroUsize; use compio::{ dispatcher::Dispatcher, @@ -35,7 +35,7 @@ async fn main() { for _i in 0..CLIENT_NUM { let (mut srv, _) = listener.accept().await.unwrap(); let handle = dispatcher - .dispatch(move || async move { + .execute(move || async move { let BufResult(res, buf) = srv.read(Vec::with_capacity(20)).await; res.unwrap(); println!("{}", std::str::from_utf8(&buf).unwrap()); @@ -43,8 +43,6 @@ async fn main() { .unwrap(); handles.push(handle.join()); } - while let Some(res) = handles.next().await { - res.unwrap().unwrap_or_else(|e| resume_unwind(e)); - } + while handles.next().await.is_some() {} dispatcher.join().await.unwrap(); }