Skip to content

Commit

Permalink
flexible background color for point compositing
Browse files Browse the repository at this point in the history
Summary:
Modified the compositor background color tests to account for either a 3rd or 4th channel. Also replaced hard coding of channel value with C.

Implemented changes to alpha channel appending logic, and cleaned up extraneous warnings and checks, per task instructions.

Fixes #1048

Reviewed By: bottler

Differential Revision: D34305312

fbshipit-source-id: 2176c3bdd897d1a2ba6ff4c6fa801fea889e4f02
  • Loading branch information
Alex Greene authored and facebook-github-bot committed Feb 18, 2022
1 parent c8f3d6b commit 59972b1
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 15 deletions.
22 changes: 11 additions & 11 deletions pytorch3d/renderer/points/compositor.py
Expand Up @@ -35,7 +35,7 @@ def forward(self, fragments, alphas, ptclds, **kwargs) -> torch.Tensor:

# images are of shape (N, C, H, W)
# check for background color & feature size C (C=4 indicates rgba)
if background_color is not None and images.shape[1] == 4:
if background_color is not None:
return _add_background_color_to_images(fragments, images, background_color)
return images

Expand All @@ -57,7 +57,7 @@ def forward(self, fragments, alphas, ptclds, **kwargs) -> torch.Tensor:

# images are of shape (N, C, H, W)
# check for background color & feature size C (C=4 indicates rgba)
if background_color is not None and images.shape[1] == 4:
if background_color is not None:
return _add_background_color_to_images(fragments, images, background_color)
return images

Expand Down Expand Up @@ -85,22 +85,22 @@ def _add_background_color_to_images(pix_idxs, images, background_color):
if not torch.is_tensor(background_color):
background_color = images.new_tensor(background_color)

background_shape = background_color.shape

if len(background_shape) != 1 or background_shape[0] not in (3, 4):
warnings.warn(
"Background color should be size (3) or (4), but is size %s instead"
% (background_shape,)
)
return images
if len(background_color.shape) != 1:
raise ValueError("Wrong shape of background_color")

background_color = background_color.to(images)

# add alpha channel
if background_shape[0] == 3:
if background_color.shape[0] == 3 and images.shape[1] == 4:
# special case to allow giving RGB background for RGBA
alpha = images.new_ones(1)
background_color = torch.cat([background_color, alpha])

if images.shape[1] != background_color.shape[0]:
raise ValueError(
f"background color has {background_color.shape[0] } channels not {images.shape[1]}"
)

num_background_pixels = background_mask.sum()

# permute so that features are the last dimension for masked_scatter to work
Expand Down
54 changes: 50 additions & 4 deletions tests/test_render_points.py
Expand Up @@ -326,7 +326,7 @@ def test_simple_sphere_batched(self):
)
self.assertClose(rgb, image_ref)

def test_compositor_background_color(self):
def test_compositor_background_color_rgba(self):

N, H, W, K, C, P = 1, 15, 15, 20, 4, 225
ptclds = torch.randn((C, P))
Expand Down Expand Up @@ -357,7 +357,7 @@ def test_compositor_background_color(self):
torch.masked_select(images, is_foreground[:, None]),
)

is_background = ~is_foreground[..., None].expand(-1, -1, -1, 4)
is_background = ~is_foreground[..., None].expand(-1, -1, -1, C)

# permute masked_images to correctly get rgb values
masked_images = masked_images.permute(0, 2, 3, 1)
Expand All @@ -367,12 +367,58 @@ def test_compositor_background_color(self):
# check if background colors are properly changed
self.assertTrue(
masked_images[is_background]
.view(-1, 4)[..., i]
.view(-1, C)[..., i]
.eq(channel_color)
.all()
)

# check background color alpha values
self.assertTrue(
masked_images[is_background].view(-1, 4)[..., 3].eq(1).all()
masked_images[is_background].view(-1, C)[..., 3].eq(1).all()
)

def test_compositor_background_color_rgb(self):

N, H, W, K, C, P = 1, 15, 15, 20, 3, 225
ptclds = torch.randn((C, P))
alphas = torch.rand((N, K, H, W))
pix_idxs = torch.randint(-1, 20, (N, K, H, W)) # 20 < P, large amount of -1
background_color = [0.5, 0, 1]

compositor_funcs = [
(NormWeightedCompositor, norm_weighted_sum),
(AlphaCompositor, alpha_composite),
]

for (compositor_class, composite_func) in compositor_funcs:

compositor = compositor_class(background_color)

# run the forward method to generate masked images
masked_images = compositor.forward(pix_idxs, alphas, ptclds)

# generate unmasked images for testing purposes
images = composite_func(pix_idxs, alphas, ptclds)

is_foreground = pix_idxs[:, 0] >= 0

# make sure foreground values are unchanged
self.assertClose(
torch.masked_select(masked_images, is_foreground[:, None]),
torch.masked_select(images, is_foreground[:, None]),
)

is_background = ~is_foreground[..., None].expand(-1, -1, -1, C)

# permute masked_images to correctly get rgb values
masked_images = masked_images.permute(0, 2, 3, 1)
for i in range(3):
channel_color = background_color[i]

# check if background colors are properly changed
self.assertTrue(
masked_images[is_background]
.view(-1, C)[..., i]
.eq(channel_color)
.all()
)

0 comments on commit 59972b1

Please sign in to comment.