diff --git a/src/lru_cache.rs b/src/lru_cache.rs index 9ef36b5..0c433d7 100644 --- a/src/lru_cache.rs +++ b/src/lru_cache.rs @@ -14,12 +14,19 @@ pub use crate::linked_hash_map::{ RawOccupiedEntryMut, RawVacantEntryMut, VacantEntry, }; -pub struct LruCache { - map: LinkedHashMap, - max_size: usize, +pub struct LruCache +where + K: Eq + Hash, + S: BuildHasher + Default, +{ + pub(crate) map: LinkedHashMap, + pub(crate) max_size: usize, } -impl LruCache { +impl LruCache +where + K: Eq + Hash, +{ #[inline] pub fn new(capacity: usize) -> Self { LruCache { @@ -37,7 +44,31 @@ impl LruCache { } } -impl LruCache { +impl PartialEq for LruCache +where + K: Eq + Hash, + V: Eq, + S: BuildHasher + Default, +{ + #[inline] + fn eq(&self, other: &Self) -> bool { + self.len() == other.len() && self.capacity() == other.capacity() && self.iter().eq(other) + } +} + +impl Eq for LruCache +where + K: Eq + Hash, + V: Eq, + S: BuildHasher + Default, +{ +} + +impl LruCache +where + K: Eq + Hash, + S: BuildHasher + Default, +{ #[inline] pub fn with_hasher(capacity: usize, hash_builder: S) -> Self { LruCache { @@ -82,9 +113,10 @@ impl LruCache { } } -impl LruCache +impl LruCache where - S: BuildHasher, + K: Eq + Hash, + S: BuildHasher + Default, { #[inline] pub fn contains_key(&mut self, key: &Q) -> bool @@ -107,6 +139,26 @@ where old_val } + /// Insert a new value into the `LruCache` with LRU callback. + /// + /// If necessary, will remove the value at the front of the LRU list to make room. + /// Calls a callback if there was an LRU removed value + #[inline] + pub fn insert_with_callback( + &mut self, + k: K, + v: V, + remove_lru_callback: F, + ) -> Option { + let old_val = self.map.insert(k, v); + if self.len() > self.capacity() { + if let Some(x) = self.remove_lru() { + remove_lru_callback(x.0, x.1) + } + } + old_val + } + /// Get the value for the given key, *without* marking the value as recently used and moving it /// to the back of the LRU list. #[inline] @@ -171,6 +223,29 @@ where self.map.entry(key) } + /// Like `entry` but with LRU callback. + /// If the returned entry is vacant, it will always have room to insert a single value. By + /// using the entry API, you can exceed the configured capacity by 1. + /// + /// The returned entry is not automatically moved to the back of the LRU list. By calling + /// `Entry::to_back` / `Entry::to_front` you can manually control the position of this entry in + /// the LRU list. + /// Calls a callback if there was an LRU removed value + + #[inline] + pub fn entry_with_callback( + &mut self, + key: K, + remove_lru_callback: F, + ) -> Entry<'_, K, V, S> { + if self.len() > self.capacity() { + if let Some(x) = self.remove_lru() { + remove_lru_callback(x.0, x.1) + } + } + self.map.entry(key) + } + /// The constructed raw entry is never automatically moved to the back of the LRU list. By /// calling `Entry::to_back` / `Entry::to_front` you can manually control the position of this /// entry in the LRU list. @@ -185,6 +260,8 @@ where /// The constructed raw entry is never automatically moved to the back of the LRU list. By /// calling `Entry::to_back` / `Entry::to_front` you can manually control the position of this /// entry in the LRU list. + /// Calls a callback if there was an LRU removed value + #[inline] pub fn raw_entry_mut(&mut self) -> RawEntryBuilderMut<'_, K, V, S> { if self.len() > self.capacity() { @@ -193,6 +270,26 @@ where self.map.raw_entry_mut() } + /// Like `raw_entry` but with LRU callback. + /// If the constructed raw entry is vacant, it will always have room to insert a single value. + /// By using the raw entry API, you can exceed the configured capacity by 1. + /// + /// The constructed raw entry is never automatically moved to the back of the LRU list. By + /// calling `Entry::to_back` / `Entry::to_front` you can manually control the position of this + /// entry in the LRU list. + #[inline] + pub fn raw_entry_mut_with_callback( + &mut self, + remove_lru_callback: F, + ) -> RawEntryBuilderMut<'_, K, V, S> { + if self.len() > self.capacity() { + if let Some(x) = self.remove_lru() { + remove_lru_callback(x.0, x.1) + } + } + self.map.raw_entry_mut() + } + #[inline] pub fn remove(&mut self, k: &Q) -> Option where @@ -223,6 +320,26 @@ where self.max_size = capacity; } + /// Like `set_capacity` but with LRU callback. + /// Set the new cache capacity for the `LruCache` with an LRU callback. + /// + /// If there are more entries in the `LruCache` than the new capacity will allow, they are + /// removed. + /// Calls a callback if there was an LRU removed value + #[inline] + pub fn set_capacity_with_callback( + &mut self, + capacity: usize, + remove_lru_callback: F, + ) { + for _ in capacity..self.len() { + if let Some(x) = self.remove_lru() { + remove_lru_callback(x.0, x.1) + } + } + self.max_size = capacity; + } + /// Remove the least recently used entry and return it. /// /// If the `LruCache` is empty this will return None. @@ -230,9 +347,17 @@ where pub fn remove_lru(&mut self) -> Option<(K, V)> { self.map.pop_front() } + + /// Peek at the least recently used entry and return a reference to it. + /// + /// If the `LruCache` is empty this will return None. + #[inline] + pub fn peek_lru(&mut self) -> Option<(&K, &V)> { + self.map.front() + } } -impl Clone for LruCache { +impl Clone for LruCache { #[inline] fn clone(&self) -> Self { LruCache { @@ -242,16 +367,24 @@ impl Clone for LruCache< } } -impl Extend<(K, V)> for LruCache { +impl Extend<(K, V)> for LruCache { #[inline] fn extend>(&mut self, iter: I) { for (k, v) in iter { - self.insert(k, v); + //self.insert(k, v); + self.map.insert(k, v); + if self.len() > self.capacity() { + self.remove_lru(); + } } } } -impl IntoIterator for LruCache { +impl IntoIterator for LruCache +where + K: Eq + Hash, + S: BuildHasher + Default, +{ type Item = (K, V); type IntoIter = IntoIter; @@ -261,7 +394,11 @@ impl IntoIterator for LruCache { } } -impl<'a, K, V, S> IntoIterator for &'a LruCache { +impl<'a, K, V, S> IntoIterator for &'a LruCache +where + K: Eq + Hash, + S: BuildHasher + Default, +{ type Item = (&'a K, &'a V); type IntoIter = Iter<'a, K, V>; @@ -271,7 +408,11 @@ impl<'a, K, V, S> IntoIterator for &'a LruCache { } } -impl<'a, K, V, S> IntoIterator for &'a mut LruCache { +impl<'a, K, V, S> IntoIterator for &'a mut LruCache +where + K: Eq + Hash, + S: BuildHasher + Default, +{ type Item = (&'a K, &'a mut V); type IntoIter = IterMut<'a, K, V>; @@ -283,8 +424,9 @@ impl<'a, K, V, S> IntoIterator for &'a mut LruCache { impl fmt::Debug for LruCache where - K: fmt::Debug, + K: Eq + Hash + fmt::Debug, V: fmt::Debug, + S: BuildHasher + Default, { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_map().entries(self.iter().rev()).finish() diff --git a/src/serde.rs b/src/serde.rs index 57c3b16..dc000ab 100644 --- a/src/serde.rs +++ b/src/serde.rs @@ -5,12 +5,12 @@ use core::{ }; use serde::{ - de::{MapAccess, SeqAccess, Visitor}, - ser::{SerializeMap, SerializeSeq}, + de::{self, MapAccess, SeqAccess, Visitor}, + ser::{SerializeMap, SerializeSeq, SerializeStruct}, Deserialize, Deserializer, Serialize, Serializer, }; -use crate::{LinkedHashMap, LinkedHashSet}; +use crate::{LinkedHashMap, LinkedHashSet, LruCache}; // LinkedHashMap impls @@ -159,3 +159,154 @@ where deserializer.deserialize_seq(LinkedHashSetVisitor::default()) } } + +// LruCache impls + +impl Serialize for LruCache +where + K: Serialize + Eq + Hash, + V: Serialize, + S: BuildHasher + Default, +{ + #[inline] + fn serialize(&self, serializer: T) -> Result { + let mut state = serializer.serialize_struct("LruCache", 2)?; + state.serialize_field("map", &self.map)?; + state.serialize_field("max_size", &self.max_size)?; + state.end() + } +} + +impl<'de, K, V, S> Deserialize<'de> for LruCache +where + K: Deserialize<'de> + Eq + Hash, + V: Deserialize<'de>, + S: BuildHasher + Default, +{ + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + enum Field { + Map, + MaxSize, + } + + impl<'de> Deserialize<'de> for Field { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct FieldVisitor; + + impl<'de> Visitor<'de> for FieldVisitor { + type Value = Field; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("`map` or `max_size`") + } + + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + match value { + "map" => Ok(Field::Map), + "max_size" => Ok(Field::MaxSize), + _ => Err(de::Error::unknown_field(value, FIELDS)), + } + } + } + + deserializer.deserialize_identifier(FieldVisitor) + } + } + + #[derive(Debug)] + struct LruCacheVisitor + where + K: Eq + Hash, + S: BuildHasher + Default, + { + marker: PhantomData>, + } + + impl LruCacheVisitor + where + K: Eq + Hash, + S: BuildHasher + Default, + { + fn new() -> Self { + LruCacheVisitor { + marker: PhantomData, + } + } + } + + impl Default for LruCacheVisitor + where + K: Eq + Hash, + S: BuildHasher + Default, + { + fn default() -> Self { + Self::new() + } + } + + impl<'de, K, V, S> Visitor<'de> for LruCacheVisitor + where + K: Deserialize<'de> + Eq + Hash, + V: Deserialize<'de>, + S: BuildHasher + Default, + { + type Value = LruCache; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("struct LruCache") + } + + fn visit_seq(self, mut outseq: M) -> Result, M::Error> + where + M: SeqAccess<'de>, + { + let map = outseq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(0, &self))?; + let max_size = outseq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(1, &self))?; + Ok(LruCache:: { map, max_size }) + } + + fn visit_map(self, mut outmap: M) -> Result, M::Error> + where + M: MapAccess<'de>, + { + let mut map = None; + let mut max_size = None; + while let Some(key) = outmap.next_key()? { + match key { + Field::Map => { + if map.is_some() { + return Err(de::Error::duplicate_field("map")); + } + map = Some(outmap.next_value()?); + } + Field::MaxSize => { + if max_size.is_some() { + return Err(de::Error::duplicate_field("max_size")); + } + max_size = Some(outmap.next_value()?); + } + } + } + let map = map.ok_or_else(|| de::Error::missing_field("map"))?; + let max_size = max_size.ok_or_else(|| de::Error::missing_field("max_size"))?; + Ok(LruCache:: { map, max_size }) + } + } + + const FIELDS: &'static [&'static str] = &["map", "max_size"]; + deserializer.deserialize_struct("LruCache", FIELDS, LruCacheVisitor::default()) + } +} diff --git a/tests/serde.rs b/tests/serde.rs index 2cf4a3e..4c6aab0 100644 --- a/tests/serde.rs +++ b/tests/serde.rs @@ -1,10 +1,22 @@ #![cfg(feature = "serde_impl")] -use std::hash::BuildHasherDefault; - -use hashlink::{LinkedHashMap, LinkedHashSet}; +use hashlink::{LinkedHashMap, LinkedHashSet, LruCache}; use rustc_hash::FxHasher; use serde_test::{assert_tokens, Token}; +use std::hash::BuildHasherDefault; + +#[cfg(target_pointer_width = "64")] +fn token_usize(t: usize) -> Token { + Token::U64(t as u64) +} +#[cfg(target_pointer_width = "32")] +fn token_usize(t: usize) -> Token { + Token::U32(t as u32) +} +#[cfg(target_pointer_width = "16")] +fn token_usize(t: usize) -> Token { + Token::U16(t as u16) +} #[test] fn map_serde_tokens_empty() { @@ -108,3 +120,108 @@ fn set_serde_tokens_generic() { ], ); } + +#[test] +fn lru_serde_tokens_empty() { + let map = LruCache::::new(16); + + assert_tokens( + &map, + &[ + Token::Struct { + name: "LruCache", + len: 2, + }, + Token::Str("map"), + Token::Map { len: Some(0) }, + Token::MapEnd, + Token::Str("max_size"), + token_usize(16), + Token::StructEnd, + ], + ); +} + +#[test] +fn lru_serde_tokens() { + let mut map = LruCache::new(16); + map.insert('a', 10); + map.insert('b', 20); + map.insert('c', 30); + + assert_tokens( + &map, + &[ + Token::Struct { + name: "LruCache", + len: 2, + }, + Token::Str("map"), + Token::Map { len: Some(3) }, + Token::Char('a'), + Token::I32(10), + Token::Char('b'), + Token::I32(20), + Token::Char('c'), + Token::I32(30), + Token::MapEnd, + Token::Str("max_size"), + token_usize(16), + Token::StructEnd, + ], + ); +} + +#[test] +fn lru_serde_tokens_empty_generic() { + let map = LruCache::>::with_hasher( + 16, + BuildHasherDefault::::default(), + ); + + assert_tokens( + &map, + &[ + Token::Struct { + name: "LruCache", + len: 2, + }, + Token::Str("map"), + Token::Map { len: Some(0) }, + Token::MapEnd, + Token::Str("max_size"), + token_usize(16), + Token::StructEnd, + ], + ); +} + +#[test] +fn lru_serde_tokens_generic() { + let mut map = LruCache::with_hasher(16, BuildHasherDefault::::default()); + map.insert('a', 10); + map.insert('b', 20); + map.insert('c', 30); + + assert_tokens( + &map, + &[ + Token::Struct { + name: "LruCache", + len: 2, + }, + Token::Str("map"), + Token::Map { len: Some(3) }, + Token::Char('a'), + Token::I32(10), + Token::Char('b'), + Token::I32(20), + Token::Char('c'), + Token::I32(30), + Token::MapEnd, + Token::Str("max_size"), + token_usize(16), + Token::StructEnd, + ], + ); +}