diff --git a/openraft/src/async_runtime.rs b/openraft/src/async_runtime.rs index 15ae881fa..5e9c73e2a 100644 --- a/openraft/src/async_runtime.rs +++ b/openraft/src/async_runtime.rs @@ -65,9 +65,6 @@ pub trait AsyncRuntime: Debug + Default + OptionalSend + OptionalSync + 'static /// Check if the [`Self::JoinError`] is `panic`. fn is_panic(join_error: &Self::JoinError) -> bool; - /// Abort the task associated with the supplied join handle. - fn abort(join_handle: &Self::JoinHandle); - /// Get the random number generator to use for generating random numbers. /// /// # Note @@ -131,11 +128,6 @@ impl AsyncRuntime for TokioRuntime { join_error.is_panic() } - #[inline] - fn abort(join_handle: &Self::JoinHandle) { - join_handle.abort(); - } - #[inline] fn thread_rng() -> Self::ThreadLocalRng { rand::thread_rng() diff --git a/openraft/src/core/tick.rs b/openraft/src/core/tick.rs index e25c5af62..7970a746a 100644 --- a/openraft/src/core/tick.rs +++ b/openraft/src/core/tick.rs @@ -3,14 +3,19 @@ use std::sync::atomic::AtomicBool; use std::sync::atomic::Ordering; use std::sync::Arc; +use std::sync::Mutex; use std::time::Duration; +use futures::future::Either; use tokio::sync::mpsc; +use tokio::sync::oneshot; use tracing::Instrument; use tracing::Level; use tracing::Span; use crate::core::notify::Notify; +use crate::type_config::alias::AsyncRuntimeOf; +use crate::type_config::alias::JoinHandleOf; use crate::AsyncRuntime; use crate::Instant; use crate::RaftTypeConfig; @@ -31,7 +36,9 @@ pub(crate) struct TickHandle where C: RaftTypeConfig { enabled: Arc, - join_handle: ::JoinHandle<()>, + cancel: Mutex>>, + #[allow(dead_code)] + join_handle: JoinHandleOf, } impl Tick @@ -44,21 +51,43 @@ where C: RaftTypeConfig enabled: enabled.clone(), tx, }; - let join_handle = C::AsyncRuntime::spawn(this.tick_loop().instrument(tracing::span!( + + let (cancel, cancel_rx) = oneshot::channel(); + + let join_handle = AsyncRuntimeOf::::spawn(this.tick_loop(cancel_rx).instrument(tracing::span!( parent: &Span::current(), Level::DEBUG, "tick" ))); - TickHandle { enabled, join_handle } + TickHandle { + enabled, + cancel: Mutex::new(Some(cancel)), + join_handle, + } } - pub(crate) async fn tick_loop(self) { + pub(crate) async fn tick_loop(self, mut cancel_rx: oneshot::Receiver<()>) { let mut i = 0; + + let mut cancel = std::pin::pin!(cancel_rx); + loop { i += 1; let at = ::Instant::now() + self.interval; - C::AsyncRuntime::sleep_until(at).await; + let mut sleep_fut = AsyncRuntimeOf::::sleep_until(at); + let sleep_fut = std::pin::pin!(sleep_fut); + let cancel_fut = cancel.as_mut(); + + match futures::future::select(cancel_fut, sleep_fut).await { + Either::Left((_canceled, _)) => { + tracing::info!("TickLoop received cancel signal, quit"); + return; + } + Either::Right((_, _)) => { + // sleep done + } + } if !self.enabled.load(Ordering::Relaxed) { i -= 1; @@ -84,6 +113,69 @@ where C: RaftTypeConfig } pub(crate) async fn shutdown(&self) { - C::AsyncRuntime::abort(&self.join_handle); + let got = { + let mut x = self.cancel.lock().unwrap(); + x.take() + }; + + if let Some(cancel) = got { + let send_res = cancel.send(()); + tracing::info!("Timer cancel signal is sent, result is ok: {}", send_res.is_ok()); + } else { + tracing::info!("Timer cancel signal is already sent"); + } + } +} + +#[cfg(test)] +mod tests { + use std::io::Cursor; + + use tokio::time::Duration; + + use crate::core::Tick; + use crate::type_config::alias::AsyncRuntimeOf; + use crate::AsyncRuntime; + use crate::RaftTypeConfig; + use crate::TokioRuntime; + + #[derive(Debug, Clone, Copy, Default, Eq, PartialEq, Ord, PartialOrd)] + #[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] + pub(crate) struct TickUTConfig {} + impl RaftTypeConfig for TickUTConfig { + type D = (); + type R = (); + type NodeId = u64; + type Node = (); + type Entry = crate::Entry; + type SnapshotData = Cursor>; + type AsyncRuntime = TokioRuntime; + } + + // AsyncRuntime::spawn is `spawn_local` with singlethreaded enabled. + // It will result in a panic: + // `spawn_local` called from outside of a `task::LocalSet`. + #[cfg(not(feature = "singlethreaded"))] + #[tokio::test] + async fn test_shutdown() -> anyhow::Result<()> { + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); + let th = Tick::::spawn(Duration::from_millis(100), tx, true); + + AsyncRuntimeOf::::sleep(Duration::from_millis(500)).await; + th.shutdown().await; + AsyncRuntimeOf::::sleep(Duration::from_millis(500)).await; + + let mut received = vec![]; + while let Some(x) = rx.recv().await { + received.push(x); + } + + assert!( + received.len() < 10, + "no more tick will be received after shutdown: {}", + received.len() + ); + + Ok(()) } }