diff --git a/src/lib.rs b/src/lib.rs index 0a558b2e..28cd7e4c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -204,6 +204,7 @@ #![warn(rust_2018_idioms)] mod map; +mod map_ref; mod node; mod raw; @@ -211,6 +212,7 @@ mod raw; pub mod iter; pub use map::HashMap; +pub use map_ref::HashMapRef; /// Default hasher for [`HashMap`]. pub type DefaultHashBuilder = ahash::RandomState; diff --git a/src/map.rs b/src/map.rs index aa56f6e6..fb4ebe9c 100644 --- a/src/map.rs +++ b/src/map.rs @@ -1587,7 +1587,7 @@ where { self.check_guard(guard); // removed selected keys - for (k, v) in self.iter(&guard) { + for (k, v) in self.iter(guard) { if !f(k, v) { let old_value: Shared<'_, V> = Shared::from(v as *const V); self.replace_node(k, None, Some(old_value), guard); @@ -1607,7 +1607,7 @@ where { self.check_guard(guard); // removed selected keys - for (k, v) in self.iter(&guard) { + for (k, v) in self.iter(guard) { if !f(k, v) { self.replace_node(k, None, None, guard); } @@ -1674,6 +1674,18 @@ where pub fn is_empty(&self) -> bool { self.len() == 0 } + + pub(crate) fn guarded_eq(&self, other: &Self, our_guard: &Guard, their_guard: &Guard) -> bool + where + V: PartialEq, + { + if self.len() != other.len() { + return false; + } + + self.iter(our_guard) + .all(|(key, value)| other.get(key, their_guard).map_or(false, |v| *value == *v)) + } } impl PartialEq for HashMap @@ -1686,11 +1698,7 @@ where if self.len() != other.len() { return false; } - - let our_guard = self.collector.register().pin(); - let their_guard = other.collector.register().pin(); - self.iter(&our_guard) - .all(|(key, value)| other.get(key, &their_guard).map_or(false, |v| *value == *v)) + self.guarded_eq(other, &self.guard(), &other.guard()) } } diff --git a/src/map_ref.rs b/src/map_ref.rs new file mode 100644 index 00000000..c1867de6 --- /dev/null +++ b/src/map_ref.rs @@ -0,0 +1,272 @@ +use crate::iter::*; +use crate::HashMap; +use crossbeam_epoch::Guard; +use std::borrow::Borrow; +use std::fmt::{self, Debug, Formatter}; +use std::hash::{BuildHasher, Hash}; +use std::ops::{Deref, Index}; + +/// A reference to a [`HashMap`], constructed with [`HashMap::pin`] or [`HashMap::with_guard`]. +/// +/// The current thread will be pinned for the duration of this reference. +/// Keep in mind that this prevents the collection of garbage generated by the map. +pub struct HashMapRef<'map, K: 'static, V: 'static, S = crate::DefaultHashBuilder> { + map: &'map HashMap, + guard: GuardRef<'map>, +} + +enum GuardRef<'g> { + Owned(Guard), + Ref(&'g Guard), +} + +impl Deref for GuardRef<'_> { + type Target = Guard; + + #[inline] + fn deref(&self) -> &Guard { + match *self { + GuardRef::Owned(ref guard) | GuardRef::Ref(&ref guard) => guard, + } + } +} + +impl HashMap +where + K: Sync + Send + Clone + Hash + Eq, + V: Sync + Send, + S: BuildHasher, +{ + /// Get a reference to this map with the current thread pinned. + /// + /// Keep in mind that for as long as you hold onto this, you are preventing the collection of + /// garbage generated by the map. + pub fn pin(&self) -> HashMapRef<'_, K, V, S> { + HashMapRef { + guard: GuardRef::Owned(self.guard()), + map: &self, + } + } + + /// Get a reference to this map with the given guard. + pub fn with_guard<'g>(&'g self, guard: &'g Guard) -> HashMapRef<'g, K, V, S> { + HashMapRef { + map: &self, + guard: GuardRef::Ref(guard), + } + } +} + +impl HashMapRef<'_, K, V, S> +where + K: Sync + Send + Clone + Hash + Eq, + V: Sync + Send, + S: BuildHasher, +{ + /// Tests if `key` is a key in this table. + /// See also [`HashMap::contains_key`]. + pub fn contains_key(&self, key: &Q) -> bool + where + K: Borrow, + Q: ?Sized + Hash + Eq, + { + self.map.contains_key(key, &self.guard) + } + + /// Returns the value to which `key` is mapped. + /// See also [`HashMap::get`]. + pub fn get<'g, Q>(&'g self, key: &Q) -> Option<&'g V> + where + K: Borrow, + Q: ?Sized + Hash + Eq, + { + self.map.get(key, &self.guard) + } + + /// Returns the key-value pair corresponding to `key`. + /// See also [`HashMap::get_key_value`]. + pub fn get_key_value<'g, Q>(&'g self, key: &Q) -> Option<(&'g K, &'g V)> + where + K: Borrow, + Q: ?Sized + Hash + Eq, + { + self.map.get_key_value(key, &self.guard) + } + + /// Maps `key` to `value` in this table. + /// See also [`HashMap::insert`]. + pub fn insert<'g>(&'g self, key: K, value: V) -> Option<&'g V> { + self.map.insert(key, value, &self.guard) + } + + /// If the value for the specified `key` is present, attempts to + /// compute a new mapping given the key and its current mapped value. + /// See also [`HashMap::compute_if_present`]. + pub fn compute_if_present<'g, Q, F>(&'g self, key: &Q, remapping_function: F) -> Option<&'g V> + where + K: Borrow, + Q: ?Sized + Hash + Eq, + F: FnOnce(&K, &V) -> Option, + { + self.map + .compute_if_present(key, remapping_function, &self.guard) + } + + /// Tries to reserve capacity for at least additional more elements. + /// See also [`HashMap::reserve`]. + pub fn reserve(&self, additional: usize) { + self.map.reserve(additional, &self.guard) + } + + /// Removes the key (and its corresponding value) from this map. + /// See also [`HashMap::remove`]. + pub fn remove<'g, Q>(&'g self, key: &Q) -> Option<&'g V> + where + K: Borrow, + Q: ?Sized + Hash + Eq, + { + self.map.remove(key, &self.guard) + } + + /// Retains only the elements specified by the predicate. + /// See also [`HashMap::retain`]. + pub fn retain(&self, f: F) + where + F: FnMut(&K, &V) -> bool, + { + self.map.retain(f, &self.guard); + } + + /// Retains only the elements specified by the predicate. + /// See also [`HashMap::retain_force`]. + pub fn retain_force(&self, f: F) + where + F: FnMut(&K, &V) -> bool, + { + self.map.retain_force(f, &self.guard); + } + + /// An iterator visiting all key-value pairs in arbitrary order. + /// The iterator element type is `(&'g K, &'g V)`. + /// See also [`HashMap::iter`]. + pub fn iter<'g>(&'g self) -> Iter<'g, K, V> { + self.map.iter(&self.guard) + } + + /// An iterator visiting all keys in arbitrary order. + /// The iterator element type is `&'g K`. + /// See also [`HashMap::keys`]. + pub fn keys<'g>(&'g self) -> Keys<'g, K, V> { + self.map.keys(&self.guard) + } + + /// An iterator visiting all values in arbitrary order. + /// The iterator element type is `&'g V`. + /// See also [`HashMap::values`]. + pub fn values<'g>(&'g self) -> Values<'g, K, V> { + self.map.values(&self.guard) + } + + /// Returns the number of entries in the map. + /// See also [`HashMap::len`]. + pub fn len(&self) -> usize { + self.map.len() + } + + /// Returns `true` if the map is empty. Otherwise returns `false`. + /// See also [`HashMap::is_empty`]. + pub fn is_empty(&self) -> bool { + self.map.is_empty() + } +} + +impl<'g, K, V, S> IntoIterator for &'g HashMapRef<'_, K, V, S> +where + K: Sync + Send + Clone + Hash + Eq, + V: Sync + Send, + S: BuildHasher, +{ + type IntoIter = Iter<'g, K, V>; + type Item = (&'g K, &'g V); + + fn into_iter(self) -> Self::IntoIter { + self.map.iter(&self.guard) + } +} + +impl Debug for HashMapRef<'_, K, V, S> +where + K: Sync + Send + Clone + Hash + Eq + Debug, + V: Sync + Send + Debug, + S: BuildHasher, +{ + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_map().entries(self).finish() + } +} + +impl Clone for HashMapRef<'_, K, V, S> +where + K: Sync + Send + Clone + Hash + Eq, + V: Sync + Send, + S: BuildHasher, +{ + fn clone(&self) -> Self { + self.map.pin() + } +} + +impl PartialEq for HashMapRef<'_, K, V, S> +where + K: Sync + Send + Clone + Hash + Eq, + V: Sync + Send + PartialEq, + S: BuildHasher, +{ + fn eq(&self, other: &Self) -> bool { + self.map.guarded_eq(&other.map, &self.guard, &other.guard) + } +} + +impl PartialEq> for HashMapRef<'_, K, V, S> +where + K: Sync + Send + Clone + Hash + Eq, + V: Sync + Send + PartialEq, + S: BuildHasher, +{ + fn eq(&self, other: &HashMap) -> bool { + self.map.guarded_eq(&other, &self.guard, &other.guard()) + } +} + +impl PartialEq> for HashMap +where + K: Sync + Send + Clone + Hash + Eq, + V: Sync + Send + PartialEq, + S: BuildHasher, +{ + fn eq(&self, other: &HashMapRef<'_, K, V, S>) -> bool { + self.guarded_eq(&other.map, &self.guard(), &other.guard) + } +} + +impl Eq for HashMapRef<'_, K, V, S> +where + K: Sync + Send + Clone + Hash + Eq, + V: Sync + Send + Eq, + S: BuildHasher, +{ +} + +impl Index<&'_ Q> for HashMapRef<'_, K, V, S> +where + K: Sync + Send + Clone + Hash + Eq + Borrow, + Q: ?Sized + Hash + Eq, + V: Sync + Send, + S: BuildHasher, +{ + type Output = V; + + fn index(&self, key: &Q) -> &V { + self.get(key).expect("no entry found for key") + } +} diff --git a/tests/basic_ref.rs b/tests/basic_ref.rs new file mode 100644 index 00000000..933c877c --- /dev/null +++ b/tests/basic_ref.rs @@ -0,0 +1,442 @@ +use flurry::HashMap; +use std::sync::Arc; + +#[test] +fn insert() { + let map = HashMap::::new(); + let map = map.pin(); + + let old = map.insert(42, 0); + assert!(old.is_none()); +} + +#[test] +fn get_empty() { + let map = HashMap::::new(); + + { + let map = map.pin(); + let e = map.get(&42); + assert!(e.is_none()); + } +} + +#[test] +fn get_key_value_empty() { + let map = HashMap::::new(); + + { + let map = map.pin(); + let e = map.get_key_value(&42); + assert!(e.is_none()); + } +} + +#[test] +fn remove_empty() { + let map = HashMap::::new(); + + { + let map = map.pin(); + let old = map.remove(&42); + assert!(old.is_none()); + } +} + +#[test] +fn insert_and_remove() { + let map = HashMap::::new(); + + { + let map = map.pin(); + map.insert(42, 0); + let old = map.remove(&42).unwrap(); + assert_eq!(old, &0); + assert!(map.get(&42).is_none()); + } +} + +#[test] +fn insert_and_get() { + let map = HashMap::::new(); + + map.pin().insert(42, 0); + { + let map = map.pin(); + let e = map.get(&42).unwrap(); + assert_eq!(e, &0); + } +} + +#[test] +fn insert_and_get_key_value() { + let map = HashMap::::new(); + + map.pin().insert(42, 0); + { + let map = map.pin(); + let e = map.get_key_value(&42).unwrap(); + assert_eq!(e, (&42, &0)); + } +} + +#[test] +fn update() { + let map = HashMap::::new(); + + let map1 = map.pin(); + map1.insert(42, 0); + let old = map1.insert(42, 1); + assert_eq!(old, Some(&0)); + { + let map2 = map.pin(); + let e = map2.get(&42).unwrap(); + assert_eq!(e, &1); + } +} + +#[test] +fn compute_if_present() { + let map = HashMap::::new(); + + let map1 = map.pin(); + map1.insert(42, 0); + let new = map1.compute_if_present(&42, |_, v| Some(v + 1)); + assert_eq!(new, Some(&1)); + { + let map2 = map.pin(); + let e = map2.get(&42).unwrap(); + assert_eq!(e, &1); + } +} + +#[test] +fn compute_if_present_empty() { + let map = HashMap::::new(); + + let map1 = map.pin(); + let new = map1.compute_if_present(&42, |_, v| Some(v + 1)); + assert!(new.is_none()); + { + let map2 = map.pin(); + assert!(map2.get(&42).is_none()); + } +} + +#[test] +fn compute_if_present_remove() { + let map = HashMap::::new(); + + let map1 = map.pin(); + map1.insert(42, 0); + let new = map1.compute_if_present(&42, |_, _| None); + assert!(new.is_none()); + { + let map2 = map.pin(); + assert!(map2.get(&42).is_none()); + } +} + +#[test] +fn concurrent_insert() { + let map = Arc::new(HashMap::::new()); + + let map1 = map.clone(); + let t1 = std::thread::spawn(move || { + for i in 0..64 { + map1.pin().insert(i, 0); + } + }); + let map2 = map.clone(); + let t2 = std::thread::spawn(move || { + for i in 0..64 { + map2.pin().insert(i, 0); + } + }); + + t1.join().unwrap(); + t2.join().unwrap(); + + let map = map.pin(); + for i in 0..64 { + let v = map.get(&i).unwrap(); + assert!(v == &0 || v == &1); + + let kv = map.get_key_value(&i).unwrap(); + assert!(kv == (&i, &0) || kv == (&i, &1)); + } +} + +#[test] +fn concurrent_remove() { + let map = Arc::new(HashMap::::new()); + + { + let map = map.pin(); + for i in 0..64 { + map.insert(i, i); + } + } + + let map1 = map.clone(); + let t1 = std::thread::spawn(move || { + let map1 = map1.pin(); + for i in 0..64 { + if let Some(v) = map1.remove(&i) { + assert_eq!(v, &i); + } + } + }); + let map2 = map.clone(); + let t2 = std::thread::spawn(move || { + let map2 = map2.pin(); + for i in 0..64 { + if let Some(v) = map2.remove(&i) { + assert_eq!(v, &i); + } + } + }); + + t1.join().unwrap(); + t2.join().unwrap(); + + // after joining the threads, the map should be empty + let map = map.pin(); + for i in 0..64 { + assert!(map.get(&i).is_none()); + } +} + +#[test] +fn concurrent_compute_if_present() { + let map = Arc::new(HashMap::::new()); + + { + let map = map.pin(); + for i in 0..64 { + map.insert(i, i); + } + } + + let map1 = map.clone(); + let t1 = std::thread::spawn(move || { + let map1 = map1.pin(); + for i in 0..64 { + let new = map1.compute_if_present(&i, |_, _| None); + assert!(new.is_none()); + } + }); + let map2 = map.clone(); + let t2 = std::thread::spawn(move || { + let map2 = map2.pin(); + for i in 0..64 { + let new = map2.compute_if_present(&i, |_, _| None); + assert!(new.is_none()); + } + }); + + t1.join().unwrap(); + t2.join().unwrap(); + + // after joining the threads, the map should be empty + let map = map.pin(); + for i in 0..64 { + assert!(map.get(&i).is_none()); + } +} + +#[test] +fn current_kv_dropped() { + let dropped1 = Arc::new(0); + let dropped2 = Arc::new(0); + + let map = HashMap::, Arc>::new(); + + map.pin().insert(dropped1.clone(), dropped2.clone()); + assert_eq!(Arc::strong_count(&dropped1), 2); + assert_eq!(Arc::strong_count(&dropped2), 2); + + drop(map); + + // dropping the map should immediately drop (not deferred) all keys and values + assert_eq!(Arc::strong_count(&dropped1), 1); + assert_eq!(Arc::strong_count(&dropped2), 1); +} + +#[test] +fn empty_maps_equal() { + let map1 = HashMap::::new(); + let map2 = HashMap::::new(); + assert_eq!(map1, map2.pin()); + assert_eq!(map1.pin(), map2); + assert_eq!(map1.pin(), map2.pin()); + assert_eq!(map2.pin(), map1.pin()); +} + +#[test] +fn different_size_maps_not_equal() { + let map1 = HashMap::::new(); + let map2 = HashMap::::new(); + { + let map1 = map1.pin(); + let map2 = map2.pin(); + map1.insert(1, 0); + map1.insert(2, 0); + map2.insert(1, 0); + } + + assert_ne!(map1, map2.pin()); + assert_ne!(map1.pin(), map2); + assert_ne!(map1.pin(), map2.pin()); + assert_ne!(map2.pin(), map1.pin()); +} + +#[test] +fn same_values_equal() { + let map1 = HashMap::::new(); + let map2 = HashMap::::new(); + { + let map1 = map1.pin(); + let map2 = map2.pin(); + map1.insert(1, 0); + map2.insert(1, 0); + } + + assert_eq!(map1, map2.pin()); + assert_eq!(map1.pin(), map2); + assert_eq!(map1.pin(), map2.pin()); + assert_eq!(map2.pin(), map1.pin()); +} + +#[test] +fn different_values_not_equal() { + let map1 = HashMap::::new(); + let map2 = HashMap::::new(); + { + let map1 = map1.pin(); + let map2 = map2.pin(); + map1.insert(1, 0); + map2.insert(1, 1); + } + + assert_ne!(map1, map2.pin()); + assert_ne!(map1.pin(), map2); + assert_ne!(map1.pin(), map2.pin()); + assert_ne!(map2.pin(), map1.pin()); +} + +#[test] +#[ignore] +// ignored because we cannot control when destructors run +fn drop_value() { + let dropped1 = Arc::new(0); + let dropped2 = Arc::new(1); + + let map = HashMap::>::new(); + + map.pin().insert(42, dropped1.clone()); + assert_eq!(Arc::strong_count(&dropped1), 2); + assert_eq!(Arc::strong_count(&dropped2), 1); + + map.pin().insert(42, dropped2.clone()); + assert_eq!(Arc::strong_count(&dropped2), 2); + + drop(map); + + // First NotifyOnDrop was dropped when it was replaced by the second + assert_eq!(Arc::strong_count(&dropped1), 1); + // Second NotifyOnDrop was dropped when the map was dropped + assert_eq!(Arc::strong_count(&dropped2), 1); +} + +#[test] +fn clone_map_empty() { + let map = HashMap::<&'static str, u32>::new(); + let map = map.pin(); + let cloned_map = map.clone(); // another ref to the same map + assert_eq!(map.len(), cloned_map.len()); + assert_eq!(map, cloned_map); + assert_eq!(cloned_map.len(), 0); +} + +#[test] +// Test that same values exists in both refs (original and cloned) +fn clone_map_filled() { + let map = HashMap::<&'static str, u32>::new(); + let map = map.pin(); + map.insert("FooKey", 0); + map.insert("BarKey", 10); + let cloned_map = map.clone(); // another ref to the same map + assert_eq!(map.len(), cloned_map.len()); + assert_eq!(map, cloned_map); + + // test that we are mapping the same tables + map.insert("NewItem", 100); + assert_eq!(map, cloned_map); +} + +#[test] +fn debug() { + let map: HashMap = HashMap::new(); + let map = map.pin(); + + map.insert(42, 0); + map.insert(16, 8); + + let formatted = format!("{:?}", map); + + assert!(formatted == "{42: 0, 16: 8}" || formatted == "{16: 8, 42: 0}"); +} + +#[test] +fn retain_empty() { + let map = HashMap::<&'static str, u32>::new(); + let map = map.pin(); + map.retain(|_, _| false); + assert_eq!(map.len(), 0); +} + +#[test] +fn retain_all_false() { + let map: HashMap = (0..10 as u32).map(|x| (x, x)).collect(); + let map = map.pin(); + map.retain(|_, _| false); + assert_eq!(map.len(), 0); +} + +#[test] +fn retain_all_true() { + let size = 10usize; + let map: HashMap = (0..size).map(|x| (x, x)).collect(); + let map = map.pin(); + map.retain(|_, _| true); + assert_eq!(map.len(), size); +} + +#[test] +fn retain_some() { + let map: HashMap = (0..10).map(|x| (x, x)).collect(); + let map = map.pin(); + let expected_map: HashMap = (5..10).map(|x| (x, x)).collect(); + map.retain(|_, v| *v >= 5); + assert_eq!(map.len(), 5); + assert_eq!(map, expected_map); +} + +#[test] +fn retain_force_empty() { + let map = HashMap::<&'static str, u32>::new(); + let map = map.pin(); + map.retain_force(|_, _| false); + assert_eq!(map.len(), 0); +} + +#[test] +fn retain_force_some() { + let map: HashMap = (0..10).map(|x| (x, x)).collect(); + let map = map.pin(); + let expected_map: HashMap = (5..10).map(|x| (x, x)).collect(); + map.retain_force(|_, v| *v >= 5); + assert_eq!(map.len(), 5); + assert_eq!(map, expected_map); +}