From c42867cc0adeb0aa3099cd2416252f0fb7c845c6 Mon Sep 17 00:00:00 2001 From: Josh Stone Date: Fri, 24 Jul 2020 12:53:26 -0700 Subject: [PATCH] Expand drain to accept any RangeBounds Rather than just `RangeFull`, this lets us accept any `Range*` type or even arbitrary `Bound` pairs, matching the signature of `Vec::drain`. --- src/map.rs | 22 ++++++++++++--- src/map/core.rs | 12 ++++++--- src/map/core/raw.rs | 65 +++++++++++++++++++++++++++++++++++++++++++++ src/set.rs | 22 ++++++++++++--- src/util.rs | 27 +++++++++++++++++++ tests/quick.rs | 30 ++++++++++++++++++++- 6 files changed, 165 insertions(+), 13 deletions(-) diff --git a/src/map.rs b/src/map.rs index 6c2c466d..bba4af98 100644 --- a/src/map.rs +++ b/src/map.rs @@ -13,7 +13,7 @@ use ::core::cmp::Ordering; use ::core::fmt; use ::core::hash::{BuildHasher, Hash, Hasher}; use ::core::iter::FromIterator; -use ::core::ops::{Index, IndexMut, RangeFull}; +use ::core::ops::{Index, IndexMut, RangeBounds}; use ::core::slice::{Iter as SliceIter, IterMut as SliceIterMut}; #[cfg(has_std)] @@ -252,9 +252,23 @@ impl IndexMap { self.core.clear(); } - /// Clears the `IndexMap`, returning all key-value pairs as a drain iterator. - /// Keeps the allocated memory for reuse. - pub fn drain(&mut self, range: RangeFull) -> Drain<'_, K, V> { + /// Clears the `IndexMap` in the given index range, returning those + /// key-value pairs as a drain iterator. + /// + /// The range may be any type that implements `RangeBounds`, + /// including all of the `std::ops::Range*` types, or even a tuple pair of + /// `Bound` start and end values. To drain the map entirely, use `RangeFull` + /// like `map.drain(..)`. + /// + /// This shifts down all entries following the drained range to fill the + /// gap, and keeps the allocated memory for reuse. + /// + /// ***Panics*** if the starting point is greater than the end point or if + /// the end point is greater than the length of the map. + pub fn drain(&mut self, range: R) -> Drain<'_, K, V> + where + R: RangeBounds, + { Drain { iter: self.core.drain(range), } diff --git a/src/map/core.rs b/src/map/core.rs index 81f41125..7adb69a1 100644 --- a/src/map/core.rs +++ b/src/map/core.rs @@ -15,10 +15,10 @@ use crate::vec::{Drain, Vec}; use core::cmp; use core::fmt; use core::mem::replace; -use core::ops::RangeFull; +use core::ops::RangeBounds; use crate::equivalent::Equivalent; -use crate::util::enumerate; +use crate::util::{enumerate, simplify_range}; use crate::{Bucket, Entries, HashValue}; /// Core of the map that does not depend on S @@ -129,8 +129,12 @@ impl IndexMapCore { self.entries.clear(); } - pub(crate) fn drain(&mut self, range: RangeFull) -> Drain<'_, Bucket> { - self.indices.clear(); + pub(crate) fn drain(&mut self, range: R) -> Drain<'_, Bucket> + where + R: RangeBounds, + { + let range = simplify_range(range, self.entries.len()); + self.erase_indices(range.start, range.end); self.entries.drain(range) } diff --git a/src/map/core/raw.rs b/src/map/core/raw.rs index fa075f01..fdfadb91 100644 --- a/src/map/core/raw.rs +++ b/src/map/core/raw.rs @@ -3,6 +3,7 @@ //! mostly in dealing with its bucket "pointers". use super::{Entry, Equivalent, HashValue, IndexMapCore, VacantEntry}; +use crate::util::enumerate; use core::fmt; use core::mem::replace; use hashbrown::raw::RawTable; @@ -44,11 +45,75 @@ impl IndexMapCore { } } + /// Erase the given index from `indices`. + /// + /// The index doesn't need to be valid in `entries` while calling this. No other index + /// adjustments are made -- this is only used by `pop` for the greatest index. pub(super) fn erase_index(&mut self, hash: HashValue, index: usize) { + debug_assert_eq!(index, self.indices.len() - 1); let raw_bucket = self.find_index(hash, index).unwrap(); unsafe { self.indices.erase(raw_bucket) }; } + /// Erase `start..end` from `indices`, and shift `end..` indices down to `start..` + /// + /// All of these items should still be at their original location in `entries`. + /// This is used by `drain`, which will let `Vec::drain` do the work on `entries`. + pub(super) fn erase_indices(&mut self, start: usize, end: usize) { + let (init, shifted_entries) = self.entries.split_at(end); + let (start_entries, erased_entries) = init.split_at(start); + + let erased = erased_entries.len(); + let shifted = shifted_entries.len(); + let half_capacity = self.indices.buckets() / 2; + + // Use a heuristic between different strategies + if erased == 0 { + // Degenerate case, nothing to do + } else if start + shifted < half_capacity && start < erased { + // Reinsert everything, as there are few kept indices + self.indices.clear(); + + // Reinsert stable indices + for (i, entry) in enumerate(start_entries) { + self.indices.insert_no_grow(entry.hash.get(), i); + } + + // Reinsert shifted indices + for (i, entry) in (start..).zip(shifted_entries) { + self.indices.insert_no_grow(entry.hash.get(), i); + } + } else if erased + shifted < half_capacity { + // Find each affected index, as there are few to adjust + + // Find erased indices + for (i, entry) in (start..).zip(erased_entries) { + let bucket = self.find_index(entry.hash, i).unwrap(); + unsafe { self.indices.erase(bucket) }; + } + + // Find shifted indices + for ((new, old), entry) in (start..).zip(end..).zip(shifted_entries) { + let bucket = self.find_index(entry.hash, old).unwrap(); + unsafe { bucket.write(new) }; + } + } else { + // Sweep the whole table for adjustments + unsafe { + for bucket in self.indices.iter() { + let i = bucket.read(); + if i >= end { + bucket.write(i - erased); + } else if i >= start { + self.indices.erase(bucket); + } + } + } + } + + debug_assert_eq!(self.indices.len(), start + shifted); + } + pub(crate) fn entry(&mut self, hash: HashValue, key: K) -> Entry<'_, K, V> where K: Eq, diff --git a/src/set.rs b/src/set.rs index c51c8daf..4560caa9 100644 --- a/src/set.rs +++ b/src/set.rs @@ -11,7 +11,7 @@ use core::cmp::Ordering; use core::fmt; use core::hash::{BuildHasher, Hash}; use core::iter::{Chain, FromIterator}; -use core::ops::{BitAnd, BitOr, BitXor, RangeFull, Sub}; +use core::ops::{BitAnd, BitOr, BitXor, RangeBounds, Sub}; use core::slice; use super::{Entries, Equivalent, IndexMap}; @@ -200,9 +200,23 @@ impl IndexSet { self.map.clear(); } - /// Clears the `IndexSet`, returning all values as a drain iterator. - /// Keeps the allocated memory for reuse. - pub fn drain(&mut self, range: RangeFull) -> Drain<'_, T> { + /// Clears the `IndexSet` in the given index range, returning those values + /// as a drain iterator. + /// + /// The range may be any type that implements `RangeBounds`, + /// including all of the `std::ops::Range*` types, or even a tuple pair of + /// `Bound` start and end values. To drain the set entirely, use `RangeFull` + /// like `set.drain(..)`. + /// + /// This shifts down all entries following the drained range to fill the + /// gap, and keeps the allocated memory for reuse. + /// + /// ***Panics*** if the starting point is greater than the end point or if + /// the end point is greater than the length of the set. + pub fn drain(&mut self, range: R) -> Drain<'_, T> + where + R: RangeBounds, + { Drain { iter: self.map.drain(range).iter, } diff --git a/src/util.rs b/src/util.rs index e7bb0e1a..5388f470 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,4 +1,5 @@ use core::iter::Enumerate; +use core::ops::{Bound, Range, RangeBounds}; pub(crate) fn third(t: (A, B, C)) -> C { t.2 @@ -10,3 +11,29 @@ where { iterable.into_iter().enumerate() } + +pub(crate) fn simplify_range(range: R, len: usize) -> Range +where + R: RangeBounds, +{ + let start = match range.start_bound() { + Bound::Unbounded => 0, + Bound::Included(&i) if i <= len => i, + Bound::Excluded(&i) if i < len => i + 1, + bound => panic!("range start {:?} should be <= length {}", bound, len), + }; + let end = match range.end_bound() { + Bound::Unbounded => len, + Bound::Excluded(&i) if i <= len => i, + Bound::Included(&i) if i < len => i + 1, + bound => panic!("range end {:?} should be <= length {}", bound, len), + }; + if start > end { + panic!( + "range start {:?} should be <= range end {:?}", + range.start_bound(), + range.end_bound() + ); + } + start..end +} diff --git a/tests/quick.rs b/tests/quick.rs index 4f8f9e58..f17fff86 100644 --- a/tests/quick.rs +++ b/tests/quick.rs @@ -4,6 +4,7 @@ use itertools::Itertools; use quickcheck::quickcheck; use quickcheck::Arbitrary; use quickcheck::Gen; +use quickcheck::TestResult; use rand::Rng; @@ -18,6 +19,7 @@ use std::collections::HashSet; use std::fmt::Debug; use std::hash::Hash; use std::iter::FromIterator; +use std::ops::Bound; use std::ops::Deref; use indexmap::map::Entry as OEntry; @@ -100,7 +102,7 @@ quickcheck! { map.capacity() >= cap } - fn drain(insert: Vec) -> bool { + fn drain_full(insert: Vec) -> bool { let mut map = IndexMap::new(); for &key in &insert { map.insert(key, ()); @@ -113,6 +115,32 @@ quickcheck! { map.is_empty() } + fn drain_bounds(insert: Vec, range: (Bound, Bound)) -> TestResult { + let mut map = IndexMap::new(); + for &key in &insert { + map.insert(key, ()); + } + + // First see if `Vec::drain` is happy with this range. + let result = std::panic::catch_unwind(|| { + let mut keys: Vec = map.keys().cloned().collect(); + keys.drain(range); + keys + }); + + if let Ok(keys) = result { + map.drain(range); + // Check that our `drain` matches the same key order. + assert!(map.keys().eq(&keys)); + // Check that hash lookups all work too. + assert!(keys.iter().all(|key| map.contains_key(key))); + TestResult::passed() + } else { + // If `Vec::drain` panicked, so should we. + TestResult::must_fail(move || { map.drain(range); }) + } + } + fn shift_remove(insert: Vec, remove: Vec) -> bool { let mut map = IndexMap::new(); for &key in &insert {