Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: accept new thread pool impl #167

Merged
merged 1 commit into from
May 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions monoio/src/blocking.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Blocking tasks related.

use std::{future::Future, sync::Arc, task::Poll};
use std::{future::Future, task::Poll};

use threadpool::{Builder as ThreadPoolBuilder, ThreadPool as ThreadPoolImpl};

Expand Down Expand Up @@ -121,6 +121,7 @@ where
/// DefaultThreadPool is a simple wrapped `threadpool::ThreadPool` that implememt
/// `monoio::blocking::ThreadPool`. You may use this implementation, or you can use your own thread
/// pool implementation.
#[derive(Clone)]
pub struct DefaultThreadPool {
pool: ThreadPoolImpl,
}
Expand Down Expand Up @@ -153,9 +154,8 @@ impl crate::task::Schedule for NoopScheduler {
}
}

#[derive(Clone)]
pub(crate) enum BlockingHandle {
Attached(Arc<dyn ThreadPool>),
Attached(Box<dyn crate::blocking::ThreadPool + Send + 'static>),
Empty(BlockingStrategy),
}

Expand Down Expand Up @@ -188,8 +188,6 @@ where

#[cfg(test)]
mod tests {
use std::sync::Arc;

use super::DefaultThreadPool;

/// NaiveThreadPool always create a new thread on executing tasks.
Expand All @@ -212,7 +210,7 @@ mod tests {

#[test]
fn hello_blocking() {
let shared_pool = Arc::new(NaiveThreadPool);
let shared_pool = Box::new(NaiveThreadPool);
let mut rt = crate::RuntimeBuilder::<crate::FusionDriver>::new()
.attach_thread_pool(shared_pool)
.enable_timer()
Expand Down Expand Up @@ -272,7 +270,7 @@ mod tests {

#[test]
fn drop_task() {
let shared_pool = Arc::new(FakeThreadPool);
let shared_pool = Box::new(FakeThreadPool);
let mut rt = crate::RuntimeBuilder::<crate::FusionDriver>::new()
.attach_thread_pool(shared_pool)
.enable_timer()
Expand All @@ -286,7 +284,7 @@ mod tests {

#[test]
fn default_pool() {
let shared_pool = Arc::new(DefaultThreadPool::new(3));
let shared_pool = Box::new(DefaultThreadPool::new(3));
let mut rt = crate::RuntimeBuilder::<crate::FusionDriver>::new()
.attach_thread_pool(shared_pool)
.enable_timer()
Expand Down
66 changes: 34 additions & 32 deletions monoio/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,15 @@ impl<T> RuntimeBuilder<T> {
/// Buildable trait.
pub trait Buildable: Sized {
/// Build the runtime.
fn build(this: &RuntimeBuilder<Self>) -> io::Result<Runtime<Self>>;
fn build(this: RuntimeBuilder<Self>) -> io::Result<Runtime<Self>>;
}

#[allow(unused)]
macro_rules! direct_build {
($ty: ty) => {
impl RuntimeBuilder<$ty> {
/// Build the runtime.
pub fn build(&self) -> io::Result<Runtime<$ty>> {
pub fn build(self) -> io::Result<Runtime<$ty>> {
Buildable::build(self)
}
}
Expand All @@ -88,10 +88,10 @@ direct_build!(TimeDriver<LegacyDriver>);

#[cfg(all(unix, feature = "legacy"))]
impl Buildable for LegacyDriver {
fn build(this: &RuntimeBuilder<Self>) -> io::Result<Runtime<LegacyDriver>> {
fn build(this: RuntimeBuilder<Self>) -> io::Result<Runtime<LegacyDriver>> {
let thread_id = gen_id();
#[cfg(feature = "sync")]
let blocking_handle = this.blocking_handle.clone();
let blocking_handle = this.blocking_handle;

BUILD_THREAD_ID.set(&thread_id, || {
let driver = match this.entries {
Expand All @@ -109,10 +109,10 @@ impl Buildable for LegacyDriver {

#[cfg(all(target_os = "linux", feature = "iouring"))]
impl Buildable for IoUringDriver {
fn build(this: &RuntimeBuilder<Self>) -> io::Result<Runtime<IoUringDriver>> {
fn build(this: RuntimeBuilder<Self>) -> io::Result<Runtime<IoUringDriver>> {
let thread_id = gen_id();
#[cfg(feature = "sync")]
let blocking_handle = this.blocking_handle.clone();
let blocking_handle = this.blocking_handle;

BUILD_THREAD_ID.set(&thread_id, || {
let driver = match this.entries {
Expand Down Expand Up @@ -150,8 +150,8 @@ impl<D> RuntimeBuilder<D> {

#[cfg(all(target_os = "linux", feature = "iouring"))]
#[must_use]
pub fn uring_builder(mut self, urb: &io_uring::Builder) -> Self {
self.urb = urb.clone();
pub fn uring_builder(mut self, urb: io_uring::Builder) -> Self {
self.urb = urb;
self
}
}
Expand All @@ -166,23 +166,23 @@ pub struct FusionDriver;
impl RuntimeBuilder<FusionDriver> {
/// Build the runtime.
#[cfg(all(target_os = "linux", feature = "iouring", feature = "legacy"))]
pub fn build(&self) -> io::Result<crate::FusionRuntime<IoUringDriver, LegacyDriver>> {
pub fn build(self) -> io::Result<crate::FusionRuntime<IoUringDriver, LegacyDriver>> {
if crate::utils::detect_uring() {
let builder = RuntimeBuilder::<IoUringDriver> {
entries: self.entries,
urb: self.urb.clone(),
urb: self.urb,
#[cfg(feature = "sync")]
blocking_handle: self.blocking_handle.clone(),
blocking_handle: self.blocking_handle,
_mark: PhantomData,
};
info!("io_uring driver built");
Ok(builder.build()?.into())
} else {
let builder = RuntimeBuilder::<LegacyDriver> {
entries: self.entries,
urb: self.urb.clone(),
urb: self.urb,
#[cfg(feature = "sync")]
blocking_handle: self.blocking_handle.clone(),
blocking_handle: self.blocking_handle,
_mark: PhantomData,
};
info!("legacy driver built");
Expand All @@ -192,24 +192,24 @@ impl RuntimeBuilder<FusionDriver> {

/// Build the runtime.
#[cfg(all(unix, not(all(target_os = "linux", feature = "iouring"))))]
pub fn build(&self) -> io::Result<crate::FusionRuntime<LegacyDriver>> {
pub fn build(self) -> io::Result<crate::FusionRuntime<LegacyDriver>> {
let builder = RuntimeBuilder::<LegacyDriver> {
entries: self.entries,
#[cfg(feature = "sync")]
blocking_handle: self.blocking_handle.clone(),
blocking_handle: self.blocking_handle,
_mark: PhantomData,
};
Ok(builder.build()?.into())
}

/// Build the runtime.
#[cfg(all(target_os = "linux", feature = "iouring", not(feature = "legacy")))]
pub fn build(&self) -> io::Result<crate::FusionRuntime<IoUringDriver>> {
pub fn build(self) -> io::Result<crate::FusionRuntime<IoUringDriver>> {
let builder = RuntimeBuilder::<IoUringDriver> {
entries: self.entries,
urb: self.urb.clone(),
urb: self.urb,
#[cfg(feature = "sync")]
blocking_handle: self.blocking_handle.clone(),
blocking_handle: self.blocking_handle,
_mark: PhantomData,
};
Ok(builder.build()?.into())
Expand All @@ -221,24 +221,24 @@ impl RuntimeBuilder<TimeDriver<FusionDriver>> {
/// Build the runtime.
#[cfg(all(target_os = "linux", feature = "iouring", feature = "legacy"))]
pub fn build(
&self,
self,
) -> io::Result<crate::FusionRuntime<TimeDriver<IoUringDriver>, TimeDriver<LegacyDriver>>> {
if crate::utils::detect_uring() {
let builder = RuntimeBuilder::<TimeDriver<IoUringDriver>> {
entries: self.entries,
urb: self.urb.clone(),
urb: self.urb,
#[cfg(feature = "sync")]
blocking_handle: self.blocking_handle.clone(),
blocking_handle: self.blocking_handle,
_mark: PhantomData,
};
info!("io_uring driver with timer built");
Ok(builder.build()?.into())
} else {
let builder = RuntimeBuilder::<TimeDriver<LegacyDriver>> {
entries: self.entries,
urb: self.urb.clone(),
urb: self.urb,
#[cfg(feature = "sync")]
blocking_handle: self.blocking_handle.clone(),
blocking_handle: self.blocking_handle,
_mark: PhantomData,
};
info!("legacy driver with timer built");
Expand All @@ -248,11 +248,11 @@ impl RuntimeBuilder<TimeDriver<FusionDriver>> {

/// Build the runtime.
#[cfg(all(unix, not(all(target_os = "linux", feature = "iouring"))))]
pub fn build(&self) -> io::Result<crate::FusionRuntime<TimeDriver<LegacyDriver>>> {
pub fn build(self) -> io::Result<crate::FusionRuntime<TimeDriver<LegacyDriver>>> {
let builder = RuntimeBuilder::<TimeDriver<LegacyDriver>> {
entries: self.entries,
#[cfg(feature = "sync")]
blocking_handle: self.blocking_handle.clone(),
blocking_handle: self.blocking_handle,
_mark: PhantomData,
};
Ok(builder.build()?.into())
Expand All @@ -263,9 +263,9 @@ impl RuntimeBuilder<TimeDriver<FusionDriver>> {
pub fn build(&self) -> io::Result<crate::FusionRuntime<TimeDriver<IoUringDriver>>> {
let builder = RuntimeBuilder::<TimeDriver<IoUringDriver>> {
entries: self.entries,
urb: self.urb.clone(),
urb: self.urb,
#[cfg(feature = "sync")]
blocking_handle: self.blocking_handle.clone(),
blocking_handle: self.blocking_handle,
_mark: PhantomData,
};
Ok(builder.build()?.into())
Expand All @@ -289,16 +289,16 @@ where
D: Buildable,
{
/// Build the runtime
fn build(this: &RuntimeBuilder<Self>) -> io::Result<Runtime<TimeDriver<D>>> {
fn build(this: RuntimeBuilder<Self>) -> io::Result<Runtime<TimeDriver<D>>> {
let Runtime {
driver,
mut context,
} = Buildable::build(&RuntimeBuilder::<D> {
} = Buildable::build(RuntimeBuilder::<D> {
entries: this.entries,
#[cfg(all(target_os = "linux", feature = "iouring"))]
urb: this.urb.clone(),
urb: this.urb,
#[cfg(feature = "sync")]
blocking_handle: this.blocking_handle.clone(),
blocking_handle: this.blocking_handle,
_mark: PhantomData,
})?;

Expand Down Expand Up @@ -338,14 +338,16 @@ impl<D: time_wrap::TimeWrapable> RuntimeBuilder<D> {
_mark: PhantomData,
}
}
}

impl<D> RuntimeBuilder<D> {
/// Attach thread pool, this will overwrite blocking strategy.
/// All `spawn_blocking` will be executed on given thread pool.
#[cfg(feature = "sync")]
#[must_use]
pub fn attach_thread_pool(
mut self,
tp: std::sync::Arc<dyn crate::blocking::ThreadPool>,
tp: Box<dyn crate::blocking::ThreadPool + Send + 'static>,
) -> Self {
self.blocking_handle = crate::blocking::BlockingHandle::Attached(tp);
self
Expand Down
2 changes: 1 addition & 1 deletion monoio/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ where
F::Output: 'static,
D: Buildable + Driver,
{
let mut rt = builder::Buildable::build(&builder::RuntimeBuilder::<D>::new())
let mut rt = builder::Buildable::build(builder::RuntimeBuilder::<D>::new())
.expect("Unable to build runtime.");
rt.block_on(future)
}
Expand Down
Loading