Skip to content

Commit

Permalink
Per-entry expiration
Browse files Browse the repository at this point in the history
- Fix a use-after-free bug in unsafe code caused by a race condition between timer
  wheel's `advance()` operation and client's `invalidate` calls.
- To make unsafe code less error prone, wrap `DeqNodes` in `ValueEntry` with
  `triomphe::Arc` to share between threads, instead of cloning them.
- Addressed an ENHANCEME TODO in the `TimerWheel`.
  • Loading branch information
tatsuya6502 committed Apr 12, 2023
1 parent dc5fd49 commit 0842409
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 87 deletions.
1 change: 1 addition & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"ahash",
"armv",
"benmanes",
"CHECKME",
"circleci",
"CLFU",
"clippy",
Expand Down
32 changes: 17 additions & 15 deletions src/common/concurrent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,16 +189,6 @@ pub(crate) struct DeqNodes<K> {
timer_node: Option<DeqNodeTimer<K>>,
}

impl<K> Clone for DeqNodes<K> {
fn clone(&self) -> Self {
Self {
access_order_q_node: self.access_order_q_node,
write_order_q_node: self.write_order_q_node,
timer_node: self.timer_node,
}
}
}

impl<K> Default for DeqNodes<K> {
fn default() -> Self {
Self {
Expand All @@ -212,10 +202,16 @@ impl<K> Default for DeqNodes<K> {
// We need this `unsafe impl` as DeqNodes have NonNull pointers.
unsafe impl<K> Send for DeqNodes<K> {}

impl<K> DeqNodes<K> {
pub(crate) fn set_timer_node(&mut self, timer_node: Option<DeqNodeTimer<K>>) {
self.timer_node = timer_node;
}
}

pub(crate) struct ValueEntry<K, V> {
pub(crate) value: V,
info: TrioArc<EntryInfo<K>>,
nodes: Mutex<DeqNodes<K>>,
nodes: TrioArc<Mutex<DeqNodes<K>>>,
}

impl<K, V> ValueEntry<K, V> {
Expand All @@ -226,19 +222,17 @@ impl<K, V> ValueEntry<K, V> {
Self {
value,
info: entry_info,
nodes: Mutex::new(DeqNodes::default()),
nodes: TrioArc::new(Mutex::new(DeqNodes::default())),
}
}

pub(crate) fn new_from(value: V, entry_info: TrioArc<EntryInfo<K>>, other: &Self) -> Self {
#[cfg(feature = "unstable-debug-counters")]
self::debug_counters::InternalGlobalDebugCounters::value_entry_created();

let nodes = (*other.nodes.lock()).clone();
Self {
value,
info: entry_info,
nodes: Mutex::new(nodes),
nodes: TrioArc::clone(&other.nodes),
}
}

Expand Down Expand Up @@ -267,6 +261,10 @@ impl<K, V> ValueEntry<K, V> {
self.info.policy_weight()
}

pub(crate) fn deq_nodes(&self) -> &TrioArc<Mutex<DeqNodes<K>>> {
&self.nodes
}

pub(crate) fn access_order_q_node(&self) -> Option<KeyDeqNodeAo<K>> {
self.nodes.lock().access_order_q_node
}
Expand Down Expand Up @@ -299,6 +297,10 @@ impl<K, V> ValueEntry<K, V> {
self.nodes.lock().timer_node = node;
}

pub(crate) fn take_timer_node(&self) -> Option<DeqNodeTimer<K>> {
self.nodes.lock().timer_node.take()
}

pub(crate) fn unset_q_nodes(&self) {
let mut nodes = self.nodes.lock();
nodes.access_order_q_node = None;
Expand Down
75 changes: 59 additions & 16 deletions src/common/timer_wheel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@ use std::{
};

use super::{
concurrent::entry_info::EntryInfo,
concurrent::{entry_info::EntryInfo, DeqNodes},
deque::{DeqNode, Deque},
time::{CheckedTimeOps, Instant},
};

use parking_lot::Mutex;
use triomphe::Arc as TrioArc;

const BUCKET_COUNTS: &[u64] = &[
Expand Down Expand Up @@ -67,16 +68,25 @@ pub(crate) enum TimerNode<K> {
Entry {
level: AtomicU8, // When unset, we use `u8::MAX`.
index: AtomicU8, // When unset, we use `u8::MAX`.
/// The `EntryInfo` of the cache entry.
entry_info: TrioArc<EntryInfo<K>>,
/// The `DeqNodes` in the `ValueEntry` of the cache entry.
deq_nodes: TrioArc<Mutex<DeqNodes<K>>>,
},
}

impl<K> TimerNode<K> {
fn new(entry_info: TrioArc<EntryInfo<K>>, level: usize, index: usize) -> Self {
fn new(
entry_info: TrioArc<EntryInfo<K>>,
deq_nodes: TrioArc<Mutex<DeqNodes<K>>>,
level: usize,
index: usize,
) -> Self {
Self::Entry {
level: AtomicU8::new(level as u8),
index: AtomicU8::new(index as u8),
entry_info,
deq_nodes,
}
}

Expand All @@ -91,6 +101,14 @@ impl<K> TimerNode<K> {
unreachable!()
}
}

pub(crate) fn unset_timer_node_in_deq_nodes(&self) {
if let Self::Entry { deq_nodes, .. } = &self {
deq_nodes.lock().set_timer_node(None);
} else {
unreachable!();
}
}
}

type Bucket<K> = Deque<TimerNode<K>>;
Expand Down Expand Up @@ -167,12 +185,15 @@ impl<K> TimerWheel<K> {
pub(crate) fn schedule(
&mut self,
entry_info: TrioArc<EntryInfo<K>>,
deq_nodes: TrioArc<Mutex<DeqNodes<K>>>,
) -> Option<NonNull<DeqNode<TimerNode<K>>>> {
debug_assert!(self.is_enabled());

if let Some(t) = entry_info.expiration_time() {
let (level, index) = self.bucket_indices(t);
let node = Box::new(DeqNode::new(TimerNode::new(entry_info, level, index)));
let node = Box::new(DeqNode::new(TimerNode::new(
entry_info, deq_nodes, level, index,
)));
let node = self.wheels[level][index].push_back(node);
Some(node)
} else {
Expand All @@ -192,6 +213,7 @@ impl<K> TimerWheel<K> {
level,
index,
entry_info,
deq_nodes,
} = &unsafe { node.as_ref() }.element
{
if let Some(t) = entry_info.expiration_time() {
Expand All @@ -205,6 +227,7 @@ impl<K> TimerWheel<K> {
// Unset the level and index.
level.store(u8::MAX, Ordering::Release);
index.store(u8::MAX, Ordering::Release);
deq_nodes.lock().set_timer_node(None);
ReschedulingResult::Removed(unsafe { Box::from_raw(node.as_ptr()) })
}
} else {
Expand Down Expand Up @@ -237,11 +260,13 @@ impl<K> TimerWheel<K> {
unsafe fn unlink_timer(&mut self, node: NonNull<DeqNode<TimerNode<K>>>) {
let p = node.as_ref();
if let TimerNode::Entry { level, index, .. } = &p.element {
let level = level.load(Ordering::Acquire);
let index = index.load(Ordering::Acquire);
if level != u8::MAX && index != u8::MAX {
self.wheels[level as usize][index as usize].unlink(node);
let lev = level.load(Ordering::Acquire);
let idx = index.load(Ordering::Acquire);
if lev != u8::MAX && idx != u8::MAX {
self.wheels[lev as usize][idx as usize].unlink(node);
}
level.store(u8::MAX, Ordering::Release);
index.store(u8::MAX, Ordering::Release);
} else {
unreachable!();
}
Expand Down Expand Up @@ -328,17 +353,29 @@ impl<K> TimerWheel<K> {
.as_nanos() as u64
}

// Returns nano-seconds between the given `time` and the time when this timer
// wheel was created. If the `time` is earlier than other, returns zero.
// Returns nano-seconds between the given `time` and `self.origin`, the time when
// this timer wheel was created.
//
// - If the `time` is earlier than other, returns zero.
// - If the `time` is later than `self.origin + u64::MAX`, returns `u64::MAX`,
// which is ~584 years in nanoseconds.
//
fn time_nanos(&self, time: Instant) -> u64 {
time.checked_duration_since(self.origin)
// `TryInto` will be in the prelude starting in Rust 2021 Edition.
use std::convert::TryInto;

let nanos_u128 = time
.checked_duration_since(self.origin)
// If `time` is earlier than `self.origin`, use zero. This would never
// happen in practice as there should be some delay between the timer
// wheel was created and the first timer event is scheduled. But we will
// do this just in case.
.unwrap_or_default() // Assuming `Duration::default()` returns `ZERO`.
// TODO ENHANCEME: Check overflow? (u128 -> u64)
.as_nanos() as u64
.as_nanos();

// Convert an `u128` into an `u64`. If the value is too large, use `u64::MAX`
// (~584 years)
nanos_u128.try_into().unwrap_or(u64::MAX)
}
}

Expand Down Expand Up @@ -405,6 +442,8 @@ impl<'iter, K> Drop for TimerEventsIter<'iter, K> {
impl<'iter, K> Iterator for TimerEventsIter<'iter, K> {
type Item = TimerEvent<K>;

/// NOTE: When necessary, the iterator returned from advance() will unset the
/// timer node in the `ValueEntry`.
fn next(&mut self) -> Option<Self::Item> {
if self.is_done {
return None;
Expand Down Expand Up @@ -454,7 +493,9 @@ impl<'iter, K> Iterator for TimerEventsIter<'iter, K> {
let expiration_time = node.as_ref().element.entry_info().expiration_time();
if let Some(t) = expiration_time {
if t <= self.current_time {
// The cache entry has expired. Return it.
// The cache entry has expired. Unset the timer node from
// the ValueEntry and return the node.
node.as_ref().element.unset_timer_node_in_deq_nodes();
return Some(TimerEvent::Expired(node));
} else {
// The cache entry has not expired. Reschedule it.
Expand All @@ -469,8 +510,8 @@ impl<'iter, K> Iterator for TimerEventsIter<'iter, K> {
}
ReschedulingResult::Removed(node) => {
// The timer event has been removed from the timer
// wheel. Return it, so that the caller can remove the
// pointer to the node from a `ValueEntry`.
// wheel. Unset the timer node from the ValueEntry.
node.as_ref().element.unset_timer_node_in_deq_nodes();
return Some(TimerEvent::Descheduled(node));
}
}
Expand Down Expand Up @@ -600,7 +641,9 @@ mod tests {
let policy_weight = 0;
let entry_info = TrioArc::new(EntryInfo::new(key_hash, now, policy_weight));
entry_info.set_expiration_time(Some(now.checked_add(ttl).unwrap()));
timer.schedule(entry_info);
let deq_nodes = Default::default();
let timer_node = timer.schedule(entry_info, TrioArc::clone(&deq_nodes));
deq_nodes.lock().set_timer_node(timer_node);
}

fn expired_key(maybe_entry: Option<TimerEvent<u32>>) -> u32 {
Expand Down
Loading

0 comments on commit 0842409

Please sign in to comment.