From 6471893f59f2c844c844be949bb64cccdcc5fdaa Mon Sep 17 00:00:00 2001 From: Gavin Peng Date: Thu, 6 Oct 2022 06:42:58 -0700 Subject: [PATCH] Multithread CPU naive mesh rasterization Summary: Threaded the for loop: ``` for (int yi = 0; yi < H; ++yi) {...} ``` in function `RasterizeMeshesNaiveCpu()`. Chunk size is approx equal. Reviewed By: bottler Differential Revision: D40063604 fbshipit-source-id: 09150269405538119b0f1b029892179501421e68 --- .../rasterize_meshes/rasterize_meshes_cpu.cpp | 150 +++++++++++++----- tests/benchmarks/bm_rasterize_meshes.py | 6 +- tests/test_rasterize_meshes.py | 12 +- 3 files changed, 121 insertions(+), 47 deletions(-) diff --git a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes_cpu.cpp b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes_cpu.cpp index 5a20df126..210df55e4 100644 --- a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes_cpu.cpp +++ b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes_cpu.cpp @@ -10,7 +10,9 @@ #include #include #include +#include #include +#include "ATen/core/TensorAccessor.h" #include "rasterize_points/rasterization_utils.h" #include "utils/geometry_utils.h" #include "utils/vec2.h" @@ -117,54 +119,28 @@ struct IsNeighbor { int neighbor_idx; }; -std::tuple -RasterizeMeshesNaiveCpu( - const torch::Tensor& face_verts, +namespace { +void RasterizeMeshesNaiveCpu_worker( + const int start_yi, + const int end_yi, const torch::Tensor& mesh_to_face_first_idx, const torch::Tensor& num_faces_per_mesh, - const torch::Tensor& clipped_faces_neighbor_idx, - const std::tuple image_size, const float blur_radius, - const int faces_per_pixel, const bool perspective_correct, const bool clip_barycentric_coords, - const bool cull_backfaces) { - if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 || - face_verts.size(2) != 3) { - AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)"); - } - if (num_faces_per_mesh.size(0) != mesh_to_face_first_idx.size(0)) { - AT_ERROR( - "num_faces_per_mesh must have save size first dimension as mesh_to_face_first_idx"); - } - - const int32_t N = mesh_to_face_first_idx.size(0); // batch_size. - const int H = std::get<0>(image_size); - const int W = std::get<1>(image_size); - const int K = faces_per_pixel; - - auto long_opts = num_faces_per_mesh.options().dtype(torch::kInt64); - auto float_opts = face_verts.options().dtype(torch::kFloat32); - - // Initialize output tensors. - torch::Tensor face_idxs = torch::full({N, H, W, K}, -1, long_opts); - torch::Tensor zbuf = torch::full({N, H, W, K}, -1, float_opts); - torch::Tensor pix_dists = torch::full({N, H, W, K}, -1, float_opts); - torch::Tensor barycentric_coords = - torch::full({N, H, W, K, 3}, -1, float_opts); - - auto face_verts_a = face_verts.accessor(); - auto face_idxs_a = face_idxs.accessor(); - auto zbuf_a = zbuf.accessor(); - auto pix_dists_a = pix_dists.accessor(); - auto barycentric_coords_a = barycentric_coords.accessor(); - auto neighbor_idx_a = clipped_faces_neighbor_idx.accessor(); - - auto face_bboxes = ComputeFaceBoundingBoxes(face_verts); - auto face_bboxes_a = face_bboxes.accessor(); - auto face_areas = ComputeFaceAreas(face_verts); - auto face_areas_a = face_areas.accessor(); - + const bool cull_backfaces, + const int32_t N, + const int H, + const int W, + const int K, + at::TensorAccessor& face_verts_a, + at::TensorAccessor& face_areas_a, + at::TensorAccessor& face_bboxes_a, + at::TensorAccessor& neighbor_idx_a, + at::TensorAccessor& zbuf_a, + at::TensorAccessor& face_idxs_a, + at::TensorAccessor& pix_dists_a, + at::TensorAccessor& barycentric_coords_a) { for (int n = 0; n < N; ++n) { // Loop through each mesh in the batch. // Get the start index of the faces in faces_packed and the num faces @@ -174,7 +150,7 @@ RasterizeMeshesNaiveCpu( (face_start_idx + num_faces_per_mesh[n].item().to()); // Iterate through the horizontal lines of the image from top to bottom. - for (int yi = 0; yi < H; ++yi) { + for (int yi = start_yi; yi < end_yi; ++yi) { // Reverse the order of yi so that +Y is pointing upwards in the image. const int yidx = H - 1 - yi; @@ -324,6 +300,92 @@ RasterizeMeshesNaiveCpu( } } } +} +} // namespace + +std::tuple +RasterizeMeshesNaiveCpu( + const torch::Tensor& face_verts, + const torch::Tensor& mesh_to_face_first_idx, + const torch::Tensor& num_faces_per_mesh, + const torch::Tensor& clipped_faces_neighbor_idx, + const std::tuple image_size, + const float blur_radius, + const int faces_per_pixel, + const bool perspective_correct, + const bool clip_barycentric_coords, + const bool cull_backfaces) { + if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 || + face_verts.size(2) != 3) { + AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)"); + } + if (num_faces_per_mesh.size(0) != mesh_to_face_first_idx.size(0)) { + AT_ERROR( + "num_faces_per_mesh must have save size first dimension as mesh_to_face_first_idx"); + } + + const int32_t N = mesh_to_face_first_idx.size(0); // batch_size. + const int H = std::get<0>(image_size); + const int W = std::get<1>(image_size); + const int K = faces_per_pixel; + + auto long_opts = num_faces_per_mesh.options().dtype(torch::kInt64); + auto float_opts = face_verts.options().dtype(torch::kFloat32); + + // Initialize output tensors. + torch::Tensor face_idxs = torch::full({N, H, W, K}, -1, long_opts); + torch::Tensor zbuf = torch::full({N, H, W, K}, -1, float_opts); + torch::Tensor pix_dists = torch::full({N, H, W, K}, -1, float_opts); + torch::Tensor barycentric_coords = + torch::full({N, H, W, K, 3}, -1, float_opts); + + auto face_verts_a = face_verts.accessor(); + auto face_idxs_a = face_idxs.accessor(); + auto zbuf_a = zbuf.accessor(); + auto pix_dists_a = pix_dists.accessor(); + auto barycentric_coords_a = barycentric_coords.accessor(); + auto neighbor_idx_a = clipped_faces_neighbor_idx.accessor(); + + auto face_bboxes = ComputeFaceBoundingBoxes(face_verts); + auto face_bboxes_a = face_bboxes.accessor(); + auto face_areas = ComputeFaceAreas(face_verts); + auto face_areas_a = face_areas.accessor(); + + const int64_t n_threads = at::get_num_threads(); + std::vector threads; + threads.reserve(n_threads); + const int chunk_size = 1 + (H - 1) / n_threads; + int start_yi = 0; + for (int iThread = 0; iThread < n_threads; ++iThread) { + const int64_t end_yi = std::min(start_yi + chunk_size, H); + threads.emplace_back( + RasterizeMeshesNaiveCpu_worker, + start_yi, + end_yi, + mesh_to_face_first_idx, + num_faces_per_mesh, + blur_radius, + perspective_correct, + clip_barycentric_coords, + cull_backfaces, + N, + H, + W, + K, + std::ref(face_verts_a), + std::ref(face_areas_a), + std::ref(face_bboxes_a), + std::ref(neighbor_idx_a), + std::ref(zbuf_a), + std::ref(face_idxs_a), + std::ref(pix_dists_a), + std::ref(barycentric_coords_a)); + start_yi += chunk_size; + } + for (auto&& thread : threads) { + thread.join(); + } + return std::make_tuple(face_idxs, zbuf, barycentric_coords, pix_dists); } diff --git a/tests/benchmarks/bm_rasterize_meshes.py b/tests/benchmarks/bm_rasterize_meshes.py index da9817116..0a6531a38 100644 --- a/tests/benchmarks/bm_rasterize_meshes.py +++ b/tests/benchmarks/bm_rasterize_meshes.py @@ -4,13 +4,15 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - +import os from itertools import product import torch from fvcore.common.benchmark import benchmark from tests.test_rasterize_meshes import TestRasterizeMeshes +BM_RASTERIZE_MESHES_N_THREADS = os.getenv("BM_RASTERIZE_MESHES_N_THREADS", 1) +torch.set_num_threads(int(BM_RASTERIZE_MESHES_N_THREADS)) # ico levels: # 0: (12 verts, 20 faces) @@ -41,7 +43,7 @@ def bm_rasterize_meshes() -> None: kwargs_list = [] num_meshes = [1] ico_level = [1] - image_size = [64, 128] + image_size = [64, 128, 512] blur = [1e-6] faces_per_pixel = [3, 50] test_cases = product(num_meshes, ico_level, image_size, blur, faces_per_pixel) diff --git a/tests/test_rasterize_meshes.py b/tests/test_rasterize_meshes.py index 2e858961b..48738ca44 100644 --- a/tests/test_rasterize_meshes.py +++ b/tests/test_rasterize_meshes.py @@ -35,7 +35,7 @@ def test_simple_python(self): self._test_barycentric_clipping(rasterize_meshes_python, device, bin_size=-1) self._test_back_face_culling(rasterize_meshes_python, device, bin_size=-1) - def test_simple_cpu_naive(self): + def _test_simple_cpu_naive_instance(self): device = torch.device("cpu") self._simple_triangle_raster(rasterize_meshes, device, bin_size=0) self._simple_blurry_raster(rasterize_meshes, device, bin_size=0) @@ -43,6 +43,16 @@ def test_simple_cpu_naive(self): self._test_perspective_correct(rasterize_meshes, device, bin_size=0) self._test_back_face_culling(rasterize_meshes, device, bin_size=0) + def test_simple_cpu_naive(self): + n_threads = torch.get_num_threads() + torch.set_num_threads(1) # single threaded + self._test_simple_cpu_naive_instance() + torch.set_num_threads(4) # even (divisible) number of threads + self._test_simple_cpu_naive_instance() + torch.set_num_threads(5) # odd (nondivisible) number of threads + self._test_simple_cpu_naive_instance() + torch.set_num_threads(n_threads) + def test_simple_cuda_naive(self): device = get_random_cuda_device() self._simple_triangle_raster(rasterize_meshes, device, bin_size=0)