Skip to content

Commit

Permalink
MaxHeap now uses Vec
Browse files Browse the repository at this point in the history
Being generic over storage at the algorithm level is a
neater strategy (see rust-itertools#654)
  • Loading branch information
ejmount committed Oct 23, 2022
1 parent 887e946 commit b195cbc
Showing 1 changed file with 36 additions and 59 deletions.
95 changes: 36 additions & 59 deletions src/k_smallest.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
use alloc::vec::IntoIter;
use core::cmp::{Ord, Ordering, Reverse};
use core::mem::{replace, transmute, MaybeUninit};
use core::ops::Range;
use core::mem::replace;

fn k_smallest_dynamic<T, I: Iterator<Item = T>>(
/// Consumes a given iterator, leaving the minimum elements in the provided storage in **ascending** order.
fn k_smallest_general<T, I: Iterator<Item = T>>(
iter: I,
k: usize,
order: impl Fn(&T, &T) -> Ordering,
) -> IntoIter<T> {
let mut storage = Vec::new();
storage.resize_with(k, MaybeUninit::uninit);
let Range { end, .. } = capped_heapsort(iter, &mut storage, order);
storage.truncate(end);
let initialized: Vec<_> = unsafe { transmute(storage) };
initialized.into_iter()
if k == 0 {
return Vec::new().into_iter();
}
let heap = MaxHeap::from_iter(k, order, iter);
heap.unwrap_sorted().into_iter()
}

pub(crate) fn reverse_cmp<T, F>(cmp: F) -> impl Fn(&T, &T) -> Ordering
Expand All @@ -35,15 +34,15 @@ where
T: Ord,
I: Iterator<Item = T>,
{
k_smallest_dynamic(iter, k, T::cmp)
k_smallest_general(iter, k, T::cmp)
}

pub(crate) fn k_smallest_by<T, I, F>(iter: I, k: usize, cmp: F) -> IntoIter<T>
where
I: Iterator<Item = T>,
F: Fn(&T, &T) -> Ordering,
{
k_smallest_dynamic(iter, k, cmp)
k_smallest_general(iter, k, cmp)
}

pub(crate) fn k_smallest_by_key<T, I, F, K>(iter: I, k: usize, key: F) -> IntoIter<T>
Expand All @@ -54,75 +53,54 @@ where
{
let iter = iter.map(|v| Pair(key(&v), v));

let results: Vec<_> = k_smallest_dynamic(iter, k, Ord::cmp)
let results: Vec<_> = k_smallest_general(iter, k, Ord::cmp)
.map(|Pair(_, t)| t)
.collect();
results.into_iter()
}

/// Consumes a given iterator, leaving the minimum elements in the provided storage in **ascending** order.
/// Returns the range of initialized elements
fn capped_heapsort<T, I: Iterator<Item = T>>(
iter: I,
storage: &mut [MaybeUninit<T>],
order: impl Fn(&T, &T) -> Ordering,
) -> Range<usize> {
if storage.is_empty() {
return 0..0;
}
let mut heap = MaxHeap::from_iter(storage, order, iter);

let valid_elements = 0..heap.len;
while heap.len > 0 {
heap.pop();
}
valid_elements
}

/// An efficient heapsort requires that the heap ordering is the inverse of the desired sort order
/// So the basic case of retrieving minimum elements requires a max heap
///
/// This type does not attempt to reproduce all the functionality of [std::collections::BinaryHeap] and instead only implements what is needed for iter operations,
/// e.g. we do not need to insert single elements.
/// Additionally, some minor optimizations used in the std BinaryHeap are not used here, e.g. elements are actually swapped rather than managing a "hole"
///
/// To be generic over the underlying storage, it takes a mutable reference to avoid having to define a storage trait.
struct MaxHeap<'a, T, C> {
struct MaxHeap<T, C> {
// It may be not be possible to shrink the storage for smaller sequencess
// so manually manage the initialization
// This is **assumed not to be empty**
storage: &'a mut [MaybeUninit<T>],
storage: alloc::vec::Vec<T>,
comparator: C,
// SAFETY: this must always be less or equal to the count of actually initialized elements
// this is always less or equal to the count of actual elements
// allowing it to be less means the heap property can cover only a subset of the vec
// while reusing the storage
len: usize,
}

impl<'a, T, C> MaxHeap<'a, T, C>
impl<T, C> MaxHeap<T, C>
where
C: Fn(&T, &T) -> Ordering,
{
fn from_iter<I>(storage: &'a mut [MaybeUninit<T>], comparator: C, mut iter: I) -> Self
fn from_iter<I>(k: usize, comparator: C, mut iter: I) -> Self
where
I: Iterator<Item = T>,
{
let storage: Vec<T> = iter.by_ref().take(k).collect();

let mut heap = Self {
len: storage.len(),
storage,
comparator,
len: 0,
};
for (i, initial_item) in iter.by_ref().take(heap.storage.len()).enumerate() {
heap.storage[i] = MaybeUninit::new(initial_item);
heap.len += 1;
}
// Filling up the storage and only afterwards rearranging to form a valid heap is slightly more efficient
// (But only by a factor of lg(k) and I'd love to hear of a usecase where that matters)
heap.heapify();

if heap.len == heap.storage.len() {
if k == heap.storage.len() {
// Nothing else needs done if we didn't fill the storage in the first place
// Also avoids unexpected behaviour with restartable iterators
for val in iter {
let _ = heap.push_pop(val);
heap.push_pop(val);
}
}
heap
Expand All @@ -133,11 +111,7 @@ where
/// element ordering.
fn get(&self, index: usize) -> Option<&T> {
if index < self.len {
let ptr = unsafe {
// There might be a better way to do this but assume_init_ref doesn't exist on MSRV
self.storage[index].as_ptr().as_ref().unwrap()
};
Some(ptr)
self.storage.get(index)
} else {
None
}
Expand Down Expand Up @@ -168,7 +142,6 @@ where
let (original_item, replacement_item) = (self.get(origin), self.get(replacement_idx));

let cmp = self.compare(original_item, replacement_item);
// If the left item also doesn't exist, this comparison will fall through
if Some(Ordering::Less) == cmp {
self.storage.swap(origin, replacement_idx);
self.sift_down(replacement_idx);
Expand All @@ -188,17 +161,13 @@ where

/// Insert the given element into the heap without changing its size
/// The displaced element is returned, i.e. either the input or previous max
fn push_pop(&mut self, val: T) -> Option<T> {
fn push_pop(&mut self, val: T) -> T {
if self.compare(self.get(0), Some(&val)) == Some(Ordering::Greater) {
let out = replace(&mut self.storage[0], MaybeUninit::new(val));
let out = replace(&mut self.storage[0], val);
self.sift_down(0);
// SAFETY: This has been moved out of storage[0]
// storage[0] will be uninitialized if and only if self.len == 0
// In that case, self.get(0) above will return None, and the comparison will fall through to None
// So to get here, self.len > 0 and therefore this element was initialized
unsafe { Some(out.assume_init()) }
out
} else {
Some(val)
val
}
}

Expand All @@ -213,6 +182,14 @@ where
fn compare(&self, a: Option<&T>, b: Option<&T>) -> Option<Ordering> {
(self.comparator)(a?, b?).into()
}

// Totally orders the elements and returns the raw storage
fn unwrap_sorted(mut self) -> Vec<T> {
while self.len > 1 {
self.pop();
}
self.storage
}
}

struct Pair<K, T>(K, T);
Expand Down

0 comments on commit b195cbc

Please sign in to comment.