diff --git a/pytorch3d/csrc/ball_query/ball_query.cu b/pytorch3d/csrc/ball_query/ball_query.cu index 586701c18..b5115a3ae 100644 --- a/pytorch3d/csrc/ball_query/ball_query.cu +++ b/pytorch3d/csrc/ball_query/ball_query.cu @@ -32,7 +32,9 @@ __global__ void BallQueryKernel( at::PackedTensorAccessor64 idxs, at::PackedTensorAccessor64 dists, const int64_t K, - const float radius2) { + const float radius, + const float radius2, + const bool skip_points_outside_cube) { const int64_t N = p1.size(0); const int64_t chunks_per_cloud = (1 + (p1.size(1) - 1) / blockDim.x); const int64_t chunks_to_do = N * chunks_per_cloud; @@ -51,7 +53,19 @@ __global__ void BallQueryKernel( // Iterate over points in p2 until desired count is reached or // all points have been considered for (int64_t j = 0, count = 0; j < lengths2[n] && count < K; ++j) { - // Calculate the distance between the points + if (skip_points_outside_cube) { + bool is_within_radius = true; + // Filter when any one coordinate is already outside the radius + for (int d = 0; is_within_radius && d < D; ++d) { + scalar_t abs_diff = fabs(p1[n][i][d] - p2[n][j][d]); + is_within_radius = (abs_diff < radius); + } + if (!is_within_radius) { + continue; + } + } + + // Else, calculate the distance between the points and compare scalar_t dist2 = 0.0; for (int d = 0; d < D; ++d) { scalar_t diff = p1[n][i][d] - p2[n][j][d]; @@ -77,7 +91,8 @@ std::tuple BallQueryCuda( const at::Tensor& lengths1, // (N,) const at::Tensor& lengths2, // (N,) int K, - float radius) { + float radius, + bool skip_points_outside_cube) { // Check inputs are on the same device at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2}, lengths1_t{lengths1, "lengths1", 3}, lengths2_t{lengths2, "lengths2", 4}; @@ -120,7 +135,9 @@ std::tuple BallQueryCuda( idxs.packed_accessor64(), dists.packed_accessor64(), K_64, - radius2); + radius, + radius2, + skip_points_outside_cube); })); AT_CUDA_CHECK(cudaGetLastError()); diff --git a/pytorch3d/csrc/ball_query/ball_query.h b/pytorch3d/csrc/ball_query/ball_query.h index eb8f54da2..dc7a7851d 100644 --- a/pytorch3d/csrc/ball_query/ball_query.h +++ b/pytorch3d/csrc/ball_query/ball_query.h @@ -25,6 +25,9 @@ // within the radius // radius: the radius around each point within which the neighbors need to be // located +// skip_points_outside_cube: If true, reduce multiplications of float values +// by not explicitly calculating distances to points that fall outside the +// D-cube with side length (2*radius) centered at each point in p1. // // Returns: // p1_neighbor_idx: LongTensor of shape (N, P1, K), where @@ -46,7 +49,8 @@ std::tuple BallQueryCpu( const at::Tensor& lengths1, const at::Tensor& lengths2, const int K, - const float radius); + const float radius, + const bool skip_points_outside_cube); // CUDA implementation std::tuple BallQueryCuda( @@ -55,7 +59,8 @@ std::tuple BallQueryCuda( const at::Tensor& lengths1, const at::Tensor& lengths2, const int K, - const float radius); + const float radius, + const bool skip_points_outside_cube); // Implementation which is exposed // Note: the backward pass reuses the KNearestNeighborBackward kernel @@ -65,7 +70,8 @@ inline std::tuple BallQuery( const at::Tensor& lengths1, const at::Tensor& lengths2, int K, - float radius) { + float radius, + bool skip_points_outside_cube) { if (p1.is_cuda() || p2.is_cuda()) { #ifdef WITH_CUDA CHECK_CUDA(p1); @@ -76,7 +82,8 @@ inline std::tuple BallQuery( lengths1.contiguous(), lengths2.contiguous(), K, - radius); + radius, + skip_points_outside_cube); #else AT_ERROR("Not compiled with GPU support."); #endif @@ -89,5 +96,6 @@ inline std::tuple BallQuery( lengths1.contiguous(), lengths2.contiguous(), K, - radius); + radius, + skip_points_outside_cube); } diff --git a/pytorch3d/csrc/ball_query/ball_query_cpu.cpp b/pytorch3d/csrc/ball_query/ball_query_cpu.cpp index 24cdf388f..0e2d4b799 100644 --- a/pytorch3d/csrc/ball_query/ball_query_cpu.cpp +++ b/pytorch3d/csrc/ball_query/ball_query_cpu.cpp @@ -7,6 +7,7 @@ */ #include +#include #include std::tuple BallQueryCpu( @@ -15,7 +16,8 @@ std::tuple BallQueryCpu( const at::Tensor& lengths1, const at::Tensor& lengths2, int K, - float radius) { + float radius, + bool skip_points_outside_cube) { const int N = p1.size(0); const int P1 = p1.size(1); const int D = p1.size(2); @@ -37,6 +39,16 @@ std::tuple BallQueryCpu( const int64_t length2 = lengths2_a[n]; for (int64_t i = 0; i < length1; ++i) { for (int64_t j = 0, count = 0; j < length2 && count < K; ++j) { + if (skip_points_outside_cube) { + bool is_within_radius = true; + for (int d = 0; is_within_radius && d < D; ++d) { + float abs_diff = fabs(p1_a[n][i][d] - p2_a[n][j][d]); + is_within_radius = (abs_diff < radius); + } + if (!is_within_radius) { + continue; + } + } float dist2 = 0; for (int d = 0; d < D; ++d) { float diff = p1_a[n][i][d] - p2_a[n][j][d]; diff --git a/pytorch3d/ops/ball_query.py b/pytorch3d/ops/ball_query.py index 31266c4d2..a597db5c9 100644 --- a/pytorch3d/ops/ball_query.py +++ b/pytorch3d/ops/ball_query.py @@ -23,11 +23,12 @@ class _ball_query(Function): """ @staticmethod - def forward(ctx, p1, p2, lengths1, lengths2, K, radius): + def forward(ctx, p1, p2, lengths1, lengths2, K, radius, skip_points_outside_cube): """ Arguments defintions the same as in the ball_query function """ - idx, dists = _C.ball_query(p1, p2, lengths1, lengths2, K, radius) + idx, dists = _C.ball_query(p1, p2, lengths1, lengths2, K, radius, + skip_points_outside_cube) ctx.save_for_backward(p1, p2, lengths1, lengths2, idx) ctx.mark_non_differentiable(idx) return dists, idx @@ -49,7 +50,7 @@ def backward(ctx, grad_dists, grad_idx): grad_p1, grad_p2 = _C.knn_points_backward( p1, p2, lengths1, lengths2, idx, 2, grad_dists ) - return grad_p1, grad_p2, None, None, None, None + return grad_p1, grad_p2, None, None, None, None, None def ball_query( @@ -60,6 +61,7 @@ def ball_query( K: int = 500, radius: float = 0.2, return_nn: bool = True, + skip_points_outside_cube: bool = False, ): """ Ball Query is an alternative to KNN. It can be @@ -98,6 +100,9 @@ def ball_query( within the radius radius: the radius around each point within which the neighbors need to be located return_nn: If set to True returns the K neighbor points in p2 for each point in p1. + skip_points_outside_cube: If set to True, reduce multiplications of float values + by not explicitly calculating distances to points that fall outside the + D-cube with side length (2*radius) centered at each point in p1. Returns: dists: Tensor of shape (N, P1, K) giving the squared distances to @@ -134,7 +139,8 @@ def ball_query( if lengths2 is None: lengths2 = torch.full((N,), P2, dtype=torch.int64, device=p1.device) - dists, idx = _ball_query.apply(p1, p2, lengths1, lengths2, K, radius) + dists, idx = _ball_query.apply(p1, p2, lengths1, lengths2, K, radius, + skip_points_outside_cube) # Gather the neighbors if needed points_nn = masked_gather(p2, idx) if return_nn else None diff --git a/tests/benchmarks/bm_ball_query_large.py b/tests/benchmarks/bm_ball_query_large.py new file mode 100644 index 000000000..dd1c185b8 --- /dev/null +++ b/tests/benchmarks/bm_ball_query_large.py @@ -0,0 +1,53 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from itertools import product + +import torch +from fvcore.common.benchmark import benchmark + +from pytorch3d.ops.ball_query import ball_query + + +def ball_query_square( + N: int, P1: int, P2: int, D: int, K: int, radius: float, device: str +): + device = torch.device(device) + pts1 = torch.rand(N, P1, D, device=device) + pts2 = torch.rand(N, P2, D, device=device) + torch.cuda.synchronize() + + def output(): + out = ball_query(pts1, pts2, K=K, radius=radius, skip_points_outside_cube=True) + torch.cuda.synchronize() + + return output + + +def bm_ball_query() -> None: + backends = ["cpu", "cuda:0"] + + kwargs_list = [] + Ns = [32] + P1s = [256] + P2s = [2**p for p in range(9, 20, 2)] + Ds = [3, 10] + Ks = [500] + Rs = [0.01, 0.1] + test_cases = product(Ns, P1s, P2s, Ds, Ks, Rs, backends) + for case in test_cases: + N, P1, P2, D, K, R, b = case + kwargs_list.append( + {"N": N, "P1": P1, "P2": P2, "D": D, "K": K, "radius": R, "device": b} + ) + benchmark( + ball_query_square, "BALLQUERY_SQUARE", kwargs_list, + num_iters=30, warmup_iters=1, + ) + + +if __name__ == "__main__": + bm_ball_query()