Skip to content

Commit

Permalink
Boolean indexing of cameras
Browse files Browse the repository at this point in the history
Summary: Reasonable to expect bool indexing.

Reviewed By: bottler, kjchalup

Differential Revision: D38741446

fbshipit-source-id: 22b607bf13110043c5624196c66ca1484fdbce6c
  • Loading branch information
shapovalov authored and facebook-github-bot committed Aug 16, 2022
1 parent 6080897 commit b7c826b
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 18 deletions.
32 changes: 23 additions & 9 deletions pytorch3d/renderer/cameras.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,31 +385,45 @@ def get_image_size(self):
return self.image_size if hasattr(self, "image_size") else None

def __getitem__(
self, index: Union[int, List[int], torch.LongTensor]
self, index: Union[int, List[int], torch.BoolTensor, torch.LongTensor]
) -> "CamerasBase":
"""
Override for the __getitem__ method in TensorProperties which needs to be
refactored.
Args:
index: an int/list/long tensor used to index all the fields in the cameras given by
self._FIELDS.
index: an integer index, list/tensor of integer indices, or tensor of boolean
indicators used to filter all the fields in the cameras given by self._FIELDS.
Returns:
if `index` is an index int/list/long tensor return an instance of the current
cameras class with only the values at the selected index.
an instance of the current cameras class with only the values at the selected index.
"""

kwargs = {}

# pyre-fixme[16]: Module `cuda` has no attribute `LongTensor`.
if not isinstance(index, (int, list, torch.LongTensor, torch.cuda.LongTensor)):
msg = "Invalid index type, expected int, List[int] or torch.LongTensor; got %r"
tensor_types = {
"bool": (torch.BoolTensor, torch.cuda.BoolTensor),
"long": (torch.LongTensor, torch.cuda.LongTensor),
}
if not isinstance(
index, (int, list, *tensor_types["bool"], *tensor_types["long"])
) or (
isinstance(index, list)
and not all(isinstance(i, int) and not isinstance(i, bool) for i in index)
):
msg = (
"Invalid index type, expected int, List[int] or Bool/LongTensor; got %r"
)
raise ValueError(msg % type(index))

if isinstance(index, int):
index = [index]

if max(index) >= len(self):
if isinstance(index, tensor_types["bool"]):
if index.ndim != 1 or index.shape[0] != len(self):
raise ValueError(
f"Boolean index of shape {index.shape} does not match cameras"
)
elif max(index) >= len(self):
raise ValueError(f"Index {max(index)} is out of bounds for select cameras")

for field in self._FIELDS:
Expand Down
4 changes: 3 additions & 1 deletion pytorch3d/structures/meshes.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,9 @@ def _set_verts_normals(self, verts_normals) -> None:
def __len__(self) -> int:
return self._N

def __getitem__(self, index) -> "Meshes":
def __getitem__(
self, index: Union[int, List[int], slice, torch.BoolTensor, torch.LongTensor]
) -> "Meshes":
"""
Args:
index: Specifying the index of the mesh to retrieve.
Expand Down
5 changes: 4 additions & 1 deletion pytorch3d/structures/pointclouds.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,10 @@ def _parse_auxiliary_input_list(
def __len__(self) -> int:
return self._N

def __getitem__(self, index) -> "Pointclouds":
def __getitem__(
self,
index: Union[int, List[int], slice, torch.BoolTensor, torch.LongTensor],
) -> "Pointclouds":
"""
Args:
index: Specifying the index of the cloud to retrieve.
Expand Down
5 changes: 4 additions & 1 deletion pytorch3d/structures/volumes.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,10 @@ def __len__(self) -> int:
return self._densities.shape[0]

def __getitem__(
self, index: Union[int, List[int], Tuple[int], slice, torch.Tensor]
self,
index: Union[
int, List[int], Tuple[int], slice, torch.BoolTensor, torch.LongTensor
],
) -> "Volumes":
"""
Args:
Expand Down
2 changes: 1 addition & 1 deletion pytorch3d/transforms/transform3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def __len__(self) -> int:
return self.get_matrix().shape[0]

def __getitem__(
self, index: Union[int, List[int], slice, torch.Tensor]
self, index: Union[int, List[int], slice, torch.BoolTensor, torch.LongTensor]
) -> "Transform3d":
"""
Args:
Expand Down
28 changes: 23 additions & 5 deletions tests/test_cameras.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,7 +884,8 @@ def test_camera_class_init(self):
self.assertTrue(new_cam.device == device)

def test_getitem(self):
R_matrix = torch.randn((6, 3, 3))
N_CAMERAS = 6
R_matrix = torch.randn((N_CAMERAS, 3, 3))
cam = FoVPerspectiveCameras(znear=10.0, zfar=100.0, R=R_matrix)

# Check get item returns an instance of the same class
Expand All @@ -908,22 +909,39 @@ def test_getitem(self):
self.assertClose(c012.R, R_matrix[0:3, ...])

# Check torch.LongTensor index
index = torch.tensor([1, 3, 5], dtype=torch.int64)
SLICE = [1, 3, 5]
index = torch.tensor(SLICE, dtype=torch.int64)
c135 = cam[index]
self.assertEqual(len(c135), 3)
self.assertClose(c135.zfar, torch.tensor([100.0] * 3))
self.assertClose(c135.znear, torch.tensor([10.0] * 3))
self.assertClose(c135.R, R_matrix[[1, 3, 5], ...])
self.assertClose(c135.R, R_matrix[SLICE, ...])

# Check torch.BoolTensor index
bool_slice = [i in SLICE for i in range(N_CAMERAS)]
index = torch.tensor(bool_slice, dtype=torch.bool)
c135 = cam[index]
self.assertEqual(len(c135), 3)
self.assertClose(c135.zfar, torch.tensor([100.0] * 3))
self.assertClose(c135.znear, torch.tensor([10.0] * 3))
self.assertClose(c135.R, R_matrix[SLICE, ...])

# Check errors with get item
with self.assertRaisesRegex(ValueError, "out of bounds"):
cam[6]
cam[N_CAMERAS]

with self.assertRaisesRegex(ValueError, "does not match cameras"):
index = torch.tensor([1, 0, 1], dtype=torch.bool)
cam[index]

with self.assertRaisesRegex(ValueError, "Invalid index type"):
cam[slice(0, 1)]

with self.assertRaisesRegex(ValueError, "Invalid index type"):
index = torch.tensor([1, 3, 5], dtype=torch.float32)
cam[[True, False]]

with self.assertRaisesRegex(ValueError, "Invalid index type"):
index = torch.tensor(SLICE, dtype=torch.float32)
cam[index]

def test_get_full_transform(self):
Expand Down

0 comments on commit b7c826b

Please sign in to comment.