Skip to content

Commit

Permalink
Added method to visit nearest neighboring items to a point
Browse files Browse the repository at this point in the history
- Added visit_neighbors and visit_neighbors_with_queue methods
- Added tests for visit_neighbors
- Added fmt::Debug trait bound to IndexableNum
- Added NeighborPriorityQueue type alias
- Added NeighborState struct for use with BinaryHeap for
  NeighborPriorityQueue
  • Loading branch information
jbuckmccready committed Feb 16, 2021
1 parent fd2204e commit f2ea39f
Show file tree
Hide file tree
Showing 2 changed files with 229 additions and 2 deletions.
159 changes: 157 additions & 2 deletions src/static_aabb2d_index.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
use fmt::Debug;
use num_traits::{Bounded, Num, NumCast};
use std::cmp::{max, min};
use std::fmt;
use std::{
cmp::{max, min},
collections::BinaryHeap,
};

/// Trait used by the [StaticAABB2DIndex] that is required to be implemented for type T.
/// It is blanket implemented for all primitive numeric types.
pub trait IndexableNum: Copy + Num + PartialOrd + Default + Bounded + NumCast {
pub trait IndexableNum: Debug + Copy + Num + PartialOrd + Default + Bounded + NumCast {
/// Simple default min implementation for [PartialOrd] types.
#[inline]
fn min(self, other: Self) -> Self {
Expand Down Expand Up @@ -863,6 +867,64 @@ where
}
}

/// Type alias for priority queue used for nearest neighbor searches.
///
/// See: [StaticAABB2DIndex::visit_neighbors_with_queue].
pub type NeighborPriorityQueue<T> = BinaryHeap<NeighborsState<T>>;

/// Holds state for priority queue used in nearest neighbors query.
///
/// Note this type is public for use in passing in an existing priority queue but
/// all fields and constructor are private for internal use only.
#[derive(Debug, Copy, Clone, PartialEq)]
pub struct NeighborsState<T>
where
T: IndexableNum,
{
index: usize,
is_leaf_node: bool,
dist: T,
}

impl<T> NeighborsState<T>
where
T: IndexableNum,
{
fn new(index: usize, is_leaf_node: bool, dist: T) -> Self {
NeighborsState {
index,
is_leaf_node,
dist,
}
}
}

impl<T> Eq for NeighborsState<T> where T: IndexableNum {}

impl<T> Ord for NeighborsState<T>
where
T: IndexableNum,
{
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
if let Some(ord) = self.partial_cmp(other) {
ord
} else {
// if ordering not possible (due to NAN) then just consider equal
std::cmp::Ordering::Equal
}
}
}

impl<T> PartialOrd for NeighborsState<T>
where
T: IndexableNum,
{
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
// flip ordering (compare other to self rather than self to other) to prioritize minimum dist in priority queue
other.dist.partial_cmp(&self.dist)
}
}

impl<T> StaticAABB2DIndex<T>
where
T: IndexableNum,
Expand Down Expand Up @@ -1079,4 +1141,97 @@ where
}
}
}

/// Visit all neighboring items in order of minimum euclidean distance to the point defined by `x` and `y` until `visitor` returns false.
///
/// ## Notes
/// * The visitor function must return false to stop visiting items or all items will be visited.
/// * The visitor function receives the index of the item being visited and the squared euclidean distance to that item from the point given.
/// * Because distances are squared (`dx * dx + dy * dy`) be cautious of smaller numeric types overflowing (e.g. it's easy to overflow an i32 with squared distances).
/// * If the point is inside of an item's bounding box then the euclidean distance is defined as 0.
/// * If repeatedly calling this method then [StaticAABB2DIndex::visit_neighbors_with_queue] can be used to avoid repeated allocations for the priority queue used internally.
#[inline]
pub fn visit_neighbors<F>(&self, x: T, y: T, visitor: &mut F)
where
F: FnMut(usize, T) -> bool,
{
let mut queue = NeighborPriorityQueue::with_capacity(self.boxes.len());
self.visit_neighbors_with_queue(x, y, visitor, &mut queue);
}

/// Works the same as [StaticAABB2DIndex::visit_neighbors] but accepts an existing binary heap to be used as a priority queue to avoid allocations.
pub fn visit_neighbors_with_queue<F>(
&self,
x: T,
y: T,
visitor: &mut F,
queue: &mut NeighborPriorityQueue<T>,
) where
F: FnMut(usize, T) -> bool,
{
// small helper function to compute axis distance between point and bounding box axis
fn axis_dist<U>(k: U, min: U, max: U) -> U
where
U: IndexableNum,
{
if k < min {
min - k
} else if k > max {
k - max
} else {
U::zero()
}
}

let mut node_index = self.boxes.len() - 1;
queue.clear();

'search_loop: loop {
let upper_bound_level_index = match self.level_bounds.binary_search(&node_index) {
// level bound found, add one to get upper bound
Ok(i) => i + 1,
// level bound not found (node_index is between bounds, do not need to add one to get upper bound)
Err(i) => i,
};

// end index of the node
let end = min(
node_index + self.node_size,
self.level_bounds[upper_bound_level_index],
);

// add nodes to queue
for pos in node_index..end {
let aabb = get_at_index!(self.boxes, pos);
let dx = axis_dist(x, aabb.min_x, aabb.max_x);
let dy = axis_dist(y, aabb.min_y, aabb.max_y);
let dist = dx * dx + dy * dy;
let index = *get_at_index!(self.indices, pos);
let is_leaf_node = node_index < self.num_items;
queue.push(NeighborsState::new(index, is_leaf_node, dist));
}

let mut continue_search = false;
// pop and visit items in queue
while let Some(state) = queue.pop() {
if state.is_leaf_node {
// visit leaf node
if !visitor(state.index, state.dist) {
// stop visiting if visitor returns false
break 'search_loop;
}
} else {
// update node index for next iteration
node_index = state.index;
// set flag to continue search
continue_search = true;
break;
}
}

if !continue_search {
break 'search_loop;
}
}
}
}
72 changes: 72 additions & 0 deletions tests/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -360,3 +360,75 @@ fn visit_query_stops_early() {
let expected_superset_indexes: HashSet<usize> = [6, 29, 31, 75].iter().cloned().collect();
assert!(results.is_subset(&expected_superset_indexes));
}

#[test]
fn visit_neighbors_max_results() {
let index = create_test_index();
let mut results = Vec::new();
let max_results = 3;
let mut visitor = |i, _| {
results.push(i);
results.len() < max_results
};

index.visit_neighbors(50, 50, &mut visitor);
results.sort();
let expected_indexes = vec![6, 31, 75];
assert_eq!(results, expected_indexes);
}

#[test]
fn visit_neighbors_max_distance() {
let index = create_test_index();
let mut results = Vec::new();
let max_distance = 12.0;
let max_distance_squared = max_distance * max_distance;
let mut visitor = |i, d| {
if (d as f64) < max_distance_squared {
results.push(i);
return true;
}
false
};

index.visit_neighbors(50, 50, &mut visitor);
results.sort();
let expected_indexes = vec![6, 29, 31, 75, 85];
assert_eq!(results, expected_indexes);
}

#[test]
fn visit_neighbors_max_results_filtered() {
let index = create_test_index();
let mut results = Vec::new();
let max_results = 6;
let mut visitor = |i, _| {
// filtering by only collecting indexes which are even
if i % 2 == 0 {
results.push(i);
return results.len() < max_results;
}
true
};

index.visit_neighbors(50, 50, &mut visitor);
results.sort();
let expected_indexes = vec![6, 16, 18, 24, 54, 80];
assert_eq!(results, expected_indexes);
}

#[test]
fn visit_neighbors_all_items() {
let index = create_test_index();
let mut results = Vec::new();
let mut visitor = |i, _| {
results.push(i);
// visit all items by always returning true
true
};

index.visit_neighbors(50, 50, &mut visitor);
results.sort();
let expected_indexes = (0..index.count()).collect::<Vec<_>>();
assert_eq!(results, expected_indexes);
}

0 comments on commit f2ea39f

Please sign in to comment.