diff --git a/Cargo.lock b/Cargo.lock index c4bb721..b9f9150 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -83,6 +83,15 @@ version = "2.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b8e56985ec62d17e9c1001dc89c88ecd7dc08e47eba5ec7c29c7b5eeecde967" +[[package]] +name = "blocking_genserver" +version = "0.1.0" +dependencies = [ + "spawned-concurrency", + "spawned-rt", + "tracing", +] + [[package]] name = "bumpalo" version = "3.17.0" diff --git a/Cargo.toml b/Cargo.toml index a6f68bf..725d92b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,7 @@ members = [ "examples/ping_pong", "examples/ping_pong_threads", "examples/updater", - "examples/updater_threads", + "examples/updater_threads", "examples/blocking_genserver", ] [workspace.dependencies] diff --git a/concurrency/src/tasks/gen_server.rs b/concurrency/src/tasks/gen_server.rs index 196fb5c..465832e 100644 --- a/concurrency/src/tasks/gen_server.rs +++ b/concurrency/src/tasks/gen_server.rs @@ -38,6 +38,26 @@ impl GenServerHandle { handle_clone } + pub(crate) fn new_blocking(mut initial_state: G::State) -> Self { + let (tx, mut rx) = mpsc::channel::>(); + let handle = GenServerHandle { tx }; + let mut gen_server: G = GenServer::new(); + let handle_clone = handle.clone(); + // Ignore the JoinHandle for now. Maybe we'll use it in the future + let _join_handle = rt::spawn_blocking(|| { + rt::block_on(async move { + if gen_server + .run(&handle, &mut rx, &mut initial_state) + .await + .is_err() + { + tracing::trace!("GenServer crashed") + }; + }) + }); + handle_clone + } + pub fn sender(&self) -> mpsc::Sender> { self.tx.clone() } @@ -98,6 +118,15 @@ where GenServerHandle::new(initial_state) } + /// Tokio tasks depend on a coolaborative multitasking model. "work stealing" can't + /// happen if the task is blocking the thread. As such, for sync compute task + /// or other blocking tasks need to be in their own separate thread, and the OS + /// will manage them through hardware interrupts. + /// Start blocking provides such thread. + fn start_blocking(initial_state: Self::State) -> GenServerHandle { + GenServerHandle::new_blocking(initial_state) + } + fn run( &mut self, handle: &GenServerHandle, @@ -201,3 +230,136 @@ where state: &mut Self::State, ) -> impl std::future::Future + Send; } + +#[cfg(test)] +mod tests { + use super::*; + use crate::tasks::send_after; + use std::{process::exit, thread, time::Duration}; + struct BadlyBehavedTask; + + #[derive(Clone)] + pub enum InMessage { + GetCount, + Stop, + } + #[derive(Clone)] + pub enum OutMsg { + Count(u64), + } + + impl GenServer for BadlyBehavedTask { + type CallMsg = InMessage; + type CastMsg = (); + type OutMsg = (); + type State = (); + type Error = (); + + fn new() -> Self { + Self {} + } + + async fn handle_call( + &mut self, + _: Self::CallMsg, + _: &GenServerHandle, + _: &mut Self::State, + ) -> CallResponse { + CallResponse::Stop(()) + } + + async fn handle_cast( + &mut self, + _: Self::CastMsg, + _: &GenServerHandle, + _: &mut Self::State, + ) -> CastResponse { + rt::sleep(Duration::from_millis(20)).await; + thread::sleep(Duration::from_secs(2)); + CastResponse::Stop + } + } + + struct WellBehavedTask; + + #[derive(Clone)] + struct CountState { + pub count: u64, + } + + impl GenServer for WellBehavedTask { + type CallMsg = InMessage; + type CastMsg = (); + type OutMsg = OutMsg; + type State = CountState; + type Error = (); + + fn new() -> Self { + Self {} + } + + async fn handle_call( + &mut self, + message: Self::CallMsg, + _: &GenServerHandle, + state: &mut Self::State, + ) -> CallResponse { + match message { + InMessage::GetCount => CallResponse::Reply(OutMsg::Count(state.count)), + InMessage::Stop => CallResponse::Stop(OutMsg::Count(state.count)), + } + } + + async fn handle_cast( + &mut self, + _: Self::CastMsg, + handle: &GenServerHandle, + state: &mut Self::State, + ) -> CastResponse { + state.count += 1; + println!("{:?}: good still alive", thread::current().id()); + send_after(Duration::from_millis(100), handle.to_owned(), ()); + CastResponse::NoReply + } + } + + #[test] + pub fn badly_behaved_thread_non_blocking() { + let runtime = rt::Runtime::new().unwrap(); + runtime.block_on(async move { + let mut badboy = BadlyBehavedTask::start(()); + let _ = badboy.cast(()).await; + let mut goodboy = WellBehavedTask::start(CountState { count: 0 }); + let _ = goodboy.cast(()).await; + rt::sleep(Duration::from_secs(1)).await; + let count = goodboy.call(InMessage::GetCount).await.unwrap(); + + match count { + OutMsg::Count(num) => { + assert_ne!(num, 10); + } + } + goodboy.call(InMessage::Stop).await.unwrap(); + }); + } + + #[test] + pub fn badly_behaved_thread() { + let runtime = rt::Runtime::new().unwrap(); + runtime.block_on(async move { + let mut badboy = BadlyBehavedTask::start_blocking(()); + let _ = badboy.cast(()).await; + let mut goodboy = WellBehavedTask::start(CountState { count: 0 }); + let _ = goodboy.cast(()).await; + rt::sleep(Duration::from_secs(1)).await; + let count = goodboy.call(InMessage::GetCount).await.unwrap(); + + match count { + OutMsg::Count(num) => { + assert_eq!(num, 10); + } + } + goodboy.call(InMessage::Stop).await.unwrap(); + }); + } +} diff --git a/concurrency/src/threads/gen_server.rs b/concurrency/src/threads/gen_server.rs index 6ce3a64..1541c43 100644 --- a/concurrency/src/threads/gen_server.rs +++ b/concurrency/src/threads/gen_server.rs @@ -98,6 +98,12 @@ where GenServerHandle::new(initial_state) } + /// We copy the same interface as tasks, but all threads can work + /// while blocking by default + fn start_blocking(initial_state: Self::State) -> GenServerHandle { + GenServerHandle::new(initial_state) + } + fn run( &mut self, handle: &GenServerHandle, diff --git a/examples/blocking_genserver/Cargo.toml b/examples/blocking_genserver/Cargo.toml new file mode 100644 index 0000000..e09f82a --- /dev/null +++ b/examples/blocking_genserver/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "blocking_genserver" +version = "0.1.0" +edition = "2024" + +[dependencies] +spawned-rt = { workspace = true } +spawned-concurrency = { workspace = true } +tracing = { workspace = true } + +[[bin]] +name = "blocking_genserver" +path = "main.rs" diff --git a/examples/blocking_genserver/main.rs b/examples/blocking_genserver/main.rs new file mode 100644 index 0000000..9a2b832 --- /dev/null +++ b/examples/blocking_genserver/main.rs @@ -0,0 +1,121 @@ +use spawned_rt::tasks as rt; +use std::time::Duration; +use std::{process::exit, thread}; + +use spawned_concurrency::tasks::{ + CallResponse, CastResponse, GenServer, GenServerHandle, send_after, +}; + +// We test a scenario with a badly behaved task +struct BadlyBehavedTask; + +#[derive(Clone)] +pub enum InMessage { + GetCount, + Stop, +} +#[derive(Clone)] +pub enum OutMsg { + Count(u64), +} + +impl GenServer for BadlyBehavedTask { + type CallMsg = InMessage; + type CastMsg = (); + type OutMsg = (); + type State = (); + type Error = (); + + fn new() -> Self { + Self {} + } + + async fn handle_call( + &mut self, + _: Self::CallMsg, + _: &GenServerHandle, + _: &mut Self::State, + ) -> CallResponse { + CallResponse::Stop(()) + } + + async fn handle_cast( + &mut self, + _: Self::CastMsg, + _: &GenServerHandle, + _: &mut Self::State, + ) -> CastResponse { + rt::sleep(Duration::from_millis(20)).await; + loop { + println!("{:?}: bad still alive", thread::current().id()); + thread::sleep(Duration::from_millis(50)); + } + } +} + +struct WellBehavedTask; + +#[derive(Clone)] +struct CountState { + pub count: u64, +} + +impl GenServer for WellBehavedTask { + type CallMsg = InMessage; + type CastMsg = (); + type OutMsg = OutMsg; + type State = CountState; + type Error = (); + + fn new() -> Self { + Self {} + } + + async fn handle_call( + &mut self, + message: Self::CallMsg, + _: &GenServerHandle, + state: &mut Self::State, + ) -> CallResponse { + match message { + InMessage::GetCount => CallResponse::Reply(OutMsg::Count(state.count)), + InMessage::Stop => CallResponse::Stop(OutMsg::Count(state.count)), + } + } + + async fn handle_cast( + &mut self, + _: Self::CastMsg, + handle: &GenServerHandle, + state: &mut Self::State, + ) -> CastResponse { + state.count += 1; + println!("{:?}: good still alive", thread::current().id()); + send_after(Duration::from_millis(100), handle.to_owned(), ()); + CastResponse::NoReply + } +} + +/// Example of start_blocking to fix issues #8 https://github.com/lambdaclass/spawned/issues/8 +/// Tasks that block can block the entire tokio runtime (and other cooperative multitasking models) +/// To fix this we implement start_blocking, which under the hood launches a new thread to deal with the issue +pub fn main() { + rt::run(async move { + // If we change BadlyBehavedTask to start instead, it can stop the entire program + let mut badboy = BadlyBehavedTask::start_blocking(()); + let _ = badboy.cast(()).await; + let mut goodboy = WellBehavedTask::start(CountState { count: 0 }); + let _ = goodboy.cast(()).await; + rt::sleep(Duration::from_secs(1)).await; + let count = goodboy.call(InMessage::GetCount).await.unwrap(); + + match count { + OutMsg::Count(num) => { + assert!(num == 10); + } + } + + goodboy.call(InMessage::Stop).await.unwrap(); + exit(0); + }) +} diff --git a/rt/src/tasks/mod.rs b/rt/src/tasks/mod.rs index 5cb9a41..7cf9373 100644 --- a/rt/src/tasks/mod.rs +++ b/rt/src/tasks/mod.rs @@ -9,12 +9,14 @@ mod tokio; +use ::tokio::runtime::Handle; + use crate::tracing::init_tracing; pub use crate::tasks::tokio::mpsc; pub use crate::tasks::tokio::oneshot; pub use crate::tasks::tokio::sleep; -pub use crate::tasks::tokio::{spawn, JoinHandle, Runtime}; +pub use crate::tasks::tokio::{spawn, spawn_blocking, JoinHandle, Runtime}; use std::future::Future; pub fn run(future: F) -> F::Output { @@ -23,3 +25,7 @@ pub fn run(future: F) -> F::Output { let rt = Runtime::new().unwrap(); rt.block_on(future) } + +pub fn block_on(future: F) -> F::Output { + Handle::current().block_on(future) +} diff --git a/rt/src/tasks/tokio/mod.rs b/rt/src/tasks/tokio/mod.rs index 7d7ba9a..51a3877 100644 --- a/rt/src/tasks/tokio/mod.rs +++ b/rt/src/tasks/tokio/mod.rs @@ -4,6 +4,7 @@ pub mod oneshot; pub use tokio::{ runtime::Runtime, - task::{spawn, JoinHandle}, + task::{spawn, spawn_blocking, JoinHandle}, time::sleep, + test, }; diff --git a/rt/src/threads/mod.rs b/rt/src/threads/mod.rs index cd8b543..29dc6c0 100644 --- a/rt/src/threads/mod.rs +++ b/rt/src/threads/mod.rs @@ -20,3 +20,12 @@ pub fn block_on(future: F) -> F::Output { let rt = Runtime::new().unwrap(); rt.block_on(future) } + +/// Spawn blocking is the same as spawn for pure threaded usage. +pub fn spawn_blocking(f: F) -> JoinHandle +where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, +{ + spawn(f) +}