Skip to content

Commit

Permalink
Fix Transform3d.stack of compositions
Browse files Browse the repository at this point in the history
Summary:
Add a test for Transform3d.stack, and make it work with composed transformations.

Fixes #1072 .

Reviewed By: patricklabatut

Differential Revision: D34211920

fbshipit-source-id: bfbd0895494ca2ad3d08a61bc82ba23637e168cc
  • Loading branch information
bottler authored and facebook-github-bot committed Feb 15, 2022
1 parent 2a1de3b commit c8f3d6b
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 26 deletions.
2 changes: 1 addition & 1 deletion pytorch3d/renderer/cameras.py
Original file line number Diff line number Diff line change
Expand Up @@ -1649,7 +1649,7 @@ def look_at_view_transform(
elev=0.0,
azim=0.0,
degrees: bool = True,
eye: Optional[Sequence] = None,
eye: Optional[Union[Sequence, torch.Tensor]] = None,
at=((0, 0, 0),), # (1, 3)
up=((0, 1, 0),), # (1, 3)
device: Device = "cpu",
Expand Down
62 changes: 37 additions & 25 deletions pytorch3d/transforms/transform3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,10 @@ def __getitem__(
index = [index]
return self.__class__(matrix=self.get_matrix()[index])

def compose(self, *others):
def compose(self, *others: "Transform3d") -> "Transform3d":
"""
Return a new Transform3d with the transforms to compose stored as
an internal list.
Return a new Transform3d representing the composition of self with the
given other transforms, which will be stored as an internal list.
Args:
*others: Any number of Transform3d objects
Expand All @@ -216,7 +216,7 @@ def compose(self, *others):
out._transforms = self._transforms + list(others)
return out

def get_matrix(self):
def get_matrix(self) -> torch.Tensor:
"""
Return a matrix which is the result of composing this transform
with others stored in self.transforms. Where necessary transforms
Expand All @@ -240,13 +240,13 @@ def get_matrix(self):
composed_matrix = _broadcast_bmm(composed_matrix, other_matrix)
return composed_matrix

def _get_matrix_inverse(self):
def _get_matrix_inverse(self) -> torch.Tensor:
"""
Return the inverse of self._matrix.
"""
return torch.inverse(self._matrix)

def inverse(self, invert_composed: bool = False):
def inverse(self, invert_composed: bool = False) -> "Transform3d":
"""
Returns a new Transform3d object that represents an inverse of the
current transformation.
Expand Down Expand Up @@ -295,14 +295,24 @@ def inverse(self, invert_composed: bool = False):

return tinv

def stack(self, *others):
def stack(self, *others: "Transform3d") -> "Transform3d":
"""
Return a new batched Transform3d representing the batch elements from
self and all the given other transforms all batched together.
Args:
*others: Any number of Transform3d objects
Returns:
A new Transform3d.
"""
transforms = [self] + list(others)
matrix = torch.cat([t._matrix for t in transforms], dim=0)
matrix = torch.cat([t.get_matrix() for t in transforms], dim=0)
out = Transform3d(dtype=self.dtype, device=self.device)
out._matrix = matrix
return out

def transform_points(self, points, eps: Optional[float] = None):
def transform_points(self, points, eps: Optional[float] = None) -> torch.Tensor:
"""
Use this transform to transform a set of 3D points. Assumes row major
ordering of the input points.
Expand Down Expand Up @@ -347,7 +357,7 @@ def transform_points(self, points, eps: Optional[float] = None):

return points_out

def transform_normals(self, normals):
def transform_normals(self, normals) -> torch.Tensor:
"""
Use this transform to transform a set of normal vectors.
Expand Down Expand Up @@ -379,19 +389,19 @@ def transform_normals(self, normals):

return normals_out

def translate(self, *args, **kwargs):
def translate(self, *args, **kwargs) -> "Transform3d":
return self.compose(Translate(device=self.device, *args, **kwargs))

def scale(self, *args, **kwargs):
def scale(self, *args, **kwargs) -> "Transform3d":
return self.compose(Scale(device=self.device, *args, **kwargs))

def rotate(self, *args, **kwargs):
def rotate(self, *args, **kwargs) -> "Transform3d":
return self.compose(Rotate(device=self.device, *args, **kwargs))

def rotate_axis_angle(self, *args, **kwargs):
def rotate_axis_angle(self, *args, **kwargs) -> "Transform3d":
return self.compose(RotateAxisAngle(device=self.device, *args, **kwargs))

def clone(self):
def clone(self) -> "Transform3d":
"""
Deep copy of Transforms object. All internal tensors are cloned
individually.
Expand All @@ -411,7 +421,7 @@ def to(
device: Device,
copy: bool = False,
dtype: Optional[torch.dtype] = None,
):
) -> "Transform3d":
"""
Match functionality of torch.Tensor.to()
If copy = True or the self Tensor is on a different device, the
Expand Down Expand Up @@ -448,10 +458,10 @@ def to(
]
return other

def cpu(self):
def cpu(self) -> "Transform3d":
return self.to("cpu")

def cuda(self):
def cuda(self) -> "Transform3d":
return self.to("cuda")


Expand Down Expand Up @@ -486,7 +496,7 @@ def __init__(
mat[:, 3, :3] = xyz
self._matrix = mat

def _get_matrix_inverse(self):
def _get_matrix_inverse(self) -> torch.Tensor:
"""
Return the inverse of self._matrix.
"""
Expand Down Expand Up @@ -533,7 +543,7 @@ def __init__(
mat[:, 2, 2] = xyz[:, 2]
self._matrix = mat

def _get_matrix_inverse(self):
def _get_matrix_inverse(self) -> torch.Tensor:
"""
Return the inverse of self._matrix.
"""
Expand Down Expand Up @@ -575,7 +585,7 @@ def __init__(
mat[:, :3, :3] = R
self._matrix = mat

def _get_matrix_inverse(self):
def _get_matrix_inverse(self) -> torch.Tensor:
"""
Return the inverse of self._matrix.
"""
Expand Down Expand Up @@ -622,7 +632,7 @@ def __init__(
super().__init__(device=angle.device, R=R)


def _handle_coord(c, dtype: torch.dtype, device: torch.device):
def _handle_coord(c, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
"""
Helper function for _handle_input.
Expand All @@ -649,7 +659,7 @@ def _handle_input(
device: Optional[Device],
name: str,
allow_singleton: bool = False,
):
) -> torch.Tensor:
"""
Helper function to handle parsing logic for building transforms. The output
is always a tensor of shape (N, 3), but there are several types of allowed
Expand Down Expand Up @@ -707,7 +717,9 @@ def _handle_input(
return xyz


def _handle_angle_input(x, dtype: torch.dtype, device: Optional[Device], name: str):
def _handle_angle_input(
x, dtype: torch.dtype, device: Optional[Device], name: str
) -> torch.Tensor:
"""
Helper function for building a rotation function using angles.
The output is always of shape (N,).
Expand All @@ -725,7 +737,7 @@ def _handle_angle_input(x, dtype: torch.dtype, device: Optional[Device], name: s
return _handle_coord(x, dtype, device_)


def _broadcast_bmm(a, b):
def _broadcast_bmm(a, b) -> torch.Tensor:
"""
Batch multiply two matrices and broadcast if necessary.
Expand Down
26 changes: 26 additions & 0 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import torch
from common_testing import TestCaseMixin
from pytorch3d.transforms import random_rotations
from pytorch3d.transforms.so3 import so3_exp_map
from pytorch3d.transforms.transform3d import (
Rotate,
Expand All @@ -21,6 +22,9 @@


class TestTransform(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
torch.manual_seed(42)

def test_to(self):
tr = Translate(torch.FloatTensor([[1.0, 2.0, 3.0]]))
R = torch.FloatTensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]])
Expand Down Expand Up @@ -406,6 +410,28 @@ def test_get_item(self, batch_size=5):
with self.assertRaises(IndexError):
t3d_selected = t3d[invalid_index]

def test_stack(self):
rotations = random_rotations(3)
transform3 = Transform3d().rotate(rotations).translate(torch.full((3, 3), 0.3))
transform1 = Scale(37)
transform4 = transform1.stack(transform3)
self.assertEqual(len(transform1), 1)
self.assertEqual(len(transform3), 3)
self.assertEqual(len(transform4), 4)
self.assertClose(
transform4.get_matrix(),
torch.cat([transform1.get_matrix(), transform3.get_matrix()]),
)
points = torch.rand(4, 5, 3)
new_points_expect = torch.cat(
[
transform1.transform_points(points[:1]),
transform3.transform_points(points[1:]),
]
)
new_points = transform4.transform_points(points)
self.assertClose(new_points, new_points_expect)


class TestTranslate(unittest.TestCase):
def test_python_scalar(self):
Expand Down

0 comments on commit c8f3d6b

Please sign in to comment.