Skip to content

Commit

Permalink
Pointclouds, Meshes and Textures self-references
Browse files Browse the repository at this point in the history
Summary: Use `self.__class__` when creating new instances, to slightly accommodate inheritance.

Reviewed By: nikhilaravi

Differential Revision: D21504476

fbshipit-source-id: b4600d15462fc1985da95a4cf761c7d794cfb0bb
  • Loading branch information
bottler authored and facebook-github-bot committed May 11, 2020
1 parent 34a0df0 commit 6a365d2
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 12 deletions.
8 changes: 4 additions & 4 deletions pytorch3d/structures/meshes.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,9 +436,9 @@ def __getitem__(self, index):
textures = None if self.textures is None else self.textures[index]

if torch.is_tensor(verts) and torch.is_tensor(faces):
return Meshes(verts=[verts], faces=[faces], textures=textures)
return self.__class__(verts=[verts], faces=[faces], textures=textures)
elif isinstance(verts, list) and isinstance(faces, list):
return Meshes(verts=verts, faces=faces, textures=textures)
return self.__class__(verts=verts, faces=faces, textures=textures)
else:
raise ValueError("(verts, faces) not defined correctly")

Expand Down Expand Up @@ -1127,7 +1127,7 @@ def clone(self):
faces_list = self.faces_list()
new_verts_list = [v.clone() for v in verts_list]
new_faces_list = [f.clone() for f in faces_list]
other = Meshes(verts=new_verts_list, faces=new_faces_list)
other = self.__class__(verts=new_verts_list, faces=new_faces_list)
for k in self._INTERNAL_TENSORS:
v = getattr(self, k)
if torch.is_tensor(v):
Expand Down Expand Up @@ -1370,7 +1370,7 @@ def extend(self, N: int):
if self.textures is not None:
tex = self.textures.extend(N)

return Meshes(verts=new_verts_list, faces=new_faces_list, textures=tex)
return self.__class__(verts=new_verts_list, faces=new_faces_list, textures=tex)


def join_meshes_as_batch(meshes: List[Meshes], include_textures: bool = True):
Expand Down
8 changes: 4 additions & 4 deletions pytorch3d/structures/pointclouds.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def __getitem__(self, index):
else:
raise IndexError(index)

return Pointclouds(points=points, normals=normals, features=features)
return self.__class__(points=points, normals=normals, features=features)

def isempty(self) -> bool:
"""
Expand Down Expand Up @@ -647,7 +647,7 @@ def clone(self):
new_normals = self.normals_padded().clone()
if features_padded is not None:
new_features = self.features_padded().clone()
other = Pointclouds(
other = self.__class__(
points=new_points, normals=new_normals, features=new_features
)
for k in self._INTERNAL_TENSORS:
Expand Down Expand Up @@ -920,7 +920,7 @@ def extend(self, N: int):
new_features_list = []
for features in self.features_list():
new_features_list.extend(features.clone() for _ in range(N))
return Pointclouds(
return self.__class__(
points=new_points_list, normals=new_normals_list, features=new_features_list
)

Expand Down Expand Up @@ -959,7 +959,7 @@ def check_shapes(x, size):
if new_features_padded is not None:
check_shapes(new_features_padded, [self._N, self._P, self._C])

new = Pointclouds(
new = self.__class__(
points=new_points_padded,
normals=new_normals_padded,
features=new_features_padded,
Expand Down
8 changes: 4 additions & 4 deletions pytorch3d/structures/textures.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def __init__(
self._num_verts_per_mesh = None

def clone(self):
other = Textures()
other = self.__class__()
for k in dir(self):
v = getattr(self, k)
if torch.is_tensor(v):
Expand All @@ -144,7 +144,7 @@ def to(self, device):
return self

def __getitem__(self, index):
other = Textures()
other = self.__class__()
for key in dir(self):
value = getattr(self, key)
if torch.is_tensor(value):
Expand Down Expand Up @@ -237,12 +237,12 @@ def extend(self, N: int) -> "Textures":
new_verts_uvs = _extend_tensor(self._verts_uvs_padded, N)
new_faces_uvs = _extend_tensor(self._faces_uvs_padded, N)
new_maps = _extend_tensor(self._maps_padded, N)
return Textures(
return self.__class__(
verts_uvs=new_verts_uvs, faces_uvs=new_faces_uvs, maps=new_maps
)
elif self._verts_rgb_padded is not None:
new_verts_rgb = _extend_tensor(self._verts_rgb_padded, N)
return Textures(verts_rgb=new_verts_rgb)
return self.__class__(verts_rgb=new_verts_rgb)
else:
msg = "Either vertex colors or texture maps are required."
raise ValueError(msg)

0 comments on commit 6a365d2

Please sign in to comment.