Skip to content

Commit

Permalink
Bugfixes, add launder methods and cupoint from array
Browse files Browse the repository at this point in the history
Fixed cupoint sqdist could underflow subtraction, fixed k_closest keeping ties, fixed find_r with identical coordinates
  • Loading branch information
hacatu committed Mar 5, 2024
1 parent 95e9bfc commit cecf561
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 11 deletions.
12 changes: 11 additions & 1 deletion src/cuboid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ where T: Ord + Clone + NumRef {
fn sqdist(&self, other: &Self) -> Self::Distance {
let mut a = T::zero();
for i in 0..N {
let d = self.buf[i].clone() - &other.buf[i];
let (x, y) = (&self.buf[i], &other.buf[i]);
// compute absolute difference between x and y in a really annoying way because generic
// math is annoying and we can't just call x.abs_diff(y)
let d = if x > y { x.clone() - y } else { y.clone() - x };
a = a + d.clone()*&d;
}
a
Expand Down Expand Up @@ -85,6 +88,13 @@ where T: Ord + Clone + NumRef {
}
}

impl<T, const N: usize> From<[T; N]> for CuPoint<T, N>
where T: Ord + Clone + NumRef {
fn from(buf: [T; N]) -> Self {
Self{buf}
}
}

impl<T, const N: usize> KdRegion for CuRegion<T, N>
where T: Ord + Clone + NumRef {
type Point = CuPoint<T, N>;
Expand Down
33 changes: 23 additions & 10 deletions src/kdtree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,14 +194,17 @@ impl<R: KdRegion, V> KdTree<R, V> {
match point.sqdist(pt).cmp(max_sqdist.as_ref().unwrap()) {
Ordering::Greater => (),
Ordering::Equal => tied_points.push((pt, v)),
Ordering::Less => if res.len() + 1 == k {
tied_points.clear();
tied_points.push(res.pushpop_max_by((pt, v), cmp_fn));
max_sqdist = Some(point.sqdist(tied_points[0].0));
while res.peek_max_by(cmp_fn).is_some_and(
|&(p, _)|point.sqdist(p) == *max_sqdist.as_ref().unwrap()
) {
tied_points.push(res.pop_max_by(cmp_fn).unwrap())
Ordering::Less => {
res.push_by((pt, v), cmp_fn);
if res.len() >= k {
tied_points.clear();
tied_points.push(res.pop_max_by(cmp_fn).unwrap());
max_sqdist = Some(point.sqdist(tied_points[0].0));
while res.peek_max_by(cmp_fn).is_some_and(
|&(p, _)|point.sqdist(p) == *max_sqdist.as_ref().unwrap()
) {
tied_points.push(res.pop_max_by(cmp_fn).unwrap())
}
}
}
}
Expand Down Expand Up @@ -286,12 +289,22 @@ impl<R: KdRegion, V> KdTree<R, V> {
).offset_from(self.points.as_ptr()) as usize
}

/// Convert an internal index into a reference to a point in the tree.
/// The internal index must have come from `launder_point_ref` or `launder_value_ref`
/// called on the same tree.
/// The intent of this function is to allow finding the points corresponding to values
/// given a value reference, like for example if some of the values are made into
/// an intrusive linked data structure.
pub unsafe fn launder_idx_point(&self, idx: usize) -> &R::Point {
&self.points[idx].0
}

/// Convert an internal index into a mutable reference to a value in the tree.
/// The internal index must have come from `launder_point_ref` or `launder_value_ref`
/// called on the same tree.
/// The intent of this function is to allow mutating the values of the points in the
/// result set of `k_closest` etc.
pub unsafe fn launder_idx(&mut self, idx: usize) -> &mut V {
pub unsafe fn launder_idx_mut(&mut self, idx: usize) -> &mut V {
&mut self.points[idx].1
}

Expand All @@ -308,7 +321,7 @@ impl<R: KdRegion, V> KdTree<R, V> {
Ordering::Greater => a = mid_idx + 1,
Ordering::Equal => {
if point == p { return mid_idx }
a = self.find_r(point, a, mid_idx, layer + a);
a = self.find_r(point, a, mid_idx, layer + 1);
if a != self.len() { return a }
a = mid_idx + 1
}
Expand Down

0 comments on commit cecf561

Please sign in to comment.