From 01b5f7b228378b6d12eaa78b86fb5215d6b4eec7 Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Tue, 7 Apr 2020 01:45:43 -0700 Subject: [PATCH] heterogenous KNN Summary: Interface and working implementation of ragged KNN. Benchmarks (which aren't ragged) haven't slowed. New benchmark shows that ragged is faster than non-ragged of the same shape. Reviewed By: jcjohnson Differential Revision: D20696507 fbshipit-source-id: 21b80f71343a3475c8d3ee0ce2680f92f0fae4de --- pytorch3d/csrc/knn/knn.cu | 115 ++++++++++++++++++--------- pytorch3d/csrc/knn/knn.h | 29 +++++-- pytorch3d/csrc/knn/knn_cpu.cpp | 18 +++-- pytorch3d/ops/knn.py | 139 ++++++++++++++++++++++++++------- tests/bm_knn.py | 49 +++++++++++- tests/test_knn.py | 66 ++++++++++++++-- 6 files changed, 332 insertions(+), 84 deletions(-) diff --git a/pytorch3d/csrc/knn/knn.cu b/pytorch3d/csrc/knn/knn.cu index a6d53951b..26201dda7 100644 --- a/pytorch3d/csrc/knn/knn.cu +++ b/pytorch3d/csrc/knn/knn.cu @@ -8,10 +8,20 @@ #include "dispatch.cuh" #include "mink.cuh" +// A chunk of work is blocksize-many points of P1. +// The number of potential chunks to do is N*(1+(P1-1)/blocksize) +// call (1+(P1-1)/blocksize) chunks_per_cloud +// These chunks are divided among the gridSize-many blocks. +// In block b, we work on chunks b, b+gridSize, b+2*gridSize etc . +// In chunk i, we work on cloud i/chunks_per_cloud on points starting from +// blocksize*(i%chunks_per_cloud). + template __global__ void KNearestNeighborKernelV0( const scalar_t* __restrict__ points1, const scalar_t* __restrict__ points2, + const int64_t* __restrict__ lengths1, + const int64_t* __restrict__ lengths2, scalar_t* __restrict__ dists, int64_t* __restrict__ idxs, const size_t N, @@ -19,18 +29,19 @@ __global__ void KNearestNeighborKernelV0( const size_t P2, const size_t D, const size_t K) { - // Stupid version: Make each thread handle one query point and loop over - // all P2 target points. There are N * P1 input points to handle, so - // do a trivial parallelization over threads. // Store both dists and indices for knn in global memory. - const int tid = threadIdx.x + blockIdx.x * blockDim.x; - const int num_threads = blockDim.x * gridDim.x; - for (int np = tid; np < N * P1; np += num_threads) { - int n = np / P1; - int p1 = np % P1; + const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x); + const int64_t chunks_to_do = N * chunks_per_cloud; + for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) { + const int64_t n = chunk / chunks_per_cloud; + const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud); + int64_t p1 = start_point + threadIdx.x; + if (p1 >= lengths1[n]) + continue; int offset = n * P1 * K + p1 * K; + int64_t length2 = lengths2[n]; MinK mink(dists + offset, idxs + offset, K); - for (int p2 = 0; p2 < P2; ++p2) { + for (int p2 = 0; p2 < length2; ++p2) { // Find the distance between points1[n, p1] and points[n, p2] scalar_t dist = 0; for (int d = 0; d < D; ++d) { @@ -48,6 +59,8 @@ template __global__ void KNearestNeighborKernelV1( const scalar_t* __restrict__ points1, const scalar_t* __restrict__ points2, + const int64_t* __restrict__ lengths1, + const int64_t* __restrict__ lengths2, scalar_t* __restrict__ dists, int64_t* __restrict__ idxs, const size_t N, @@ -58,18 +71,22 @@ __global__ void KNearestNeighborKernelV1( // so we can cache the current point in a thread-local array. We still store // the current best K dists and indices in global memory, so this should work // for very large K and fairly large D. - const int tid = threadIdx.x + blockIdx.x * blockDim.x; - const int num_threads = blockDim.x * gridDim.x; scalar_t cur_point[D]; - for (int np = tid; np < N * P1; np += num_threads) { - int n = np / P1; - int p1 = np % P1; + const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x); + const int64_t chunks_to_do = N * chunks_per_cloud; + for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) { + const int64_t n = chunk / chunks_per_cloud; + const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud); + int64_t p1 = start_point + threadIdx.x; + if (p1 >= lengths1[n]) + continue; for (int d = 0; d < D; ++d) { cur_point[d] = points1[n * P1 * D + p1 * D + d]; } int offset = n * P1 * K + p1 * K; + int64_t length2 = lengths2[n]; MinK mink(dists + offset, idxs + offset, K); - for (int p2 = 0; p2 < P2; ++p2) { + for (int p2 = 0; p2 < length2; ++p2) { // Find the distance between cur_point and points[n, p2] scalar_t dist = 0; for (int d = 0; d < D; ++d) { @@ -89,14 +106,16 @@ struct KNearestNeighborV1Functor { size_t threads, const scalar_t* __restrict__ points1, const scalar_t* __restrict__ points2, + const int64_t* __restrict__ lengths1, + const int64_t* __restrict__ lengths2, scalar_t* __restrict__ dists, int64_t* __restrict__ idxs, const size_t N, const size_t P1, const size_t P2, const size_t K) { - KNearestNeighborKernelV1 - <<>>(points1, points2, dists, idxs, N, P1, P2, K); + KNearestNeighborKernelV1<<>>( + points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, K); } }; @@ -104,25 +123,31 @@ template __global__ void KNearestNeighborKernelV2( const scalar_t* __restrict__ points1, const scalar_t* __restrict__ points2, + const int64_t* __restrict__ lengths1, + const int64_t* __restrict__ lengths2, scalar_t* __restrict__ dists, int64_t* __restrict__ idxs, const int64_t N, const int64_t P1, const int64_t P2) { // Same general implementation as V2, but also hoist K into a template arg. - const int tid = threadIdx.x + blockIdx.x * blockDim.x; - const int num_threads = blockDim.x * gridDim.x; scalar_t cur_point[D]; scalar_t min_dists[K]; int min_idxs[K]; - for (int np = tid; np < N * P1; np += num_threads) { - int n = np / P1; - int p1 = np % P1; + const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x); + const int64_t chunks_to_do = N * chunks_per_cloud; + for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) { + const int64_t n = chunk / chunks_per_cloud; + const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud); + int64_t p1 = start_point + threadIdx.x; + if (p1 >= lengths1[n]) + continue; for (int d = 0; d < D; ++d) { cur_point[d] = points1[n * P1 * D + p1 * D + d]; } + int64_t length2 = lengths2[n]; MinK mink(min_dists, min_idxs, K); - for (int p2 = 0; p2 < P2; ++p2) { + for (int p2 = 0; p2 < length2; ++p2) { scalar_t dist = 0; for (int d = 0; d < D; ++d) { int offset = n * P2 * D + p2 * D + d; @@ -146,13 +171,15 @@ struct KNearestNeighborKernelV2Functor { size_t threads, const scalar_t* __restrict__ points1, const scalar_t* __restrict__ points2, + const int64_t* __restrict__ lengths1, + const int64_t* __restrict__ lengths2, scalar_t* __restrict__ dists, int64_t* __restrict__ idxs, const int64_t N, const int64_t P1, const int64_t P2) { - KNearestNeighborKernelV2 - <<>>(points1, points2, dists, idxs, N, P1, P2); + KNearestNeighborKernelV2<<>>( + points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2); } }; @@ -160,6 +187,8 @@ template __global__ void KNearestNeighborKernelV3( const scalar_t* __restrict__ points1, const scalar_t* __restrict__ points2, + const int64_t* __restrict__ lengths1, + const int64_t* __restrict__ lengths2, scalar_t* __restrict__ dists, int64_t* __restrict__ idxs, const size_t N, @@ -169,19 +198,23 @@ __global__ void KNearestNeighborKernelV3( // Enabling sorting for this version leads to huge slowdowns; I suspect // that it forces min_dists into local memory rather than registers. // As a result this version is always unsorted. - const int tid = threadIdx.x + blockIdx.x * blockDim.x; - const int num_threads = blockDim.x * gridDim.x; scalar_t cur_point[D]; scalar_t min_dists[K]; int min_idxs[K]; - for (int np = tid; np < N * P1; np += num_threads) { - int n = np / P1; - int p1 = np % P1; + const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x); + const int64_t chunks_to_do = N * chunks_per_cloud; + for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) { + const int64_t n = chunk / chunks_per_cloud; + const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud); + int64_t p1 = start_point + threadIdx.x; + if (p1 >= lengths1[n]) + continue; for (int d = 0; d < D; ++d) { cur_point[d] = points1[n * P1 * D + p1 * D + d]; } + int64_t length2 = lengths2[n]; RegisterMinK mink(min_dists, min_idxs); - for (int p2 = 0; p2 < P2; ++p2) { + for (int p2 = 0; p2 < length2; ++p2) { scalar_t dist = 0; for (int d = 0; d < D; ++d) { int offset = n * P2 * D + p2 * D + d; @@ -205,13 +238,15 @@ struct KNearestNeighborKernelV3Functor { size_t threads, const scalar_t* __restrict__ points1, const scalar_t* __restrict__ points2, + const int64_t* __restrict__ lengths1, + const int64_t* __restrict__ lengths2, scalar_t* __restrict__ dists, int64_t* __restrict__ idxs, const size_t N, const size_t P1, const size_t P2) { - KNearestNeighborKernelV3 - <<>>(points1, points2, dists, idxs, N, P1, P2); + KNearestNeighborKernelV3<<>>( + points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2); } }; @@ -257,6 +292,8 @@ int ChooseVersion(const int64_t D, const int64_t K) { std::tuple KNearestNeighborIdxCuda( const at::Tensor& p1, const at::Tensor& p2, + const at::Tensor& lengths1, + const at::Tensor& lengths2, int K, int version) { const auto N = p1.size(0); @@ -267,8 +304,8 @@ std::tuple KNearestNeighborIdxCuda( AT_ASSERTM(p2.size(2) == D, "Point sets must have the same last dimension"); auto long_dtype = p1.options().dtype(at::kLong); - auto idxs = at::full({N, P1, K}, -1, long_dtype); - auto dists = at::full({N, P1, K}, -1, p1.options()); + auto idxs = at::zeros({N, P1, K}, long_dtype); + auto dists = at::zeros({N, P1, K}, p1.options()); if (version < 0) { version = ChooseVersion(D, K); @@ -294,6 +331,8 @@ std::tuple KNearestNeighborIdxCuda( <<>>( p1.data_ptr(), p2.data_ptr(), + lengths1.data_ptr(), + lengths2.data_ptr(), dists.data_ptr(), idxs.data_ptr(), N, @@ -314,6 +353,8 @@ std::tuple KNearestNeighborIdxCuda( threads, p1.data_ptr(), p2.data_ptr(), + lengths1.data_ptr(), + lengths2.data_ptr(), dists.data_ptr(), idxs.data_ptr(), N, @@ -336,6 +377,8 @@ std::tuple KNearestNeighborIdxCuda( threads, p1.data_ptr(), p2.data_ptr(), + lengths1.data_ptr(), + lengths2.data_ptr(), dists.data_ptr(), idxs.data_ptr(), N, @@ -357,6 +400,8 @@ std::tuple KNearestNeighborIdxCuda( threads, p1.data_ptr(), p2.data_ptr(), + lengths1.data_ptr(), + lengths2.data_ptr(), dists.data_ptr(), idxs.data_ptr(), N, diff --git a/pytorch3d/csrc/knn/knn.h b/pytorch3d/csrc/knn/knn.h index 65c3732b5..de30d2e1e 100644 --- a/pytorch3d/csrc/knn/knn.h +++ b/pytorch3d/csrc/knn/knn.h @@ -13,25 +13,38 @@ // containing P1 points of dimension D. // p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each // containing P2 points of dimension D. +// lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud. +// lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud. // K: int giving the number of nearest points to return. // sorted: bool telling whether to sort the K returned points by their // distance. // version: Integer telling which implementation to use. -// TODO(jcjohns): Document this more, or maybe remove it before landing. // // Returns: // p1_neighbor_idx: LongTensor of shape (N, P1, K), where -// p1_neighbor_idx[n, i, k] = j means that the kth nearest -// neighbor to p1[n, i] in the cloud p2[n] is p2[n, j]. +// p1_neighbor_idx[n, i, k] = j means that the kth nearest +// neighbor to p1[n, i] in the cloud p2[n] is p2[n, j]. +// It is padded with zeros so that it can be used easily in a later +// gather() operation. +// +// p1_neighbor_dists: FloatTensor of shape (N, P1, K) containing the squared +// distance from each point p1[n, p, :] to its K neighbors +// p2[n, p1_neighbor_idx[n, p, k], :]. // CPU implementation. -std::tuple -KNearestNeighborIdxCpu(const at::Tensor& p1, const at::Tensor& p2, int K); +std::tuple KNearestNeighborIdxCpu( + const at::Tensor& p1, + const at::Tensor& p2, + const at::Tensor& lengths1, + const at::Tensor& lengths2, + int K); // CUDA implementation std::tuple KNearestNeighborIdxCuda( const at::Tensor& p1, const at::Tensor& p2, + const at::Tensor& lengths1, + const at::Tensor& lengths2, int K, int version); @@ -39,16 +52,18 @@ std::tuple KNearestNeighborIdxCuda( std::tuple KNearestNeighborIdx( const at::Tensor& p1, const at::Tensor& p2, + const at::Tensor& lengths1, + const at::Tensor& lengths2, int K, int version) { if (p1.is_cuda() || p2.is_cuda()) { #ifdef WITH_CUDA CHECK_CONTIGUOUS_CUDA(p1); CHECK_CONTIGUOUS_CUDA(p2); - return KNearestNeighborIdxCuda(p1, p2, K, version); + return KNearestNeighborIdxCuda(p1, p2, lengths1, lengths2, K, version); #else AT_ERROR("Not compiled with GPU support."); #endif } - return KNearestNeighborIdxCpu(p1, p2, K); + return KNearestNeighborIdxCpu(p1, p2, lengths1, lengths2, K); } diff --git a/pytorch3d/csrc/knn/knn_cpu.cpp b/pytorch3d/csrc/knn/knn_cpu.cpp index a2a55d2c0..84d18a657 100644 --- a/pytorch3d/csrc/knn/knn_cpu.cpp +++ b/pytorch3d/csrc/knn/knn_cpu.cpp @@ -4,27 +4,35 @@ #include #include -std::tuple -KNearestNeighborIdxCpu(const at::Tensor& p1, const at::Tensor& p2, int K) { +std::tuple KNearestNeighborIdxCpu( + const at::Tensor& p1, + const at::Tensor& p2, + const at::Tensor& lengths1, + const at::Tensor& lengths2, + int K) { const int N = p1.size(0); const int P1 = p1.size(1); const int D = p1.size(2); const int P2 = p2.size(1); auto long_opts = p1.options().dtype(torch::kInt64); - torch::Tensor idxs = torch::full({N, P1, K}, -1, long_opts); + torch::Tensor idxs = torch::full({N, P1, K}, 0, long_opts); torch::Tensor dists = torch::full({N, P1, K}, 0, p1.options()); auto p1_a = p1.accessor(); auto p2_a = p2.accessor(); + auto lengths1_a = lengths1.accessor(); + auto lengths2_a = lengths2.accessor(); auto idxs_a = idxs.accessor(); auto dists_a = dists.accessor(); for (int n = 0; n < N; ++n) { - for (int i1 = 0; i1 < P1; ++i1) { + const int64_t length1 = lengths1_a[n]; + const int64_t length2 = lengths2_a[n]; + for (int64_t i1 = 0; i1 < length1; ++i1) { // Use a priority queue to store (distance, index) tuples. std::priority_queue> q; - for (int i2 = 0; i2 < P2; ++i2) { + for (int64_t i2 = 0; i2 < length2; ++i2) { float dist = 0; for (int d = 0; d < D; ++d) { float diff = p1_a[n][i1][d] - p2_a[n][i2][d]; diff --git a/pytorch3d/ops/knn.py b/pytorch3d/ops/knn.py index 2ec35992b..122336dfe 100644 --- a/pytorch3d/ops/knn.py +++ b/pytorch3d/ops/knn.py @@ -4,37 +4,80 @@ from pytorch3d import _C -def knn_points_idx(p1, p2, K, sorted=False, version=-1): +def knn_points_idx( + p1, + p2, + K: int, + lengths1=None, + lengths2=None, + sorted: bool = False, + version: int = -1, +): """ K-Nearest neighbors on point clouds. Args: p1: Tensor of shape (N, P1, D) giving a batch of point clouds, each - containing P1 points of dimension D. + containing up to P1 points of dimension D. p2: Tensor of shape (N, P2, D) giving a batch of point clouds, each - containing P2 points of dimension D. - K: Integer giving the number of nearest neighbors to return + containing up to P2 points of dimension D. + K: Integer giving the number of nearest neighbors to return. + lengths1: LongTensor of shape (N,) of values in the range [0, P1], giving the + length of each pointcloud in p1. Or None to indicate that every cloud has + length P1. + lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the + length of each pointcloud in p2. Or None to indicate that every cloud has + length P2. sorted: Whether to sort the resulting points. version: Which KNN implementation to use in the backend. If version=-1, - the correct implementation is selected based on the shapes of - the inputs. + the correct implementation is selected based on the shapes of the inputs. Returns: - idx: LongTensor of shape (N, P1, K) giving the indices of the - K nearest neighbors from points in p1 to points in p2. - Concretely, if idx[n, i, k] = j then p2[n, j] is one of the K - nearest neighbor to p1[n, i] in p2[n]. If sorted=True, then - p2[n, j] is the kth nearest neighbor to p1[n, i]. + p1_neighbor_idx: LongTensor of shape (N, P1, K) giving the indices of the + K nearest neighbors from points in p1 to points in p2. + Concretely, if idx[n, i, k] = j then p2[n, j] is one of the K nearest + neighbors to p1[n, i] in p2[n]. If sorted=True, then p2[n, j] is the kth + nearest neighbor to p1[n, i]. This is padded with zeros both where a cloud + in p2 has fewer than K points and where a cloud in p1 has fewer than P1 + points. + If you want an (N, P1, K, D) tensor of the actual points, you can get it + using + p2[:, :, None].expand(-1, -1, K, -1).gather(1, + x_idx[:, :, :, None].expand(-1, -1, -1, D) + ) + If K=1 and you want an (N, P1, D) tensor of the actual points, use + p2.gather(1, x_idx.expand(-1, -1, D)) + + p1_neighbor_dists: Tensor of shape (N, P1, K) giving the squared distances to + the nearest neighbors. This is padded with zeros both where a cloud in p2 + has fewer than K points and where a cloud in p1 has fewer than P1 points. + Warning: this is calculated outside of the autograd framework. """ - idx, dists = _C.knn_points_idx(p1, p2, K, version) + P1 = p1.shape[1] + P2 = p2.shape[1] + if lengths1 is None: + lengths1 = torch.full((p1.shape[0],), P1, dtype=torch.int64, device=p1.device) + if lengths2 is None: + lengths2 = torch.full((p1.shape[0],), P2, dtype=torch.int64, device=p1.device) + idx, dists = _C.knn_points_idx(p1, p2, lengths1, lengths2, K, version) if sorted: - dists, sort_idx = dists.sort(dim=2) + if lengths2.min() < K: + device = dists.device + mask1 = lengths2[:, None] <= torch.arange(K, device=device)[None] + # mask1 has shape [N, K], true where dists irrelevant + mask2 = mask1[:, None].expand(-1, P1, -1) + # mask2 has shape [N, P1, K], true where dists irrelevant + dists[mask2] = float("inf") + dists, sort_idx = dists.sort(dim=2) + dists[mask2] = 0 + else: + dists, sort_idx = dists.sort(dim=2) idx = idx.gather(2, sort_idx) return idx, dists @torch.no_grad() -def _knn_points_idx_naive(p1, p2, K, sorted=False) -> torch.Tensor: +def _knn_points_idx_naive(p1, p2, K: int, lengths1, lengths2) -> torch.Tensor: """ Naive PyTorch implementation of K-Nearest Neighbors. @@ -43,25 +86,67 @@ def _knn_points_idx_naive(p1, p2, K, sorted=False) -> torch.Tensor: Args: p1: Tensor of shape (N, P1, D) giving a batch of point clouds, each - containing P1 points of dimension D. + containing up to P1 points of dimension D. p2: Tensor of shape (N, P2, D) giving a batch of point clouds, each - containing P2 points of dimension D. - K: Integer giving the number of nearest neighbors to return - sorted: Whether to sort the resulting points. + containing up to P2 points of dimension D. + K: Integer giving the number of nearest neighbors to return. + lengths1: LongTensor of shape (N,) of values in the range [0, P1], giving the + length of each pointcloud in p1. Or None to indicate that every cloud has + length P1. + lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the + length of each pointcloud in p2. Or None to indicate that every cloud has + length P2. Returns: idx: LongTensor of shape (N, P1, K) giving the indices of the - K nearest neighbors from points in p1 to points in p2. - Concretely, if idx[n, i, k] = j then p2[n, j] is one of the K - nearest neighbor to p1[n, i] in p2[n]. If sorted=True, then - p2[n, j] is the kth nearest neighbor to p1[n, i]. - dists: Tensor of shape (N, P1, K) giving the distances to the nearest - neighbors. + K nearest neighbors from points in p1 to points in p2. + Concretely, if idx[n, i, k] = j then p2[n, j] is the kth nearest neighbor + to p1[n, i]. This is padded with zeros both where a cloud in p2 has fewer + than K points and where a cloud in p1 has fewer than P1 points. + dists: Tensor of shape (N, P1, K) giving the squared distances to the nearest + neighbors. This is padded with zeros both where a cloud in p2 has fewer than + K points and where a cloud in p1 has fewer than P1 points. """ N, P1, D = p1.shape _N, P2, _D = p2.shape + assert N == _N and D == _D - diffs = p1.view(N, P1, 1, D) - p2.view(N, 1, P2, D) + + if lengths1 is None: + lengths1 = torch.full((N,), P1, dtype=torch.int64, device=p1.device) + if lengths2 is None: + lengths2 = torch.full((N,), P2, dtype=torch.int64, device=p1.device) + + p1_copy = p1.clone() + p2_copy = p2.clone() + + # We pad the values with infinities so that the smallest differences are + # among actual points. + inf = float("inf") + p1_mask = torch.arange(P1, device=p1.device)[None] >= lengths1[:, None] + p1_copy[p1_mask] = inf + p2_copy[torch.arange(P2, device=p1.device)[None] >= lengths2[:, None]] = -inf + + # view is safe here: we are merely adding extra dimensions of length 1 + diffs = p1_copy.view(N, P1, 1, D) - p2_copy.view(N, 1, P2, D) dists2 = (diffs * diffs).sum(dim=3) - out = dists2.topk(K, dim=2, largest=False, sorted=sorted) - return out.indices, out.values + + # We always sort, because this works well with padding. + out = dists2.topk(min(K, P2), dim=2, largest=False, sorted=True) + + out_indices = out.indices + out_values = out.values + + if P2 < K: + # Need to add padding + pad_shape = (N, P1, K - P2) + out_indices = torch.cat([out_indices, out_indices.new_zeros(pad_shape)], 2) + out_values = torch.cat([out_values, out_values.new_zeros(pad_shape)], 2) + + K_mask = torch.arange(K, device=p1.device)[None] >= lengths2[:, None] + # Create a combined mask for where the points in p1 are padded + # or the corresponding p2 has fewer than K points. + p1_K_mask = p1_mask[:, :, None] | K_mask[:, None, :] + out_indices[p1_K_mask] = 0 + out_values[p1_K_mask] = 0 + return out_indices, out_values diff --git a/tests/bm_knn.py b/tests/bm_knn.py index d0041fcb8..38685811e 100644 --- a/tests/bm_knn.py +++ b/tests/bm_knn.py @@ -13,6 +13,7 @@ def bm_knn() -> None: benchmark_knn_cpu() benchmark_knn_cuda_vs_naive() benchmark_knn_cuda_versions() + benchmark_knn_cuda_versions_ragged() def benchmark_knn_cuda_versions() -> None: @@ -36,6 +37,25 @@ def benchmark_knn_cuda_versions() -> None: benchmark(nn_cuda_with_init, "NN_CUDA", nn_kwargs, warmup_iters=1) +def benchmark_knn_cuda_versions_ragged() -> None: + # Compare our different KNN implementations, + # and also compare against our existing 1-NN + Ns = [8] + Ps = [4096, 16384] + Ds = [3] + Ks = [1, 4, 16, 64] + versions = [0, 1, 2, 3] + knn_kwargs = [] + for N, P, D, K, version in product(Ns, Ps, Ds, Ks, versions): + if version == 2 and K > 32: + continue + if version == 3 and K > 4: + continue + knn_kwargs.append({"N": N, "D": D, "P": P, "K": K, "v": version}) + benchmark(knn_cuda_with_init, "KNN_CUDA_COMPARISON", knn_kwargs, warmup_iters=1) + benchmark(knn_cuda_ragged, "KNN_CUDA_RAGGED", knn_kwargs, warmup_iters=1) + + def benchmark_knn_cuda_vs_naive() -> None: # Compare against naive pytorch version of KNN Ns = [1, 2, 4] @@ -72,10 +92,27 @@ def knn_cuda_with_init(N, D, P, K, v=-1): device = torch.device("cuda:0") x = torch.randn(N, P, D, device=device) y = torch.randn(N, P, D, device=device) + lengths = torch.full((N,), P, dtype=torch.int64, device=device) + torch.cuda.synchronize() def knn(): - _C.knn_points_idx(x, y, K, v) + _C.knn_points_idx(x, y, lengths, lengths, K, v) + torch.cuda.synchronize() + + return knn + + +def knn_cuda_ragged(N, D, P, K, v=-1): + device = torch.device("cuda:0") + x = torch.randn(N, P, D, device=device) + y = torch.randn(N, P, D, device=device) + lengths1 = torch.randint(P, size=(N,), device=device, dtype=torch.int64) + lengths2 = torch.randint(P, size=(N,), device=device, dtype=torch.int64) + torch.cuda.synchronize() + + def knn(): + _C.knn_points_idx(x, y, lengths1, lengths2, K, v) torch.cuda.synchronize() return knn @@ -85,9 +122,10 @@ def knn_cpu_with_init(N, D, P, K): device = torch.device("cpu") x = torch.randn(N, P, D, device=device) y = torch.randn(N, P, D, device=device) + lengths = torch.full((N,), P, dtype=torch.int64, device=device) def knn(): - _C.knn_points_idx(x, y, K, 0) + _C.knn_points_idx(x, y, lengths, lengths, K, -1) return knn @@ -96,10 +134,12 @@ def knn_python_cuda_with_init(N, D, P, K): device = torch.device("cuda") x = torch.randn(N, P, D, device=device) y = torch.randn(N, P, D, device=device) + lengths = torch.full((N,), P, dtype=torch.int64, device=device) + torch.cuda.synchronize() def knn(): - _knn_points_idx_naive(x, y, K) + _knn_points_idx_naive(x, y, K=K, lengths1=lengths, lengths2=lengths) torch.cuda.synchronize() return knn @@ -109,9 +149,10 @@ def knn_python_cpu_with_init(N, D, P, K): device = torch.device("cpu") x = torch.randn(N, P, D, device=device) y = torch.randn(N, P, D, device=device) + lengths = torch.full((N,), P, dtype=torch.int64, device=device) def knn(): - _knn_points_idx_naive(x, y, K) + _knn_points_idx_naive(x, y, K=K, lengths1=lengths, lengths2=lengths) return knn diff --git a/tests/test_knn.py b/tests/test_knn.py index d7fda8687..5fe3698a5 100644 --- a/tests/test_knn.py +++ b/tests/test_knn.py @@ -8,6 +8,10 @@ class TestKNN(unittest.TestCase): + def setUp(self) -> None: + super().setUp() + torch.manual_seed(1) + def _check_knn_result(self, out1, out2, sorted): # When sorted=True, points should be sorted by distance and should # match between implementations. When sorted=False we we only want to @@ -26,7 +30,7 @@ def _check_knn_result(self, out1, out2, sorted): self.assertTrue(torch.all(idx1 == idx2)) self.assertTrue(torch.allclose(dist1, dist2)) - def test_knn_vs_python_cpu(self): + def test_knn_vs_python_cpu_square(self): """ Test CPU output vs PyTorch implementation """ device = torch.device("cpu") Ns = [1, 4] @@ -37,13 +41,19 @@ def test_knn_vs_python_cpu(self): sorts = [True, False] factors = [Ns, Ds, P1s, P2s, Ks, sorts] for N, D, P1, P2, K, sort in product(*factors): + lengths1 = torch.full((N,), P1, dtype=torch.int64, device=device) + lengths2 = torch.full((N,), P2, dtype=torch.int64, device=device) x = torch.randn(N, P1, D, device=device) y = torch.randn(N, P2, D, device=device) - out1 = _knn_points_idx_naive(x, y, K, sort) - out2 = knn_points_idx(x, y, K, sort) + out1 = _knn_points_idx_naive( + x, y, lengths1=lengths1, lengths2=lengths2, K=K + ) + out2 = knn_points_idx( + x, y, K=K, lengths1=lengths1, lengths2=lengths2, sorted=sort + ) self._check_knn_result(out1, out2, sort) - def test_knn_vs_python_cuda(self): + def test_knn_vs_python_cuda_square(self): """ Test CUDA output vs PyTorch implementation """ device = torch.device("cuda") Ns = [1, 4] @@ -57,9 +67,53 @@ def test_knn_vs_python_cuda(self): for N, D, P1, P2, K, sort in product(*factors): x = torch.randn(N, P1, D, device=device) y = torch.randn(N, P2, D, device=device) - out1 = _knn_points_idx_naive(x, y, K, sorted=sort) + out1 = _knn_points_idx_naive(x, y, lengths1=None, lengths2=None, K=K) + for version in versions: + if version == 3 and K > 4: + continue + out2 = knn_points_idx(x, y, K=K, sorted=sort, version=version) + self._check_knn_result(out1, out2, sort) + + def test_knn_vs_python_cpu_ragged(self): + device = torch.device("cpu") + lengths1 = torch.tensor([10, 100, 10, 100], device=device, dtype=torch.int64) + lengths2 = torch.tensor([10, 10, 100, 100], device=device, dtype=torch.int64) + N = 4 + D = 3 + Ks = [1, 9, 10, 11, 101] + sorts = [False, True] + factors = [Ks, sorts] + for K, sort in product(*factors): + x = torch.randn(N, lengths1.max(), D, device=device) + y = torch.randn(N, lengths2.max(), D, device=device) + out1 = _knn_points_idx_naive( + x, y, lengths1=lengths1, lengths2=lengths2, K=K + ) + out2 = knn_points_idx( + x, y, lengths1=lengths1, lengths2=lengths2, K=K, sorted=sort + ) + self._check_knn_result(out1, out2, sort) + + def test_knn_vs_python_cuda_ragged(self): + device = torch.device("cuda") + lengths1 = torch.tensor([10, 100, 10, 100], device=device, dtype=torch.int64) + lengths2 = torch.tensor([10, 10, 100, 100], device=device, dtype=torch.int64) + N = 4 + D = 3 + Ks = [1, 9, 10, 11, 101] + sorts = [True, False] + versions = [0, 1, 2, 3] + factors = [Ks, sorts] + for K, sort in product(*factors): + x = torch.randn(N, lengths1.max(), D, device=device) + y = torch.randn(N, lengths2.max(), D, device=device) + out1 = _knn_points_idx_naive( + x, y, lengths1=lengths1, lengths2=lengths2, K=K + ) for version in versions: if version == 3 and K > 4: continue - out2 = knn_points_idx(x, y, K, sort, version) + out2 = knn_points_idx( + x, y, lengths1=lengths1, lengths2=lengths2, K=K, sorted=sort + ) self._check_knn_result(out1, out2, sort)