diff --git a/compio-runtime/src/runtime/mod.rs b/compio-runtime/src/runtime/mod.rs index 717995e5..e9dcba61 100644 --- a/compio-runtime/src/runtime/mod.rs +++ b/compio-runtime/src/runtime/mod.rs @@ -33,7 +33,7 @@ mod send_wrapper; use send_wrapper::SendWrapper; #[cfg(feature = "time")] -use crate::runtime::time::{TimerFuture, TimerRuntime}; +use crate::runtime::time::{TimerFuture, TimerKey, TimerRuntime}; use crate::{BufResult, affinity::bind_to_cpu_set, runtime::op::OpFuture}; scoped_tls::scoped_thread_local!(static CURRENT_RUNTIME: Runtime); @@ -313,7 +313,7 @@ impl Runtime { } #[cfg(feature = "time")] - pub(crate) fn cancel_timer(&self, key: usize) { + pub(crate) fn cancel_timer(&self, key: &TimerKey) { self.timer_runtime.borrow_mut().cancel(key); } @@ -331,16 +331,16 @@ impl Runtime { } #[cfg(feature = "time")] - pub(crate) fn poll_timer(&self, cx: &mut Context, key: usize) -> Poll<()> { + pub(crate) fn poll_timer(&self, cx: &mut Context, key: &TimerKey) -> Poll<()> { instrument!(compio_log::Level::DEBUG, "poll_timer", ?cx, ?key); let mut timer_runtime = self.timer_runtime.borrow_mut(); - if !timer_runtime.is_completed(key) { + if timer_runtime.remove_completed(key) { + debug!("ready"); + Poll::Ready(()) + } else { debug!("pending"); timer_runtime.update_waker(key, cx.waker().clone()); Poll::Pending - } else { - debug!("ready"); - Poll::Ready(()) } } diff --git a/compio-runtime/src/runtime/time.rs b/compio-runtime/src/runtime/time.rs index e96b6b44..021513e3 100644 --- a/compio-runtime/src/runtime/time.rs +++ b/compio-runtime/src/runtime/time.rs @@ -1,15 +1,13 @@ use std::{ - cmp::Reverse, - collections::BinaryHeap, + collections::BTreeMap, future::Future, marker::PhantomData, + mem::replace, pin::Pin, task::{Context, Poll, Waker}, time::{Duration, Instant}, }; -use slab::Slab; - use crate::runtime::Runtime; pub(crate) enum FutureState { @@ -23,118 +21,101 @@ impl Default for FutureState { } } -#[derive(Debug)] -struct TimerEntry { - key: usize, - delay: Duration, -} - -impl PartialEq for TimerEntry { - fn eq(&self, other: &Self) -> bool { - self.delay == other.delay - } -} - -impl Eq for TimerEntry {} - -impl PartialOrd for TimerEntry { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl Ord for TimerEntry { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - self.delay.cmp(&other.delay) - } +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub struct TimerKey { + deadline: Instant, + key: u64, + _local_marker: PhantomData<*const ()>, } pub struct TimerRuntime { - time: Instant, - tasks: Slab, - wheel: BinaryHeap>, + key: u64, + wheel: BTreeMap, } impl TimerRuntime { pub fn new() -> Self { Self { - time: Instant::now(), - tasks: Slab::default(), - wheel: BinaryHeap::default(), + key: 0, + wheel: BTreeMap::default(), } } - pub fn is_completed(&self, key: usize) -> bool { - self.tasks + /// If the timer is completed, remove it and return `true`. Otherwise return + /// `false` and keep it. + pub fn remove_completed(&mut self, key: &TimerKey) -> bool { + let completed = self + .wheel .get(key) .map(|state| matches!(state, FutureState::Completed)) - .unwrap_or_default() + .unwrap_or_default(); + if completed { + self.wheel.remove(key); + } + completed } - pub fn insert(&mut self, instant: Instant) -> Option { - let delay = instant - self.time; - if delay <= self.time.elapsed() { + /// Insert a new timer. If the deadline is in the past, return `None`. + pub fn insert(&mut self, deadline: Instant) -> Option { + if deadline <= Instant::now() { return None; } - let key = self.tasks.insert(FutureState::Active(None)); - let entry = TimerEntry { key, delay }; - self.wheel.push(Reverse(entry)); + let key = TimerKey { + key: self.key, + deadline, + _local_marker: PhantomData, + }; + self.wheel.insert(key, FutureState::default()); + + self.key += 1; + Some(key) } - pub fn update_waker(&mut self, key: usize, waker: Waker) { - if let Some(w) = self.tasks.get_mut(key) { + /// Update the waker for a timer. + pub fn update_waker(&mut self, key: &TimerKey, waker: Waker) { + if let Some(w) = self.wheel.get_mut(key) { *w = FutureState::Active(Some(waker)); } } - pub fn cancel(&mut self, key: usize) { - self.tasks.remove(key); + /// Cancel a timer. + pub fn cancel(&mut self, key: &TimerKey) { + self.wheel.remove(key); } + /// Get the minimum timeout duration for the next poll. pub fn min_timeout(&self) -> Option { - self.wheel.peek().map(|entry| { - let elapsed = self.time.elapsed(); - if entry.0.delay > elapsed { - entry.0.delay - elapsed - } else { - Duration::ZERO - } + self.wheel.first_key_value().map(|(key, _)| { + let now = Instant::now(); + key.deadline.saturating_duration_since(now) }) } + /// Wake all the timer futures that have reached their deadline. pub fn wake(&mut self) { if self.wheel.is_empty() { return; } - let elapsed = self.time.elapsed(); - while let Some(entry) = self.wheel.pop() { - if entry.0.delay <= elapsed { - if let Some(state) = self.tasks.get_mut(entry.0.key) { - let old_state = std::mem::replace(state, FutureState::Completed); - if let FutureState::Active(Some(waker)) = old_state { - waker.wake(); - } + + let now = Instant::now(); + + self.wheel + .iter_mut() + .take_while(|(k, _)| k.deadline <= now) + .for_each(|(_, v)| { + if let FutureState::Active(Some(waker)) = replace(v, FutureState::Completed) { + waker.wake(); } - } else { - self.wheel.push(entry); - break; - } - } + }); } } -pub struct TimerFuture { - key: usize, - _local_marker: PhantomData<*const ()>, -} +pub struct TimerFuture(TimerKey); impl TimerFuture { - pub fn new(key: usize) -> Self { - Self { - key, - _local_marker: PhantomData, - } + pub fn new(key: TimerKey) -> Self { + Self(key) } } @@ -142,13 +123,13 @@ impl Future for TimerFuture { type Output = (); fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - Runtime::with_current(|r| r.poll_timer(cx, self.key)) + Runtime::with_current(|r| r.poll_timer(cx, &self.0)) } } impl Drop for TimerFuture { fn drop(&mut self) { - Runtime::with_current(|r| r.cancel_timer(self.key)); + Runtime::with_current(|r| r.cancel_timer(&self.0)); } }