diff --git a/src/static_aabb2d_index.rs b/src/static_aabb2d_index.rs index 1bf0eaf..7fdbbac 100644 --- a/src/static_aabb2d_index.rs +++ b/src/static_aabb2d_index.rs @@ -1,10 +1,14 @@ +use fmt::Debug; use num_traits::{Bounded, Num, NumCast}; -use std::cmp::{max, min}; use std::fmt; +use std::{ + cmp::{max, min}, + collections::BinaryHeap, +}; /// Trait used by the [StaticAABB2DIndex] that is required to be implemented for type T. /// It is blanket implemented for all primitive numeric types. -pub trait IndexableNum: Copy + Num + PartialOrd + Default + Bounded + NumCast { +pub trait IndexableNum: Debug + Copy + Num + PartialOrd + Default + Bounded + NumCast { /// Simple default min implementation for [PartialOrd] types. #[inline] fn min(self, other: Self) -> Self { @@ -863,6 +867,64 @@ where } } +/// Type alias for priority queue used for nearest neighbor searches. +/// +/// See: [StaticAABB2DIndex::visit_neighbors_with_queue]. +pub type NeighborPriorityQueue = BinaryHeap>; + +/// Holds state for priority queue used in nearest neighbors query. +/// +/// Note this type is public for use in passing in an existing priority queue but +/// all fields and constructor are private for internal use only. +#[derive(Debug, Copy, Clone, PartialEq)] +pub struct NeighborsState +where + T: IndexableNum, +{ + index: usize, + is_leaf_node: bool, + dist: T, +} + +impl NeighborsState +where + T: IndexableNum, +{ + fn new(index: usize, is_leaf_node: bool, dist: T) -> Self { + NeighborsState { + index, + is_leaf_node, + dist, + } + } +} + +impl Eq for NeighborsState where T: IndexableNum {} + +impl Ord for NeighborsState +where + T: IndexableNum, +{ + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + if let Some(ord) = self.partial_cmp(other) { + ord + } else { + // if ordering not possible (due to NAN) then just consider equal + std::cmp::Ordering::Equal + } + } +} + +impl PartialOrd for NeighborsState +where + T: IndexableNum, +{ + fn partial_cmp(&self, other: &Self) -> Option { + // flip ordering (compare other to self rather than self to other) to prioritize minimum dist in priority queue + other.dist.partial_cmp(&self.dist) + } +} + impl StaticAABB2DIndex where T: IndexableNum, @@ -1079,4 +1141,97 @@ where } } } + + /// Visit all neighboring items in order of minimum euclidean distance to the point defined by `x` and `y` until `visitor` returns false. + /// + /// ## Notes + /// * The visitor function must return false to stop visiting items or all items will be visited. + /// * The visitor function receives the index of the item being visited and the squared euclidean distance to that item from the point given. + /// * Because distances are squared (`dx * dx + dy * dy`) be cautious of smaller numeric types overflowing (e.g. it's easy to overflow an i32 with squared distances). + /// * If the point is inside of an item's bounding box then the euclidean distance is defined as 0. + /// * If repeatedly calling this method then [StaticAABB2DIndex::visit_neighbors_with_queue] can be used to avoid repeated allocations for the priority queue used internally. + #[inline] + pub fn visit_neighbors(&self, x: T, y: T, visitor: &mut F) + where + F: FnMut(usize, T) -> bool, + { + let mut queue = NeighborPriorityQueue::with_capacity(self.boxes.len()); + self.visit_neighbors_with_queue(x, y, visitor, &mut queue); + } + + /// Works the same as [StaticAABB2DIndex::visit_neighbors] but accepts an existing binary heap to be used as a priority queue to avoid allocations. + pub fn visit_neighbors_with_queue( + &self, + x: T, + y: T, + visitor: &mut F, + queue: &mut NeighborPriorityQueue, + ) where + F: FnMut(usize, T) -> bool, + { + // small helper function to compute axis distance between point and bounding box axis + fn axis_dist(k: U, min: U, max: U) -> U + where + U: IndexableNum, + { + if k < min { + min - k + } else if k > max { + k - max + } else { + U::zero() + } + } + + let mut node_index = self.boxes.len() - 1; + queue.clear(); + + 'search_loop: loop { + let upper_bound_level_index = match self.level_bounds.binary_search(&node_index) { + // level bound found, add one to get upper bound + Ok(i) => i + 1, + // level bound not found (node_index is between bounds, do not need to add one to get upper bound) + Err(i) => i, + }; + + // end index of the node + let end = min( + node_index + self.node_size, + self.level_bounds[upper_bound_level_index], + ); + + // add nodes to queue + for pos in node_index..end { + let aabb = get_at_index!(self.boxes, pos); + let dx = axis_dist(x, aabb.min_x, aabb.max_x); + let dy = axis_dist(y, aabb.min_y, aabb.max_y); + let dist = dx * dx + dy * dy; + let index = *get_at_index!(self.indices, pos); + let is_leaf_node = node_index < self.num_items; + queue.push(NeighborsState::new(index, is_leaf_node, dist)); + } + + let mut continue_search = false; + // pop and visit items in queue + while let Some(state) = queue.pop() { + if state.is_leaf_node { + // visit leaf node + if !visitor(state.index, state.dist) { + // stop visiting if visitor returns false + break 'search_loop; + } + } else { + // update node index for next iteration + node_index = state.index; + // set flag to continue search + continue_search = true; + break; + } + } + + if !continue_search { + break 'search_loop; + } + } + } } diff --git a/tests/test.rs b/tests/test.rs index 003cbba..0450917 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -360,3 +360,75 @@ fn visit_query_stops_early() { let expected_superset_indexes: HashSet = [6, 29, 31, 75].iter().cloned().collect(); assert!(results.is_subset(&expected_superset_indexes)); } + +#[test] +fn visit_neighbors_max_results() { + let index = create_test_index(); + let mut results = Vec::new(); + let max_results = 3; + let mut visitor = |i, _| { + results.push(i); + results.len() < max_results + }; + + index.visit_neighbors(50, 50, &mut visitor); + results.sort(); + let expected_indexes = vec![6, 31, 75]; + assert_eq!(results, expected_indexes); +} + +#[test] +fn visit_neighbors_max_distance() { + let index = create_test_index(); + let mut results = Vec::new(); + let max_distance = 12.0; + let max_distance_squared = max_distance * max_distance; + let mut visitor = |i, d| { + if (d as f64) < max_distance_squared { + results.push(i); + return true; + } + false + }; + + index.visit_neighbors(50, 50, &mut visitor); + results.sort(); + let expected_indexes = vec![6, 29, 31, 75, 85]; + assert_eq!(results, expected_indexes); +} + +#[test] +fn visit_neighbors_max_results_filtered() { + let index = create_test_index(); + let mut results = Vec::new(); + let max_results = 6; + let mut visitor = |i, _| { + // filtering by only collecting indexes which are even + if i % 2 == 0 { + results.push(i); + return results.len() < max_results; + } + true + }; + + index.visit_neighbors(50, 50, &mut visitor); + results.sort(); + let expected_indexes = vec![6, 16, 18, 24, 54, 80]; + assert_eq!(results, expected_indexes); +} + +#[test] +fn visit_neighbors_all_items() { + let index = create_test_index(); + let mut results = Vec::new(); + let mut visitor = |i, _| { + results.push(i); + // visit all items by always returning true + true + }; + + index.visit_neighbors(50, 50, &mut visitor); + results.sort(); + let expected_indexes = (0..index.count()).collect::>(); + assert_eq!(results, expected_indexes); +}