diff --git a/src/k_smallest.rs b/src/k_smallest.rs index ab7f19c86..ca99a41c7 100644 --- a/src/k_smallest.rs +++ b/src/k_smallest.rs @@ -1,19 +1,18 @@ use alloc::vec::IntoIter; use core::cmp::{Ord, Ordering, Reverse}; -use core::mem::{replace, transmute, MaybeUninit}; -use core::ops::Range; +use core::mem::replace; -fn k_smallest_dynamic>( +/// Consumes a given iterator, leaving the minimum elements in the provided storage in **ascending** order. +fn k_smallest_general>( iter: I, k: usize, order: impl Fn(&T, &T) -> Ordering, ) -> IntoIter { - let mut storage = Vec::new(); - storage.resize_with(k, MaybeUninit::uninit); - let Range { end, .. } = capped_heapsort(iter, &mut storage, order); - storage.truncate(end); - let initialized: Vec<_> = unsafe { transmute(storage) }; - initialized.into_iter() + if k == 0 { + return Vec::new().into_iter(); + } + let heap = MaxHeap::from_iter(k, order, iter); + heap.unwrap_sorted().into_iter() } pub(crate) fn reverse_cmp(cmp: F) -> impl Fn(&T, &T) -> Ordering @@ -35,7 +34,7 @@ where T: Ord, I: Iterator, { - k_smallest_dynamic(iter, k, T::cmp) + k_smallest_general(iter, k, T::cmp) } pub(crate) fn k_smallest_by(iter: I, k: usize, cmp: F) -> IntoIter @@ -43,7 +42,7 @@ where I: Iterator, F: Fn(&T, &T) -> Ordering, { - k_smallest_dynamic(iter, k, cmp) + k_smallest_general(iter, k, cmp) } pub(crate) fn k_smallest_by_key(iter: I, k: usize, key: F) -> IntoIter @@ -54,75 +53,54 @@ where { let iter = iter.map(|v| Pair(key(&v), v)); - let results: Vec<_> = k_smallest_dynamic(iter, k, Ord::cmp) + let results: Vec<_> = k_smallest_general(iter, k, Ord::cmp) .map(|Pair(_, t)| t) .collect(); results.into_iter() } -/// Consumes a given iterator, leaving the minimum elements in the provided storage in **ascending** order. -/// Returns the range of initialized elements -fn capped_heapsort>( - iter: I, - storage: &mut [MaybeUninit], - order: impl Fn(&T, &T) -> Ordering, -) -> Range { - if storage.is_empty() { - return 0..0; - } - let mut heap = MaxHeap::from_iter(storage, order, iter); - - let valid_elements = 0..heap.len; - while heap.len > 0 { - heap.pop(); - } - valid_elements -} - /// An efficient heapsort requires that the heap ordering is the inverse of the desired sort order /// So the basic case of retrieving minimum elements requires a max heap /// /// This type does not attempt to reproduce all the functionality of [std::collections::BinaryHeap] and instead only implements what is needed for iter operations, /// e.g. we do not need to insert single elements. /// Additionally, some minor optimizations used in the std BinaryHeap are not used here, e.g. elements are actually swapped rather than managing a "hole" -/// -/// To be generic over the underlying storage, it takes a mutable reference to avoid having to define a storage trait. -struct MaxHeap<'a, T, C> { +struct MaxHeap { // It may be not be possible to shrink the storage for smaller sequencess // so manually manage the initialization // This is **assumed not to be empty** - storage: &'a mut [MaybeUninit], + storage: alloc::vec::Vec, comparator: C, - // SAFETY: this must always be less or equal to the count of actually initialized elements + // this is always less or equal to the count of actual elements + // allowing it to be less means the heap property can cover only a subset of the vec + // while reusing the storage len: usize, } -impl<'a, T, C> MaxHeap<'a, T, C> +impl MaxHeap where C: Fn(&T, &T) -> Ordering, { - fn from_iter(storage: &'a mut [MaybeUninit], comparator: C, mut iter: I) -> Self + fn from_iter(k: usize, comparator: C, mut iter: I) -> Self where I: Iterator, { + let storage: Vec = iter.by_ref().take(k).collect(); + let mut heap = Self { + len: storage.len(), storage, comparator, - len: 0, }; - for (i, initial_item) in iter.by_ref().take(heap.storage.len()).enumerate() { - heap.storage[i] = MaybeUninit::new(initial_item); - heap.len += 1; - } // Filling up the storage and only afterwards rearranging to form a valid heap is slightly more efficient // (But only by a factor of lg(k) and I'd love to hear of a usecase where that matters) heap.heapify(); - if heap.len == heap.storage.len() { + if k == heap.storage.len() { // Nothing else needs done if we didn't fill the storage in the first place // Also avoids unexpected behaviour with restartable iterators for val in iter { - let _ = heap.push_pop(val); + heap.push_pop(val); } } heap @@ -133,11 +111,7 @@ where /// element ordering. fn get(&self, index: usize) -> Option<&T> { if index < self.len { - let ptr = unsafe { - // There might be a better way to do this but assume_init_ref doesn't exist on MSRV - self.storage[index].as_ptr().as_ref().unwrap() - }; - Some(ptr) + self.storage.get(index) } else { None } @@ -168,7 +142,6 @@ where let (original_item, replacement_item) = (self.get(origin), self.get(replacement_idx)); let cmp = self.compare(original_item, replacement_item); - // If the left item also doesn't exist, this comparison will fall through if Some(Ordering::Less) == cmp { self.storage.swap(origin, replacement_idx); self.sift_down(replacement_idx); @@ -188,17 +161,13 @@ where /// Insert the given element into the heap without changing its size /// The displaced element is returned, i.e. either the input or previous max - fn push_pop(&mut self, val: T) -> Option { + fn push_pop(&mut self, val: T) -> T { if self.compare(self.get(0), Some(&val)) == Some(Ordering::Greater) { - let out = replace(&mut self.storage[0], MaybeUninit::new(val)); + let out = replace(&mut self.storage[0], val); self.sift_down(0); - // SAFETY: This has been moved out of storage[0] - // storage[0] will be uninitialized if and only if self.len == 0 - // In that case, self.get(0) above will return None, and the comparison will fall through to None - // So to get here, self.len > 0 and therefore this element was initialized - unsafe { Some(out.assume_init()) } + out } else { - Some(val) + val } } @@ -213,6 +182,14 @@ where fn compare(&self, a: Option<&T>, b: Option<&T>) -> Option { (self.comparator)(a?, b?).into() } + + // Totally orders the elements and returns the raw storage + fn unwrap_sorted(mut self) -> Vec { + while self.len > 1 { + self.pop(); + } + self.storage + } } struct Pair(K, T);