diff --git a/Cargo.lock b/Cargo.lock index b9f9150..9d6c50d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1210,6 +1210,7 @@ version = "0.1.0" dependencies = [ "crossbeam", "tokio", + "tokio-util", "tracing", "tracing-subscriber", ] diff --git a/Cargo.toml b/Cargo.toml index 725d92b..33c635a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,8 @@ members = [ "examples/ping_pong", "examples/ping_pong_threads", "examples/updater", - "examples/updater_threads", "examples/blocking_genserver", + "examples/updater_threads", + "examples/blocking_genserver", ] [workspace.dependencies] diff --git a/concurrency/src/tasks/gen_server.rs b/concurrency/src/tasks/gen_server.rs index f800a2f..a72a52a 100644 --- a/concurrency/src/tasks/gen_server.rs +++ b/concurrency/src/tasks/gen_server.rs @@ -105,8 +105,8 @@ pub trait GenServer where Self: Send + Sized, { - type CallMsg: Send + Sized; - type CastMsg: Send + Sized; + type CallMsg: Clone + Send + Sized + Sync; + type CastMsg: Clone + Send + Sized + Sync; type OutMsg: Send + Sized; type State: Clone + Send; type Error: Debug + Send; @@ -259,6 +259,7 @@ where #[cfg(test)] mod tests { + use super::*; use crate::tasks::send_after; use std::{thread, time::Duration}; diff --git a/concurrency/src/tasks/mod.rs b/concurrency/src/tasks/mod.rs index 93bcedb..9cc76e6 100644 --- a/concurrency/src/tasks/mod.rs +++ b/concurrency/src/tasks/mod.rs @@ -6,7 +6,10 @@ mod gen_server; mod process; mod time; +#[cfg(test)] +mod timer_tests; + pub use error::GenServerError; pub use gen_server::{CallResponse, CastResponse, GenServer, GenServerHandle, GenServerInMsg}; pub use process::{send, Process, ProcessInfo}; -pub use time::send_after; +pub use time::{send_after, send_interval}; diff --git a/concurrency/src/tasks/time.rs b/concurrency/src/tasks/time.rs index 5979323..f26118b 100644 --- a/concurrency/src/tasks/time.rs +++ b/concurrency/src/tasks/time.rs @@ -1,21 +1,72 @@ +use futures::future::select; use std::time::Duration; -use spawned_rt::tasks::{self as rt, JoinHandle}; +use spawned_rt::tasks::{self as rt, CancellationToken, JoinHandle}; use super::{GenServer, GenServerHandle}; +pub struct TimerHandle { + pub join_handle: JoinHandle<()>, + pub cancellation_token: CancellationToken, +} + // Sends a message after a given period to the specified GenServer. The task terminates // once the send has completed pub fn send_after( period: Duration, mut handle: GenServerHandle, message: T::CastMsg, -) -> JoinHandle<()> +) -> TimerHandle +where + T: GenServer + 'static, +{ + let cancellation_token = CancellationToken::new(); + let cloned_token = cancellation_token.clone(); + let join_handle = rt::spawn(async move { + let _ = select( + Box::pin(cloned_token.cancelled()), + Box::pin(async { + rt::sleep(period).await; + let _ = handle.cast(message.clone()).await; + }), + ) + .await; + }); + TimerHandle { + join_handle, + cancellation_token, + } +} + +// Sends a message to the specified GenServe repeatedly after `Time` milliseconds. +pub fn send_interval( + period: Duration, + mut handle: GenServerHandle, + message: T::CastMsg, +) -> TimerHandle where T: GenServer + 'static, { - rt::spawn(async move { - rt::sleep(period).await; - let _ = handle.cast(message).await; - }) + let cancellation_token = CancellationToken::new(); + let cloned_token = cancellation_token.clone(); + let join_handle = rt::spawn(async move { + loop { + let result = select( + Box::pin(cloned_token.cancelled()), + Box::pin(async { + rt::sleep(period).await; + let _ = handle.cast(message.clone()).await; + }), + ) + .await; + match result { + futures::future::Either::Left(_) => break, + futures::future::Either::Right(_) => (), + } + } + }); + TimerHandle { + join_handle, + cancellation_token, + } } diff --git a/concurrency/src/tasks/timer_tests.rs b/concurrency/src/tasks/timer_tests.rs new file mode 100644 index 0000000..d805c82 --- /dev/null +++ b/concurrency/src/tasks/timer_tests.rs @@ -0,0 +1,248 @@ +use crate::tasks::{send_interval, CallResponse, CastResponse, GenServer, GenServerHandle}; +use spawned_rt::tasks::{self as rt, CancellationToken}; +use std::time::Duration; + +use super::send_after; + +type RepeaterHandle = GenServerHandle; + +#[derive(Clone)] +struct RepeaterState { + pub(crate) count: i32, + pub(crate) cancellation_token: Option, +} + +#[derive(Clone)] +enum RepeaterCastMessage { + Inc, + StopTimer, +} + +#[derive(Clone)] +enum RepeaterCallMessage { + GetCount, +} + +#[derive(PartialEq, Debug)] +enum RepeaterOutMessage { + Count(i32), +} + +struct Repeater; + +impl Repeater { + pub async fn stop_timer(server: &mut RepeaterHandle) -> Result<(), ()> { + server + .cast(RepeaterCastMessage::StopTimer) + .await + .map_err(|_| ()) + } + + pub async fn get_count(server: &mut RepeaterHandle) -> Result { + server + .call(RepeaterCallMessage::GetCount) + .await + .map_err(|_| ()) + } +} + +impl GenServer for Repeater { + type CallMsg = RepeaterCallMessage; + type CastMsg = RepeaterCastMessage; + type OutMsg = RepeaterOutMessage; + type State = RepeaterState; + type Error = (); + + fn new() -> Self { + Self + } + + async fn init( + &mut self, + handle: &RepeaterHandle, + mut state: Self::State, + ) -> Result { + let timer = send_interval( + Duration::from_millis(100), + handle.clone(), + RepeaterCastMessage::Inc, + ); + state.cancellation_token = Some(timer.cancellation_token); + Ok(state) + } + + async fn handle_call( + &mut self, + _message: Self::CallMsg, + _handle: &RepeaterHandle, + state: Self::State, + ) -> CallResponse { + let count = state.count; + CallResponse::Reply(state, RepeaterOutMessage::Count(count)) + } + + async fn handle_cast( + &mut self, + message: Self::CastMsg, + _handle: &GenServerHandle, + mut state: Self::State, + ) -> CastResponse { + match message { + RepeaterCastMessage::Inc => { + state.count += 1; + } + RepeaterCastMessage::StopTimer => { + if let Some(ct) = state.cancellation_token.clone() { + ct.cancel() + }; + } + }; + CastResponse::NoReply(state) + } +} + +#[test] +pub fn test_send_interval_and_cancellation() { + let runtime = rt::Runtime::new().unwrap(); + runtime.block_on(async move { + // Start a Repeater + let mut repeater = Repeater::start(RepeaterState { + count: 0, + cancellation_token: None, + }); + + // Wait for 1 second + rt::sleep(Duration::from_secs(1)).await; + + // Check count + let count = Repeater::get_count(&mut repeater).await.unwrap(); + + // 9 messages in 1 second (after first 100 milliseconds sleep) + assert_eq!(RepeaterOutMessage::Count(9), count); + + // Pause timer + Repeater::stop_timer(&mut repeater).await.unwrap(); + + // Wait another second + rt::sleep(Duration::from_secs(1)).await; + + // Check count again + let count2 = Repeater::get_count(&mut repeater).await.unwrap(); + + // As timer was paused, count should remain at 9 + assert_eq!(RepeaterOutMessage::Count(9), count2); + }); +} + +type DelayedHandle = GenServerHandle; + +#[derive(Clone)] +struct DelayedState { + pub(crate) count: i32, +} + +#[derive(Clone)] +enum DelayedCastMessage { + Inc, +} + +#[derive(Clone)] +enum DelayedCallMessage { + GetCount, +} + +#[derive(PartialEq, Debug)] +enum DelayedOutMessage { + Count(i32), +} + +struct Delayed; + +impl Delayed { + pub async fn get_count(server: &mut DelayedHandle) -> Result { + server + .call(DelayedCallMessage::GetCount) + .await + .map_err(|_| ()) + } +} + +impl GenServer for Delayed { + type CallMsg = DelayedCallMessage; + type CastMsg = DelayedCastMessage; + type OutMsg = DelayedOutMessage; + type State = DelayedState; + type Error = (); + + fn new() -> Self { + Self + } + + async fn handle_call( + &mut self, + _message: Self::CallMsg, + _handle: &DelayedHandle, + state: Self::State, + ) -> CallResponse { + let count = state.count; + CallResponse::Reply(state, DelayedOutMessage::Count(count)) + } + + async fn handle_cast( + &mut self, + message: Self::CastMsg, + _handle: &DelayedHandle, + mut state: Self::State, + ) -> CastResponse { + match message { + DelayedCastMessage::Inc => { + state.count += 1; + } + }; + CastResponse::NoReply(state) + } +} + +#[test] +pub fn test_send_after_and_cancellation() { + let runtime = rt::Runtime::new().unwrap(); + runtime.block_on(async move { + // Start a Delayed + let mut repeater = Delayed::start(DelayedState { count: 0 }); + + // Set a just once timed message + let _ = send_after( + Duration::from_millis(100), + repeater.clone(), + DelayedCastMessage::Inc, + ); + + // Wait for 200 milliseconds + rt::sleep(Duration::from_millis(200)).await; + + // Check count + let count = Delayed::get_count(&mut repeater).await.unwrap(); + + // Only one message (no repetition) + assert_eq!(DelayedOutMessage::Count(1), count); + + // New timer + let timer = send_after( + Duration::from_millis(100), + repeater.clone(), + DelayedCastMessage::Inc, + ); + + // Cancel the new timer before timeout + timer.cancellation_token.cancel(); + + // Wait another 200 milliseconds + rt::sleep(Duration::from_millis(200)).await; + + // Check count again + let count2 = Delayed::get_count(&mut repeater).await.unwrap(); + + // As timer was cancelled, count should remain at 1 + assert_eq!(DelayedOutMessage::Count(1), count2); + }); +} diff --git a/concurrency/src/threads/gen_server.rs b/concurrency/src/threads/gen_server.rs index 912067b..9d58754 100644 --- a/concurrency/src/threads/gen_server.rs +++ b/concurrency/src/threads/gen_server.rs @@ -83,8 +83,8 @@ pub trait GenServer where Self: Send + Sized, { - type CallMsg: Send + Sized; - type CastMsg: Send + Sized; + type CallMsg: Clone + Send + Sized; + type CastMsg: Clone + Send + Sized; type OutMsg: Send + Sized; type State: Clone + Send; type Error: Debug; diff --git a/concurrency/src/threads/mod.rs b/concurrency/src/threads/mod.rs index 858dd52..0b8f4b2 100644 --- a/concurrency/src/threads/mod.rs +++ b/concurrency/src/threads/mod.rs @@ -6,6 +6,9 @@ mod gen_server; mod process; mod time; +#[cfg(test)] +mod timer_tests; + pub use gen_server::{CallResponse, CastResponse, GenServer, GenServerHandle, GenServerInMsg}; pub use process::{send, Process, ProcessInfo}; -pub use time::send_after; +pub use time::{send_after, send_interval}; diff --git a/concurrency/src/threads/time.rs b/concurrency/src/threads/time.rs index ec42b1b..3d47c05 100644 --- a/concurrency/src/threads/time.rs +++ b/concurrency/src/threads/time.rs @@ -1,21 +1,59 @@ use std::time::Duration; -use spawned_rt::threads::{self as rt, JoinHandle}; +use spawned_rt::threads::{self as rt, CancellationToken, JoinHandle}; use super::{GenServer, GenServerHandle}; +pub struct TimerHandle { + pub join_handle: JoinHandle<()>, + pub cancellation_token: CancellationToken, +} + // Sends a message after a given period to the specified GenServer. The task terminates // once the send has completed pub fn send_after( period: Duration, mut handle: GenServerHandle, message: T::CastMsg, -) -> JoinHandle<()> +) -> TimerHandle +where + T: GenServer + 'static, +{ + let cancellation_token = CancellationToken::new(); + let mut cloned_token = cancellation_token.clone(); + let join_handle = rt::spawn(move || { + rt::sleep(period); + if !cloned_token.is_cancelled() { + let _ = handle.cast(message); + }; + }); + TimerHandle { + join_handle, + cancellation_token, + } +} + +// Sends a message to the specified GenServe repeatedly after `Time` milliseconds. +pub fn send_interval( + period: Duration, + mut handle: GenServerHandle, + message: T::CastMsg, +) -> TimerHandle where T: GenServer + 'static, { - rt::spawn(move || { + let cancellation_token = CancellationToken::new(); + let mut cloned_token = cancellation_token.clone(); + let join_handle = rt::spawn(move || loop { rt::sleep(period); - let _ = handle.cast(message); - }) + if cloned_token.is_cancelled() { + break; + } else { + let _ = handle.cast(message.clone()); + }; + }); + TimerHandle { + join_handle, + cancellation_token, + } } diff --git a/concurrency/src/threads/timer_tests.rs b/concurrency/src/threads/timer_tests.rs new file mode 100644 index 0000000..6b3b8a4 --- /dev/null +++ b/concurrency/src/threads/timer_tests.rs @@ -0,0 +1,233 @@ +use crate::threads::{send_interval, CallResponse, CastResponse, GenServer, GenServerHandle}; +use spawned_rt::threads::{self as rt, CancellationToken}; +use std::time::Duration; + +use super::send_after; + +type RepeaterHandle = GenServerHandle; + +#[derive(Clone)] +struct RepeaterState { + pub(crate) count: i32, + pub(crate) cancellation_token: Option, +} + +#[derive(Clone)] +enum RepeaterCastMessage { + Inc, + StopTimer, +} + +#[derive(Clone)] +enum RepeaterCallMessage { + GetCount, +} + +#[derive(PartialEq, Debug)] +enum RepeaterOutMessage { + Count(i32), +} + +struct Repeater; + +impl Repeater { + pub fn stop_timer(server: &mut RepeaterHandle) -> Result<(), ()> { + server.cast(RepeaterCastMessage::StopTimer).map_err(|_| ()) + } + + pub fn get_count(server: &mut RepeaterHandle) -> Result { + server.call(RepeaterCallMessage::GetCount).map_err(|_| ()) + } +} + +impl GenServer for Repeater { + type CallMsg = RepeaterCallMessage; + type CastMsg = RepeaterCastMessage; + type OutMsg = RepeaterOutMessage; + type State = RepeaterState; + type Error = (); + + fn new() -> Self { + Self + } + + fn init( + &mut self, + handle: &RepeaterHandle, + mut state: Self::State, + ) -> Result { + let timer = send_interval( + Duration::from_millis(100), + handle.clone(), + RepeaterCastMessage::Inc, + ); + state.cancellation_token = Some(timer.cancellation_token); + Ok(state) + } + + fn handle_call( + &mut self, + _message: Self::CallMsg, + _handle: &RepeaterHandle, + state: Self::State, + ) -> CallResponse { + let count = state.count; + CallResponse::Reply(state, RepeaterOutMessage::Count(count)) + } + + fn handle_cast( + &mut self, + message: Self::CastMsg, + _handle: &GenServerHandle, + mut state: Self::State, + ) -> CastResponse { + match message { + RepeaterCastMessage::Inc => { + state.count += 1; + } + RepeaterCastMessage::StopTimer => { + if let Some(mut ct) = state.cancellation_token.clone() { + ct.cancel() + }; + } + }; + CastResponse::NoReply(state) + } +} + +#[test] +pub fn test_send_interval_and_cancellation() { + // Start a Repeater + let mut repeater = Repeater::start(RepeaterState { + count: 0, + cancellation_token: None, + }); + + // Wait for 1 second + rt::sleep(Duration::from_secs(1)); + + // Check count + let count = Repeater::get_count(&mut repeater).unwrap(); + + // 9 messages in 1 second (after first 100 milliseconds sleep) + assert_eq!(RepeaterOutMessage::Count(9), count); + + // Pause timer + Repeater::stop_timer(&mut repeater).unwrap(); + + // Wait another second + rt::sleep(Duration::from_secs(1)); + + // Check count again + let count2 = Repeater::get_count(&mut repeater).unwrap(); + + // As timer was paused, count should remain at 9 + assert_eq!(RepeaterOutMessage::Count(9), count2); +} + +type DelayedHandle = GenServerHandle; + +#[derive(Clone)] +struct DelayedState { + pub(crate) count: i32, +} + +#[derive(Clone)] +enum DelayedCastMessage { + Inc, +} + +#[derive(Clone)] +enum DelayedCallMessage { + GetCount, +} + +#[derive(PartialEq, Debug)] +enum DelayedOutMessage { + Count(i32), +} + +struct Delayed; + +impl Delayed { + pub fn get_count(server: &mut DelayedHandle) -> Result { + server.call(DelayedCallMessage::GetCount).map_err(|_| ()) + } +} + +impl GenServer for Delayed { + type CallMsg = DelayedCallMessage; + type CastMsg = DelayedCastMessage; + type OutMsg = DelayedOutMessage; + type State = DelayedState; + type Error = (); + + fn new() -> Self { + Self + } + + fn handle_call( + &mut self, + _message: Self::CallMsg, + _handle: &DelayedHandle, + state: Self::State, + ) -> CallResponse { + let count = state.count; + CallResponse::Reply(state, DelayedOutMessage::Count(count)) + } + + fn handle_cast( + &mut self, + message: Self::CastMsg, + _handle: &DelayedHandle, + mut state: Self::State, + ) -> CastResponse { + match message { + DelayedCastMessage::Inc => { + state.count += 1; + } + }; + CastResponse::NoReply(state) + } +} + +#[test] +pub fn test_send_after_and_cancellation() { + // Start a Delayed + let mut repeater = Delayed::start(DelayedState { count: 0 }); + + // Set a just once timed message + let _ = send_after( + Duration::from_millis(100), + repeater.clone(), + DelayedCastMessage::Inc, + ); + + // Wait for 200 milliseconds + rt::sleep(Duration::from_millis(200)); + + // Check count + let count = Delayed::get_count(&mut repeater).unwrap(); + + // Only one message (no repetition) + assert_eq!(DelayedOutMessage::Count(1), count); + + // New timer + let mut timer = send_after( + Duration::from_millis(100), + repeater.clone(), + DelayedCastMessage::Inc, + ); + + // Cancel the new timer before timeout + timer.cancellation_token.cancel(); + + // Wait another 200 milliseconds + rt::sleep(Duration::from_millis(200)); + + // Check count again + let count2 = Delayed::get_count(&mut repeater).unwrap(); + + // As timer was cancelled, count should remain at 1 + assert_eq!(DelayedOutMessage::Count(1), count2); +} diff --git a/concurrency/src/time.rs b/concurrency/src/time.rs deleted file mode 100644 index be1e2a3..0000000 --- a/concurrency/src/time.rs +++ /dev/null @@ -1,21 +0,0 @@ -use std::time::Duration; - -use spawned_rt::{self as rt, JoinHandle, mpsc::Sender}; - -use crate::{GenServer, GenServerInMsg}; - -// Sends a message after a given period to the specified GenServer. The task terminates -// once the send has completed -pub fn send_after( - period: Duration, - tx: Sender>, - message: T::InMsg, -) -> JoinHandle<()> -where - T: GenServer + 'static, -{ - rt::spawn(async move { - rt::sleep(period).await; - let _ = tx.send(GenServerInMsg::Cast { message }); - }) -} diff --git a/examples/updater/src/main.rs b/examples/updater/src/main.rs index c01f3c6..f04b3d4 100644 --- a/examples/updater/src/main.rs +++ b/examples/updater/src/main.rs @@ -14,12 +14,15 @@ use spawned_rt::tasks as rt; fn main() { rt::run(async { + tracing::info!("Starting Updater"); UpdaterServer::start(UpdateServerState { url: "https://httpbin.org/ip".to_string(), periodicity: Duration::from_millis(1000), + timer_token: None, }); // giving it some time before ending thread::sleep(Duration::from_secs(10)); + tracing::info!("Updater stopped"); }) } diff --git a/examples/updater/src/server.rs b/examples/updater/src/server.rs index 5c610a2..ad6566c 100644 --- a/examples/updater/src/server.rs +++ b/examples/updater/src/server.rs @@ -1,8 +1,9 @@ use std::time::Duration; use spawned_concurrency::tasks::{ - send_after, CallResponse, CastResponse, GenServer, GenServerHandle, + send_interval, CallResponse, CastResponse, GenServer, GenServerHandle, }; +use spawned_rt::tasks::CancellationToken; use crate::messages::{UpdaterInMessage as InMessage, UpdaterOutMessage as OutMessage}; @@ -12,6 +13,7 @@ type UpdateServerHandle = GenServerHandle; pub struct UpdateServerState { pub url: String, pub periodicity: Duration, + pub timer_token: Option, } pub struct UpdaterServer {} @@ -26,13 +28,14 @@ impl GenServer for UpdaterServer { Self {} } - // Initializing GenServer to start periodic checks + // Initializing GenServer to start periodic checks. async fn init( &mut self, handle: &GenServerHandle, - state: Self::State, + mut state: Self::State, ) -> Result { - send_after(state.periodicity, handle.clone(), InMessage::Check); + let timer = send_interval(state.periodicity, handle.clone(), InMessage::Check); + state.timer_token = Some(timer.cancellation_token); Ok(state) } @@ -48,18 +51,16 @@ impl GenServer for UpdaterServer { async fn handle_cast( &mut self, message: Self::CastMsg, - handle: &UpdateServerHandle, + _handle: &UpdateServerHandle, state: Self::State, ) -> CastResponse { match message { Self::CastMsg::Check => { - send_after(state.periodicity, handle.clone(), InMessage::Check); + //send_after(state.periodicity, handle.clone(), InMessage::Check); let url = state.url.clone(); tracing::info!("Fetching: {url}"); let resp = req(url).await; - tracing::info!("Response: {resp:?}"); - CastResponse::NoReply(state) } } diff --git a/rt/Cargo.toml b/rt/Cargo.toml index e514397..b4fbbcf 100644 --- a/rt/Cargo.toml +++ b/rt/Cargo.toml @@ -5,6 +5,7 @@ edition = "2021" [dependencies] tokio = { version = "1", features = ["full"] } +tokio-util = { version = "0.7.15" } crossbeam = { version = "0.7.3" } tracing = { workspace = true } tracing-subscriber = { workspace = true } diff --git a/rt/src/tasks/mod.rs b/rt/src/tasks/mod.rs index 7cf9373..7508d43 100644 --- a/rt/src/tasks/mod.rs +++ b/rt/src/tasks/mod.rs @@ -16,6 +16,7 @@ use crate::tracing::init_tracing; pub use crate::tasks::tokio::mpsc; pub use crate::tasks::tokio::oneshot; pub use crate::tasks::tokio::sleep; +pub use crate::tasks::tokio::CancellationToken; pub use crate::tasks::tokio::{spawn, spawn_blocking, JoinHandle, Runtime}; use std::future::Future; diff --git a/rt/src/tasks/tokio/mod.rs b/rt/src/tasks/tokio/mod.rs index 8131b27..aaf679d 100644 --- a/rt/src/tasks/tokio/mod.rs +++ b/rt/src/tasks/tokio/mod.rs @@ -7,3 +7,4 @@ pub use tokio::{ task::{spawn, spawn_blocking, JoinHandle}, time::sleep, }; +pub use tokio_util::sync::CancellationToken; diff --git a/rt/src/threads/mod.rs b/rt/src/threads/mod.rs index 29dc6c0..adcea5f 100644 --- a/rt/src/threads/mod.rs +++ b/rt/src/threads/mod.rs @@ -3,6 +3,10 @@ pub mod mpsc; pub mod oneshot; +use std::sync::{ + atomic::{AtomicBool, Ordering}, + Arc, +}; pub use std::{ future::Future, thread::{sleep, spawn, JoinHandle}, @@ -29,3 +33,24 @@ where { spawn(f) } + +#[derive(Clone, Default)] +pub struct CancellationToken { + is_cancelled: Arc, +} + +impl CancellationToken { + pub fn new() -> Self { + CancellationToken { + is_cancelled: Arc::new(false.into()), + } + } + + pub fn is_cancelled(&mut self) -> bool { + self.is_cancelled.fetch_and(false, Ordering::SeqCst) + } + + pub fn cancel(&mut self) { + self.is_cancelled.fetch_or(true, Ordering::SeqCst); + } +}