Skip to content

Commit

Permalink
Temporary fix for mesh rasterization bug for traingles partially behi…
Browse files Browse the repository at this point in the history
…nd the camera

Summary: A triangle is culled if any vertex in a triangle is behind the camera.  This fixes incorrect rendering of triangles that are partially behind the camera, where screen coordinate calculations are strange.  It doesn't work for triangles that are partially behind the camera but still intersect with the view frustum.

Reviewed By: nikhilaravi

Differential Revision: D22856181

fbshipit-source-id: a9cbaa1327d89601b83d0dfd3e4a04f934a4a213
  • Loading branch information
Steve Branson authored and facebook-github-bot committed Aug 21, 2020
1 parent 57a22e7 commit 9aaba04
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 8 deletions.
17 changes: 13 additions & 4 deletions pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,14 @@ __device__ bool CheckPointOutsideBoundingBox(
const float x_max = xlims.y + blur_radius;
const float y_max = ylims.y + blur_radius;

// Faces with at least one vertex behind the camera won't render correctly
// and should be removed or clipped before calling the rasterizer
const bool z_invalid = zlims.x < kEpsilon;

// Check if the current point is oustside the triangle bounding box.
return (pxy.x > x_max || pxy.x < x_min || pxy.y > y_max || pxy.y < y_min);
return (
pxy.x > x_max || pxy.x < x_min || pxy.y > y_max || pxy.y < y_min ||
z_invalid);
}

// This function checks if a pixel given by xy location pxy lies within the
Expand Down Expand Up @@ -625,10 +631,13 @@ __global__ void RasterizeMeshesCoarseCudaKernel(
float ymin = FloatMin3(v0.y, v1.y, v2.y) - sqrt(blur_radius);
float xmax = FloatMax3(v0.x, v1.x, v2.x) + sqrt(blur_radius);
float ymax = FloatMax3(v0.y, v1.y, v2.y) + sqrt(blur_radius);
float zmax = FloatMax3(v0.z, v1.z, v2.z);
float zmin = FloatMin3(v0.z, v1.z, v2.z);

if (zmax < 0) {
continue; // Face is behind the camera.
// Faces with at least one vertex behind the camera won't render
// correctly and should be removed or clipped before calling the
// rasterizer
if (zmin < kEpsilon) {
continue;
}

// Brute-force search over all bins; TODO(T54294966) something smarter.
Expand Down
15 changes: 11 additions & 4 deletions pytorch3d/csrc/rasterize_meshes/rasterize_meshes_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,12 @@ bool CheckPointOutsideBoundingBox(
float x_max = face_bbox[2] + blur_radius;
float y_max = face_bbox[3] + blur_radius;

// Faces with at least one vertex behind the camera won't render correctly
// and should be removed or clipped before calling the rasterizer
const bool z_invalid = face_bbox[4] < kEpsilon;

// Check if the current point is within the triangle bounding box.
return (px > x_max || px < x_min || py > y_max || py < y_min);
return (px > x_max || px < x_min || py > y_max || py < y_min || z_invalid);
}

// Calculate areas of all faces. Returns a tensor of shape (total_faces, 1)
Expand Down Expand Up @@ -468,10 +472,13 @@ torch::Tensor RasterizeMeshesCoarseCpu(
float face_y_min = face_bboxes_a[f][1] - std::sqrt(blur_radius);
float face_x_max = face_bboxes_a[f][2] + std::sqrt(blur_radius);
float face_y_max = face_bboxes_a[f][3] + std::sqrt(blur_radius);
float face_z_max = face_bboxes_a[f][5];
float face_z_min = face_bboxes_a[f][4];

if (face_z_max < 0) {
continue; // Face is behind the camera.
// Faces with at least one vertex behind the camera won't render
// correctly and should be removed or clipped before calling the
// rasterizer
if (face_z_min < kEpsilon) {
continue;
}

// Use a half-open interval so that faces exactly on the
Expand Down
7 changes: 7 additions & 0 deletions pytorch3d/renderer/mesh/rasterize_meshes.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ def rasterize_meshes_python(
x_maxs = torch.max(faces_verts[:, :, 0], dim=1, keepdim=True).values
y_mins = torch.min(faces_verts[:, :, 1], dim=1, keepdim=True).values
y_maxs = torch.max(faces_verts[:, :, 1], dim=1, keepdim=True).values
z_mins = torch.min(faces_verts[:, :, 2], dim=1, keepdim=True).values

# Expand by blur radius.
x_mins = x_mins - np.sqrt(blur_radius) - kEpsilon
Expand Down Expand Up @@ -351,6 +352,12 @@ def rasterize_meshes_python(
or yf > y_maxs[f]
)

# Faces with at least one vertex behind the camera won't
# render correctly and should be removed or clipped before
# calling the rasterizer
if z_mins[f] < kEpsilon:
continue

# Check if pixel is outside of face bbox.
if outside_bbox:
continue
Expand Down
8 changes: 8 additions & 0 deletions tests/test_rasterize_meshes.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,10 @@ def _compare_impls(
+ (zbuf1 * grad_zbuf).sum()
+ (bary1 * grad_bary).sum()
)

# avoid gradient error if rasterize_meshes_python() culls all triangles
loss1 += grad_var1.sum() * 0.0

loss1.backward()
grad_verts1 = grad_var1.grad.data.clone().cpu()

Expand All @@ -563,6 +567,10 @@ def _compare_impls(
+ (zbuf2 * grad_zbuf).sum()
+ (bary2 * grad_bary).sum()
)

# avoid gradient error if rasterize_meshes_python() culls all triangles
loss2 += grad_var2.sum() * 0.0

grad_var1.grad.data.zero_()
loss2.backward()
grad_verts2 = grad_var2.grad.data.clone().cpu()
Expand Down

0 comments on commit 9aaba04

Please sign in to comment.