diff --git a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu index 47f0664ea..067826d46 100644 --- a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu +++ b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu @@ -111,7 +111,8 @@ __device__ void CheckPixelInsideFace( const float blur_radius, const float2 pxy, // Coordinates of the pixel const int K, - const bool perspective_correct) { + const bool perspective_correct, + const bool cull_backfaces) { const auto v012 = GetSingleFaceVerts(face_verts, face_idx); const float3 v0 = thrust::get<0>(v012); const float3 v1 = thrust::get<1>(v012); @@ -124,16 +125,20 @@ __device__ void CheckPixelInsideFace( // Perform checks and skip if: // 1. the face is behind the camera - // 2. the face has very small face area - // 3. the pixel is outside the face bbox + // 2. the face is facing away from the camera + // 3. the face has very small face area + // 4. the pixel is outside the face bbox const float zmax = FloatMax3(v0.z, v1.z, v2.z); const bool outside_bbox = CheckPointOutsideBoundingBox( v0, v1, v2, sqrt(blur_radius), pxy); // use sqrt of blur for bbox const float face_area = EdgeFunctionForward(v0xy, v1xy, v2xy); + // Check if the face is visible to the camera. + const bool back_face = face_area < 0.0; const bool zero_face_area = (face_area <= kEpsilon && face_area >= -1.0f * kEpsilon); - if (zmax < 0 || outside_bbox || zero_face_area) { + if (zmax < 0 || cull_backfaces && back_face || outside_bbox || + zero_face_area) { return; } @@ -191,6 +196,7 @@ __global__ void RasterizeMeshesNaiveCudaKernel( const int64_t* num_faces_per_mesh, const float blur_radius, const bool perspective_correct, + const bool cull_backfaces, const int N, const int H, const int W, @@ -251,7 +257,8 @@ __global__ void RasterizeMeshesNaiveCudaKernel( blur_radius, pxy, K, - perspective_correct); + perspective_correct, + cull_backfaces); } // TODO: make sorting an option as only top k is needed, not sorted values. @@ -276,7 +283,8 @@ RasterizeMeshesNaiveCuda( const int image_size, const float blur_radius, const int num_closest, - const bool perspective_correct) { + const bool perspective_correct, + const bool cull_backfaces) { if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 || face_verts.size(2) != 3) { AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)"); @@ -314,6 +322,7 @@ RasterizeMeshesNaiveCuda( num_faces_per_mesh.contiguous().data_ptr(), blur_radius, perspective_correct, + cull_backfaces, N, H, W, @@ -667,6 +676,7 @@ __global__ void RasterizeMeshesFineCudaKernel( const float blur_radius, const int bin_size, const bool perspective_correct, + const bool cull_backfaces, const int N, const int B, const int M, @@ -730,7 +740,8 @@ __global__ void RasterizeMeshesFineCudaKernel( blur_radius, pxy, K, - perspective_correct); + perspective_correct, + cull_backfaces); } // Now we've looked at all the faces for this bin, so we can write @@ -762,7 +773,8 @@ RasterizeMeshesFineCuda( const float blur_radius, const int bin_size, const int faces_per_pixel, - const bool perspective_correct) { + const bool perspective_correct, + const bool cull_backfaces) { if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 || face_verts.size(2) != 3) { AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)"); @@ -797,6 +809,7 @@ RasterizeMeshesFineCuda( blur_radius, bin_size, perspective_correct, + cull_backfaces, N, B, M, diff --git a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.h b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.h index 1131d9863..210cecbfb 100644 --- a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.h +++ b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.h @@ -17,7 +17,8 @@ RasterizeMeshesNaiveCpu( const int image_size, const float blur_radius, const int faces_per_pixel, - const bool perspective_correct); + const bool perspective_correct, + const bool cull_backfaces); #ifdef WITH_CUDA std::tuple @@ -28,7 +29,8 @@ RasterizeMeshesNaiveCuda( const int image_size, const float blur_radius, const int num_closest, - const bool perspective_correct); + const bool perspective_correct, + const bool cull_backfaces); #endif // Forward pass for rasterizing a batch of meshes. // @@ -55,6 +57,14 @@ RasterizeMeshesNaiveCuda( // coordinates for each pixel; if this is False then // this function instead returns screen-space // barycentric coordinates for each pixel. +// cull_backfaces: Bool, Whether to only rasterize mesh faces which are +// visible to the camera. This assumes that vertices of +// front-facing triangles are ordered in an anti-clockwise +// fashion, and triangles that face away from the camera are +// in a clockwise order relative to the current view +// direction. NOTE: This will only work if the mesh faces are +// consistently defined with counter-clockwise ordering when +// viewed from the outside. // // Returns: // A 4 element tuple of: @@ -80,7 +90,8 @@ RasterizeMeshesNaive( const int image_size, const float blur_radius, const int faces_per_pixel, - const bool perspective_correct) { + const bool perspective_correct, + const bool cull_backfaces) { // TODO: Better type checking. if (face_verts.is_cuda()) { #ifdef WITH_CUDA @@ -91,7 +102,8 @@ RasterizeMeshesNaive( image_size, blur_radius, faces_per_pixel, - perspective_correct); + perspective_correct, + cull_backfaces); #else AT_ERROR("Not compiled with GPU support"); #endif @@ -103,7 +115,8 @@ RasterizeMeshesNaive( image_size, blur_radius, faces_per_pixel, - perspective_correct); + perspective_correct, + cull_backfaces); } } @@ -274,7 +287,8 @@ RasterizeMeshesFineCuda( const float blur_radius, const int bin_size, const int faces_per_pixel, - const bool perspective_correct); + const bool perspective_correct, + const bool cull_backfaces); #endif // Args: // face_verts: Tensor of shape (F, 3, 3) giving (packed) vertex positions for @@ -296,6 +310,14 @@ RasterizeMeshesFineCuda( // coordinates for each pixel; if this is False then // this function instead returns screen-space // barycentric coordinates for each pixel. +// cull_backfaces: Bool, Whether to only rasterize mesh faces which are +// visible to the camera. This assumes that vertices of +// front-facing triangles are ordered in an anti-clockwise +// fashion, and triangles that face away from the camera are +// in a clockwise order relative to the current view +// direction. NOTE: This will only work if the mesh faces are +// consistently defined with counter-clockwise ordering when +// viewed from the outside. // // Returns (same as rasterize_meshes): // A 4 element tuple of: @@ -321,7 +343,8 @@ RasterizeMeshesFine( const float blur_radius, const int bin_size, const int faces_per_pixel, - const bool perspective_correct) { + const bool perspective_correct, + const bool cull_backfaces) { if (face_verts.is_cuda()) { #ifdef WITH_CUDA return RasterizeMeshesFineCuda( @@ -331,7 +354,8 @@ RasterizeMeshesFine( blur_radius, bin_size, faces_per_pixel, - perspective_correct); + perspective_correct, + cull_backfaces); #else AT_ERROR("Not compiled with GPU support"); #endif @@ -372,7 +396,14 @@ RasterizeMeshesFine( // coordinates for each pixel; if this is False then // this function instead returns screen-space // barycentric coordinates for each pixel. -// +// cull_backfaces: Bool, Whether to only rasterize mesh faces which are +// visible to the camera. This assumes that vertices of +// front-facing triangles are ordered in an anti-clockwise +// fashion, and triangles that face away from the camera are +// in a clockwise order relative to the current view +// direction. NOTE: This will only work if the mesh faces are +// consistently defined with counter-clockwise ordering when +// viewed from the outside. // // Returns: // A 4 element tuple of: @@ -400,7 +431,8 @@ RasterizeMeshes( const int faces_per_pixel, const int bin_size, const int max_faces_per_bin, - const bool perspective_correct) { + const bool perspective_correct, + const bool cull_backfaces) { if (bin_size > 0 && max_faces_per_bin > 0) { // Use coarse-to-fine rasterization auto bin_faces = RasterizeMeshesCoarse( @@ -418,7 +450,8 @@ RasterizeMeshes( blur_radius, bin_size, faces_per_pixel, - perspective_correct); + perspective_correct, + cull_backfaces); } else { // Use the naive per-pixel implementation return RasterizeMeshesNaive( @@ -428,6 +461,7 @@ RasterizeMeshes( image_size, blur_radius, faces_per_pixel, - perspective_correct); + perspective_correct, + cull_backfaces); } } diff --git a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes_cpu.cpp b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes_cpu.cpp index dd810d355..a4d28afd2 100644 --- a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes_cpu.cpp +++ b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes_cpu.cpp @@ -107,7 +107,8 @@ RasterizeMeshesNaiveCpu( int image_size, const float blur_radius, const int faces_per_pixel, - const bool perspective_correct) { + const bool perspective_correct, + const bool cull_backfaces) { if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 || face_verts.size(2) != 3) { AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)"); @@ -184,8 +185,13 @@ RasterizeMeshesNaiveCpu( const vec2 v1(x1, y1); const vec2 v2(x2, y2); - // Skip faces with zero area. const float face_area = face_areas_a[f]; + const bool back_face = face_area < 0.0; + // Check if the face is visible to the camera. + if (cull_backfaces && back_face) { + continue; + } + // Skip faces with zero area. if (face_area <= kEpsilon && face_area >= -1.0f * kEpsilon) { continue; } diff --git a/pytorch3d/io/obj_io.py b/pytorch3d/io/obj_io.py index cdf401b3d..e5fda294d 100644 --- a/pytorch3d/io/obj_io.py +++ b/pytorch3d/io/obj_io.py @@ -140,16 +140,16 @@ def load_obj(f_obj, load_textures=True): If there are faces with more than 3 vertices they are subdivided into triangles. Polygonal faces are assummed to have vertices ordered counter-clockwise so the (right-handed) normal points - into the screen e.g. a proper rectangular face would be specified like this: + out of the screen e.g. a proper rectangular face would be specified like this: :: 0_________1 | | | | 3 ________2 - The face would be split into two triangles: (0, 1, 2) and (0, 2, 3), - both of which are also oriented clockwise and have normals - pointing into the screen. + The face would be split into two triangles: (0, 2, 1) and (0, 3, 2), + both of which are also oriented counter-clockwise and have normals + pointing out of the screen. Args: f: A file-like object (with methods read, readline, tell, and seek), diff --git a/pytorch3d/renderer/mesh/rasterize_meshes.py b/pytorch3d/renderer/mesh/rasterize_meshes.py index e72a9596a..67cb8a79a 100644 --- a/pytorch3d/renderer/mesh/rasterize_meshes.py +++ b/pytorch3d/renderer/mesh/rasterize_meshes.py @@ -20,6 +20,7 @@ def rasterize_meshes( bin_size: Optional[int] = None, max_faces_per_bin: Optional[int] = None, perspective_correct: bool = False, + cull_backfaces: bool = False, ): """ Rasterize a batch of meshes given the shape of the desired output image. @@ -45,8 +46,16 @@ def rasterize_meshes( bin. If more than this many faces actually fall into a bin, an error will be raised. This should not affect the output values, but can affect the memory usage in the forward pass. - perspective_correct: Whether to apply perspective correction when computing + perspective_correct: Bool, Whether to apply perspective correction when computing barycentric coordinates for pixels. + cull_backfaces: Bool, Whether to only rasterize mesh faces which are + visible to the camera. This assumes that vertices of + front-facing triangles are ordered in an anti-clockwise + fashion, and triangles that face away from the camera are + in a clockwise order relative to the current view + direction. NOTE: This will only work if the mesh faces are + consistently defined with counter-clockwise ordering when + viewed from the outside. Returns: 4-element tuple containing @@ -118,6 +127,7 @@ def rasterize_meshes( bin_size, max_faces_per_bin, perspective_correct, + cull_backfaces, ) @@ -139,6 +149,7 @@ class _RasterizeFaceVerts(torch.autograd.Function): for each mesh in the batch. image_size, blur_radius, faces_per_pixel: same as rasterize_meshes. perspective_correct: same as rasterize_meshes. + cull_backfaces: same as rasterize_meshes. Returns: same as rasterize_meshes function. @@ -156,6 +167,7 @@ def forward( bin_size: int = 0, max_faces_per_bin: int = 0, perspective_correct: bool = False, + cull_backfaces: bool = False, ): pix_to_face, zbuf, barycentric_coords, dists = _C.rasterize_meshes( face_verts, @@ -167,6 +179,7 @@ def forward( bin_size, max_faces_per_bin, perspective_correct, + cull_backfaces, ) ctx.save_for_backward(face_verts, pix_to_face) ctx.perspective_correct = perspective_correct @@ -183,6 +196,7 @@ def backward(ctx, grad_pix_to_face, grad_zbuf, grad_barycentric_coords, grad_dis grad_bin_size = None grad_max_faces_per_bin = None grad_perspective_correct = None + grad_cull_backfaces = None face_verts, pix_to_face = ctx.saved_tensors grad_face_verts = _C.rasterize_meshes_backward( face_verts, @@ -202,6 +216,7 @@ def backward(ctx, grad_pix_to_face, grad_zbuf, grad_barycentric_coords, grad_dis grad_bin_size, grad_max_faces_per_bin, grad_perspective_correct, + grad_cull_backfaces, ) return grads @@ -217,6 +232,7 @@ def rasterize_meshes_python( blur_radius: float = 0.0, faces_per_pixel: int = 8, perspective_correct: bool = False, + cull_backfaces: bool = False, ): """ Naive PyTorch implementation of mesh rasterization with the same inputs and @@ -287,7 +303,12 @@ def rasterize_meshes_python( face = faces_verts[f].squeeze() v0, v1, v2 = face.unbind(0) - face_area = edge_function(v2, v0, v1) + face_area = edge_function(v0, v1, v2) + + # Ignore triangles facing away from the camera. + back_face = face_area < 0 + if cull_backfaces and back_face: + continue # Ignore faces which have zero area. if face_area == 0.0: @@ -365,8 +386,8 @@ def edge_function(p, v0, v1): .. code-block:: python - A = p - v0 - B = v1 - v0 + B = p - v0 + A = v1 - v0 v1 ________ /\ / diff --git a/pytorch3d/renderer/mesh/rasterizer.py b/pytorch3d/renderer/mesh/rasterizer.py index b995bf27e..a5c9a2e89 100644 --- a/pytorch3d/renderer/mesh/rasterizer.py +++ b/pytorch3d/renderer/mesh/rasterizer.py @@ -26,6 +26,7 @@ class RasterizationSettings: "bin_size", "max_faces_per_bin", "perspective_correct", + "cull_backfaces", ] def __init__( @@ -36,6 +37,7 @@ def __init__( bin_size: Optional[int] = None, max_faces_per_bin: Optional[int] = None, perspective_correct: bool = False, + cull_backfaces: bool = False, ): self.image_size = image_size self.blur_radius = blur_radius @@ -43,6 +45,7 @@ def __init__( self.bin_size = bin_size self.max_faces_per_bin = max_faces_per_bin self.perspective_correct = perspective_correct + self.cull_backfaces = cull_backfaces class MeshRasterizer(nn.Module): @@ -122,6 +125,7 @@ def forward(self, meshes_world, **kwargs) -> Fragments: bin_size=raster_settings.bin_size, max_faces_per_bin=raster_settings.max_faces_per_bin, perspective_correct=raster_settings.perspective_correct, + cull_backfaces=raster_settings.cull_backfaces, ) return Fragments( pix_to_face=pix_to_face, zbuf=zbuf, bary_coords=bary_coords, dists=dists diff --git a/tests/test_rasterize_meshes.py b/tests/test_rasterize_meshes.py index 69844d93a..9979a0419 100644 --- a/tests/test_rasterize_meshes.py +++ b/tests/test_rasterize_meshes.py @@ -21,6 +21,7 @@ def test_simple_python(self): self._simple_blurry_raster(rasterize_meshes_python, device, bin_size=-1) self._test_behind_camera(rasterize_meshes_python, device, bin_size=-1) self._test_perspective_correct(rasterize_meshes_python, device, bin_size=-1) + self._test_back_face_culling(rasterize_meshes_python, device, bin_size=-1) def test_simple_cpu_naive(self): device = torch.device("cpu") @@ -28,6 +29,7 @@ def test_simple_cpu_naive(self): self._simple_blurry_raster(rasterize_meshes, device, bin_size=0) self._test_behind_camera(rasterize_meshes, device, bin_size=0) self._test_perspective_correct(rasterize_meshes, device, bin_size=0) + self._test_back_face_culling(rasterize_meshes, device, bin_size=0) def test_simple_cuda_naive(self): device = torch.device("cuda:0") @@ -35,6 +37,7 @@ def test_simple_cuda_naive(self): self._simple_blurry_raster(rasterize_meshes, device, bin_size=0) self._test_behind_camera(rasterize_meshes, device, bin_size=0) self._test_perspective_correct(rasterize_meshes, device, bin_size=0) + self._test_back_face_culling(rasterize_meshes, device, bin_size=0) def test_simple_cuda_binned(self): device = torch.device("cuda:0") @@ -42,6 +45,7 @@ def test_simple_cuda_binned(self): self._simple_blurry_raster(rasterize_meshes, device, bin_size=5) self._test_behind_camera(rasterize_meshes, device, bin_size=5) self._test_perspective_correct(rasterize_meshes, device, bin_size=5) + self._test_back_face_culling(rasterize_meshes, device, bin_size=5) def test_python_vs_cpu_vs_cuda(self): torch.manual_seed(231) @@ -377,6 +381,81 @@ def test_cuda_naive_vs_binned_perspective_correct(self): args = () self._compare_impls(fn1, fn2, args, args, verts1, verts2, compare_grads=True) + def _test_back_face_culling(self, rasterize_meshes_fn, device, bin_size): + # Square based pyramid mesh. + # fmt: off + verts = torch.tensor([ + [-0.5, 0.0, 0.5], # noqa: E241 E201 Front right + [ 0.5, 0.0, 0.5], # noqa: E241 E201 Front left + [ 0.5, 0.0, 1.5], # noqa: E241 E201 Back left + [-0.5, 0.0, 1.5], # noqa: E241 E201 Back right + [ 0.0, 1.0, 1.0] # noqa: E241 E201 Top point of pyramid + ], dtype=torch.float32, device=device) + + faces = torch.tensor([ + [2, 1, 0], # noqa: E241 E201 Square base + [3, 2, 0], # noqa: E241 E201 Square base + [1, 0, 4], # noqa: E241 E201 Triangle on front + [2, 4, 3], # noqa: E241 E201 Triangle on back + [3, 4, 0], # noqa: E241 E201 Triangle on left side + [1, 4, 2] # noqa: E241 E201 Triangle on right side + ], dtype=torch.int64, device=device) + # fmt: on + mesh = Meshes(verts=[verts], faces=[faces]) + kwargs = { + "meshes": mesh, + "image_size": 10, + "faces_per_pixel": 2, + "blur_radius": 0.0, + "perspective_correct": False, + "cull_backfaces": False, + } + if bin_size != -1: + kwargs["bin_size"] = bin_size + + # fmt: off + pix_to_face_frontface = torch.tensor([ + [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201 + [-1, -1, -1, -1, 2, 2, -1, -1, -1, -1], # noqa: E241 E201 + [-1, -1, -1, -1, 2, 2, -1, -1, -1, -1], # noqa: E241 E201 + [-1, -1, -1, 2, 2, 2, 2, -1, -1, -1], # noqa: E241 E201 + [-1, -1, -1, 2, 2, 2, 2, -1, -1, -1], # noqa: E241 E201 + [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201 + [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201 + [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201 + [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201 + [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1] # noqa: E241 E201 + ], dtype=torch.int64, device=device) + pix_to_face_backface = torch.tensor([ + [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201 + [-1, -1, -1, -1, 3, 3, -1, -1, -1, -1], # noqa: E241 E201 + [-1, -1, -1, -1, 3, 3, -1, -1, -1, -1], # noqa: E241 E201 + [-1, -1, -1, 3, 3, 3, 3, -1, -1, -1], # noqa: E241 E201 + [-1, -1, -1, 3, 3, 3, 3, -1, -1, -1], # noqa: E241 E201 + [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201 + [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201 + [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201 + [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201 + [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1] # noqa: E241 E201 + ], dtype=torch.int64, device=device) + # fmt: on + + pix_to_face_padded = -torch.ones_like(pix_to_face_frontface) + # Run with and without culling + # Without culling, for k=0, the front face (i.e. face 2) is + # rasterized and for k=1, the back face (i.e. face 3) is + # rasterized. + idx_f, zbuf_f, bary_f, dists_f = rasterize_meshes_fn(**kwargs) + self.assertTrue(torch.all(idx_f[..., 0].squeeze() == pix_to_face_frontface)) + self.assertTrue(torch.all(idx_f[..., 1].squeeze() == pix_to_face_backface)) + + # With culling, for k=0, the front face (i.e. face 2) is + # rasterized and for k=1, there are no faces rasterized + kwargs["cull_backfaces"] = True + idx_t, zbuf_t, bary_t, dists_t = rasterize_meshes_fn(**kwargs) + self.assertTrue(torch.all(idx_t[..., 0].squeeze() == pix_to_face_frontface)) + self.assertTrue(torch.all(idx_t[..., 1].squeeze() == pix_to_face_padded)) + def _compare_impls( self, fn1,