Skip to content

Commit

Permalink
heterogenous KNN
Browse files Browse the repository at this point in the history
Summary: Interface and working implementation of ragged KNN. Benchmarks (which aren't ragged) haven't slowed. New benchmark shows that ragged is faster than non-ragged of the same shape.

Reviewed By: jcjohnson

Differential Revision: D20696507

fbshipit-source-id: 21b80f71343a3475c8d3ee0ce2680f92f0fae4de
  • Loading branch information
bottler authored and facebook-github-bot committed Apr 7, 2020
1 parent 29b9c44 commit 01b5f7b
Show file tree
Hide file tree
Showing 6 changed files with 332 additions and 84 deletions.
115 changes: 80 additions & 35 deletions pytorch3d/csrc/knn/knn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,40 @@
#include "dispatch.cuh"
#include "mink.cuh"

// A chunk of work is blocksize-many points of P1.
// The number of potential chunks to do is N*(1+(P1-1)/blocksize)
// call (1+(P1-1)/blocksize) chunks_per_cloud
// These chunks are divided among the gridSize-many blocks.
// In block b, we work on chunks b, b+gridSize, b+2*gridSize etc .
// In chunk i, we work on cloud i/chunks_per_cloud on points starting from
// blocksize*(i%chunks_per_cloud).

template <typename scalar_t>
__global__ void KNearestNeighborKernelV0(
const scalar_t* __restrict__ points1,
const scalar_t* __restrict__ points2,
const int64_t* __restrict__ lengths1,
const int64_t* __restrict__ lengths2,
scalar_t* __restrict__ dists,
int64_t* __restrict__ idxs,
const size_t N,
const size_t P1,
const size_t P2,
const size_t D,
const size_t K) {
// Stupid version: Make each thread handle one query point and loop over
// all P2 target points. There are N * P1 input points to handle, so
// do a trivial parallelization over threads.
// Store both dists and indices for knn in global memory.
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
const int num_threads = blockDim.x * gridDim.x;
for (int np = tid; np < N * P1; np += num_threads) {
int n = np / P1;
int p1 = np % P1;
const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x);
const int64_t chunks_to_do = N * chunks_per_cloud;
for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) {
const int64_t n = chunk / chunks_per_cloud;
const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud);
int64_t p1 = start_point + threadIdx.x;
if (p1 >= lengths1[n])
continue;
int offset = n * P1 * K + p1 * K;
int64_t length2 = lengths2[n];
MinK<scalar_t, int64_t> mink(dists + offset, idxs + offset, K);
for (int p2 = 0; p2 < P2; ++p2) {
for (int p2 = 0; p2 < length2; ++p2) {
// Find the distance between points1[n, p1] and points[n, p2]
scalar_t dist = 0;
for (int d = 0; d < D; ++d) {
Expand All @@ -48,6 +59,8 @@ template <typename scalar_t, int64_t D>
__global__ void KNearestNeighborKernelV1(
const scalar_t* __restrict__ points1,
const scalar_t* __restrict__ points2,
const int64_t* __restrict__ lengths1,
const int64_t* __restrict__ lengths2,
scalar_t* __restrict__ dists,
int64_t* __restrict__ idxs,
const size_t N,
Expand All @@ -58,18 +71,22 @@ __global__ void KNearestNeighborKernelV1(
// so we can cache the current point in a thread-local array. We still store
// the current best K dists and indices in global memory, so this should work
// for very large K and fairly large D.
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
const int num_threads = blockDim.x * gridDim.x;
scalar_t cur_point[D];
for (int np = tid; np < N * P1; np += num_threads) {
int n = np / P1;
int p1 = np % P1;
const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x);
const int64_t chunks_to_do = N * chunks_per_cloud;
for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) {
const int64_t n = chunk / chunks_per_cloud;
const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud);
int64_t p1 = start_point + threadIdx.x;
if (p1 >= lengths1[n])
continue;
for (int d = 0; d < D; ++d) {
cur_point[d] = points1[n * P1 * D + p1 * D + d];
}
int offset = n * P1 * K + p1 * K;
int64_t length2 = lengths2[n];
MinK<scalar_t, int64_t> mink(dists + offset, idxs + offset, K);
for (int p2 = 0; p2 < P2; ++p2) {
for (int p2 = 0; p2 < length2; ++p2) {
// Find the distance between cur_point and points[n, p2]
scalar_t dist = 0;
for (int d = 0; d < D; ++d) {
Expand All @@ -89,40 +106,48 @@ struct KNearestNeighborV1Functor {
size_t threads,
const scalar_t* __restrict__ points1,
const scalar_t* __restrict__ points2,
const int64_t* __restrict__ lengths1,
const int64_t* __restrict__ lengths2,
scalar_t* __restrict__ dists,
int64_t* __restrict__ idxs,
const size_t N,
const size_t P1,
const size_t P2,
const size_t K) {
KNearestNeighborKernelV1<scalar_t, D>
<<<blocks, threads>>>(points1, points2, dists, idxs, N, P1, P2, K);
KNearestNeighborKernelV1<scalar_t, D><<<blocks, threads>>>(
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, K);
}
};

template <typename scalar_t, int64_t D, int64_t K>
__global__ void KNearestNeighborKernelV2(
const scalar_t* __restrict__ points1,
const scalar_t* __restrict__ points2,
const int64_t* __restrict__ lengths1,
const int64_t* __restrict__ lengths2,
scalar_t* __restrict__ dists,
int64_t* __restrict__ idxs,
const int64_t N,
const int64_t P1,
const int64_t P2) {
// Same general implementation as V2, but also hoist K into a template arg.
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
const int num_threads = blockDim.x * gridDim.x;
scalar_t cur_point[D];
scalar_t min_dists[K];
int min_idxs[K];
for (int np = tid; np < N * P1; np += num_threads) {
int n = np / P1;
int p1 = np % P1;
const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x);
const int64_t chunks_to_do = N * chunks_per_cloud;
for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) {
const int64_t n = chunk / chunks_per_cloud;
const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud);
int64_t p1 = start_point + threadIdx.x;
if (p1 >= lengths1[n])
continue;
for (int d = 0; d < D; ++d) {
cur_point[d] = points1[n * P1 * D + p1 * D + d];
}
int64_t length2 = lengths2[n];
MinK<scalar_t, int> mink(min_dists, min_idxs, K);
for (int p2 = 0; p2 < P2; ++p2) {
for (int p2 = 0; p2 < length2; ++p2) {
scalar_t dist = 0;
for (int d = 0; d < D; ++d) {
int offset = n * P2 * D + p2 * D + d;
Expand All @@ -146,20 +171,24 @@ struct KNearestNeighborKernelV2Functor {
size_t threads,
const scalar_t* __restrict__ points1,
const scalar_t* __restrict__ points2,
const int64_t* __restrict__ lengths1,
const int64_t* __restrict__ lengths2,
scalar_t* __restrict__ dists,
int64_t* __restrict__ idxs,
const int64_t N,
const int64_t P1,
const int64_t P2) {
KNearestNeighborKernelV2<scalar_t, D, K>
<<<blocks, threads>>>(points1, points2, dists, idxs, N, P1, P2);
KNearestNeighborKernelV2<scalar_t, D, K><<<blocks, threads>>>(
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2);
}
};

template <typename scalar_t, int D, int K>
__global__ void KNearestNeighborKernelV3(
const scalar_t* __restrict__ points1,
const scalar_t* __restrict__ points2,
const int64_t* __restrict__ lengths1,
const int64_t* __restrict__ lengths2,
scalar_t* __restrict__ dists,
int64_t* __restrict__ idxs,
const size_t N,
Expand All @@ -169,19 +198,23 @@ __global__ void KNearestNeighborKernelV3(
// Enabling sorting for this version leads to huge slowdowns; I suspect
// that it forces min_dists into local memory rather than registers.
// As a result this version is always unsorted.
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
const int num_threads = blockDim.x * gridDim.x;
scalar_t cur_point[D];
scalar_t min_dists[K];
int min_idxs[K];
for (int np = tid; np < N * P1; np += num_threads) {
int n = np / P1;
int p1 = np % P1;
const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x);
const int64_t chunks_to_do = N * chunks_per_cloud;
for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) {
const int64_t n = chunk / chunks_per_cloud;
const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud);
int64_t p1 = start_point + threadIdx.x;
if (p1 >= lengths1[n])
continue;
for (int d = 0; d < D; ++d) {
cur_point[d] = points1[n * P1 * D + p1 * D + d];
}
int64_t length2 = lengths2[n];
RegisterMinK<scalar_t, int, K> mink(min_dists, min_idxs);
for (int p2 = 0; p2 < P2; ++p2) {
for (int p2 = 0; p2 < length2; ++p2) {
scalar_t dist = 0;
for (int d = 0; d < D; ++d) {
int offset = n * P2 * D + p2 * D + d;
Expand All @@ -205,13 +238,15 @@ struct KNearestNeighborKernelV3Functor {
size_t threads,
const scalar_t* __restrict__ points1,
const scalar_t* __restrict__ points2,
const int64_t* __restrict__ lengths1,
const int64_t* __restrict__ lengths2,
scalar_t* __restrict__ dists,
int64_t* __restrict__ idxs,
const size_t N,
const size_t P1,
const size_t P2) {
KNearestNeighborKernelV3<scalar_t, D, K>
<<<blocks, threads>>>(points1, points2, dists, idxs, N, P1, P2);
KNearestNeighborKernelV3<scalar_t, D, K><<<blocks, threads>>>(
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2);
}
};

Expand Down Expand Up @@ -257,6 +292,8 @@ int ChooseVersion(const int64_t D, const int64_t K) {
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
const at::Tensor& p1,
const at::Tensor& p2,
const at::Tensor& lengths1,
const at::Tensor& lengths2,
int K,
int version) {
const auto N = p1.size(0);
Expand All @@ -267,8 +304,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(

AT_ASSERTM(p2.size(2) == D, "Point sets must have the same last dimension");
auto long_dtype = p1.options().dtype(at::kLong);
auto idxs = at::full({N, P1, K}, -1, long_dtype);
auto dists = at::full({N, P1, K}, -1, p1.options());
auto idxs = at::zeros({N, P1, K}, long_dtype);
auto dists = at::zeros({N, P1, K}, p1.options());

if (version < 0) {
version = ChooseVersion(D, K);
Expand All @@ -294,6 +331,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
<<<blocks, threads>>>(
p1.data_ptr<scalar_t>(),
p2.data_ptr<scalar_t>(),
lengths1.data_ptr<int64_t>(),
lengths2.data_ptr<int64_t>(),
dists.data_ptr<scalar_t>(),
idxs.data_ptr<int64_t>(),
N,
Expand All @@ -314,6 +353,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
threads,
p1.data_ptr<scalar_t>(),
p2.data_ptr<scalar_t>(),
lengths1.data_ptr<int64_t>(),
lengths2.data_ptr<int64_t>(),
dists.data_ptr<scalar_t>(),
idxs.data_ptr<int64_t>(),
N,
Expand All @@ -336,6 +377,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
threads,
p1.data_ptr<scalar_t>(),
p2.data_ptr<scalar_t>(),
lengths1.data_ptr<int64_t>(),
lengths2.data_ptr<int64_t>(),
dists.data_ptr<scalar_t>(),
idxs.data_ptr<int64_t>(),
N,
Expand All @@ -357,6 +400,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
threads,
p1.data_ptr<scalar_t>(),
p2.data_ptr<scalar_t>(),
lengths1.data_ptr<int64_t>(),
lengths2.data_ptr<int64_t>(),
dists.data_ptr<scalar_t>(),
idxs.data_ptr<int64_t>(),
N,
Expand Down
29 changes: 22 additions & 7 deletions pytorch3d/csrc/knn/knn.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,42 +13,57 @@
// containing P1 points of dimension D.
// p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each
// containing P2 points of dimension D.
// lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud.
// lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud.
// K: int giving the number of nearest points to return.
// sorted: bool telling whether to sort the K returned points by their
// distance.
// version: Integer telling which implementation to use.
// TODO(jcjohns): Document this more, or maybe remove it before landing.
//
// Returns:
// p1_neighbor_idx: LongTensor of shape (N, P1, K), where
// p1_neighbor_idx[n, i, k] = j means that the kth nearest
// neighbor to p1[n, i] in the cloud p2[n] is p2[n, j].
// p1_neighbor_idx[n, i, k] = j means that the kth nearest
// neighbor to p1[n, i] in the cloud p2[n] is p2[n, j].
// It is padded with zeros so that it can be used easily in a later
// gather() operation.
//
// p1_neighbor_dists: FloatTensor of shape (N, P1, K) containing the squared
// distance from each point p1[n, p, :] to its K neighbors
// p2[n, p1_neighbor_idx[n, p, k], :].

// CPU implementation.
std::tuple<at::Tensor, at::Tensor>
KNearestNeighborIdxCpu(const at::Tensor& p1, const at::Tensor& p2, int K);
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCpu(
const at::Tensor& p1,
const at::Tensor& p2,
const at::Tensor& lengths1,
const at::Tensor& lengths2,
int K);

// CUDA implementation
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
const at::Tensor& p1,
const at::Tensor& p2,
const at::Tensor& lengths1,
const at::Tensor& lengths2,
int K,
int version);

// Implementation which is exposed.
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdx(
const at::Tensor& p1,
const at::Tensor& p2,
const at::Tensor& lengths1,
const at::Tensor& lengths2,
int K,
int version) {
if (p1.is_cuda() || p2.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(p1);
CHECK_CONTIGUOUS_CUDA(p2);
return KNearestNeighborIdxCuda(p1, p2, K, version);
return KNearestNeighborIdxCuda(p1, p2, lengths1, lengths2, K, version);
#else
AT_ERROR("Not compiled with GPU support.");
#endif
}
return KNearestNeighborIdxCpu(p1, p2, K);
return KNearestNeighborIdxCpu(p1, p2, lengths1, lengths2, K);
}
18 changes: 13 additions & 5 deletions pytorch3d/csrc/knn/knn_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,35 @@
#include <queue>
#include <tuple>

std::tuple<at::Tensor, at::Tensor>
KNearestNeighborIdxCpu(const at::Tensor& p1, const at::Tensor& p2, int K) {
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCpu(
const at::Tensor& p1,
const at::Tensor& p2,
const at::Tensor& lengths1,
const at::Tensor& lengths2,
int K) {
const int N = p1.size(0);
const int P1 = p1.size(1);
const int D = p1.size(2);
const int P2 = p2.size(1);

auto long_opts = p1.options().dtype(torch::kInt64);
torch::Tensor idxs = torch::full({N, P1, K}, -1, long_opts);
torch::Tensor idxs = torch::full({N, P1, K}, 0, long_opts);
torch::Tensor dists = torch::full({N, P1, K}, 0, p1.options());

auto p1_a = p1.accessor<float, 3>();
auto p2_a = p2.accessor<float, 3>();
auto lengths1_a = lengths1.accessor<int64_t, 1>();
auto lengths2_a = lengths2.accessor<int64_t, 1>();
auto idxs_a = idxs.accessor<int64_t, 3>();
auto dists_a = dists.accessor<float, 3>();

for (int n = 0; n < N; ++n) {
for (int i1 = 0; i1 < P1; ++i1) {
const int64_t length1 = lengths1_a[n];
const int64_t length2 = lengths2_a[n];
for (int64_t i1 = 0; i1 < length1; ++i1) {
// Use a priority queue to store (distance, index) tuples.
std::priority_queue<std::tuple<float, int>> q;
for (int i2 = 0; i2 < P2; ++i2) {
for (int64_t i2 = 0; i2 < length2; ++i2) {
float dist = 0;
for (int d = 0; d < D; ++d) {
float diff = p1_a[n][i1][d] - p2_a[n][i2][d];
Expand Down

0 comments on commit 01b5f7b

Please sign in to comment.