Skip to content

Commit

Permalink
In blending, pull common functionality into get_background_color
Browse files Browse the repository at this point in the history
Summary: A small refactor, originally intended for use with the splatter.

Reviewed By: bottler

Differential Revision: D36210393

fbshipit-source-id: b3372f7cc7690ee45dd3059b2d4be1c8dfa63180
  • Loading branch information
Krzysztof Chalupka authored and facebook-github-bot committed May 17, 2022
1 parent 4372001 commit ea5df60
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions pytorch3d/renderer/blending.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import torch
from pytorch3d import _C
from pytorch3d.common.datatypes import Device


# Example functions for blending the top K colors per pixel using the outputs
Expand Down Expand Up @@ -37,6 +38,17 @@ class BlendParams(NamedTuple):
background_color: Union[torch.Tensor, Sequence[float]] = (1.0, 1.0, 1.0)


def _get_background_color(
blend_params: BlendParams, device: Device, dtype=torch.float32
) -> torch.Tensor:
background_color_ = blend_params.background_color
if isinstance(background_color_, torch.Tensor):
background_color = background_color_.to(device)
else:
background_color = torch.tensor(background_color_, dtype=dtype, device=device)
return background_color


def hard_rgb_blend(
colors: torch.Tensor, fragments, blend_params: BlendParams
) -> torch.Tensor:
Expand All @@ -57,18 +69,11 @@ def hard_rgb_blend(
Returns:
RGBA pixel_colors: (N, H, W, 4)
"""
N, H, W, K = fragments.pix_to_face.shape
device = fragments.pix_to_face.device
background_color = _get_background_color(blend_params, fragments.pix_to_face.device)

# Mask for the background.
is_background = fragments.pix_to_face[..., 0] < 0 # (N, H, W)

background_color_ = blend_params.background_color
if isinstance(background_color_, torch.Tensor):
background_color = background_color_.to(device)
else:
background_color = colors.new_tensor(background_color_)

# Find out how much background_color needs to be expanded to be used for masked_scatter.
num_background_pixels = is_background.sum()

Expand Down Expand Up @@ -182,13 +187,8 @@ def softmax_rgb_blend(
"""

N, H, W, K = fragments.pix_to_face.shape
device = fragments.pix_to_face.device
pixel_colors = torch.ones((N, H, W, 4), dtype=colors.dtype, device=colors.device)
background_ = blend_params.background_color
if not isinstance(background_, torch.Tensor):
background = torch.tensor(background_, dtype=torch.float32, device=device)
else:
background = background_.to(device)
background_color = _get_background_color(blend_params, fragments.pix_to_face.device)

# Weight for background color
eps = 1e-10
Expand Down Expand Up @@ -233,7 +233,7 @@ def softmax_rgb_blend(

# Sum: weights * textures + background color
weighted_colors = (weights_num[..., None] * colors).sum(dim=-2)
weighted_background = delta * background
weighted_background = delta * background_color
pixel_colors[..., :3] = (weighted_colors + weighted_background) / denom
pixel_colors[..., 3] = 1.0 - alpha

Expand Down

0 comments on commit ea5df60

Please sign in to comment.