Skip to content

Commit

Permalink
transforms 3d convention fix
Browse files Browse the repository at this point in the history
Summary: Fixed the rotation matrices generated by the RotateAxisAngle class and updated the tests. Added documentation for Transforms3d to clarify the conventions.

Reviewed By: gkioxari

Differential Revision: D19912903

fbshipit-source-id: c64926ce4e1381b145811557c32b73663d6d92d1
  • Loading branch information
nikhilaravi authored and facebook-github-bot committed Feb 19, 2020
1 parent bdc2bb5 commit 8301163
Show file tree
Hide file tree
Showing 4 changed files with 203 additions and 104 deletions.
43 changes: 37 additions & 6 deletions pytorch3d/transforms/rotation_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,32 @@
import torch


"""
The transformation matrices returned from the functions in this file assume
the points on which the transformation will be applied are column vectors.
i.e. the R matrix is structured as
R = [
[Rxx, Rxy, Rxz],
[Ryx, Ryy, Ryz],
[Rzx, Rzy, Rzz],
] # (3, 3)
This matrix can be applied to column vectors by post multiplication
by the points e.g.
points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point
transformed_points = R * points
To apply the same matrix to points which are row vectors, the R matrix
can be transposed and pre multiplied by the points:
e.g.
points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point
transformed_points = points * R.transpose(1, 0)
"""


def quaternion_to_matrix(quaternions):
"""
Convert rotations given as quaternions to rotation matrices.
Expand Down Expand Up @@ -80,7 +106,7 @@ def matrix_to_quaternion(matrix):
return torch.stack((o0, o1, o2, o3), -1)


def _primary_matrix(axis: str, angle):
def _axis_angle_rotation(axis: str, angle):
"""
Return the rotation matrices for one of the rotations about an axis
of which Euler angles describe, for each value of the angle given.
Expand All @@ -92,17 +118,20 @@ def _primary_matrix(axis: str, angle):
Returns:
Rotation matrices as tensor of shape (..., 3, 3).
"""

cos = torch.cos(angle)
sin = torch.sin(angle)
one = torch.ones_like(angle)
zero = torch.zeros_like(angle)

if axis == "X":
o = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
if axis == "Y":
o = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
if axis == "Z":
o = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
return torch.stack(o, -1).reshape(angle.shape + (3, 3))
R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)

return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))


def euler_angles_to_matrix(euler_angles, convention: str):
Expand All @@ -126,7 +155,9 @@ def euler_angles_to_matrix(euler_angles, convention: str):
for letter in convention:
if letter not in ("X", "Y", "Z"):
raise ValueError(f"Invalid letter {letter} in convention string.")
matrices = map(_primary_matrix, convention, torch.unbind(euler_angles, -1))
matrices = map(
_axis_angle_rotation, convention, torch.unbind(euler_angles, -1)
)
return functools.reduce(torch.matmul, matrices)


Expand Down
69 changes: 42 additions & 27 deletions pytorch3d/transforms/transform3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import warnings
import torch

from .rotation_conversions import _axis_angle_rotation


class Transform3d:
"""
Expand Down Expand Up @@ -103,12 +105,35 @@ class Transform3d:
s1_params -= lr * s1_params.grad
t_params -= lr * t_params.grad
s2_params -= lr * s2_params.grad
CONVENTIONS
We adopt a right-hand coordinate system, meaning that rotation about an axis
with a positive angle results in a counter clockwise rotation.
This class assumes that transformations are applied on inputs which
are row vectors. The internal representation of the Nx4x4 transformation
matrix is of the form:
.. code-block:: python
M = [
[Rxx, Ryx, Rzx, 0],
[Rxy, Ryy, Rzy, 0],
[Rxz, Ryz, Rzz, 0],
[Tx, Ty, Tz, 1],
]
To apply the transformation to points which are row vectors, the M matrix
can be pre multiplied by the points:
.. code-block:: python
points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point
transformed_points = points * M
"""

def __init__(self, dtype=torch.float32, device="cpu"):
"""
This class assumes a row major ordering for all matrices.
"""
self._matrix = torch.eye(4, dtype=dtype, device=device).view(1, 4, 4)
self._transforms = [] # store transforms to compose
self._lu = None
Expand Down Expand Up @@ -493,9 +518,12 @@ def __init__(
Create a new Transform3d representing 3D rotation about an axis
by an angle.
Assuming a right-hand coordinate system, positive rotation angles result
in a counter clockwise rotation.
Args:
angle:
- A torch tensor of shape (N, 1)
- A torch tensor of shape (N,)
- A python scalar
- A torch scalar
axis:
Expand All @@ -509,21 +537,11 @@ def __init__(
raise ValueError(msg % axis)
angle = _handle_angle_input(angle, dtype, device, "RotateAxisAngle")
angle = (angle / 180.0 * math.pi) if degrees else angle
N = angle.shape[0]

cos = torch.cos(angle)
sin = torch.sin(angle)
one = torch.ones_like(angle)
zero = torch.zeros_like(angle)

if axis == "X":
R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
if axis == "Y":
R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
if axis == "Z":
R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)

R = torch.stack(R_flat, -1).reshape((N, 3, 3))
# We assume the points on which this transformation will be applied
# are row vectors. The rotation matrix returned from _axis_angle_rotation
# is for transforming column vectors. Therefore we transpose this matrix.
# R will always be of shape (N, 3, 3)
R = _axis_angle_rotation(axis, angle).transpose(1, 2)
super().__init__(device=device, R=R)


Expand Down Expand Up @@ -606,19 +624,16 @@ def _handle_input(
def _handle_angle_input(x, dtype, device: str, name: str):
"""
Helper function for building a rotation function using angles.
The output is always of shape (N, 1).
The output is always of shape (N,).
The input can be one of:
- Torch tensor (N, 1) or (N)
- Torch tensor of shape (N,)
- Python scalar
- Torch scalar
"""
# If x is actually a tensor of shape (N, 1) then just return it
if torch.is_tensor(x) and x.dim() == 2:
if x.shape[1] != 1:
msg = "Expected tensor of shape (N, 1); got %r (in %s)"
raise ValueError(msg % (x.shape, name))
return x
if torch.is_tensor(x) and x.dim() > 1:
msg = "Expected tensor of shape (N,); got %r (in %s)"
raise ValueError(msg % (x.shape, name))
else:
return _handle_coord(x, dtype, device)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_rotation_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch

from pytorch3d.transforms.rotation_conversions import (
_axis_angle_rotation,
euler_angles_to_matrix,
matrix_to_euler_angles,
matrix_to_quaternion,
Expand Down Expand Up @@ -118,7 +119,6 @@ def test_from_euler(self):
def test_to_euler(self):
"""mtx -> euler -> mtx"""
data = random_rotations(13, dtype=torch.float64)

for convention in self._all_euler_angle_conventions():
euler_angles = matrix_to_euler_angles(data, convention)
mdata = euler_angles_to_matrix(euler_angles, convention)
Expand Down

0 comments on commit 8301163

Please sign in to comment.