Skip to content

Commit

Permalink
Refactor mesh coarse rasterization
Browse files Browse the repository at this point in the history
Summary: Renaming parts of the mesh coarse rasterization and separating the bounding box calculation. All in preparation for sharing code with point rasterization.

Reviewed By: bottler

Differential Revision: D30369112

fbshipit-source-id: 3508c0b1239b355030cfa4038d5f3d6a945ebbf4
  • Loading branch information
jcjohnson authored and facebook-github-bot committed Sep 8, 2021
1 parent 62dbf37 commit eed68f4
Showing 1 changed file with 151 additions and 114 deletions.
265 changes: 151 additions & 114 deletions pytorch3d/csrc/rasterize_coarse/rasterize_coarse.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,44 +17,55 @@
#include "utils/float_math.cuh"
#include "utils/geometry_utils.cuh" // For kEpsilon -- gross

// Get the xyz coordinates of the three vertices for the face given by the
// index face_idx into face_verts.
__device__ thrust::tuple<float3, float3, float3> GetSingleFaceVerts(
const float* face_verts,
int face_idx) {
const float x0 = face_verts[face_idx * 9 + 0];
const float y0 = face_verts[face_idx * 9 + 1];
const float z0 = face_verts[face_idx * 9 + 2];
const float x1 = face_verts[face_idx * 9 + 3];
const float y1 = face_verts[face_idx * 9 + 4];
const float z1 = face_verts[face_idx * 9 + 5];
const float x2 = face_verts[face_idx * 9 + 6];
const float y2 = face_verts[face_idx * 9 + 7];
const float z2 = face_verts[face_idx * 9 + 8];

const float3 v0xyz = make_float3(x0, y0, z0);
const float3 v1xyz = make_float3(x1, y1, z1);
const float3 v2xyz = make_float3(x2, y2, z2);

return thrust::make_tuple(v0xyz, v1xyz, v2xyz);
__global__ void TriangleBoundingBoxKernel(
const float* face_verts, // (F, 3, 3)
const int F,
const float blur_radius,
float* bboxes, // (4, F)
bool* skip_face) { // (F,)
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
const int num_threads = blockDim.x * gridDim.x;
const float sqrt_radius = sqrt(blur_radius);
for (int f = tid; f < F; f += num_threads) {
const float v0x = face_verts[f * 9 + 0 * 3 + 0];
const float v0y = face_verts[f * 9 + 0 * 3 + 1];
const float v0z = face_verts[f * 9 + 0 * 3 + 2];
const float v1x = face_verts[f * 9 + 1 * 3 + 0];
const float v1y = face_verts[f * 9 + 1 * 3 + 1];
const float v1z = face_verts[f * 9 + 1 * 3 + 2];
const float v2x = face_verts[f * 9 + 2 * 3 + 0];
const float v2y = face_verts[f * 9 + 2 * 3 + 1];
const float v2z = face_verts[f * 9 + 2 * 3 + 2];
const float xmin = FloatMin3(v0x, v1x, v2x) - sqrt_radius;
const float xmax = FloatMax3(v0x, v1x, v2x) + sqrt_radius;
const float ymin = FloatMin3(v0y, v1y, v2y) - sqrt_radius;
const float ymax = FloatMax3(v0y, v1y, v2y) + sqrt_radius;
const float zmin = FloatMin3(v0z, v1z, v2z);
const bool skip = zmin < kEpsilon;
bboxes[0 * F + f] = xmin;
bboxes[1 * F + f] = xmax;
bboxes[2 * F + f] = ymin;
bboxes[3 * F + f] = ymax;
skip_face[f] = skip;
}
}

__global__ void RasterizeMeshesCoarseCudaKernel(
const float* face_verts,
const int64_t* mesh_to_face_first_idx,
const int64_t* num_faces_per_mesh,
const float blur_radius,
__global__ void RasterizeCoarseCudaKernel(
const float* bboxes, // (4, E) (xmin, xmax, ymin, ymax)
const bool* should_skip, // (E,)
const int64_t* elem_first_idxs,
const int64_t* elems_per_batch,
const int N,
const int F,
const int E,
const int H,
const int W,
const int bin_size,
const int chunk_size,
const int max_faces_per_bin,
int* faces_per_bin,
int* bin_faces) {
const int max_elem_per_bin,
int* elems_per_bin,
int* bin_elems) {
extern __shared__ char sbuf[];
const int M = max_faces_per_bin;
const int M = max_elem_per_bin;
// Integer divide round up
const int num_bins_x = 1 + (W - 1) / bin_size;
const int num_bins_y = 1 + (H - 1) / bin_size;
Expand All @@ -71,53 +82,39 @@ __global__ void RasterizeMeshesCoarseCudaKernel(
const float half_pix_y = NDC_y_half_range / H;

// This is a boolean array of shape (num_bins_y, num_bins_x, chunk_size)
// stored in shared memory that will track whether each point in the chunk
// stored in shared memory that will track whether each elem in the chunk
// falls into each bin of the image.
BitMask binmask((unsigned int*)sbuf, num_bins_y, num_bins_x, chunk_size);

// Have each block handle a chunk of faces
const int chunks_per_batch = 1 + (F - 1) / chunk_size;
// Have each block handle a chunk of elements
const int chunks_per_batch = 1 + (E - 1) / chunk_size;
const int num_chunks = N * chunks_per_batch;

for (int chunk = blockIdx.x; chunk < num_chunks; chunk += gridDim.x) {
const int batch_idx = chunk / chunks_per_batch; // batch index
const int chunk_idx = chunk % chunks_per_batch;
const int face_start_idx = chunk_idx * chunk_size;
const int elem_chunk_start_idx = chunk_idx * chunk_size;

binmask.block_clear();
const int64_t mesh_face_start_idx = mesh_to_face_first_idx[batch_idx];
const int64_t mesh_face_stop_idx =
mesh_face_start_idx + num_faces_per_mesh[batch_idx];
const int64_t elem_start_idx = elem_first_idxs[batch_idx];
const int64_t elem_stop_idx = elem_start_idx + elems_per_batch[batch_idx];

// Have each thread handle a different face within the chunk
for (int f = threadIdx.x; f < chunk_size; f += blockDim.x) {
const int f_idx = face_start_idx + f;
for (int e = threadIdx.x; e < chunk_size; e += blockDim.x) {
const int e_idx = elem_chunk_start_idx + e;

// Check if face index corresponds to the mesh in the batch given by
// batch_idx
if (f_idx >= mesh_face_stop_idx || f_idx < mesh_face_start_idx) {
// Check that we are still within the same element of the batch
if (e_idx >= elem_stop_idx || e_idx < elem_start_idx) {
continue;
}

// Get xyz coordinates of the three face vertices.
const auto v012 = GetSingleFaceVerts(face_verts, f_idx);
const float3 v0 = thrust::get<0>(v012);
const float3 v1 = thrust::get<1>(v012);
const float3 v2 = thrust::get<2>(v012);

// Compute screen-space bbox for the triangle expanded by blur.
float xmin = FloatMin3(v0.x, v1.x, v2.x) - sqrt(blur_radius);
float ymin = FloatMin3(v0.y, v1.y, v2.y) - sqrt(blur_radius);
float xmax = FloatMax3(v0.x, v1.x, v2.x) + sqrt(blur_radius);
float ymax = FloatMax3(v0.y, v1.y, v2.y) + sqrt(blur_radius);
float zmin = FloatMin3(v0.z, v1.z, v2.z);

// Faces with at least one vertex behind the camera won't render
// correctly and should be removed or clipped before calling the
// rasterizer
if (zmin < kEpsilon) {
if (should_skip[e_idx]) {
continue;
}
const float xmin = bboxes[0 * E + e_idx];
const float xmax = bboxes[1 * E + e_idx];
const float ymin = bboxes[2 * E + e_idx];
const float ymax = bboxes[3 * E + e_idx];

// Brute-force search over all bins; TODO(T54294966) something smarter.
for (int by = 0; by < num_bins_y; ++by) {
Expand All @@ -141,39 +138,39 @@ __global__ void RasterizeMeshesCoarseCudaKernel(

const bool x_overlap = (xmin <= bin_x_max) && (bin_x_min < xmax);
if (y_overlap && x_overlap) {
binmask.set(by, bx, f);
binmask.set(by, bx, e);
}
}
}
}
__syncthreads();
// Now we have processed every face in the current chunk. We need to
// count the number of faces in each bin so we can write the indices
// Now we have processed every elem in the current chunk. We need to
// count the number of elems in each bin so we can write the indices
// out to global memory. We have each thread handle a different bin.
for (int byx = threadIdx.x; byx < num_bins_y * num_bins_x;
byx += blockDim.x) {
const int by = byx / num_bins_x;
const int bx = byx % num_bins_x;
const int count = binmask.count(by, bx);
const int faces_per_bin_idx =
const int elems_per_bin_idx =
batch_idx * num_bins_y * num_bins_x + by * num_bins_x + bx;

// This atomically increments the (global) number of faces found
// This atomically increments the (global) number of elems found
// in the current bin, and gets the previous value of the counter;
// this effectively allocates space in the bin_faces array for the
// faces in the current chunk that fall into this bin.
const int start = atomicAdd(faces_per_bin + faces_per_bin_idx, count);
// elems in the current chunk that fall into this bin.
const int start = atomicAdd(elems_per_bin + elems_per_bin_idx, count);

// Now loop over the binmask and write the active bits for this bin
// out to bin_faces.
int next_idx = batch_idx * num_bins_y * num_bins_x * M +
by * num_bins_x * M + bx * M + start;
for (int f = 0; f < chunk_size; ++f) {
if (binmask.get(by, bx, f)) {
for (int e = 0; e < chunk_size; ++e) {
if (binmask.get(by, bx, e)) {
// TODO(T54296346) find the correct method for handling errors in
// CUDA. Throw an error if num_faces_per_bin > max_faces_per_bin.
// Either decrease bin size or increase max_faces_per_bin
bin_faces[next_idx] = face_start_idx + f;
bin_elems[next_idx] = elem_chunk_start_idx + e;
next_idx++;
}
}
Expand All @@ -182,6 +179,69 @@ __global__ void RasterizeMeshesCoarseCudaKernel(
}
}

at::Tensor RasterizeCoarseCuda(
const at::Tensor& bboxes,
const at::Tensor& should_skip,
const at::Tensor& elem_first_idxs,
const at::Tensor& elems_per_batch,
const std::tuple<int, int> image_size,
const int bin_size,
const int max_elems_per_bin) {
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(bboxes.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

const int H = std::get<0>(image_size);
const int W = std::get<1>(image_size);

const int E = bboxes.size(1);
const int N = elems_per_batch.size(0);
const int M = max_elems_per_bin;

// Integer divide round up
const int num_bins_y = 1 + (H - 1) / bin_size;
const int num_bins_x = 1 + (W - 1) / bin_size;

if (num_bins_y >= kMaxItemsPerBin || num_bins_x >= kMaxItemsPerBin) {
std::stringstream ss;
ss << "In RasterizeCoarseCuda got num_bins_y: " << num_bins_y
<< ", num_bins_x: " << num_bins_x << ", "
<< "; that's too many!";
AT_ERROR(ss.str());
}
auto opts = elems_per_batch.options().dtype(at::kInt);
at::Tensor elems_per_bin = at::zeros({N, num_bins_y, num_bins_x}, opts);
at::Tensor bin_elems = at::full({N, num_bins_y, num_bins_x, M}, -1, opts);

if (bin_elems.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return bin_elems;
}

const int chunk_size = 512;
const size_t shared_size = num_bins_y * num_bins_x * chunk_size / 8;
const size_t blocks = 64;
const size_t threads = 512;

RasterizeCoarseCudaKernel<<<blocks, threads, shared_size, stream>>>(
bboxes.contiguous().data_ptr<float>(),
should_skip.contiguous().data_ptr<bool>(),
elem_first_idxs.contiguous().data_ptr<int64_t>(),
elems_per_batch.contiguous().data_ptr<int64_t>(),
N,
E,
H,
W,
bin_size,
chunk_size,
M,
elems_per_bin.data_ptr<int32_t>(),
bin_elems.data_ptr<int32_t>());

AT_CUDA_CHECK(cudaGetLastError());
return bin_elems;
}

__global__ void RasterizePointsCoarseCudaKernel(
const float* points, // (P, 3)
const int64_t* cloud_to_packed_first_idx, // (N)
Expand Down Expand Up @@ -352,55 +412,32 @@ at::Tensor RasterizeMeshesCoarseCuda(
at::cuda::CUDAGuard device_guard(face_verts.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

const int H = std::get<0>(image_size);
const int W = std::get<1>(image_size);

// Allocate tensors for bboxes and should_skip
const int F = face_verts.size(0);
const int N = num_faces_per_mesh.size(0);
const int M = max_faces_per_bin;

// Integer divide round up.
const int num_bins_y = 1 + (H - 1) / bin_size;
const int num_bins_x = 1 + (W - 1) / bin_size;

if (num_bins_y >= kMaxItemsPerBin || num_bins_x >= kMaxItemsPerBin) {
std::stringstream ss;
ss << "In Coarse Rasterizer got num_bins_y: " << num_bins_y
<< ", num_bins_x: " << num_bins_x << ", "
<< "; that's too many!";
AT_ERROR(ss.str());
}
auto opts = num_faces_per_mesh.options().dtype(at::kInt);
at::Tensor faces_per_bin = at::zeros({N, num_bins_y, num_bins_x}, opts);
at::Tensor bin_faces = at::full({N, num_bins_y, num_bins_x, M}, -1, opts);

if (bin_faces.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return bin_faces;
}

const int chunk_size = 512;
const size_t shared_size = num_bins_y * num_bins_x * chunk_size / 8;
const size_t blocks = 64;
const size_t threads = 512;

RasterizeMeshesCoarseCudaKernel<<<blocks, threads, shared_size, stream>>>(
auto float_opts = face_verts.options().dtype(at::kFloat);
auto bool_opts = face_verts.options().dtype(at::kBool);
at::Tensor bboxes = at::empty({4, F}, float_opts);
at::Tensor should_skip = at::empty({F}, bool_opts);

// Launch kernel to compute triangle bboxes
const size_t blocks = 128;
const size_t threads = 256;
TriangleBoundingBoxKernel<<<blocks, threads, 0, stream>>>(
face_verts.contiguous().data_ptr<float>(),
mesh_to_face_first_idx.contiguous().data_ptr<int64_t>(),
num_faces_per_mesh.contiguous().data_ptr<int64_t>(),
blur_radius,
N,
F,
H,
W,
bin_size,
chunk_size,
M,
faces_per_bin.data_ptr<int32_t>(),
bin_faces.data_ptr<int32_t>());

blur_radius,
bboxes.contiguous().data_ptr<float>(),
should_skip.contiguous().data_ptr<bool>());
AT_CUDA_CHECK(cudaGetLastError());
return bin_faces;

return RasterizeCoarseCuda(
bboxes,
should_skip,
mesh_to_face_first_idx,
num_faces_per_mesh,
image_size,
bin_size,
max_faces_per_bin);
}

at::Tensor RasterizePointsCoarseCuda(
Expand Down

0 comments on commit eed68f4

Please sign in to comment.