Skip to content

Commit

Permalink
Ball Query
Browse files Browse the repository at this point in the history
Summary:
Implementation of ball query from PointNet++.  This function is similar to KNN (find the neighbors in p2 for all points in p1). These are the key differences:
-  It will return the **first** K neighbors within a specified radius as opposed to the **closest** K neighbors.
- As all the points in p2 do not need to be considered to find the closest K, the algorithm is much faster than KNN when p2 has a large number of points.
- The neighbors are not sorted
- Due to the radius threshold it is not guaranteed that there will be K neighbors even if there are more than K points in p2.
- The padding value for `idx` is -1 instead of 0.

# Note:
- Some of the code is very similar to KNN so it could be possible to modify the KNN forward kernels to support ball query.
- Some users might want to use kNN with ball query - for this we could provide a wrapper function around the current `knn_points` which enables applying the radius threshold afterwards as an alternative. This could be called `ball_query_knn`.

Reviewed By: jcjohnson

Differential Revision: D30261362

fbshipit-source-id: 66b6a7e0114beff7164daf7eba21546ff41ec450
  • Loading branch information
nikhilaravi authored and facebook-github-bot committed Aug 12, 2021
1 parent e5c58a8 commit 103da63
Show file tree
Hide file tree
Showing 10 changed files with 709 additions and 1 deletion.
130 changes: 130 additions & 0 deletions pytorch3d/csrc/ball_query/ball_query.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include "utils/pytorch3d_cutils.h"

// A chunk of work is blocksize-many points of P1.
// The number of potential chunks to do is N*(1+(P1-1)/blocksize)
// call (1+(P1-1)/blocksize) chunks_per_cloud
// These chunks are divided among the gridSize-many blocks.
// In block b, we work on chunks b, b+gridSize, b+2*gridSize etc .
// In chunk i, we work on cloud i/chunks_per_cloud on points starting from
// blocksize*(i%chunks_per_cloud).

template <typename scalar_t>
__global__ void BallQueryKernel(
const at::PackedTensorAccessor64<scalar_t, 3, at::RestrictPtrTraits> p1,
const at::PackedTensorAccessor64<scalar_t, 3, at::RestrictPtrTraits> p2,
const at::PackedTensorAccessor64<int64_t, 1, at::RestrictPtrTraits>
lengths1,
const at::PackedTensorAccessor64<int64_t, 1, at::RestrictPtrTraits>
lengths2,
at::PackedTensorAccessor64<int64_t, 3, at::RestrictPtrTraits> idxs,
at::PackedTensorAccessor64<scalar_t, 3, at::RestrictPtrTraits> dists,
const int64_t K,
const float radius2) {
const int64_t N = p1.size(0);
const int64_t chunks_per_cloud = (1 + (p1.size(1) - 1) / blockDim.x);
const int64_t chunks_to_do = N * chunks_per_cloud;
const int D = p1.size(2);

for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) {
const int64_t n = chunk / chunks_per_cloud; // batch_index
const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud);
int64_t i = start_point + threadIdx.x;

// Check if point is valid in heterogeneous tensor
if (i >= lengths1[n]) {
continue;
}

// Iterate over points in p2 until desired count is reached or
// all points have been considered
for (int64_t j = 0, count = 0; j < lengths2[n] && count < K; ++j) {
// Calculate the distance between the points
scalar_t dist2 = 0.0;
for (int d = 0; d < D; ++d) {
scalar_t diff = p1[n][i][d] - p2[n][j][d];
dist2 += (diff * diff);
}

if (dist2 < radius2) {
// If the point is within the radius
// Set the value of the index to the point index
idxs[n][i][count] = j;
dists[n][i][count] = dist2;

// increment the number of selected samples for the point i
++count;
}
}
}
}

std::tuple<at::Tensor, at::Tensor> BallQueryCuda(
const at::Tensor& p1, // (N, P1, 3)
const at::Tensor& p2, // (N, P2, 3)
const at::Tensor& lengths1, // (N,)
const at::Tensor& lengths2, // (N,)
int K,
float radius) {
// Check inputs are on the same device
at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2},
lengths1_t{lengths1, "lengths1", 3}, lengths2_t{lengths2, "lengths2", 4};
at::CheckedFrom c = "BallQueryCuda";
at::checkAllSameGPU(c, {p1_t, p2_t, lengths1_t, lengths2_t});
at::checkAllSameType(c, {p1_t, p2_t});

// Set the device for the kernel launch based on the device of p1
at::cuda::CUDAGuard device_guard(p1.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

TORCH_CHECK(
p2.size(2) == p1.size(2), "Point sets must have the same last dimension");

const int N = p1.size(0);
const int P1 = p1.size(1);
const int64_t K_64 = K;
const float radius2 = radius * radius;

// Output tensor with indices of neighbors for each point in p1
auto long_dtype = lengths1.options().dtype(at::kLong);
auto idxs = at::full({N, P1, K}, -1, long_dtype);
auto dists = at::zeros({N, P1, K}, p1.options());

if (idxs.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(idxs, dists);
}

const size_t blocks = 256;
const size_t threads = 256;

AT_DISPATCH_FLOATING_TYPES(
p1.scalar_type(), "ball_query_kernel_cuda", ([&] {
BallQueryKernel<<<blocks, threads, 0, stream>>>(
p1.packed_accessor64<float, 3, at::RestrictPtrTraits>(),
p2.packed_accessor64<float, 3, at::RestrictPtrTraits>(),
lengths1.packed_accessor64<int64_t, 1, at::RestrictPtrTraits>(),
lengths2.packed_accessor64<int64_t, 1, at::RestrictPtrTraits>(),
idxs.packed_accessor64<int64_t, 3, at::RestrictPtrTraits>(),
dists.packed_accessor64<float, 3, at::RestrictPtrTraits>(),
K_64,
radius2);
}));

AT_CUDA_CHECK(cudaGetLastError());

return std::make_tuple(idxs, dists);
}
91 changes: 91 additions & 0 deletions pytorch3d/csrc/ball_query/ball_query.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once
#include <torch/extension.h>
#include <tuple>
#include "utils/pytorch3d_cutils.h"

// Compute indices of K neighbors in pointcloud p2 to points
// in pointcloud p1 which fall within a specified radius
//
// Args:
// p1: FloatTensor of shape (N, P1, D) giving a batch of pointclouds each
// containing P1 points of dimension D.
// p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each
// containing P2 points of dimension D.
// lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud.
// lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud.
// K: Integer giving the upper bound on the number of samples to take
// within the radius
// radius: the radius around each point within which the neighbors need to be
// located
//
// Returns:
// p1_neighbor_idx: LongTensor of shape (N, P1, K), where
// p1_neighbor_idx[n, i, k] = j means that the kth
// neighbor to p1[n, i] in the cloud p2[n] is p2[n, j].
// This is padded with -1s both where a cloud in p2 has fewer than
// S points and where a cloud in p1 has fewer than P1 points and
// also if there are fewer than K points which satisfy the radius
// threshold.
//
// p1_neighbor_dists: FloatTensor of shape (N, P1, K) containing the squared
// distance from each point p1[n, p, :] to its K neighbors
// p2[n, p1_neighbor_idx[n, p, k], :].

// CPU implementation
std::tuple<at::Tensor, at::Tensor> BallQueryCpu(
const at::Tensor& p1,
const at::Tensor& p2,
const at::Tensor& lengths1,
const at::Tensor& lengths2,
const int K,
const float radius);

// CUDA implementation
std::tuple<at::Tensor, at::Tensor> BallQueryCuda(
const at::Tensor& p1,
const at::Tensor& p2,
const at::Tensor& lengths1,
const at::Tensor& lengths2,
const int K,
const float radius);

// Implementation which is exposed
// Note: the backward pass reuses the KNearestNeighborBackward kernel
inline std::tuple<at::Tensor, at::Tensor> BallQuery(
const at::Tensor& p1,
const at::Tensor& p2,
const at::Tensor& lengths1,
const at::Tensor& lengths2,
int K,
float radius) {
if (p1.is_cuda() || p2.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CUDA(p1);
CHECK_CUDA(p2);
return BallQueryCuda(
p1.contiguous(),
p2.contiguous(),
lengths1.contiguous(),
lengths2.contiguous(),
K,
radius);
#else
AT_ERROR("Not compiled with GPU support.");
#endif
}
return BallQueryCpu(
p1.contiguous(),
p2.contiguous(),
lengths1.contiguous(),
lengths2.contiguous(),
K,
radius);
}
55 changes: 55 additions & 0 deletions pytorch3d/csrc/ball_query/ball_query_cpu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <torch/extension.h>
#include <queue>
#include <tuple>

std::tuple<at::Tensor, at::Tensor> BallQueryCpu(
const at::Tensor& p1,
const at::Tensor& p2,
const at::Tensor& lengths1,
const at::Tensor& lengths2,
int K,
float radius) {
const int N = p1.size(0);
const int P1 = p1.size(1);
const int D = p1.size(2);

auto long_opts = lengths1.options().dtype(torch::kInt64);
torch::Tensor idxs = torch::full({N, P1, K}, -1, long_opts);
torch::Tensor dists = torch::full({N, P1, K}, 0, p1.options());
const float radius2 = radius * radius;

auto p1_a = p1.accessor<float, 3>();
auto p2_a = p2.accessor<float, 3>();
auto lengths1_a = lengths1.accessor<int64_t, 1>();
auto lengths2_a = lengths2.accessor<int64_t, 1>();
auto idxs_a = idxs.accessor<int64_t, 3>();
auto dists_a = dists.accessor<float, 3>();

for (int n = 0; n < N; ++n) {
const int64_t length1 = lengths1_a[n];
const int64_t length2 = lengths2_a[n];
for (int64_t i = 0; i < length1; ++i) {
for (int64_t j = 0, count = 0; j < length2 && count < K; ++j) {
float dist2 = 0;
for (int d = 0; d < D; ++d) {
float diff = p1_a[n][i][d] - p2_a[n][j][d];
dist2 += diff * diff;
}
if (dist2 < radius2) {
dists_a[n][i][count] = dist2;
idxs_a[n][i][count] = j;
++count;
}
}
}
}
return std::make_tuple(idxs, dists);
}
4 changes: 4 additions & 0 deletions pytorch3d/csrc/ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// clang-format on
#include "./pulsar/pytorch/renderer.h"
#include "./pulsar/pytorch/tensor_util.h"
#include "ball_query/ball_query.h"
#include "blending/sigmoid_alpha_blend.h"
#include "compositing/alpha_composite.h"
#include "compositing/norm_weighted_sum.h"
Expand All @@ -38,6 +39,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#endif
m.def("knn_points_idx", &KNearestNeighborIdx);
m.def("knn_points_backward", &KNearestNeighborBackward);

// Ball Query
m.def("ball_query", &BallQuery);
m.def(
"mesh_normal_consistency_find_verts", &MeshNormalConsistencyFindVertices);
m.def("gather_scatter", &GatherScatter);
Expand Down
4 changes: 4 additions & 0 deletions pytorch3d/csrc/knn/knn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,10 @@ __global__ void KNearestNeighborBackwardKernel(
const float grad_dist = grad_dists[n * P1 * K + p1_idx * K + k];
// index of point in p2 corresponding to the k-th nearest neighbor
const size_t p2_idx = idxs[n * P1 * K + p1_idx * K + k];
// If the index is the pad value of -1 then ignore it
if (p2_idx == -1) {
continue;
}
const float diff = 2.0 * grad_dist *
(p1[n * P1 * D + p1_idx * D + d] - p2[n * P2 * D + p2_idx * D + d]);
atomicAdd(grad_p1 + n * P1 * D + p1_idx * D + d, diff);
Expand Down
4 changes: 4 additions & 0 deletions pytorch3d/csrc/knn/knn_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCpu(
for (int64_t i1 = 0; i1 < length1; ++i1) {
for (int64_t k = 0; k < length2; ++k) {
const int64_t i2 = idxs_a[n][i1][k];
// If the index is the pad value of -1 then ignore it
if (i2 == -1) {
continue;
}
for (int64_t d = 0; d < D; ++d) {
const float diff =
2.0f * grad_dists_a[n][i1][k] * (p1_a[n][i1][d] - p2_a[n][i2][d]);
Expand Down
2 changes: 1 addition & 1 deletion pytorch3d/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from .ball_query import ball_query
from .cameras_alignment import corresponding_cameras_alignment
from .cubify import cubify
from .graph_conv import GraphConv
Expand Down Expand Up @@ -34,5 +35,4 @@
)
from .vert_align import vert_align


__all__ = [k for k in globals().keys() if not k.startswith("_")]
Loading

0 comments on commit 103da63

Please sign in to comment.