Skip to content

Commit

Permalink
Unify coarse rasterization for points and meshes
Browse files Browse the repository at this point in the history
Summary:
There has historically been a lot of duplication between the coarse rasterization logic for point clouds and meshes. This diff factors out the shared logic, so coarse rasterization of point clouds and meshes share the same core logic.

Previously the only difference between the coarse rasterization kernels for points and meshes was the logic for checking whether a {point / triangle} intersects a tile in the image. We implement a generic coarse rasterization kernel that takes a set of 2D bounding boxes rather than geometric primitives; we then implement separate kernels that compute 2D bounding boxes for points and triangles.

This change does not affect the Python API at all. It also should not change any rasterization behavior, since this diff is just a refactoring of the existing logic.

I see this diff as the first in a few pieces of rasterizer refactoring. Followup diffs should do the following:
- Add a check for bin overflow in the generic coarse rasterizer kernel: allocate a global scalar to flag bin overflow which kernel worker threads can write to in case they detect bin overflow. The C++ launcher function can then check this flag after the kernel returns and issue a warning to the user in case of overflow.
- As a slightly more involved mechanism, if bin overflow is detected then the coarse kernel can continue running in order to count how many elements fall into each bin, without actually writing out their indices to the coarse output tensor. Then the actual number of entries per bin can be used to re-allocate the output tensor and re-run the coarse rasterization kernel so that bin overflow can be automatically avoided.
- The unification of the coarse and fine rasterization kernels also allows us to insert an extra CUDA kernel prior to coarse rasterization that filters out primitives outside the view frustum. This would be helpful for rendering full scenes (e.g. Matterport data) where only a small piece of the mesh is actually visible at any one time.

Reviewed By: bottler

Differential Revision: D25710361

fbshipit-source-id: 9c9dea512cb339c42adb3c92e7733fedd586ce1b
  • Loading branch information
jcjohnson authored and facebook-github-bot committed Sep 8, 2021
1 parent eed68f4 commit bbc7573
Showing 1 changed file with 45 additions and 190 deletions.
235 changes: 45 additions & 190 deletions pytorch3d/csrc/rasterize_coarse/rasterize_coarse.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,29 @@ __global__ void TriangleBoundingBoxKernel(
}
}

__global__ void PointBoundingBoxKernel(
const float* points, // (P, 3)
const float* radius, // (P,)
const int P,
float* bboxes, // (4, P)
bool* skip_points) {
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
const int num_threads = blockDim.x * gridDim.x;
for (int p = tid; p < P; p += num_threads) {
const float x = points[p * 3 + 0];
const float y = points[p * 3 + 1];
const float z = points[p * 3 + 2];
const float r = radius[p];
// TODO: change to kEpsilon to match triangles?
const bool skip = z < 0;
bboxes[0 * P + p] = x - r;
bboxes[1 * P + p] = x + r;
bboxes[2 * P + p] = y - r;
bboxes[3 * P + p] = y + r;
skip_points[p] = skip;
}
}

__global__ void RasterizeCoarseCudaKernel(
const float* bboxes, // (4, E) (xmin, xmax, ymin, ymax)
const bool* should_skip, // (E,)
Expand Down Expand Up @@ -242,150 +265,6 @@ at::Tensor RasterizeCoarseCuda(
return bin_elems;
}

__global__ void RasterizePointsCoarseCudaKernel(
const float* points, // (P, 3)
const int64_t* cloud_to_packed_first_idx, // (N)
const int64_t* num_points_per_cloud, // (N)
const float* radius,
const int N,
const int P,
const int H,
const int W,
const int bin_size,
const int chunk_size,
const int max_points_per_bin,
int* points_per_bin,
int* bin_points) {
extern __shared__ char sbuf[];
const int M = max_points_per_bin;

// Integer divide round up
const int num_bins_x = 1 + (W - 1) / bin_size;
const int num_bins_y = 1 + (H - 1) / bin_size;

// NDC range depends on the ratio of W/H
// The shorter side from (H, W) is given an NDC range of 2.0 and
// the other side is scaled by the ratio of H:W.
const float NDC_x_half_range = NonSquareNdcRange(W, H) / 2.0f;
const float NDC_y_half_range = NonSquareNdcRange(H, W) / 2.0f;

// Size of half a pixel in NDC units is the NDC half range
// divided by the corresponding image dimension
const float half_pix_x = NDC_x_half_range / W;
const float half_pix_y = NDC_y_half_range / H;

// This is a boolean array of shape (num_bins_y, num_bins_x, chunk_size)
// stored in shared memory that will track whether each point in the chunk
// falls into each bin of the image.
BitMask binmask((unsigned int*)sbuf, num_bins_y, num_bins_x, chunk_size);

// Have each block handle a chunk of points and build a 3D bitmask in
// shared memory to mark which points hit which bins. In this first phase,
// each thread processes one point at a time. After processing the chunk,
// one thread is assigned per bin, and the thread counts and writes the
// points for the bin out to global memory.
const int chunks_per_batch = 1 + (P - 1) / chunk_size;
const int num_chunks = N * chunks_per_batch;
for (int chunk = blockIdx.x; chunk < num_chunks; chunk += gridDim.x) {
const int batch_idx = chunk / chunks_per_batch;
const int chunk_idx = chunk % chunks_per_batch;
const int point_start_idx = chunk_idx * chunk_size;

binmask.block_clear();

// Using the batch index of the thread get the start and stop
// indices for the points.
const int64_t cloud_point_start_idx = cloud_to_packed_first_idx[batch_idx];
const int64_t cloud_point_stop_idx =
cloud_point_start_idx + num_points_per_cloud[batch_idx];

// Have each thread handle a different point within the chunk
for (int p = threadIdx.x; p < chunk_size; p += blockDim.x) {
const int p_idx = point_start_idx + p;

// Check if point index corresponds to the cloud in the batch given by
// batch_idx.
if (p_idx >= cloud_point_stop_idx || p_idx < cloud_point_start_idx) {
continue;
}

const float px = points[p_idx * 3 + 0];
const float py = points[p_idx * 3 + 1];
const float pz = points[p_idx * 3 + 2];
const float p_radius = radius[p_idx];
if (pz < 0)
continue; // Don't render points behind the camera.
const float px0 = px - p_radius;
const float px1 = px + p_radius;
const float py0 = py - p_radius;
const float py1 = py + p_radius;

// Brute-force search over all bins; TODO something smarter?
// For example we could compute the exact bin where the point falls,
// then check neighboring bins. This way we wouldn't have to check
// all bins (however then we might have more warp divergence?)
for (int by = 0; by < num_bins_y; ++by) {
// Get y extent for the bin. PixToNonSquareNdc gives us the location of
// the center of each pixel, so we need to add/subtract a half
// pixel to get the true extent of the bin.
const float by0 = PixToNonSquareNdc(by * bin_size, H, W) - half_pix_y;
const float by1 =
PixToNonSquareNdc((by + 1) * bin_size - 1, H, W) + half_pix_y;
const bool y_overlap = (py0 <= by1) && (by0 <= py1);

if (!y_overlap) {
continue;
}
for (int bx = 0; bx < num_bins_x; ++bx) {
// Get x extent for the bin; again we need to adjust the
// output of PixToNonSquareNdc by half a pixel.
const float bx0 = PixToNonSquareNdc(bx * bin_size, W, H) - half_pix_x;
const float bx1 =
PixToNonSquareNdc((bx + 1) * bin_size - 1, W, H) + half_pix_x;
const bool x_overlap = (px0 <= bx1) && (bx0 <= px1);

if (x_overlap) {
binmask.set(by, bx, p);
}
}
}
}
__syncthreads();
// Now we have processed every point in the current chunk. We need to
// count the number of points in each bin so we can write the indices
// out to global memory. We have each thread handle a different bin.
for (int byx = threadIdx.x; byx < num_bins_y * num_bins_x;
byx += blockDim.x) {
const int by = byx / num_bins_x;
const int bx = byx % num_bins_x;
const int count = binmask.count(by, bx);
const int points_per_bin_idx =
batch_idx * num_bins_y * num_bins_x + by * num_bins_x + bx;

// This atomically increments the (global) number of points found
// in the current bin, and gets the previous value of the counter;
// this effectively allocates space in the bin_points array for the
// points in the current chunk that fall into this bin.
const int start = atomicAdd(points_per_bin + points_per_bin_idx, count);

// Now loop over the binmask and write the active bits for this bin
// out to bin_points.
int next_idx = batch_idx * num_bins_y * num_bins_x * M +
by * num_bins_x * M + bx * M + start;
for (int p = 0; p < chunk_size; ++p) {
if (binmask.get(by, bx, p)) {
// TODO: Throw an error if next_idx >= M -- this means that
// we got more than max_points_per_bin in this bin
// TODO: check if atomicAdd is needed in line 265.
bin_points[next_idx] = point_start_idx + p;
next_idx++;
}
}
}
__syncthreads();
}
}

at::Tensor RasterizeMeshesCoarseCuda(
const at::Tensor& face_verts,
const at::Tensor& mesh_to_face_first_idx,
Expand Down Expand Up @@ -442,8 +321,8 @@ at::Tensor RasterizeMeshesCoarseCuda(

at::Tensor RasterizePointsCoarseCuda(
const at::Tensor& points, // (P, 3)
const at::Tensor& cloud_to_packed_first_idx, // (N)
const at::Tensor& num_points_per_cloud, // (N)
const at::Tensor& cloud_to_packed_first_idx, // (N,)
const at::Tensor& num_points_per_cloud, // (N,)
const std::tuple<int, int> image_size,
const at::Tensor& radius,
const int bin_size,
Expand All @@ -465,54 +344,30 @@ at::Tensor RasterizePointsCoarseCuda(
at::cuda::CUDAGuard device_guard(points.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

const int H = std::get<0>(image_size);
const int W = std::get<1>(image_size);

// Allocate tensors for bboxes and should_skip
const int P = points.size(0);
const int N = num_points_per_cloud.size(0);
const int M = max_points_per_bin;

// Integer divide round up.
const int num_bins_y = 1 + (H - 1) / bin_size;
const int num_bins_x = 1 + (W - 1) / bin_size;
auto float_opts = points.options().dtype(at::kFloat);
auto bool_opts = points.options().dtype(at::kBool);
at::Tensor bboxes = at::empty({4, P}, float_opts);
at::Tensor should_skip = at::empty({P}, bool_opts);

if (num_bins_y >= kMaxItemsPerBin || num_bins_x >= kMaxItemsPerBin) {
// Make sure we do not use too much shared memory.
std::stringstream ss;
ss << "In Coarse Rasterizer got num_bins_y: " << num_bins_y
<< ", num_bins_x: " << num_bins_x << ", "
<< "; that's too many!";
AT_ERROR(ss.str());
}
auto opts = num_points_per_cloud.options().dtype(at::kInt);
at::Tensor points_per_bin = at::zeros({N, num_bins_y, num_bins_x}, opts);
at::Tensor bin_points = at::full({N, num_bins_y, num_bins_x, M}, -1, opts);

if (bin_points.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return bin_points;
}

const int chunk_size = 512;
const size_t shared_size = num_bins_y * num_bins_x * chunk_size / 8;
const size_t blocks = 64;
const size_t threads = 512;

RasterizePointsCoarseCudaKernel<<<blocks, threads, shared_size, stream>>>(
// Launch kernel to compute point bboxes
const size_t blocks = 128;
const size_t threads = 256;
PointBoundingBoxKernel<<<blocks, threads, 0, stream>>>(
points.contiguous().data_ptr<float>(),
cloud_to_packed_first_idx.contiguous().data_ptr<int64_t>(),
num_points_per_cloud.contiguous().data_ptr<int64_t>(),
radius.contiguous().data_ptr<float>(),
N,
P,
H,
W,
bin_size,
chunk_size,
M,
points_per_bin.contiguous().data_ptr<int32_t>(),
bin_points.contiguous().data_ptr<int32_t>());

bboxes.contiguous().data_ptr<float>(),
should_skip.contiguous().data_ptr<bool>());
AT_CUDA_CHECK(cudaGetLastError());
return bin_points;

return RasterizeCoarseCuda(
bboxes,
should_skip,
cloud_to_packed_first_idx,
num_points_per_cloud,
image_size,
bin_size,
max_points_per_bin);
}

0 comments on commit bbc7573

Please sign in to comment.