Skip to content

Commit

Permalink
validate lengths in chamfer and farthest_points
Browse files Browse the repository at this point in the history
Summary: Fixes #1326

Reviewed By: kjchalup

Differential Revision: D39259697

fbshipit-source-id: 51392f4cc4a956165a62901cb115fcefe0e17277
  • Loading branch information
bottler authored and facebook-github-bot committed Sep 8, 2022
1 parent 6e25fe8 commit cb7bd33
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 10 deletions.
9 changes: 5 additions & 4 deletions pytorch3d/loss/chamfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,11 @@ def _handle_pointcloud_input(
if points.ndim != 3:
raise ValueError("Expected points to be of shape (N, P, D)")
X = points
if lengths is not None and (
lengths.ndim != 1 or lengths.shape[0] != X.shape[0]
):
raise ValueError("Expected lengths to be of shape (N,)")
if lengths is not None:
if lengths.ndim != 1 or lengths.shape[0] != X.shape[0]:
raise ValueError("Expected lengths to be of shape (N,)")
if lengths.max() > X.shape[1]:
raise ValueError("A length value was too long")
if lengths is None:
lengths = torch.full(
(X.shape[0],), X.shape[1], dtype=torch.int64, device=points.device
Expand Down
16 changes: 10 additions & 6 deletions pytorch3d/ops/sample_farthest_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,11 @@ def sample_farthest_points(
# Validate inputs
if lengths is None:
lengths = torch.full((N,), P, dtype=torch.int64, device=device)

if lengths.shape != (N,):
raise ValueError("points and lengths must have same batch dimension.")
else:
if lengths.shape != (N,):
raise ValueError("points and lengths must have same batch dimension.")
if lengths.max() > P:
raise ValueError("A value in lengths was too large.")

# TODO: support providing K as a ratio of the total number of points instead of as an int
if isinstance(K, int):
Expand Down Expand Up @@ -107,9 +109,11 @@ def sample_farthest_points_naive(
# Validate inputs
if lengths is None:
lengths = torch.full((N,), P, dtype=torch.int64, device=device)

if lengths.shape[0] != N:
raise ValueError("points and lengths must have same batch dimension.")
else:
if lengths.shape != (N,):
raise ValueError("points and lengths must have same batch dimension.")
if lengths.max() > P:
raise ValueError("Invalid lengths.")

# TODO: support providing K as a ratio of the total number of points instead of as an int
if isinstance(K, int):
Expand Down

0 comments on commit cb7bd33

Please sign in to comment.