Skip to content

Commit

Permalink
ambient lights batching #1043
Browse files Browse the repository at this point in the history
Summary:
convert_to_tensors_and_broadcast had a special case for a single input, which is not used anywhere except fails to do the right thing if a TensorProperties has only one kwarg. At the moment AmbientLights may be the only way to hit the problem. Fix by removing the special case.

Fixes #1043

Reviewed By: nikhilaravi

Differential Revision: D33638345

fbshipit-source-id: 7a6695f44242e650504320f73b6da74254d49ac7
  • Loading branch information
bottler authored and facebook-github-bot committed Jan 20, 2022
1 parent fddd6a7 commit 9e2bc3a
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
3 changes: 0 additions & 3 deletions pytorch3d/renderer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,4 @@ def convert_to_tensors_and_broadcast(
expand_sizes = (N,) + (-1,) * len(c.shape[1:])
args_Nd.append(c.expand(*expand_sizes))

if len(args) == 1:
args_Nd = args_Nd[0] # Return the first element

return args_Nd
13 changes: 12 additions & 1 deletion tests/test_lighting.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy as np
import torch
from common_testing import TestCaseMixin
from pytorch3d.renderer.lighting import DirectionalLights, PointLights
from pytorch3d.renderer.lighting import AmbientLights, DirectionalLights, PointLights
from pytorch3d.transforms import RotateAxisAngle


Expand Down Expand Up @@ -121,6 +121,17 @@ def test_initialize_lights_dimensions_fail(self):
with self.assertRaises(ValueError):
PointLights(location=torch.randn(10, 4))

def test_initialize_ambient(self):
N = 13
color = 0.8 * torch.ones((N, 3))
lights = AmbientLights(ambient_color=color)
self.assertEqual(len(lights), N)
self.assertClose(lights.ambient_color, color)

lights = AmbientLights(ambient_color=color[:1])
self.assertEqual(len(lights), 1)
self.assertClose(lights.ambient_color, color[:1])


class TestDiffuseLighting(TestCaseMixin, unittest.TestCase):
def test_diffuse_directional_lights(self):
Expand Down

0 comments on commit 9e2bc3a

Please sign in to comment.