Skip to content

Commit

Permalink
Move coarse rasterization to new file
Browse files Browse the repository at this point in the history
Summary: In preparation for sharing coarse rasterization between point clouds and meshes, move the functions to a new file. No code changes.

Reviewed By: bottler

Differential Revision: D30367812

fbshipit-source-id: 9e73835a26c4ac91f5c9f61ff682bc8218e36c6a
  • Loading branch information
jcjohnson authored and facebook-github-bot committed Sep 8, 2021
1 parent f2c44e3 commit 62dbf37
Show file tree
Hide file tree
Showing 8 changed files with 534 additions and 480 deletions.
File renamed without changes.
481 changes: 481 additions & 0 deletions pytorch3d/csrc/rasterize_coarse/rasterize_coarse.cu

Large diffs are not rendered by default.

38 changes: 38 additions & 0 deletions pytorch3d/csrc/rasterize_coarse/rasterize_coarse.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <torch/extension.h>
#include <tuple>

// Arguments are the same as RasterizeMeshesCoarse from
// rasterize_meshes/rasterize_meshes.h
#ifdef WITH_CUDA
torch::Tensor RasterizeMeshesCoarseCuda(
const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh,
const std::tuple<int, int> image_size,
const float blur_radius,
const int bin_size,
const int max_faces_per_bin);
#endif

// Arguments are the same as RasterizePointsCoarse from
// rasterize_points/rasterize_points.h
#ifdef WITH_CUDA
torch::Tensor RasterizePointsCoarseCuda(
const torch::Tensor& points,
const torch::Tensor& cloud_to_packed_first_idx,
const torch::Tensor& num_points_per_cloud,
const std::tuple<int, int> image_size,
const torch::Tensor& radius,
const int bin_size,
const int max_points_per_bin);
#endif
233 changes: 0 additions & 233 deletions pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
#include <thrust/tuple.h>
#include <cstdio>
#include <tuple>
#include "rasterize_points/bitmask.cuh"
#include "rasterize_points/rasterization_utils.cuh"
#include "utils/float_math.cuh"
#include "utils/geometry_utils.cuh"
Expand All @@ -32,14 +31,6 @@ __device__ bool operator<(const Pixel& a, const Pixel& b) {
return a.z < b.z;
}

__device__ float FloatMin3(const float p1, const float p2, const float p3) {
return fminf(p1, fminf(p2, p3));
}

__device__ float FloatMax3(const float p1, const float p2, const float p3) {
return fmaxf(p1, fmaxf(p2, p3));
}

// Get the xyz coordinates of the three vertices for the face given by the
// index face_idx into face_verts.
__device__ thrust::tuple<float3, float3, float3> GetSingleFaceVerts(
Expand Down Expand Up @@ -630,230 +621,6 @@ at::Tensor RasterizeMeshesBackwardCuda(
return grad_face_verts;
}

// ****************************************************************************
// * COARSE RASTERIZATION *
// ****************************************************************************

__global__ void RasterizeMeshesCoarseCudaKernel(
const float* face_verts,
const int64_t* mesh_to_face_first_idx,
const int64_t* num_faces_per_mesh,
const float blur_radius,
const int N,
const int F,
const int H,
const int W,
const int bin_size,
const int chunk_size,
const int max_faces_per_bin,
int* faces_per_bin,
int* bin_faces) {
extern __shared__ char sbuf[];
const int M = max_faces_per_bin;
// Integer divide round up
const int num_bins_x = 1 + (W - 1) / bin_size;
const int num_bins_y = 1 + (H - 1) / bin_size;

// NDC range depends on the ratio of W/H
// The shorter side from (H, W) is given an NDC range of 2.0 and
// the other side is scaled by the ratio of H:W.
const float NDC_x_half_range = NonSquareNdcRange(W, H) / 2.0f;
const float NDC_y_half_range = NonSquareNdcRange(H, W) / 2.0f;

// Size of half a pixel in NDC units is the NDC half range
// divided by the corresponding image dimension
const float half_pix_x = NDC_x_half_range / W;
const float half_pix_y = NDC_y_half_range / H;

// This is a boolean array of shape (num_bins_y, num_bins_x, chunk_size)
// stored in shared memory that will track whether each point in the chunk
// falls into each bin of the image.
BitMask binmask((unsigned int*)sbuf, num_bins_y, num_bins_x, chunk_size);

// Have each block handle a chunk of faces
const int chunks_per_batch = 1 + (F - 1) / chunk_size;
const int num_chunks = N * chunks_per_batch;

for (int chunk = blockIdx.x; chunk < num_chunks; chunk += gridDim.x) {
const int batch_idx = chunk / chunks_per_batch; // batch index
const int chunk_idx = chunk % chunks_per_batch;
const int face_start_idx = chunk_idx * chunk_size;

binmask.block_clear();
const int64_t mesh_face_start_idx = mesh_to_face_first_idx[batch_idx];
const int64_t mesh_face_stop_idx =
mesh_face_start_idx + num_faces_per_mesh[batch_idx];

// Have each thread handle a different face within the chunk
for (int f = threadIdx.x; f < chunk_size; f += blockDim.x) {
const int f_idx = face_start_idx + f;

// Check if face index corresponds to the mesh in the batch given by
// batch_idx
if (f_idx >= mesh_face_stop_idx || f_idx < mesh_face_start_idx) {
continue;
}

// Get xyz coordinates of the three face vertices.
const auto v012 = GetSingleFaceVerts(face_verts, f_idx);
const float3 v0 = thrust::get<0>(v012);
const float3 v1 = thrust::get<1>(v012);
const float3 v2 = thrust::get<2>(v012);

// Compute screen-space bbox for the triangle expanded by blur.
float xmin = FloatMin3(v0.x, v1.x, v2.x) - sqrt(blur_radius);
float ymin = FloatMin3(v0.y, v1.y, v2.y) - sqrt(blur_radius);
float xmax = FloatMax3(v0.x, v1.x, v2.x) + sqrt(blur_radius);
float ymax = FloatMax3(v0.y, v1.y, v2.y) + sqrt(blur_radius);
float zmin = FloatMin3(v0.z, v1.z, v2.z);

// Faces with at least one vertex behind the camera won't render
// correctly and should be removed or clipped before calling the
// rasterizer
if (zmin < kEpsilon) {
continue;
}

// Brute-force search over all bins; TODO(T54294966) something smarter.
for (int by = 0; by < num_bins_y; ++by) {
// Y coordinate of the top and bottom of the bin.
// PixToNdc gives the location of the center of each pixel, so we
// need to add/subtract a half pixel to get the true extent of the bin.
// Reverse ordering of Y axis so that +Y is upwards in the image.
const float bin_y_min =
PixToNonSquareNdc(by * bin_size, H, W) - half_pix_y;
const float bin_y_max =
PixToNonSquareNdc((by + 1) * bin_size - 1, H, W) + half_pix_y;
const bool y_overlap = (ymin <= bin_y_max) && (bin_y_min < ymax);

for (int bx = 0; bx < num_bins_x; ++bx) {
// X coordinate of the left and right of the bin.
// Reverse ordering of x axis so that +X is left.
const float bin_x_max =
PixToNonSquareNdc((bx + 1) * bin_size - 1, W, H) + half_pix_x;
const float bin_x_min =
PixToNonSquareNdc(bx * bin_size, W, H) - half_pix_x;

const bool x_overlap = (xmin <= bin_x_max) && (bin_x_min < xmax);
if (y_overlap && x_overlap) {
binmask.set(by, bx, f);
}
}
}
}
__syncthreads();
// Now we have processed every face in the current chunk. We need to
// count the number of faces in each bin so we can write the indices
// out to global memory. We have each thread handle a different bin.
for (int byx = threadIdx.x; byx < num_bins_y * num_bins_x;
byx += blockDim.x) {
const int by = byx / num_bins_x;
const int bx = byx % num_bins_x;
const int count = binmask.count(by, bx);
const int faces_per_bin_idx =
batch_idx * num_bins_y * num_bins_x + by * num_bins_x + bx;

// This atomically increments the (global) number of faces found
// in the current bin, and gets the previous value of the counter;
// this effectively allocates space in the bin_faces array for the
// faces in the current chunk that fall into this bin.
const int start = atomicAdd(faces_per_bin + faces_per_bin_idx, count);

// Now loop over the binmask and write the active bits for this bin
// out to bin_faces.
int next_idx = batch_idx * num_bins_y * num_bins_x * M +
by * num_bins_x * M + bx * M + start;
for (int f = 0; f < chunk_size; ++f) {
if (binmask.get(by, bx, f)) {
// TODO(T54296346) find the correct method for handling errors in
// CUDA. Throw an error if num_faces_per_bin > max_faces_per_bin.
// Either decrease bin size or increase max_faces_per_bin
bin_faces[next_idx] = face_start_idx + f;
next_idx++;
}
}
}
__syncthreads();
}
}

at::Tensor RasterizeMeshesCoarseCuda(
const at::Tensor& face_verts,
const at::Tensor& mesh_to_face_first_idx,
const at::Tensor& num_faces_per_mesh,
const std::tuple<int, int> image_size,
const float blur_radius,
const int bin_size,
const int max_faces_per_bin) {
TORCH_CHECK(
face_verts.ndimension() == 3 && face_verts.size(1) == 3 &&
face_verts.size(2) == 3,
"face_verts must have dimensions (num_faces, 3, 3)");

// Check inputs are on the same device
at::TensorArg face_verts_t{face_verts, "face_verts", 1},
mesh_to_face_first_idx_t{
mesh_to_face_first_idx, "mesh_to_face_first_idx", 2},
num_faces_per_mesh_t{num_faces_per_mesh, "num_faces_per_mesh", 3};
at::CheckedFrom c = "RasterizeMeshesCoarseCuda";
at::checkAllSameGPU(
c, {face_verts_t, mesh_to_face_first_idx_t, num_faces_per_mesh_t});

// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(face_verts.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

const int H = std::get<0>(image_size);
const int W = std::get<1>(image_size);

const int F = face_verts.size(0);
const int N = num_faces_per_mesh.size(0);
const int M = max_faces_per_bin;

// Integer divide round up.
const int num_bins_y = 1 + (H - 1) / bin_size;
const int num_bins_x = 1 + (W - 1) / bin_size;

if (num_bins_y >= kMaxItemsPerBin || num_bins_x >= kMaxItemsPerBin) {
std::stringstream ss;
ss << "In Coarse Rasterizer got num_bins_y: " << num_bins_y
<< ", num_bins_x: " << num_bins_x << ", "
<< "; that's too many!";
AT_ERROR(ss.str());
}
auto opts = num_faces_per_mesh.options().dtype(at::kInt);
at::Tensor faces_per_bin = at::zeros({N, num_bins_y, num_bins_x}, opts);
at::Tensor bin_faces = at::full({N, num_bins_y, num_bins_x, M}, -1, opts);

if (bin_faces.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return bin_faces;
}

const int chunk_size = 512;
const size_t shared_size = num_bins_y * num_bins_x * chunk_size / 8;
const size_t blocks = 64;
const size_t threads = 512;

RasterizeMeshesCoarseCudaKernel<<<blocks, threads, shared_size, stream>>>(
face_verts.contiguous().data_ptr<float>(),
mesh_to_face_first_idx.contiguous().data_ptr<int64_t>(),
num_faces_per_mesh.contiguous().data_ptr<int64_t>(),
blur_radius,
N,
F,
H,
W,
bin_size,
chunk_size,
M,
faces_per_bin.data_ptr<int32_t>(),
bin_faces.data_ptr<int32_t>());

AT_CUDA_CHECK(cudaGetLastError());
return bin_faces;
}

// ****************************************************************************
// * FINE RASTERIZATION *
// ****************************************************************************
Expand Down
15 changes: 4 additions & 11 deletions pytorch3d/csrc/rasterize_meshes/rasterize_meshes.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <torch/extension.h>
#include <cstdio>
#include <tuple>
#include "rasterize_coarse/rasterize_coarse.h"
#include "utils/pytorch3d_cutils.h"

// ****************************************************************************
Expand Down Expand Up @@ -236,6 +237,8 @@ torch::Tensor RasterizeMeshesBackward(
// * COARSE RASTERIZATION *
// ****************************************************************************

// RasterizeMeshesCoarseCuda in rasterize_coarse/rasterize_coarse.h

torch::Tensor RasterizeMeshesCoarseCpu(
const torch::Tensor& face_verts,
const at::Tensor& mesh_to_face_first_idx,
Expand All @@ -245,16 +248,6 @@ torch::Tensor RasterizeMeshesCoarseCpu(
const int bin_size,
const int max_faces_per_bin);

#ifdef WITH_CUDA
torch::Tensor RasterizeMeshesCoarseCuda(
const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh,
const std::tuple<int, int> image_size,
const float blur_radius,
const int bin_size,
const int max_faces_per_bin);
#endif
// Args:
// face_verts: Tensor of shape (F, 3, 3) giving (packed) vertex positions for
// faces in all the meshes in the batch. Concretely,
Expand Down Expand Up @@ -499,7 +492,7 @@ RasterizeMeshes(
const bool cull_backfaces) {
if (bin_size > 0 && max_faces_per_bin > 0) {
// Use coarse-to-fine rasterization
auto bin_faces = RasterizeMeshesCoarse(
at::Tensor bin_faces = RasterizeMeshesCoarse(
face_verts,
mesh_to_face_first_idx,
num_faces_per_mesh,
Expand Down
Loading

0 comments on commit 62dbf37

Please sign in to comment.