diff --git a/pytorch3d/transforms/rotation_conversions.py b/pytorch3d/transforms/rotation_conversions.py index c8b4b3a68..7c01b569d 100644 --- a/pytorch3d/transforms/rotation_conversions.py +++ b/pytorch3d/transforms/rotation_conversions.py @@ -5,6 +5,32 @@ import torch +""" +The transformation matrices returned from the functions in this file assume +the points on which the transformation will be applied are column vectors. +i.e. the R matrix is structured as + + R = [ + [Rxx, Rxy, Rxz], + [Ryx, Ryy, Ryz], + [Rzx, Rzy, Rzz], + ] # (3, 3) + +This matrix can be applied to column vectors by post multiplication +by the points e.g. + + points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point + transformed_points = R * points + +To apply the same matrix to points which are row vectors, the R matrix +can be transposed and pre multiplied by the points: + +e.g. + points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point + transformed_points = points * R.transpose(1, 0) +""" + + def quaternion_to_matrix(quaternions): """ Convert rotations given as quaternions to rotation matrices. @@ -80,7 +106,7 @@ def matrix_to_quaternion(matrix): return torch.stack((o0, o1, o2, o3), -1) -def _primary_matrix(axis: str, angle): +def _axis_angle_rotation(axis: str, angle): """ Return the rotation matrices for one of the rotations about an axis of which Euler angles describe, for each value of the angle given. @@ -92,17 +118,20 @@ def _primary_matrix(axis: str, angle): Returns: Rotation matrices as tensor of shape (..., 3, 3). """ + cos = torch.cos(angle) sin = torch.sin(angle) one = torch.ones_like(angle) zero = torch.zeros_like(angle) + if axis == "X": - o = (one, zero, zero, zero, cos, -sin, zero, sin, cos) + R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) if axis == "Y": - o = (cos, zero, sin, zero, one, zero, -sin, zero, cos) + R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) if axis == "Z": - o = (cos, -sin, zero, sin, cos, zero, zero, zero, one) - return torch.stack(o, -1).reshape(angle.shape + (3, 3)) + R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) + + return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) def euler_angles_to_matrix(euler_angles, convention: str): @@ -126,7 +155,9 @@ def euler_angles_to_matrix(euler_angles, convention: str): for letter in convention: if letter not in ("X", "Y", "Z"): raise ValueError(f"Invalid letter {letter} in convention string.") - matrices = map(_primary_matrix, convention, torch.unbind(euler_angles, -1)) + matrices = map( + _axis_angle_rotation, convention, torch.unbind(euler_angles, -1) + ) return functools.reduce(torch.matmul, matrices) diff --git a/pytorch3d/transforms/transform3d.py b/pytorch3d/transforms/transform3d.py index c04835138..7b88a5aa1 100644 --- a/pytorch3d/transforms/transform3d.py +++ b/pytorch3d/transforms/transform3d.py @@ -5,6 +5,8 @@ import warnings import torch +from .rotation_conversions import _axis_angle_rotation + class Transform3d: """ @@ -103,12 +105,35 @@ class Transform3d: s1_params -= lr * s1_params.grad t_params -= lr * t_params.grad s2_params -= lr * s2_params.grad + + CONVENTIONS + We adopt a right-hand coordinate system, meaning that rotation about an axis + with a positive angle results in a counter clockwise rotation. + + This class assumes that transformations are applied on inputs which + are row vectors. The internal representation of the Nx4x4 transformation + matrix is of the form: + + .. code-block:: python + + M = [ + [Rxx, Ryx, Rzx, 0], + [Rxy, Ryy, Rzy, 0], + [Rxz, Ryz, Rzz, 0], + [Tx, Ty, Tz, 1], + ] + + To apply the transformation to points which are row vectors, the M matrix + can be pre multiplied by the points: + + .. code-block:: python + + points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point + transformed_points = points * M + """ def __init__(self, dtype=torch.float32, device="cpu"): - """ - This class assumes a row major ordering for all matrices. - """ self._matrix = torch.eye(4, dtype=dtype, device=device).view(1, 4, 4) self._transforms = [] # store transforms to compose self._lu = None @@ -493,9 +518,12 @@ def __init__( Create a new Transform3d representing 3D rotation about an axis by an angle. + Assuming a right-hand coordinate system, positive rotation angles result + in a counter clockwise rotation. + Args: angle: - - A torch tensor of shape (N, 1) + - A torch tensor of shape (N,) - A python scalar - A torch scalar axis: @@ -509,21 +537,11 @@ def __init__( raise ValueError(msg % axis) angle = _handle_angle_input(angle, dtype, device, "RotateAxisAngle") angle = (angle / 180.0 * math.pi) if degrees else angle - N = angle.shape[0] - - cos = torch.cos(angle) - sin = torch.sin(angle) - one = torch.ones_like(angle) - zero = torch.zeros_like(angle) - - if axis == "X": - R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) - if axis == "Y": - R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) - if axis == "Z": - R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) - - R = torch.stack(R_flat, -1).reshape((N, 3, 3)) + # We assume the points on which this transformation will be applied + # are row vectors. The rotation matrix returned from _axis_angle_rotation + # is for transforming column vectors. Therefore we transpose this matrix. + # R will always be of shape (N, 3, 3) + R = _axis_angle_rotation(axis, angle).transpose(1, 2) super().__init__(device=device, R=R) @@ -606,19 +624,16 @@ def _handle_input( def _handle_angle_input(x, dtype, device: str, name: str): """ Helper function for building a rotation function using angles. - The output is always of shape (N, 1). + The output is always of shape (N,). The input can be one of: - - Torch tensor (N, 1) or (N) + - Torch tensor of shape (N,) - Python scalar - Torch scalar """ - # If x is actually a tensor of shape (N, 1) then just return it - if torch.is_tensor(x) and x.dim() == 2: - if x.shape[1] != 1: - msg = "Expected tensor of shape (N, 1); got %r (in %s)" - raise ValueError(msg % (x.shape, name)) - return x + if torch.is_tensor(x) and x.dim() > 1: + msg = "Expected tensor of shape (N,); got %r (in %s)" + raise ValueError(msg % (x.shape, name)) else: return _handle_coord(x, dtype, device) diff --git a/tests/test_rotation_conversions.py b/tests/test_rotation_conversions.py index 4f35d8c95..1f14c72e5 100644 --- a/tests/test_rotation_conversions.py +++ b/tests/test_rotation_conversions.py @@ -8,6 +8,7 @@ import torch from pytorch3d.transforms.rotation_conversions import ( + _axis_angle_rotation, euler_angles_to_matrix, matrix_to_euler_angles, matrix_to_quaternion, @@ -118,7 +119,6 @@ def test_from_euler(self): def test_to_euler(self): """mtx -> euler -> mtx""" data = random_rotations(13, dtype=torch.float64) - for convention in self._all_euler_angle_conventions(): euler_angles = matrix_to_euler_angles(data, convention) mdata = euler_angles_to_matrix(euler_angles, convention) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 1db25ff7c..82b168298 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -120,7 +120,7 @@ def test_scale_translate(self): self.assertTrue(torch.allclose(normals_out, normals_out_expected)) def test_rotate_axis_angle(self): - t = Transform3d().rotate_axis_angle(-90.0, axis="Z") + t = Transform3d().rotate_axis_angle(90.0, axis="Z") points = torch.tensor( [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 1.0]] ).view(1, 3, 3) @@ -737,15 +737,23 @@ def test_rotate_x_python_scalar(self): matrix = torch.tensor( [ [ - [1.0, 0.0, 0.0, 0.0], # noqa: E241, E201 - [0.0, 0.0, -1.0, 0.0], # noqa: E241, E201 - [0.0, 1.0, 0.0, 0.0], # noqa: E241, E201 - [0.0, 0.0, 0.0, 1.0], # noqa: E241, E201 + [1.0, 0.0, 0.0, 0.0], # noqa: E241, E201 + [0.0, 0.0, 1.0, 0.0], # noqa: E241, E201 + [0.0, -1.0, 0.0, 0.0], # noqa: E241, E201 + [0.0, 0.0, 0.0, 1.0], # noqa: E241, E201 ] ], dtype=torch.float32, ) # fmt: on + points = torch.tensor([0.0, 1.0, 0.0])[None, None, :] # (1, 1, 3) + transformed_points = t.transform_points(points) + expected_points = torch.tensor([0.0, 0.0, 1.0]) + self.assertTrue( + torch.allclose( + transformed_points.squeeze(), expected_points, atol=1e-7 + ) + ) self.assertTrue(torch.allclose(t._matrix, matrix)) def test_rotate_x_torch_scalar(self): @@ -755,15 +763,23 @@ def test_rotate_x_torch_scalar(self): matrix = torch.tensor( [ [ - [1.0, 0.0, 0.0, 0.0], # noqa: E241, E201 - [0.0, 0.0, -1.0, 0.0], # noqa: E241, E201 - [0.0, 1.0, 0.0, 0.0], # noqa: E241, E201 - [0.0, 0.0, 0.0, 1.0], # noqa: E241, E201 + [1.0, 0.0, 0.0, 0.0], # noqa: E241, E201 + [0.0, 0.0, 1.0, 0.0], # noqa: E241, E201 + [0.0, -1.0, 0.0, 0.0], # noqa: E241, E201 + [0.0, 0.0, 0.0, 1.0], # noqa: E241, E201 ] ], dtype=torch.float32, ) # fmt: on + points = torch.tensor([0.0, 1.0, 0.0])[None, None, :] # (1, 1, 3) + transformed_points = t.transform_points(points) + expected_points = torch.tensor([0.0, 0.0, 1.0]) + self.assertTrue( + torch.allclose( + transformed_points.squeeze(), expected_points, atol=1e-7 + ) + ) self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7)) def test_rotate_x_torch_tensor(self): @@ -781,23 +797,23 @@ def test_rotate_x_torch_tensor(self): [0.0, 0.0, 0.0, 1.0], ], [ - [1.0, 0.0, 0.0, 0.0], # noqa: E241, E201 - [0.0, r2_2, -r2_i, 0.0], # noqa: E241, E201 - [0.0, r2_i, r2_2, 0.0], # noqa: E241, E201 - [0.0, 0.0, 0.0, 1.0], # noqa: E241, E201 + [1.0, 0.0, 0.0, 0.0], # noqa: E241, E201 + [0.0, r2_2, r2_i, 0.0], # noqa: E241, E201 + [0.0, -r2_i, r2_2, 0.0], # noqa: E241, E201 + [0.0, 0.0, 0.0, 1.0], # noqa: E241, E201 ], [ - [1.0, 0.0, 0.0, 0.0], # noqa: E241, E201 - [0.0, 0.0, -1.0, 0.0], # noqa: E241, E201 - [0.0, 1.0, 0.0, 0.0], # noqa: E241, E201 - [0.0, 0.0, 0.0, 1.0], # noqa: E241, E201 + [1.0, 0.0, 0.0, 0.0], # noqa: E241, E201 + [0.0, 0.0, 1.0, 0.0], # noqa: E241, E201 + [0.0, -1.0, 0.0, 0.0], # noqa: E241, E201 + [0.0, 0.0, 0.0, 1.0], # noqa: E241, E201 ] ], dtype=torch.float32, ) # fmt: on self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7)) - angle = angle[..., None] # (N, 1) + angle = angle t = RotateAxisAngle(angle=angle, axis="X") self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7)) @@ -807,33 +823,54 @@ def test_rotate_y_python_scalar(self): matrix = torch.tensor( [ [ - [ 0.0, 0.0, 1.0, 0.0], # noqa: E241, E201 - [ 0.0, 1.0, 0.0, 0.0], # noqa: E241, E201 - [-1.0, 0.0, 0.0, 0.0], # noqa: E241, E201 - [ 0.0, 0.0, 0.0, 1.0], # noqa: E241, E201 + [0.0, 0.0, -1.0, 0.0], # noqa: E241, E201 + [0.0, 1.0, 0.0, 0.0], # noqa: E241, E201 + [1.0, 0.0, 0.0, 0.0], # noqa: E241, E201 + [0.0, 0.0, 0.0, 1.0], # noqa: E241, E201 ] ], dtype=torch.float32, ) # fmt: on + points = torch.tensor([1.0, 0.0, 0.0])[None, None, :] # (1, 1, 3) + transformed_points = t.transform_points(points) + expected_points = torch.tensor([0.0, 0.0, -1.0]) + self.assertTrue( + torch.allclose( + transformed_points.squeeze(), expected_points, atol=1e-7 + ) + ) self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7)) def test_rotate_y_torch_scalar(self): + """ + Test rotation about Y axis. With a right hand coordinate system this + should result in a vector pointing along the x-axis being rotated to + point along the negative z axis. + """ angle = torch.tensor(90.0) t = RotateAxisAngle(angle=angle, axis="Y") # fmt: off matrix = torch.tensor( [ [ - [ 0.0, 0.0, 1.0, 0.0], # noqa: E241, E201 - [ 0.0, 1.0, 0.0, 0.0], # noqa: E241, E201 - [-1.0, 0.0, 0.0, 0.0], # noqa: E241, E201 - [ 0.0, 0.0, 0.0, 1.0], # noqa: E241, E201 + [0.0, 0.0, -1.0, 0.0], # noqa: E241, E201 + [0.0, 1.0, 0.0, 0.0], # noqa: E241, E201 + [1.0, 0.0, 0.0, 0.0], # noqa: E241, E201 + [0.0, 0.0, 0.0, 1.0], # noqa: E241, E201 ] ], dtype=torch.float32, ) # fmt: on + points = torch.tensor([1.0, 0.0, 0.0])[None, None, :] # (1, 1, 3) + transformed_points = t.transform_points(points) + expected_points = torch.tensor([0.0, 0.0, -1.0]) + self.assertTrue( + torch.allclose( + transformed_points.squeeze(), expected_points, atol=1e-7 + ) + ) self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7)) def test_rotate_y_torch_tensor(self): @@ -851,16 +888,16 @@ def test_rotate_y_torch_tensor(self): [0.0, 0.0, 0.0, 1.0], ], [ - [ r2_2, 0.0, r2_i, 0.0], # noqa: E241, E201 - [ 0.0, 1.0, 0.0, 0.0], # noqa: E241, E201 - [-r2_i, 0.0, r2_2, 0.0], # noqa: E241, E201 - [ 0.0, 0.0, 0.0, 1.0], # noqa: E241, E201 + [r2_2, 0.0, -r2_i, 0.0], # noqa: E241, E201 + [ 0.0, 1.0, 0.0, 0.0], # noqa: E241, E201 + [r2_i, 0.0, r2_2, 0.0], # noqa: E241, E201 + [ 0.0, 0.0, 0.0, 1.0], # noqa: E241, E201 ], [ - [ 0.0, 0.0, 1.0, 0.0], # noqa: E241, E201 - [ 0.0, 1.0, 0.0, 0.0], # noqa: E241, E201 - [-1.0, 0.0, 0.0, 0.0], # noqa: E241, E201 - [ 0.0, 0.0, 0.0, 1.0], # noqa: E241, E201 + [0.0, 0.0, -1.0, 0.0], # noqa: E241, E201 + [0.0, 1.0, 0.0, 0.0], # noqa: E241, E201 + [1.0, 0.0, 0.0, 0.0], # noqa: E241, E201 + [0.0, 0.0, 0.0, 1.0], # noqa: E241, E201 ] ], dtype=torch.float32, @@ -874,15 +911,23 @@ def test_rotate_z_python_scalar(self): matrix = torch.tensor( [ [ - [0.0, -1.0, 0.0, 0.0], # noqa: E241, E201 - [1.0, 0.0, 0.0, 0.0], # noqa: E241, E201 - [0.0, 0.0, 1.0, 0.0], # noqa: E241, E201 - [0.0, 0.0, 0.0, 1.0], # noqa: E241, E201 + [ 0.0, 1.0, 0.0, 0.0], # noqa: E241, E201 + [-1.0, 0.0, 0.0, 0.0], # noqa: E241, E201 + [ 0.0, 0.0, 1.0, 0.0], # noqa: E241, E201 + [ 0.0, 0.0, 0.0, 1.0], # noqa: E241, E201 ] ], dtype=torch.float32, ) # fmt: on + points = torch.tensor([1.0, 0.0, 0.0])[None, None, :] # (1, 1, 3) + transformed_points = t.transform_points(points) + expected_points = torch.tensor([0.0, 1.0, 0.0]) + self.assertTrue( + torch.allclose( + transformed_points.squeeze(), expected_points, atol=1e-7 + ) + ) self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7)) def test_rotate_z_torch_scalar(self): @@ -892,15 +937,23 @@ def test_rotate_z_torch_scalar(self): matrix = torch.tensor( [ [ - [0.0, -1.0, 0.0, 0.0], # noqa: E241, E201 - [1.0, 0.0, 0.0, 0.0], # noqa: E241, E201 - [0.0, 0.0, 1.0, 0.0], # noqa: E241, E201 - [0.0, 0.0, 0.0, 1.0], # noqa: E241, E201 + [ 0.0, 1.0, 0.0, 0.0], # noqa: E241, E201 + [-1.0, 0.0, 0.0, 0.0], # noqa: E241, E201 + [ 0.0, 0.0, 1.0, 0.0], # noqa: E241, E201 + [ 0.0, 0.0, 0.0, 1.0], # noqa: E241, E201 ] ], dtype=torch.float32, ) # fmt: on + points = torch.tensor([1.0, 0.0, 0.0])[None, None, :] # (1, 1, 3) + transformed_points = t.transform_points(points) + expected_points = torch.tensor([0.0, 1.0, 0.0]) + self.assertTrue( + torch.allclose( + transformed_points.squeeze(), expected_points, atol=1e-7 + ) + ) self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7)) def test_rotate_z_torch_tensor(self): @@ -918,16 +971,16 @@ def test_rotate_z_torch_tensor(self): [0.0, 0.0, 0.0, 1.0], ], [ - [r2_2, -r2_i, 0.0, 0.0], # noqa: E241, E201 - [r2_i, r2_2, 0.0, 0.0], # noqa: E241, E201 - [ 0.0, 0.0, 1.0, 0.0], # noqa: E241, E201 - [ 0.0, 0.0, 0.0, 1.0], # noqa: E241, E201 + [ r2_2, r2_i, 0.0, 0.0], # noqa: E241, E201 + [-r2_i, r2_2, 0.0, 0.0], # noqa: E241, E201 + [ 0.0, 0.0, 1.0, 0.0], # noqa: E241, E201 + [ 0.0, 0.0, 0.0, 1.0], # noqa: E241, E201 ], [ - [0.0, -1.0, 0.0, 0.0], # noqa: E241, E201 - [1.0, 0.0, 0.0, 0.0], # noqa: E241, E201 - [0.0, 0.0, 1.0, 0.0], # noqa: E241, E201 - [0.0, 0.0, 0.0, 1.0], # noqa: E241, E201 + [ 0.0, 1.0, 0.0, 0.0], # noqa: E241, E201 + [-1.0, 0.0, 0.0, 0.0], # noqa: E241, E201 + [ 0.0, 0.0, 1.0, 0.0], # noqa: E241, E201 + [ 0.0, 0.0, 0.0, 1.0], # noqa: E241, E201 ] ], dtype=torch.float32, @@ -945,10 +998,10 @@ def test_rotate_compose_x_y_z(self): matrix1 = torch.tensor( [ [ - [1.0, 0.0, 0.0, 0.0], # noqa: E241, E201 - [0.0, 0.0, -1.0, 0.0], # noqa: E241, E201 - [0.0, 1.0, 0.0, 0.0], # noqa: E241, E201 - [0.0, 0.0, 0.0, 1.0], # noqa: E241, E201 + [1.0, 0.0, 0.0, 0.0], # noqa: E241, E201 + [0.0, 0.0, 1.0, 0.0], # noqa: E241, E201 + [0.0, -1.0, 0.0, 0.0], # noqa: E241, E201 + [0.0, 0.0, 0.0, 1.0], # noqa: E241, E201 ] ], dtype=torch.float32, @@ -956,10 +1009,10 @@ def test_rotate_compose_x_y_z(self): matrix2 = torch.tensor( [ [ - [ 0.0, 0.0, 1.0, 0.0], # noqa: E241, E201 - [ 0.0, 1.0, 0.0, 0.0], # noqa: E241, E201 - [-1.0, 0.0, 0.0, 0.0], # noqa: E241, E201 - [ 0.0, 0.0, 0.0, 1.0], # noqa: E241, E201 + [0.0, 0.0, -1.0, 0.0], # noqa: E241, E201 + [0.0, 1.0, 0.0, 0.0], # noqa: E241, E201 + [1.0, 0.0, 0.0, 0.0], # noqa: E241, E201 + [0.0, 0.0, 0.0, 1.0], # noqa: E241, E201 ] ], dtype=torch.float32, @@ -967,10 +1020,10 @@ def test_rotate_compose_x_y_z(self): matrix3 = torch.tensor( [ [ - [0.0, -1.0, 0.0, 0.0], # noqa: E241, E201 - [1.0, 0.0, 0.0, 0.0], # noqa: E241, E201 - [0.0, 0.0, 1.0, 0.0], # noqa: E241, E201 - [0.0, 0.0, 0.0, 1.0], # noqa: E241, E201 + [ 0.0, 1.0, 0.0, 0.0], # noqa: E241, E201 + [-1.0, 0.0, 0.0, 0.0], # noqa: E241, E201 + [ 0.0, 0.0, 1.0, 0.0], # noqa: E241, E201 + [ 0.0, 0.0, 0.0, 1.0], # noqa: E241, E201 ] ], dtype=torch.float32, @@ -987,10 +1040,10 @@ def test_rotate_angle_radians(self): matrix = torch.tensor( [ [ - [0.0, -1.0, 0.0, 0.0], # noqa: E241, E201 - [1.0, 0.0, 0.0, 0.0], # noqa: E241, E201 - [0.0, 0.0, 1.0, 0.0], # noqa: E241, E201 - [0.0, 0.0, 0.0, 1.0], # noqa: E241, E201 + [ 0.0, 1.0, 0.0, 0.0], # noqa: E241, E201 + [-1.0, 0.0, 0.0, 0.0], # noqa: E241, E201 + [ 0.0, 0.0, 1.0, 0.0], # noqa: E241, E201 + [ 0.0, 0.0, 0.0, 1.0], # noqa: E241, E201 ] ], dtype=torch.float32, @@ -1004,10 +1057,10 @@ def test_lower_case_axis(self): matrix = torch.tensor( [ [ - [0.0, -1.0, 0.0, 0.0], # noqa: E241, E201 - [1.0, 0.0, 0.0, 0.0], # noqa: E241, E201 - [0.0, 0.0, 1.0, 0.0], # noqa: E241, E201 - [0.0, 0.0, 0.0, 1.0], # noqa: E241, E201 + [ 0.0, 1.0, 0.0, 0.0], # noqa: E241, E201 + [-1.0, 0.0, 0.0, 0.0], # noqa: E241, E201 + [ 0.0, 0.0, 1.0, 0.0], # noqa: E241, E201 + [ 0.0, 0.0, 0.0, 1.0], # noqa: E241, E201 ] ], dtype=torch.float32,