diff --git a/src/iter.rs b/src/iter.rs index 6e8bb019..a7d6dcf5 100644 --- a/src/iter.rs +++ b/src/iter.rs @@ -82,7 +82,7 @@ where let key = self.next_unexpired(now)?; self.list.push_back(key); let key = self.list.back()?; - let mut value = self.map.get_mut(&key)?; + let mut value = self.map.get_mut(key)?; value.1 = now; unsafe { @@ -225,7 +225,7 @@ where let now = Instant::now(); self.next_unexpired(now)?; let key = &self.list[self.item_index]; - let value = self.map.get(&key)?; + let value = self.map.get(key)?; unsafe { let key = std::mem::transmute::<&Key, &'a Key>(key); diff --git a/src/lib.rs b/src/lib.rs index d5a22a08..936e4e86 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -85,6 +85,7 @@ #[cfg(feature = "sn_fake_clock")] use sn_fake_clock::FakeClock as Instant; use std::borrow::Borrow; +use std::collections::btree_map::Entry as BTreeMapEntry; use std::collections::{BTreeMap, VecDeque}; use std::time::Duration; #[cfg(not(feature = "sn_fake_clock"))] @@ -325,6 +326,32 @@ where PeekIter::new(&self.map, &self.list, self.time_to_live) } + /// Retains only the elements specified by the predicate. Also removes expired elements + /// before passing them to the predicate. + pub fn retain(&mut self, mut f: F) + where + F: FnMut((&'_ Key, &'_ Value)) -> bool, + { + let (map, list) = (&mut self.map, &mut self.list); + + let now = Instant::now(); + let ttl = self.time_to_live; + + list.retain(|key| match map.entry(key.clone()) { + BTreeMapEntry::Occupied(entry) => { + if matches!(ttl, Some(ttl) if entry.get().1 + ttl < now) + || !f((key, &entry.get().0)) + { + let _ = entry.remove(); + false + } else { + true + } + } + BTreeMapEntry::Vacant(_) => false, + }); + } + // Move `key` in the ordered list to the last fn update_key(list: &mut VecDeque, key: &Q) where @@ -511,7 +538,7 @@ mod test { #[test] fn size_only() { let size = 10usize; - let mut lru_cache = super::LruCache::::with_capacity(size); + let mut lru_cache = LruCache::::with_capacity(size); for i in 0..10 { assert_eq!(lru_cache.len(), i); @@ -534,7 +561,7 @@ mod test { #[test] fn time_only() { let time_to_live = Duration::from_millis(100); - let mut lru_cache = super::LruCache::::with_expiry_duration(time_to_live); + let mut lru_cache = LruCache::::with_expiry_duration(time_to_live); for i in 0..10 { assert_eq!(lru_cache.len(), i); @@ -562,7 +589,7 @@ mod test { #[test] fn time_only_check() { let time_to_live = Duration::from_millis(50); - let mut lru_cache = super::LruCache::::with_expiry_duration(time_to_live); + let mut lru_cache = LruCache::::with_expiry_duration(time_to_live); assert_eq!(lru_cache.len(), 0); let _ = lru_cache.insert(0, 0); @@ -579,7 +606,7 @@ mod test { let size = 10usize; let time_to_live = Duration::from_millis(100); let mut lru_cache = - super::LruCache::::with_expiry_duration_and_capacity(time_to_live, size); + LruCache::::with_expiry_duration_and_capacity(time_to_live, size); for i in 0..1000 { if i < size { @@ -612,7 +639,7 @@ mod test { let time_to_live = Duration::from_millis(100); let mut lru_cache = - super::LruCache::::with_expiry_duration_and_capacity(time_to_live, size); + LruCache::::with_expiry_duration_and_capacity(time_to_live, size); for i in 0..1000 { if i < size { @@ -753,6 +780,25 @@ mod test { } } + mod retain { + use super::*; + + #[test] + fn it_removes_all_invalid_entries() { + let mut lru_cache = LruCache::::with_capacity(4); + let _ = lru_cache.insert(2, 2); + let _ = lru_cache.insert(0, 0); + let _ = lru_cache.insert(3, 3); + let _ = lru_cache.insert(1, 1); + + lru_cache.retain(|(_, &value)| value > 1); + + let cached = lru_cache.peek_iter().collect::>(); + + assert_eq!(cached, vec![(&3, &3), (&2, &2)]); + } + } + mod notify_iter { use super::*; @@ -839,7 +885,7 @@ mod test { #[test] fn it_yields_cached_entries_in_most_recently_used_order() { let time_to_live = Duration::from_millis(500); - let mut lru_cache = super::LruCache::::with_expiry_duration(time_to_live); + let mut lru_cache = LruCache::::with_expiry_duration(time_to_live); let _ = lru_cache.insert(1, 1); let _ = lru_cache.insert(2, 2); @@ -854,7 +900,7 @@ mod test { #[test] fn it_yields_only_unexpired_entries() { let time_to_live = Duration::from_millis(500); - let mut lru_cache = super::LruCache::::with_expiry_duration(time_to_live); + let mut lru_cache = LruCache::::with_expiry_duration(time_to_live); let _ = lru_cache.insert(1, 1); let _ = lru_cache.insert(2, 2); @@ -870,7 +916,7 @@ mod test { #[test] fn it_doesnt_modify_entry_update_time() { let time_to_live = Duration::from_millis(500); - let mut lru_cache = super::LruCache::::with_expiry_duration(time_to_live); + let mut lru_cache = LruCache::::with_expiry_duration(time_to_live); let _ = lru_cache.insert(1, 1); let expected_time = lru_cache @@ -895,7 +941,7 @@ mod test { #[test] fn update_time_check() { let time_to_live = Duration::from_millis(500); - let mut lru_cache = super::LruCache::::with_expiry_duration(time_to_live); + let mut lru_cache = LruCache::::with_expiry_duration(time_to_live); assert_eq!(lru_cache.len(), 0); let _ = lru_cache.insert(0, 0); @@ -911,9 +957,9 @@ mod test { #[test] fn deref_coercions() { - let mut lru_cache = super::LruCache::::with_capacity(1); + let mut lru_cache = LruCache::::with_capacity(1); let _ = lru_cache.insert("foo".to_string(), 0); - assert_eq!(true, lru_cache.contains_key("foo")); + assert!(lru_cache.contains_key("foo")); assert_eq!(Some(&0), lru_cache.get("foo")); assert_eq!(Some(&mut 0), lru_cache.get_mut("foo")); assert_eq!(Some(&0), lru_cache.peek("foo"));