Skip to content

Commit

Permalink
shader: fix HardDepthShader sizes + tests (#1252)
Browse files Browse the repository at this point in the history
Summary:
This fixes a indexing bug in HardDepthShader and adds proper unit tests for both of the DepthShaders. This bug was introduced when updating the shader sizes and discovered when I switched my local model onto pytorch3d trunk instead of the patched copy.

Pull Request resolved: #1252

Test Plan:
Unit test + custom model code

```
pytest tests/test_shader.py
```

![image](https://user-images.githubusercontent.com/909104/178397456-f478d0e0-9f6c-467a-a85b-adb4c47adfee.png)

Reviewed By: bottler

Differential Revision: D37775767

Pulled By: d4l3k

fbshipit-source-id: 5f001903985976d7067d1fa0a3102d602790e3e8
  • Loading branch information
d4l3k authored and facebook-github-bot committed Jul 12, 2022
1 parent 8d10ba5 commit 4ecc9ea
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 3 deletions.
6 changes: 3 additions & 3 deletions pytorch3d/renderer/mesh/shader.py
Expand Up @@ -374,11 +374,11 @@ def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tenso
cameras = super()._get_cameras(**kwargs)

zfar = kwargs.get("zfar", getattr(cameras, "zfar", 100.0))
mask = fragments.pix_to_face < 0
mask = fragments.pix_to_face[..., 0:1] < 0

zbuf = fragments.zbuf[..., 0].clone()
zbuf = fragments.zbuf[..., 0:1].clone()
zbuf[mask] = zfar
return zbuf.unsqueeze(3)
return zbuf


class SoftDepthShader(ShaderBase):
Expand Down
32 changes: 32 additions & 0 deletions tests/test_shader.py
Expand Up @@ -91,3 +91,35 @@ def test_cameras_check(self):

with self.assertRaises(ValueError):
shader(fragments, meshes)

def test_depth_shader(self):
shader_classes = [
HardDepthShader,
SoftDepthShader,
]

verts = torch.tensor(
[[-1, -1, 0], [1, -1, 1], [1, 1, 0], [-1, 1, 1]], dtype=torch.float32
)
faces = torch.tensor([[0, 1, 2], [2, 3, 0]], dtype=torch.int64)
meshes = Meshes(verts=[verts], faces=[faces])

pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2)
barycentric_coords = torch.tensor(
[[0.1, 0.2, 0.7], [0.3, 0.5, 0.2]], dtype=torch.float32
).view(1, 1, 1, 2, -1)
for faces_per_pixel in [1, 2]:
fragments = Fragments(
pix_to_face=pix_to_face[:, :, :, :faces_per_pixel],
bary_coords=barycentric_coords[:, :, :, :faces_per_pixel],
zbuf=torch.ones_like(pix_to_face),
dists=torch.ones_like(pix_to_face),
)
R, T = look_at_view_transform()
cameras = PerspectiveCameras(R=R, T=T)

for shader_class in shader_classes:
shader = shader_class()

out = shader(fragments, meshes, cameras=cameras)
self.assertEqual(out.shape, (1, 1, 1, 1))

0 comments on commit 4ecc9ea

Please sign in to comment.