Skip to content

Commit

Permalink
rasterizer.to without cameras
Browse files Browse the repository at this point in the history
Summary: As reported in #1100, a rasterizer couldn't be moved if it was missing the optional cameras member. Fix that. This matters because the renderer.to calls rasterizer.to, so this to() could be called even by a user who never sets a cameras member.

Reviewed By: nikhilaravi

Differential Revision: D34643841

fbshipit-source-id: 7e26e32e8bc585eb1ee533052754a7b59bc7467a
  • Loading branch information
bottler authored and facebook-github-bot committed Mar 9, 2022
1 parent 4a1f176 commit c371a9a
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 2 deletions.
3 changes: 2 additions & 1 deletion pytorch3d/renderer/mesh/rasterizer.py
Expand Up @@ -110,7 +110,8 @@ def __init__(self, cameras=None, raster_settings=None) -> None:

def to(self, device):
# Manually move to device cameras as it is not a subclass of nn.Module
self.cameras = self.cameras.to(device)
if self.cameras is not None:
self.cameras = self.cameras.to(device)
return self

def transform(self, meshes_world, **kwargs) -> torch.Tensor:
Expand Down
3 changes: 2 additions & 1 deletion pytorch3d/renderer/points/rasterizer.py
Expand Up @@ -115,7 +115,8 @@ def transform(self, point_clouds, **kwargs) -> torch.Tensor:

def to(self, device):
# Manually move to device cameras as it is not a subclass of nn.Module
self.cameras = self.cameras.to(device)
if self.cameras is not None:
self.cameras = self.cameras.to(device)
return self

def forward(self, point_clouds, **kwargs) -> PointFragments:
Expand Down
12 changes: 12 additions & 0 deletions tests/test_rasterizer.py
Expand Up @@ -134,6 +134,12 @@ def test_simple_sphere(self):

self.assertTrue(torch.allclose(image, image_ref))

def test_simple_to(self):
# Check that to() works without a cameras object.
device = torch.device("cuda:0")
rasterizer = MeshRasterizer()
rasterizer.to(device)


class TestPointRasterizer(unittest.TestCase):
def test_simple_sphere(self):
Expand Down Expand Up @@ -203,3 +209,9 @@ def test_simple_sphere(self):
image[image >= 0] = 1.0
image[image < 0] = 0.0
self.assertTrue(torch.allclose(image, image_ref[..., 0]))

def test_simple_to(self):
# Check that to() works without a cameras object.
device = torch.device("cuda:0")
rasterizer = PointsRasterizer()
rasterizer.to(device)

0 comments on commit c371a9a

Please sign in to comment.