From 4bf30593ffc5488a03c75a16dd118013f5d0eb5e Mon Sep 17 00:00:00 2001 From: Nikhila Ravi Date: Wed, 22 Apr 2020 08:20:16 -0700 Subject: [PATCH] back face culling in rasterization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Added backface culling as an option to the `raster_settings`. This is needed for the full forward rendering of shapenet meshes with texture (some meshes contain multiple overlapping segments which have different textures). For a triangle (v0, v1, v2) define the vectors A = (v1 - v0) and B = (v2 − v0) and use this to calculate the area of the triangle as: ``` area = 0.5 * A x B area = 0.5 * ((x1 − x0)(y2 − y0) − (x2 − x0)(y1 − y0)) ``` The area will be positive if (v0, v1, v2) are oriented counterclockwise (a front face), and negative if (v0, v1, v2) are oriented clockwise (a back face). We can reuse the `edge_function` as it already calculates the triangle area. Reviewed By: jcjohnson Differential Revision: D20960115 fbshipit-source-id: 2d8a4b9ccfb653df18e79aed8d05c7ec0f057ab1 --- .../csrc/rasterize_meshes/rasterize_meshes.cu | 29 +++++-- .../csrc/rasterize_meshes/rasterize_meshes.h | 58 +++++++++++--- .../rasterize_meshes/rasterize_meshes_cpu.cpp | 10 ++- pytorch3d/io/obj_io.py | 8 +- pytorch3d/renderer/mesh/rasterize_meshes.py | 29 ++++++- pytorch3d/renderer/mesh/rasterizer.py | 4 + tests/test_rasterize_meshes.py | 79 +++++++++++++++++++ 7 files changed, 187 insertions(+), 30 deletions(-) diff --git a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu index 47f0664ea..067826d46 100644 --- a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu +++ b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu @@ -111,7 +111,8 @@ __device__ void CheckPixelInsideFace( const float blur_radius, const float2 pxy, // Coordinates of the pixel const int K, - const bool perspective_correct) { + const bool perspective_correct, + const bool cull_backfaces) { const auto v012 = GetSingleFaceVerts(face_verts, face_idx); const float3 v0 = thrust::get<0>(v012); const float3 v1 = thrust::get<1>(v012); @@ -124,16 +125,20 @@ __device__ void CheckPixelInsideFace( // Perform checks and skip if: // 1. the face is behind the camera - // 2. the face has very small face area - // 3. the pixel is outside the face bbox + // 2. the face is facing away from the camera + // 3. the face has very small face area + // 4. the pixel is outside the face bbox const float zmax = FloatMax3(v0.z, v1.z, v2.z); const bool outside_bbox = CheckPointOutsideBoundingBox( v0, v1, v2, sqrt(blur_radius), pxy); // use sqrt of blur for bbox const float face_area = EdgeFunctionForward(v0xy, v1xy, v2xy); + // Check if the face is visible to the camera. + const bool back_face = face_area < 0.0; const bool zero_face_area = (face_area <= kEpsilon && face_area >= -1.0f * kEpsilon); - if (zmax < 0 || outside_bbox || zero_face_area) { + if (zmax < 0 || cull_backfaces && back_face || outside_bbox || + zero_face_area) { return; } @@ -191,6 +196,7 @@ __global__ void RasterizeMeshesNaiveCudaKernel( const int64_t* num_faces_per_mesh, const float blur_radius, const bool perspective_correct, + const bool cull_backfaces, const int N, const int H, const int W, @@ -251,7 +257,8 @@ __global__ void RasterizeMeshesNaiveCudaKernel( blur_radius, pxy, K, - perspective_correct); + perspective_correct, + cull_backfaces); } // TODO: make sorting an option as only top k is needed, not sorted values. @@ -276,7 +283,8 @@ RasterizeMeshesNaiveCuda( const int image_size, const float blur_radius, const int num_closest, - const bool perspective_correct) { + const bool perspective_correct, + const bool cull_backfaces) { if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 || face_verts.size(2) != 3) { AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)"); @@ -314,6 +322,7 @@ RasterizeMeshesNaiveCuda( num_faces_per_mesh.contiguous().data_ptr(), blur_radius, perspective_correct, + cull_backfaces, N, H, W, @@ -667,6 +676,7 @@ __global__ void RasterizeMeshesFineCudaKernel( const float blur_radius, const int bin_size, const bool perspective_correct, + const bool cull_backfaces, const int N, const int B, const int M, @@ -730,7 +740,8 @@ __global__ void RasterizeMeshesFineCudaKernel( blur_radius, pxy, K, - perspective_correct); + perspective_correct, + cull_backfaces); } // Now we've looked at all the faces for this bin, so we can write @@ -762,7 +773,8 @@ RasterizeMeshesFineCuda( const float blur_radius, const int bin_size, const int faces_per_pixel, - const bool perspective_correct) { + const bool perspective_correct, + const bool cull_backfaces) { if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 || face_verts.size(2) != 3) { AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)"); @@ -797,6 +809,7 @@ RasterizeMeshesFineCuda( blur_radius, bin_size, perspective_correct, + cull_backfaces, N, B, M, diff --git a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.h b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.h index 1131d9863..210cecbfb 100644 --- a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.h +++ b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.h @@ -17,7 +17,8 @@ RasterizeMeshesNaiveCpu( const int image_size, const float blur_radius, const int faces_per_pixel, - const bool perspective_correct); + const bool perspective_correct, + const bool cull_backfaces); #ifdef WITH_CUDA std::tuple @@ -28,7 +29,8 @@ RasterizeMeshesNaiveCuda( const int image_size, const float blur_radius, const int num_closest, - const bool perspective_correct); + const bool perspective_correct, + const bool cull_backfaces); #endif // Forward pass for rasterizing a batch of meshes. // @@ -55,6 +57,14 @@ RasterizeMeshesNaiveCuda( // coordinates for each pixel; if this is False then // this function instead returns screen-space // barycentric coordinates for each pixel. +// cull_backfaces: Bool, Whether to only rasterize mesh faces which are +// visible to the camera. This assumes that vertices of +// front-facing triangles are ordered in an anti-clockwise +// fashion, and triangles that face away from the camera are +// in a clockwise order relative to the current view +// direction. NOTE: This will only work if the mesh faces are +// consistently defined with counter-clockwise ordering when +// viewed from the outside. // // Returns: // A 4 element tuple of: @@ -80,7 +90,8 @@ RasterizeMeshesNaive( const int image_size, const float blur_radius, const int faces_per_pixel, - const bool perspective_correct) { + const bool perspective_correct, + const bool cull_backfaces) { // TODO: Better type checking. if (face_verts.is_cuda()) { #ifdef WITH_CUDA @@ -91,7 +102,8 @@ RasterizeMeshesNaive( image_size, blur_radius, faces_per_pixel, - perspective_correct); + perspective_correct, + cull_backfaces); #else AT_ERROR("Not compiled with GPU support"); #endif @@ -103,7 +115,8 @@ RasterizeMeshesNaive( image_size, blur_radius, faces_per_pixel, - perspective_correct); + perspective_correct, + cull_backfaces); } } @@ -274,7 +287,8 @@ RasterizeMeshesFineCuda( const float blur_radius, const int bin_size, const int faces_per_pixel, - const bool perspective_correct); + const bool perspective_correct, + const bool cull_backfaces); #endif // Args: // face_verts: Tensor of shape (F, 3, 3) giving (packed) vertex positions for @@ -296,6 +310,14 @@ RasterizeMeshesFineCuda( // coordinates for each pixel; if this is False then // this function instead returns screen-space // barycentric coordinates for each pixel. +// cull_backfaces: Bool, Whether to only rasterize mesh faces which are +// visible to the camera. This assumes that vertices of +// front-facing triangles are ordered in an anti-clockwise +// fashion, and triangles that face away from the camera are +// in a clockwise order relative to the current view +// direction. NOTE: This will only work if the mesh faces are +// consistently defined with counter-clockwise ordering when +// viewed from the outside. // // Returns (same as rasterize_meshes): // A 4 element tuple of: @@ -321,7 +343,8 @@ RasterizeMeshesFine( const float blur_radius, const int bin_size, const int faces_per_pixel, - const bool perspective_correct) { + const bool perspective_correct, + const bool cull_backfaces) { if (face_verts.is_cuda()) { #ifdef WITH_CUDA return RasterizeMeshesFineCuda( @@ -331,7 +354,8 @@ RasterizeMeshesFine( blur_radius, bin_size, faces_per_pixel, - perspective_correct); + perspective_correct, + cull_backfaces); #else AT_ERROR("Not compiled with GPU support"); #endif @@ -372,7 +396,14 @@ RasterizeMeshesFine( // coordinates for each pixel; if this is False then // this function instead returns screen-space // barycentric coordinates for each pixel. -// +// cull_backfaces: Bool, Whether to only rasterize mesh faces which are +// visible to the camera. This assumes that vertices of +// front-facing triangles are ordered in an anti-clockwise +// fashion, and triangles that face away from the camera are +// in a clockwise order relative to the current view +// direction. NOTE: This will only work if the mesh faces are +// consistently defined with counter-clockwise ordering when +// viewed from the outside. // // Returns: // A 4 element tuple of: @@ -400,7 +431,8 @@ RasterizeMeshes( const int faces_per_pixel, const int bin_size, const int max_faces_per_bin, - const bool perspective_correct) { + const bool perspective_correct, + const bool cull_backfaces) { if (bin_size > 0 && max_faces_per_bin > 0) { // Use coarse-to-fine rasterization auto bin_faces = RasterizeMeshesCoarse( @@ -418,7 +450,8 @@ RasterizeMeshes( blur_radius, bin_size, faces_per_pixel, - perspective_correct); + perspective_correct, + cull_backfaces); } else { // Use the naive per-pixel implementation return RasterizeMeshesNaive( @@ -428,6 +461,7 @@ RasterizeMeshes( image_size, blur_radius, faces_per_pixel, - perspective_correct); + perspective_correct, + cull_backfaces); } } diff --git a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes_cpu.cpp b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes_cpu.cpp index dd810d355..a4d28afd2 100644 --- a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes_cpu.cpp +++ b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes_cpu.cpp @@ -107,7 +107,8 @@ RasterizeMeshesNaiveCpu( int image_size, const float blur_radius, const int faces_per_pixel, - const bool perspective_correct) { + const bool perspective_correct, + const bool cull_backfaces) { if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 || face_verts.size(2) != 3) { AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)"); @@ -184,8 +185,13 @@ RasterizeMeshesNaiveCpu( const vec2 v1(x1, y1); const vec2 v2(x2, y2); - // Skip faces with zero area. const float face_area = face_areas_a[f]; + const bool back_face = face_area < 0.0; + // Check if the face is visible to the camera. + if (cull_backfaces && back_face) { + continue; + } + // Skip faces with zero area. if (face_area <= kEpsilon && face_area >= -1.0f * kEpsilon) { continue; } diff --git a/pytorch3d/io/obj_io.py b/pytorch3d/io/obj_io.py index cdf401b3d..e5fda294d 100644 --- a/pytorch3d/io/obj_io.py +++ b/pytorch3d/io/obj_io.py @@ -140,16 +140,16 @@ def load_obj(f_obj, load_textures=True): If there are faces with more than 3 vertices they are subdivided into triangles. Polygonal faces are assummed to have vertices ordered counter-clockwise so the (right-handed) normal points - into the screen e.g. a proper rectangular face would be specified like this: + out of the screen e.g. a proper rectangular face would be specified like this: :: 0_________1 | | | | 3 ________2 - The face would be split into two triangles: (0, 1, 2) and (0, 2, 3), - both of which are also oriented clockwise and have normals - pointing into the screen. + The face would be split into two triangles: (0, 2, 1) and (0, 3, 2), + both of which are also oriented counter-clockwise and have normals + pointing out of the screen. Args: f: A file-like object (with methods read, readline, tell, and seek), diff --git a/pytorch3d/renderer/mesh/rasterize_meshes.py b/pytorch3d/renderer/mesh/rasterize_meshes.py index e72a9596a..67cb8a79a 100644 --- a/pytorch3d/renderer/mesh/rasterize_meshes.py +++ b/pytorch3d/renderer/mesh/rasterize_meshes.py @@ -20,6 +20,7 @@ def rasterize_meshes( bin_size: Optional[int] = None, max_faces_per_bin: Optional[int] = None, perspective_correct: bool = False, + cull_backfaces: bool = False, ): """ Rasterize a batch of meshes given the shape of the desired output image. @@ -45,8 +46,16 @@ def rasterize_meshes( bin. If more than this many faces actually fall into a bin, an error will be raised. This should not affect the output values, but can affect the memory usage in the forward pass. - perspective_correct: Whether to apply perspective correction when computing + perspective_correct: Bool, Whether to apply perspective correction when computing barycentric coordinates for pixels. + cull_backfaces: Bool, Whether to only rasterize mesh faces which are + visible to the camera. This assumes that vertices of + front-facing triangles are ordered in an anti-clockwise + fashion, and triangles that face away from the camera are + in a clockwise order relative to the current view + direction. NOTE: This will only work if the mesh faces are + consistently defined with counter-clockwise ordering when + viewed from the outside. Returns: 4-element tuple containing @@ -118,6 +127,7 @@ def rasterize_meshes( bin_size, max_faces_per_bin, perspective_correct, + cull_backfaces, ) @@ -139,6 +149,7 @@ class _RasterizeFaceVerts(torch.autograd.Function): for each mesh in the batch. image_size, blur_radius, faces_per_pixel: same as rasterize_meshes. perspective_correct: same as rasterize_meshes. + cull_backfaces: same as rasterize_meshes. Returns: same as rasterize_meshes function. @@ -156,6 +167,7 @@ def forward( bin_size: int = 0, max_faces_per_bin: int = 0, perspective_correct: bool = False, + cull_backfaces: bool = False, ): pix_to_face, zbuf, barycentric_coords, dists = _C.rasterize_meshes( face_verts, @@ -167,6 +179,7 @@ def forward( bin_size, max_faces_per_bin, perspective_correct, + cull_backfaces, ) ctx.save_for_backward(face_verts, pix_to_face) ctx.perspective_correct = perspective_correct @@ -183,6 +196,7 @@ def backward(ctx, grad_pix_to_face, grad_zbuf, grad_barycentric_coords, grad_dis grad_bin_size = None grad_max_faces_per_bin = None grad_perspective_correct = None + grad_cull_backfaces = None face_verts, pix_to_face = ctx.saved_tensors grad_face_verts = _C.rasterize_meshes_backward( face_verts, @@ -202,6 +216,7 @@ def backward(ctx, grad_pix_to_face, grad_zbuf, grad_barycentric_coords, grad_dis grad_bin_size, grad_max_faces_per_bin, grad_perspective_correct, + grad_cull_backfaces, ) return grads @@ -217,6 +232,7 @@ def rasterize_meshes_python( blur_radius: float = 0.0, faces_per_pixel: int = 8, perspective_correct: bool = False, + cull_backfaces: bool = False, ): """ Naive PyTorch implementation of mesh rasterization with the same inputs and @@ -287,7 +303,12 @@ def rasterize_meshes_python( face = faces_verts[f].squeeze() v0, v1, v2 = face.unbind(0) - face_area = edge_function(v2, v0, v1) + face_area = edge_function(v0, v1, v2) + + # Ignore triangles facing away from the camera. + back_face = face_area < 0 + if cull_backfaces and back_face: + continue # Ignore faces which have zero area. if face_area == 0.0: @@ -365,8 +386,8 @@ def edge_function(p, v0, v1): .. code-block:: python - A = p - v0 - B = v1 - v0 + B = p - v0 + A = v1 - v0 v1 ________ /\ / diff --git a/pytorch3d/renderer/mesh/rasterizer.py b/pytorch3d/renderer/mesh/rasterizer.py index b995bf27e..a5c9a2e89 100644 --- a/pytorch3d/renderer/mesh/rasterizer.py +++ b/pytorch3d/renderer/mesh/rasterizer.py @@ -26,6 +26,7 @@ class RasterizationSettings: "bin_size", "max_faces_per_bin", "perspective_correct", + "cull_backfaces", ] def __init__( @@ -36,6 +37,7 @@ def __init__( bin_size: Optional[int] = None, max_faces_per_bin: Optional[int] = None, perspective_correct: bool = False, + cull_backfaces: bool = False, ): self.image_size = image_size self.blur_radius = blur_radius @@ -43,6 +45,7 @@ def __init__( self.bin_size = bin_size self.max_faces_per_bin = max_faces_per_bin self.perspective_correct = perspective_correct + self.cull_backfaces = cull_backfaces class MeshRasterizer(nn.Module): @@ -122,6 +125,7 @@ def forward(self, meshes_world, **kwargs) -> Fragments: bin_size=raster_settings.bin_size, max_faces_per_bin=raster_settings.max_faces_per_bin, perspective_correct=raster_settings.perspective_correct, + cull_backfaces=raster_settings.cull_backfaces, ) return Fragments( pix_to_face=pix_to_face, zbuf=zbuf, bary_coords=bary_coords, dists=dists diff --git a/tests/test_rasterize_meshes.py b/tests/test_rasterize_meshes.py index 69844d93a..9979a0419 100644 --- a/tests/test_rasterize_meshes.py +++ b/tests/test_rasterize_meshes.py @@ -21,6 +21,7 @@ def test_simple_python(self): self._simple_blurry_raster(rasterize_meshes_python, device, bin_size=-1) self._test_behind_camera(rasterize_meshes_python, device, bin_size=-1) self._test_perspective_correct(rasterize_meshes_python, device, bin_size=-1) + self._test_back_face_culling(rasterize_meshes_python, device, bin_size=-1) def test_simple_cpu_naive(self): device = torch.device("cpu") @@ -28,6 +29,7 @@ def test_simple_cpu_naive(self): self._simple_blurry_raster(rasterize_meshes, device, bin_size=0) self._test_behind_camera(rasterize_meshes, device, bin_size=0) self._test_perspective_correct(rasterize_meshes, device, bin_size=0) + self._test_back_face_culling(rasterize_meshes, device, bin_size=0) def test_simple_cuda_naive(self): device = torch.device("cuda:0") @@ -35,6 +37,7 @@ def test_simple_cuda_naive(self): self._simple_blurry_raster(rasterize_meshes, device, bin_size=0) self._test_behind_camera(rasterize_meshes, device, bin_size=0) self._test_perspective_correct(rasterize_meshes, device, bin_size=0) + self._test_back_face_culling(rasterize_meshes, device, bin_size=0) def test_simple_cuda_binned(self): device = torch.device("cuda:0") @@ -42,6 +45,7 @@ def test_simple_cuda_binned(self): self._simple_blurry_raster(rasterize_meshes, device, bin_size=5) self._test_behind_camera(rasterize_meshes, device, bin_size=5) self._test_perspective_correct(rasterize_meshes, device, bin_size=5) + self._test_back_face_culling(rasterize_meshes, device, bin_size=5) def test_python_vs_cpu_vs_cuda(self): torch.manual_seed(231) @@ -377,6 +381,81 @@ def test_cuda_naive_vs_binned_perspective_correct(self): args = () self._compare_impls(fn1, fn2, args, args, verts1, verts2, compare_grads=True) + def _test_back_face_culling(self, rasterize_meshes_fn, device, bin_size): + # Square based pyramid mesh. + # fmt: off + verts = torch.tensor([ + [-0.5, 0.0, 0.5], # noqa: E241 E201 Front right + [ 0.5, 0.0, 0.5], # noqa: E241 E201 Front left + [ 0.5, 0.0, 1.5], # noqa: E241 E201 Back left + [-0.5, 0.0, 1.5], # noqa: E241 E201 Back right + [ 0.0, 1.0, 1.0] # noqa: E241 E201 Top point of pyramid + ], dtype=torch.float32, device=device) + + faces = torch.tensor([ + [2, 1, 0], # noqa: E241 E201 Square base + [3, 2, 0], # noqa: E241 E201 Square base + [1, 0, 4], # noqa: E241 E201 Triangle on front + [2, 4, 3], # noqa: E241 E201 Triangle on back + [3, 4, 0], # noqa: E241 E201 Triangle on left side + [1, 4, 2] # noqa: E241 E201 Triangle on right side + ], dtype=torch.int64, device=device) + # fmt: on + mesh = Meshes(verts=[verts], faces=[faces]) + kwargs = { + "meshes": mesh, + "image_size": 10, + "faces_per_pixel": 2, + "blur_radius": 0.0, + "perspective_correct": False, + "cull_backfaces": False, + } + if bin_size != -1: + kwargs["bin_size"] = bin_size + + # fmt: off + pix_to_face_frontface = torch.tensor([ + [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201 + [-1, -1, -1, -1, 2, 2, -1, -1, -1, -1], # noqa: E241 E201 + [-1, -1, -1, -1, 2, 2, -1, -1, -1, -1], # noqa: E241 E201 + [-1, -1, -1, 2, 2, 2, 2, -1, -1, -1], # noqa: E241 E201 + [-1, -1, -1, 2, 2, 2, 2, -1, -1, -1], # noqa: E241 E201 + [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201 + [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201 + [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201 + [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201 + [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1] # noqa: E241 E201 + ], dtype=torch.int64, device=device) + pix_to_face_backface = torch.tensor([ + [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201 + [-1, -1, -1, -1, 3, 3, -1, -1, -1, -1], # noqa: E241 E201 + [-1, -1, -1, -1, 3, 3, -1, -1, -1, -1], # noqa: E241 E201 + [-1, -1, -1, 3, 3, 3, 3, -1, -1, -1], # noqa: E241 E201 + [-1, -1, -1, 3, 3, 3, 3, -1, -1, -1], # noqa: E241 E201 + [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201 + [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201 + [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201 + [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201 + [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1] # noqa: E241 E201 + ], dtype=torch.int64, device=device) + # fmt: on + + pix_to_face_padded = -torch.ones_like(pix_to_face_frontface) + # Run with and without culling + # Without culling, for k=0, the front face (i.e. face 2) is + # rasterized and for k=1, the back face (i.e. face 3) is + # rasterized. + idx_f, zbuf_f, bary_f, dists_f = rasterize_meshes_fn(**kwargs) + self.assertTrue(torch.all(idx_f[..., 0].squeeze() == pix_to_face_frontface)) + self.assertTrue(torch.all(idx_f[..., 1].squeeze() == pix_to_face_backface)) + + # With culling, for k=0, the front face (i.e. face 2) is + # rasterized and for k=1, there are no faces rasterized + kwargs["cull_backfaces"] = True + idx_t, zbuf_t, bary_t, dists_t = rasterize_meshes_fn(**kwargs) + self.assertTrue(torch.all(idx_t[..., 0].squeeze() == pix_to_face_frontface)) + self.assertTrue(torch.all(idx_t[..., 1].squeeze() == pix_to_face_padded)) + def _compare_impls( self, fn1,