Skip to content

Commit

Permalink
Multithread CPU naive mesh rasterization
Browse files Browse the repository at this point in the history
Summary:
Threaded the for loop:
```
for (int yi = 0; yi < H; ++yi) {...}
```
in function `RasterizeMeshesNaiveCpu()`.
Chunk size is approx equal.

Reviewed By: bottler

Differential Revision: D40063604

fbshipit-source-id: 09150269405538119b0f1b029892179501421e68
  • Loading branch information
Gavin Peng authored and facebook-github-bot committed Oct 6, 2022
1 parent 37bd280 commit 6471893
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 47 deletions.
150 changes: 106 additions & 44 deletions pytorch3d/csrc/rasterize_meshes/rasterize_meshes_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
#include <algorithm>
#include <list>
#include <queue>
#include <thread>
#include <tuple>
#include "ATen/core/TensorAccessor.h"
#include "rasterize_points/rasterization_utils.h"
#include "utils/geometry_utils.h"
#include "utils/vec2.h"
Expand Down Expand Up @@ -117,54 +119,28 @@ struct IsNeighbor {
int neighbor_idx;
};

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeMeshesNaiveCpu(
const torch::Tensor& face_verts,
namespace {
void RasterizeMeshesNaiveCpu_worker(
const int start_yi,
const int end_yi,
const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh,
const torch::Tensor& clipped_faces_neighbor_idx,
const std::tuple<int, int> image_size,
const float blur_radius,
const int faces_per_pixel,
const bool perspective_correct,
const bool clip_barycentric_coords,
const bool cull_backfaces) {
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
face_verts.size(2) != 3) {
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
}
if (num_faces_per_mesh.size(0) != mesh_to_face_first_idx.size(0)) {
AT_ERROR(
"num_faces_per_mesh must have save size first dimension as mesh_to_face_first_idx");
}

const int32_t N = mesh_to_face_first_idx.size(0); // batch_size.
const int H = std::get<0>(image_size);
const int W = std::get<1>(image_size);
const int K = faces_per_pixel;

auto long_opts = num_faces_per_mesh.options().dtype(torch::kInt64);
auto float_opts = face_verts.options().dtype(torch::kFloat32);

// Initialize output tensors.
torch::Tensor face_idxs = torch::full({N, H, W, K}, -1, long_opts);
torch::Tensor zbuf = torch::full({N, H, W, K}, -1, float_opts);
torch::Tensor pix_dists = torch::full({N, H, W, K}, -1, float_opts);
torch::Tensor barycentric_coords =
torch::full({N, H, W, K, 3}, -1, float_opts);

auto face_verts_a = face_verts.accessor<float, 3>();
auto face_idxs_a = face_idxs.accessor<int64_t, 4>();
auto zbuf_a = zbuf.accessor<float, 4>();
auto pix_dists_a = pix_dists.accessor<float, 4>();
auto barycentric_coords_a = barycentric_coords.accessor<float, 5>();
auto neighbor_idx_a = clipped_faces_neighbor_idx.accessor<int64_t, 1>();

auto face_bboxes = ComputeFaceBoundingBoxes(face_verts);
auto face_bboxes_a = face_bboxes.accessor<float, 2>();
auto face_areas = ComputeFaceAreas(face_verts);
auto face_areas_a = face_areas.accessor<float, 1>();

const bool cull_backfaces,
const int32_t N,
const int H,
const int W,
const int K,
at::TensorAccessor<float, 3>& face_verts_a,
at::TensorAccessor<float, 1>& face_areas_a,
at::TensorAccessor<float, 2>& face_bboxes_a,
at::TensorAccessor<int64_t, 1>& neighbor_idx_a,
at::TensorAccessor<float, 4>& zbuf_a,
at::TensorAccessor<int64_t, 4>& face_idxs_a,
at::TensorAccessor<float, 4>& pix_dists_a,
at::TensorAccessor<float, 5>& barycentric_coords_a) {
for (int n = 0; n < N; ++n) {
// Loop through each mesh in the batch.
// Get the start index of the faces in faces_packed and the num faces
Expand All @@ -174,7 +150,7 @@ RasterizeMeshesNaiveCpu(
(face_start_idx + num_faces_per_mesh[n].item().to<int32_t>());

// Iterate through the horizontal lines of the image from top to bottom.
for (int yi = 0; yi < H; ++yi) {
for (int yi = start_yi; yi < end_yi; ++yi) {
// Reverse the order of yi so that +Y is pointing upwards in the image.
const int yidx = H - 1 - yi;

Expand Down Expand Up @@ -324,6 +300,92 @@ RasterizeMeshesNaiveCpu(
}
}
}
}
} // namespace

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeMeshesNaiveCpu(
const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh,
const torch::Tensor& clipped_faces_neighbor_idx,
const std::tuple<int, int> image_size,
const float blur_radius,
const int faces_per_pixel,
const bool perspective_correct,
const bool clip_barycentric_coords,
const bool cull_backfaces) {
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
face_verts.size(2) != 3) {
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
}
if (num_faces_per_mesh.size(0) != mesh_to_face_first_idx.size(0)) {
AT_ERROR(
"num_faces_per_mesh must have save size first dimension as mesh_to_face_first_idx");
}

const int32_t N = mesh_to_face_first_idx.size(0); // batch_size.
const int H = std::get<0>(image_size);
const int W = std::get<1>(image_size);
const int K = faces_per_pixel;

auto long_opts = num_faces_per_mesh.options().dtype(torch::kInt64);
auto float_opts = face_verts.options().dtype(torch::kFloat32);

// Initialize output tensors.
torch::Tensor face_idxs = torch::full({N, H, W, K}, -1, long_opts);
torch::Tensor zbuf = torch::full({N, H, W, K}, -1, float_opts);
torch::Tensor pix_dists = torch::full({N, H, W, K}, -1, float_opts);
torch::Tensor barycentric_coords =
torch::full({N, H, W, K, 3}, -1, float_opts);

auto face_verts_a = face_verts.accessor<float, 3>();
auto face_idxs_a = face_idxs.accessor<int64_t, 4>();
auto zbuf_a = zbuf.accessor<float, 4>();
auto pix_dists_a = pix_dists.accessor<float, 4>();
auto barycentric_coords_a = barycentric_coords.accessor<float, 5>();
auto neighbor_idx_a = clipped_faces_neighbor_idx.accessor<int64_t, 1>();

auto face_bboxes = ComputeFaceBoundingBoxes(face_verts);
auto face_bboxes_a = face_bboxes.accessor<float, 2>();
auto face_areas = ComputeFaceAreas(face_verts);
auto face_areas_a = face_areas.accessor<float, 1>();

const int64_t n_threads = at::get_num_threads();
std::vector<std::thread> threads;
threads.reserve(n_threads);
const int chunk_size = 1 + (H - 1) / n_threads;
int start_yi = 0;
for (int iThread = 0; iThread < n_threads; ++iThread) {
const int64_t end_yi = std::min(start_yi + chunk_size, H);
threads.emplace_back(
RasterizeMeshesNaiveCpu_worker,
start_yi,
end_yi,
mesh_to_face_first_idx,
num_faces_per_mesh,
blur_radius,
perspective_correct,
clip_barycentric_coords,
cull_backfaces,
N,
H,
W,
K,
std::ref(face_verts_a),
std::ref(face_areas_a),
std::ref(face_bboxes_a),
std::ref(neighbor_idx_a),
std::ref(zbuf_a),
std::ref(face_idxs_a),
std::ref(pix_dists_a),
std::ref(barycentric_coords_a));
start_yi += chunk_size;
}
for (auto&& thread : threads) {
thread.join();
}

return std::make_tuple(face_idxs, zbuf, barycentric_coords, pix_dists);
}

Expand Down
6 changes: 4 additions & 2 deletions tests/benchmarks/bm_rasterize_meshes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import os
from itertools import product

import torch
from fvcore.common.benchmark import benchmark
from tests.test_rasterize_meshes import TestRasterizeMeshes

BM_RASTERIZE_MESHES_N_THREADS = os.getenv("BM_RASTERIZE_MESHES_N_THREADS", 1)
torch.set_num_threads(int(BM_RASTERIZE_MESHES_N_THREADS))

# ico levels:
# 0: (12 verts, 20 faces)
Expand Down Expand Up @@ -41,7 +43,7 @@ def bm_rasterize_meshes() -> None:
kwargs_list = []
num_meshes = [1]
ico_level = [1]
image_size = [64, 128]
image_size = [64, 128, 512]
blur = [1e-6]
faces_per_pixel = [3, 50]
test_cases = product(num_meshes, ico_level, image_size, blur, faces_per_pixel)
Expand Down
12 changes: 11 additions & 1 deletion tests/test_rasterize_meshes.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,24 @@ def test_simple_python(self):
self._test_barycentric_clipping(rasterize_meshes_python, device, bin_size=-1)
self._test_back_face_culling(rasterize_meshes_python, device, bin_size=-1)

def test_simple_cpu_naive(self):
def _test_simple_cpu_naive_instance(self):
device = torch.device("cpu")
self._simple_triangle_raster(rasterize_meshes, device, bin_size=0)
self._simple_blurry_raster(rasterize_meshes, device, bin_size=0)
self._test_behind_camera(rasterize_meshes, device, bin_size=0)
self._test_perspective_correct(rasterize_meshes, device, bin_size=0)
self._test_back_face_culling(rasterize_meshes, device, bin_size=0)

def test_simple_cpu_naive(self):
n_threads = torch.get_num_threads()
torch.set_num_threads(1) # single threaded
self._test_simple_cpu_naive_instance()
torch.set_num_threads(4) # even (divisible) number of threads
self._test_simple_cpu_naive_instance()
torch.set_num_threads(5) # odd (nondivisible) number of threads
self._test_simple_cpu_naive_instance()
torch.set_num_threads(n_threads)

def test_simple_cuda_naive(self):
device = get_random_cuda_device()
self._simple_triangle_raster(rasterize_meshes, device, bin_size=0)
Expand Down

0 comments on commit 6471893

Please sign in to comment.