Skip to content

Commit

Permalink
Farthest point sampling CUDA
Browse files Browse the repository at this point in the history
Summary:
CUDA implementation of farthest point sampling algorithm.

## Visual comparison

Compared to random sampling, farthest point sampling gives better coverage of the shape.

{F658631262}

## Reduction

Parallelized block reduction to find the max value at each iteration happens as follows:

1. First split the points into two equal sized parts (e.g. for a list with 8 values):
`[20, 27, 6, 8 | 11, 10, 2, 33]`
2. Use half of the thread (4 threads) to compare pairs of elements from each half (e.g elements [0, 4], [1, 5] etc) and store the result in the first half of the list:
`[20, 27, 6, 33 | 11, 10, 2, 33]`
Now we no longer care about the second part but again divide the first part into two
`[20, 27 | 6, 33| -, -, -, -]`
Now we can use 2 threads to compare the 4 elements
4. Finally we have gotten down to a single pair
`[20 | 33 | -, - | -, -, -, -]`
Use 1 thread to compare the remaining two elements
5. The max will now be at thread id = 0
`[33 | - | -, - | -, -, -, -]`
The reduction will give the farthest point for the selected batch index at this iteration.

Reviewed By: bottler, jcjohnson

Differential Revision: D30401803

fbshipit-source-id: 525bd5ae27c4b13b501812cfe62306bb003827d2
  • Loading branch information
nikhilaravi authored and facebook-github-bot committed Sep 15, 2021
1 parent d9f7611 commit bd04ffa
Show file tree
Hide file tree
Showing 11 changed files with 441 additions and 33 deletions.
2 changes: 1 addition & 1 deletion pytorch3d/csrc/ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
#include "point_mesh/point_mesh_cuda.h"
#include "rasterize_meshes/rasterize_meshes.h"
#include "rasterize_points/rasterize_points.h"
#include "sample_pdf/sample_pdf.h"
#include "sample_farthest_points/sample_farthest_points.h"
#include "sample_pdf/sample_pdf.h"

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("face_areas_normals_forward", &FaceAreasNormalsForward);
Expand Down
2 changes: 1 addition & 1 deletion pytorch3d/csrc/point_mesh/point_mesh_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ __global__ void DistanceForwardKernel(
// Unroll the last 6 iterations of the loop since they will happen
// synchronized within a single warp.
if (tid < 32)
WarpReduce<float>(min_dists, min_idxs, tid);
WarpReduceMin<float>(min_dists, min_idxs, tid);

// Finally thread 0 writes the result to the output buffer.
if (tid == 0) {
Expand Down
12 changes: 4 additions & 8 deletions pytorch3d/csrc/sample_farthest_points/sample_farthest_points.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ at::Tensor FarthestPointSamplingCpu(
const at::Tensor& points,
const at::Tensor& lengths,
const at::Tensor& K,
const bool random_start_point) {
const at::Tensor& start_idxs) {
// Get constants
const int64_t N = points.size(0);
const int64_t P = points.size(1);
Expand All @@ -32,6 +32,7 @@ at::Tensor FarthestPointSamplingCpu(
auto lengths_a = lengths.accessor<int64_t, 1>();
auto k_a = K.accessor<int64_t, 1>();
auto sampled_indices_a = sampled_indices.accessor<int64_t, 2>();
auto start_idxs_a = start_idxs.accessor<int64_t, 1>();

// Initialize a mask to prevent duplicates
// If true, the point has already been selected.
Expand All @@ -41,20 +42,15 @@ at::Tensor FarthestPointSamplingCpu(
// distances from each point to any of the previously selected points
std::vector<float> dists(P, std::numeric_limits<float>::max());

// Initialize random number generation for random starting points
std::random_device rd;
std::default_random_engine eng(rd());

for (int64_t n = 0; n < N; ++n) {
// Resize and reset points mask and distances for each batch
selected_points_mask.resize(lengths_a[n]);
dists.resize(lengths_a[n]);
std::fill(selected_points_mask.begin(), selected_points_mask.end(), false);
std::fill(dists.begin(), dists.end(), std::numeric_limits<float>::max());

// Select a starting point index and save it
std::uniform_int_distribution<int> distr(0, lengths_a[n] - 1);
int64_t last_idx = random_start_point ? distr(eng) : 0;
// Get the starting point index and save it
int64_t last_idx = start_idxs_a[n];
sampled_indices_a[n][0] = last_idx;

// Set the value of the mask at this point to false
Expand Down
252 changes: 252 additions & 0 deletions pytorch3d/csrc/sample_farthest_points/sample_farthest_points.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include "utils/pytorch3d_cutils.h"
#include "utils/warp_reduce.cuh"

template <unsigned int block_size>
__global__ void FarthestPointSamplingKernel(
// clang-format off
const at::PackedTensorAccessor64<float, 3, at::RestrictPtrTraits> points,
const at::PackedTensorAccessor64<int64_t, 1, at::RestrictPtrTraits> lengths,
const at::PackedTensorAccessor64<int64_t, 1, at::RestrictPtrTraits> K,
at::PackedTensorAccessor64<int64_t, 2, at::RestrictPtrTraits> idxs,
at::PackedTensorAccessor64<float, 2, at::RestrictPtrTraits> min_point_dist,
const at::PackedTensorAccessor64<int64_t, 1, at::RestrictPtrTraits> start_idxs
// clang-format on
) {
// Get constants
const int64_t N = points.size(0);
const int64_t P = points.size(1);
const int64_t D = points.size(2);

// Create single shared memory buffer which is split and cast to different
// types: dists/dists_idx are used to save the maximum distances seen by the
// points processed by any one thread and the associated point indices.
// These values only need to be accessed by other threads in this block which
// are processing the same batch and not by other blocks.
extern __shared__ char shared_buf[];
float* dists = (float*)shared_buf; // block_size floats
int64_t* dists_idx = (int64_t*)&dists[block_size]; // block_size int64_t

// Get batch index and thread index
const int64_t batch_idx = blockIdx.x;
const size_t tid = threadIdx.x;

// If K is greater than the number of points in the pointcloud
// we only need to iterate until the smaller value is reached.
const int64_t k_n = min(K[batch_idx], lengths[batch_idx]);

// Write the first selected point to global memory in the first thread
int64_t selected = start_idxs[batch_idx];
if (tid == 0)
idxs[batch_idx][0] = selected;

// Iterate to find k_n sampled points
for (int64_t k = 1; k < k_n; ++k) {
// Keep track of the maximum of the minimum distance to previously selected
// points seen by this thread
int64_t max_dist_idx = 0;
float max_dist = -1.0;

// Iterate through all the points in this pointcloud. For already selected
// points, the minimum distance to the set of previously selected points
// will be 0.0 so they won't be selected again.
for (int64_t p = tid; p < lengths[batch_idx]; p += block_size) {
// Calculate the distance to the last selected point
float dist2 = 0.0;
for (int64_t d = 0; d < D; ++d) {
float diff = points[batch_idx][selected][d] - points[batch_idx][p][d];
dist2 += (diff * diff);
}

// If the distance of point p to the last selected point is
// less than the previous minimum distance of p to the set of selected
// points, then updated the corresponding value in min_point_dist
// so it always contains the min distance.
const float p_min_dist = min(dist2, min_point_dist[batch_idx][p]);
min_point_dist[batch_idx][p] = p_min_dist;

// Update the max distance and point idx for this thread.
max_dist_idx = (p_min_dist > max_dist) ? p : max_dist_idx;
max_dist = (p_min_dist > max_dist) ? p_min_dist : max_dist;
}

// After going through all points for this thread, save the max
// point and idx seen by this thread. Each thread sees P/block_size points.
dists[tid] = max_dist;
dists_idx[tid] = max_dist_idx;
// Sync to ensure all threads in the block have updated their max point.
__syncthreads();

// Parallelized block reduction to find the max point seen by
// all the threads in this block for iteration k.
// Each block represents one batch element so we can use a divide/conquer
// approach to find the max, syncing all threads after each step.

for (int s = block_size / 2; s > 0; s >>= 1) {
if (tid < s) {
// Compare the best point seen by two threads and update the shared
// memory at the location of the first thread index with the max out
// of the two threads.
if (dists[tid] < dists[tid + s]) {
dists[tid] = dists[tid + s];
dists_idx[tid] = dists_idx[tid + s];
}
}
__syncthreads();
}

// TODO(nikhilar): As reduction proceeds, the number of “active” threads
// decreases. When tid < 32, there should only be one warp left which could
// be unrolled.

// The overall max after reducing will be saved
// at the location of tid = 0.
selected = dists_idx[0];

if (tid == 0) {
// Write the farthest point for iteration k to global memory
idxs[batch_idx][k] = selected;
}
}
}

at::Tensor FarthestPointSamplingCuda(
const at::Tensor& points, // (N, P, 3)
const at::Tensor& lengths, // (N,)
const at::Tensor& K, // (N,)
const at::Tensor& start_idxs) {
// Check inputs are on the same device
at::TensorArg p_t{points, "points", 1}, lengths_t{lengths, "lengths", 2},
k_t{K, "K", 3}, start_idxs_t{start_idxs, "start_idxs", 4};
at::CheckedFrom c = "FarthestPointSamplingCuda";
at::checkAllSameGPU(c, {p_t, lengths_t, k_t, start_idxs_t});
at::checkAllSameType(c, {lengths_t, k_t, start_idxs_t});

// Set the device for the kernel launch based on the device of points
at::cuda::CUDAGuard device_guard(points.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

TORCH_CHECK(
points.size(0) == lengths.size(0),
"Point and lengths must have the same batch dimension");

TORCH_CHECK(
points.size(0) == K.size(0),
"Points and K must have the same batch dimension");

const int64_t N = points.size(0);
const int64_t P = points.size(1);
const int64_t max_K = at::max(K).item<int64_t>();

// Initialize the output tensor with the sampled indices
auto idxs = at::full({N, max_K}, -1, lengths.options());
auto min_point_dist = at::full({N, P}, 1e10, points.options());

if (N == 0 || P == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return idxs;
}

// Set the number of blocks to the batch size so that the
// block reduction step can be done for each pointcloud
// to find the max distance point in the pointcloud at each iteration.
const size_t blocks = N;

// Set the threads to the nearest power of 2 of the number of
// points in the pointcloud (up to the max threads in a block).
// This will ensure each thread processes the minimum necessary number of
// points (P/threads).
const int points_pow_2 = std::log(static_cast<double>(P)) / std::log(2.0);
const size_t threads = max(min(1 << points_pow_2, MAX_THREADS_PER_BLOCK), 1);

// Create the accessors
auto points_a = points.packed_accessor64<float, 3, at::RestrictPtrTraits>();
auto lengths_a =
lengths.packed_accessor64<int64_t, 1, at::RestrictPtrTraits>();
auto K_a = K.packed_accessor64<int64_t, 1, at::RestrictPtrTraits>();
auto idxs_a = idxs.packed_accessor64<int64_t, 2, at::RestrictPtrTraits>();
auto start_idxs_a =
start_idxs.packed_accessor64<int64_t, 1, at::RestrictPtrTraits>();
auto min_point_dist_a =
min_point_dist.packed_accessor64<float, 2, at::RestrictPtrTraits>();

// Initialize the shared memory which will be used to store the
// distance/index of the best point seen by each thread.
size_t shared_mem = threads * sizeof(float) + threads * sizeof(int64_t);
// TODO: using shared memory for min_point_dist gives an ~2x speed up
// compared to using a global (N, P) shaped tensor, however for
// larger pointclouds this may exceed the shared memory limit per block.
// If a speed up is required for smaller pointclouds, then the storage
// could be switched to shared memory if the required total shared memory is
// within the memory limit per block.

// Support a case for all powers of 2 up to MAX_THREADS_PER_BLOCK possible per
// block.
switch (threads) {
case 1024:
FarthestPointSamplingKernel<1024>
<<<blocks, threads, shared_mem, stream>>>(
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
break;
case 512:
FarthestPointSamplingKernel<512><<<blocks, threads, shared_mem, stream>>>(
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
break;
case 256:
FarthestPointSamplingKernel<256><<<blocks, threads, shared_mem, stream>>>(
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
break;
case 128:
FarthestPointSamplingKernel<128><<<blocks, threads, shared_mem, stream>>>(
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
break;
case 64:
FarthestPointSamplingKernel<64><<<blocks, threads, shared_mem, stream>>>(
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
break;
case 32:
FarthestPointSamplingKernel<32><<<blocks, threads, shared_mem, stream>>>(
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
break;
case 16:
FarthestPointSamplingKernel<16><<<blocks, threads, shared_mem, stream>>>(
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
break;
case 8:
FarthestPointSamplingKernel<8><<<blocks, threads, shared_mem, stream>>>(
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
break;
case 4:
FarthestPointSamplingKernel<4><<<threads, threads, shared_mem, stream>>>(
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
break;
case 2:
FarthestPointSamplingKernel<2><<<threads, threads, shared_mem, stream>>>(
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
break;
case 1:
FarthestPointSamplingKernel<1><<<threads, threads, shared_mem, stream>>>(
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
break;
default:
FarthestPointSamplingKernel<1024>
<<<blocks, threads, shared_mem, stream>>>(
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
}

AT_CUDA_CHECK(cudaGetLastError());
return idxs;
}
28 changes: 22 additions & 6 deletions pytorch3d/csrc/sample_farthest_points/sample_farthest_points.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,28 +29,44 @@
// K: a tensor of length (N,) giving the number of
// samples to select for each element in the batch.
// The number of samples is typically << P.
// random_start_point: bool, if True, a random point is selected as the
// starting point for iterative sampling.
// start_idxs: (N,) long Tensor giving the index of the first point to
// sample. Default is all 0. When a random start point is required,
// start_idxs should be set to a random value between [0, lengths[n]]
// for batch element n.
// Returns:
// selected_indices: (N, K) array of selected indices. If the values in
// K are not all the same, then the shape will be (N, max(K), D), and
// padded with -1 for batch elements where k_i < max(K). The selected
// points are gathered in the pytorch autograd wrapper.

at::Tensor FarthestPointSamplingCuda(
const at::Tensor& points,
const at::Tensor& lengths,
const at::Tensor& K,
const at::Tensor& start_idxs);

at::Tensor FarthestPointSamplingCpu(
const at::Tensor& points,
const at::Tensor& lengths,
const at::Tensor& K,
const bool random_start_point);
const at::Tensor& start_idxs);

// Exposed implementation.
at::Tensor FarthestPointSampling(
const at::Tensor& points,
const at::Tensor& lengths,
const at::Tensor& K,
const bool random_start_point) {
const at::Tensor& start_idxs) {
if (points.is_cuda() || lengths.is_cuda() || K.is_cuda()) {
AT_ERROR("CUDA implementation not yet supported");
#ifdef WITH_CUDA
CHECK_CUDA(points);
CHECK_CUDA(lengths);
CHECK_CUDA(K);
CHECK_CUDA(start_idxs);
return FarthestPointSamplingCuda(points, lengths, K, start_idxs);
#else
AT_ERROR("Not compiled with GPU support.");
#endif
}
return FarthestPointSamplingCpu(points, lengths, K, random_start_point);
return FarthestPointSamplingCpu(points, lengths, K, start_idxs);
}
3 changes: 3 additions & 0 deletions pytorch3d/csrc/utils/pytorch3d_cutils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,6 @@
#define CHECK_CONTIGUOUS_CUDA(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)

// Max possible threads per block
const int MAX_THREADS_PER_BLOCK = 1024;
Loading

0 comments on commit bd04ffa

Please sign in to comment.