Skip to content

Commit

Permalink
make points2volumes feature rescaling optional
Browse files Browse the repository at this point in the history
Summary: Add option to not rescale the features, giving more control. #1137

Reviewed By: nikhilaravi

Differential Revision: D35219577

fbshipit-source-id: cbbb643b91b71bc908cedc6dac0f63f6d1355c85
  • Loading branch information
bottler authored and facebook-github-bot committed Apr 13, 2022
1 parent 0a7c354 commit 78fd5af
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 8 deletions.
22 changes: 17 additions & 5 deletions pytorch3d/ops/points_to_volumes.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def add_pointclouds_to_volumes(
initial_volumes: "Volumes",
mode: str = "trilinear",
min_weight: float = 1e-4,
rescale_features: bool = True,
_python: bool = False,
) -> "Volumes":
"""
Expand Down Expand Up @@ -250,6 +251,10 @@ def add_pointclouds_to_volumes(
min_weight: A scalar controlling the lowest possible total per-voxel
weight used to normalize the features accumulated in a voxel.
Only active for `mode==trilinear`.
rescale_features: If False, output features are just the sum of input and
added points. If True, they are averaged. In both cases,
output densities are just summed without rescaling, so
you may need to rescale them afterwards.
_python: Set to True to use a pure Python implementation, e.g. for test
purposes, which requires more memory and may be slower.
Expand Down Expand Up @@ -286,6 +291,7 @@ def add_pointclouds_to_volumes(
grid_sizes=initial_volumes.get_grid_sizes(),
mask=mask,
mode=mode,
rescale_features=rescale_features,
_python=_python,
)

Expand All @@ -303,6 +309,7 @@ def add_points_features_to_volume_densities_features(
min_weight: float = 1e-4,
mask: Optional[torch.Tensor] = None,
grid_sizes: Optional[torch.LongTensor] = None,
rescale_features: bool = True,
_python: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Expand Down Expand Up @@ -345,6 +352,10 @@ def add_points_features_to_volume_densities_features(
grid_sizes: `LongTensor` of shape (minibatch, 3) representing the
spatial resolutions of each of the the non-flattened `volumes` tensors,
or None to indicate the whole volume is used for every batch element.
rescale_features: If False, output features are just the sum of input and
added points. If True, they are averaged. In both cases,
output densities are just summed without rescaling, so
you may need to rescale them afterwards.
_python: Set to True to use a pure Python implementation.
Returns:
volume_features: Output volume of shape `(minibatch, feature_dim, D, H, W)`
Expand Down Expand Up @@ -401,12 +412,13 @@ def add_points_features_to_volume_densities_features(
True, # align_corners
splat,
)
if splat:
# divide each feature by the total weight of the votes
volume_features = volume_features / volume_densities.clamp(min_weight)
else:

if rescale_features:
# divide each feature by the total weight of the votes
volume_features = volume_features / volume_densities.clamp(1.0)
if splat:
volume_features = volume_features / volume_densities.clamp(min_weight)
else:
volume_features = volume_features / volume_densities.clamp(1.0)

return volume_features, volume_densities

Expand Down
16 changes: 13 additions & 3 deletions tests/common_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,12 +192,22 @@ def assertClose(
self.fail(f"{msg} {err}")
self.fail(err)

def assertConstant(self, input: TensorOrArray, value: Real) -> None:
def assertConstant(
self, input: TensorOrArray, value: Real, *, atol: float = 0
) -> None:
"""
Asserts input is entirely filled with value.
Args:
input: tensor or array
value: expected value
atol: tolerance
"""
self.assertEqual(input.min(), value)
self.assertEqual(input.max(), value)
mn, mx = input.min(), input.max()
msg = f"values in range [{mn}, {mx}], not {value}, shape {input.shape}"
if atol == 0:
self.assertEqual(input.min(), value, msg=msg)
self.assertEqual(input.max(), value, msg=msg)
else:
self.assertGreater(input.min(), value - atol, msg=msg)
self.assertLess(input.max(), value + atol, msg=msg)
17 changes: 17 additions & 0 deletions tests/test_points_to_volumes.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,23 @@ def test_defaulted_arguments(self):
)
self.assertClose(torch.sum(densities), torch.tensor(30 * 1000.0), atol=0.1)

def test_unscaled(self):
D = 5
P = 1000
B, C, H, W = 2, 3, D, D
densities = torch.zeros(B, 1, D, H, W)
features = torch.zeros(B, C, D, H, W)
volumes = Volumes(densities=densities, features=features)
points = torch.rand(B, 1000, 3) * (D - 1) - ((D - 1) * 0.5)
point_features = torch.rand(B, 1000, C)
pointclouds = Pointclouds(points=points, features=point_features)

volumes2 = add_pointclouds_to_volumes(
pointclouds, volumes, rescale_features=False
)
self.assertConstant(volumes2.densities().sum([2, 3, 4]) / P, 1, atol=1e-5)
self.assertConstant(volumes2.features().sum([2, 3, 4]) / P, 0.5, atol=0.03)

def _check_volume_slice_color_density(
self, V, split_dim, interp_mode, clr_gt, slice_type, border=3
):
Expand Down

0 comments on commit 78fd5af

Please sign in to comment.