Skip to content

Commit

Permalink
Fix: Pointclouds.inside_box reducing over spatial dimensions.
Browse files Browse the repository at this point in the history
Summary: As subj. Tests corrected accordingly. Also changed the test to provide a bit better diagnostics.

Reviewed By: bottler

Differential Revision: D32879498

fbshipit-source-id: 0a852e4a13dcb4ca3e54d71c6b263c5d2eeaf4eb
  • Loading branch information
shapovalov authored and facebook-github-bot committed Dec 6, 2021
1 parent d9f7095 commit a6508ac
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 8 deletions.
4 changes: 2 additions & 2 deletions pytorch3d/structures/pointclouds.py
Original file line number Diff line number Diff line change
Expand Up @@ -1176,5 +1176,5 @@ def inside_box(self, box):
]
box = torch.cat(box, 0)

idx = (points_packed >= box[:, 0]) * (points_packed <= box[:, 1])
return idx
coord_inside = (points_packed >= box[:, 0]) * (points_packed <= box[:, 1])
return coord_inside.all(dim=-1)
13 changes: 11 additions & 2 deletions tests/common_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,17 @@ def assertClose(
if close:
return

diff = backend.abs(input + 0.0 - other)
ratio = diff / backend.abs(other)
# handle bool case
if backend == torch and input.dtype == torch.bool:
diff = (input != other).float()
ratio = diff
if backend == np and input.dtype == bool:
diff = (input != other).astype(float)
ratio = diff
else:
diff = backend.abs(input + 0.0 - other)
ratio = diff / backend.abs(other)

try_relative = (diff <= atol) | (backend.isfinite(ratio) & (ratio > 0))
if try_relative.all():
if backend == np:
Expand Down
10 changes: 6 additions & 4 deletions tests/test_pointclouds.py
Original file line number Diff line number Diff line change
Expand Up @@ -976,7 +976,9 @@ def test_update_padded(self):

def test_inside_box(self):
def inside_box_naive(cloud, box_min, box_max):
return (cloud >= box_min.view(1, 3)) * (cloud <= box_max.view(1, 3))
return ((cloud >= box_min.view(1, 3)) * (cloud <= box_max.view(1, 3))).all(
dim=-1
)

N, P, C = 5, 100, 4

Expand All @@ -994,7 +996,7 @@ def inside_box_naive(cloud, box_min, box_max):
for i, cloud in enumerate(clouds.points_list()):
within_box_naive.append(inside_box_naive(cloud, box[i, 0], box[i, 1]))
within_box_naive = torch.cat(within_box_naive, 0)
self.assertTrue(within_box.eq(within_box_naive).all())
self.assertClose(within_box, within_box_naive)

# box of shape 2x3
box2 = box[0, :]
Expand All @@ -1005,13 +1007,13 @@ def inside_box_naive(cloud, box_min, box_max):
for cloud in clouds.points_list():
within_box_naive2.append(inside_box_naive(cloud, box2[0], box2[1]))
within_box_naive2 = torch.cat(within_box_naive2, 0)
self.assertTrue(within_box2.eq(within_box_naive2).all())
self.assertClose(within_box2, within_box_naive2)

# box of shape 1x2x3
box3 = box2.expand(1, 2, 3)

within_box3 = clouds.inside_box(box3)
self.assertTrue(within_box2.eq(within_box3).all())
self.assertClose(within_box2, within_box3)

# invalid box
invalid_box = torch.cat(
Expand Down

0 comments on commit a6508ac

Please sign in to comment.