Skip to content

Commit

Permalink
Extending the API of Transform3d with SE(3) log
Browse files Browse the repository at this point in the history
Summary:
This is quite a thin wrapper – not sure we need it. The motivation is that `Transform3d` is not as matrix-centric now, it can be converted to SE(3) logarithm equally easily.

It simplifies things like averaging cameras and getting axis-angle of camera rotation (previously, one would need to call `se3_log_map(cameras.get_world_to_camera_transform().get_matrix())`), now one fewer thing to call / discover.

Reviewed By: bottler

Differential Revision: D39928000

fbshipit-source-id: 85248d5b8af136618f1d08791af5297ea5179d19
  • Loading branch information
shapovalov authored and facebook-github-bot committed Sep 29, 2022
1 parent 74bbd6f commit 9a0f9ae
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 7 deletions.
62 changes: 55 additions & 7 deletions pytorch3d/transforms/transform3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ..common.datatypes import Device, get_device, make_device
from ..common.workaround import _safe_det_3x3
from .rotation_conversions import _axis_angle_rotation
from .se3 import se3_log_map


class Transform3d:
Expand Down Expand Up @@ -130,13 +131,13 @@ class Transform3d:
[Tx, Ty, Tz, 1],
]
To apply the transformation to points which are row vectors, the M matrix
can be pre multiplied by the points:
To apply the transformation to points, which are row vectors, the latter are
converted to homogeneous (4D) coordinates and right-multiplied by the M matrix:
.. code-block:: python
points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point
transformed_points = points * M
[transformed_points, 1] ∝ [points, 1] @ M
"""

Expand Down Expand Up @@ -218,9 +219,10 @@ def compose(self, *others: "Transform3d") -> "Transform3d":

def get_matrix(self) -> torch.Tensor:
"""
Return a matrix which is the result of composing this transform
with others stored in self.transforms. Where necessary transforms
are broadcast against each other.
Returns a 4×4 matrix corresponding to each transform in the batch.
If the transform was composed from others, the matrix for the composite
transform will be returned.
For example, if self.transforms contains transforms t1, t2, and t3, and
given a set of points x, the following should be true:
Expand All @@ -230,8 +232,11 @@ def get_matrix(self) -> torch.Tensor:
y2 = t3.transform(t2.transform(t1.transform(x)))
y1.get_matrix() == y2.get_matrix()
Where necessary, those transforms are broadcast against each other.
Returns:
A transformation matrix representing the composed inputs.
A (N, 4, 4) batch of transformation matrices representing
the stored transforms. See the class documentation for the conventions.
"""
composed_matrix = self._matrix.clone()
if len(self._transforms) > 0:
Expand All @@ -240,6 +245,49 @@ def get_matrix(self) -> torch.Tensor:
composed_matrix = _broadcast_bmm(composed_matrix, other_matrix)
return composed_matrix

def get_se3_log(self, eps: float = 1e-4, cos_bound: float = 1e-4) -> torch.Tensor:
"""
Returns a 6D SE(3) log vector corresponding to each transform in the batch.
In the SE(3) logarithmic representation SE(3) matrices are
represented as 6-dimensional vectors `[log_translation | log_rotation]`,
i.e. a concatenation of two 3D vectors `log_translation` and `log_rotation`.
The conversion from the 4x4 SE(3) matrix `transform` to the
6D representation `log_transform = [log_translation | log_rotation]`
is done as follows:
```
log_transform = log(transform.get_matrix())
log_translation = log_transform[3, :3]
log_rotation = inv_hat(log_transform[:3, :3])
```
where `log` is the matrix logarithm
and `inv_hat` is the inverse of the Hat operator [2].
See the docstring for `se3.se3_log_map` and [1], Sec 9.4.2. for more
detailed description.
Args:
eps: A threshold for clipping the squared norm of the rotation logarithm
to avoid division by zero in the singular case.
cos_bound: Clamps the cosine of the rotation angle to
[-1 + cos_bound, 3 - cos_bound] to avoid non-finite outputs.
The non-finite outputs can be caused by passing small rotation angles
to the `acos` function in `so3_rotation_angle` of `so3_log_map`.
Returns:
A (N, 6) tensor, rows of which represent the individual transforms
stored in the object as SE(3) logarithms.
Raises:
ValueError if the stored transform is not Euclidean (e.g. R is not a rotation
matrix or the last column has non-zeros in the first three places).
[1] https://jinyongjeong.github.io/Download/SE3/jlblanco2010geometry3d_techrep.pdf
[2] https://en.wikipedia.org/wiki/Hat_operator
"""
return se3_log_map(self.get_matrix(), eps, cos_bound)

def _get_matrix_inverse(self) -> torch.Tensor:
"""
Return the inverse of self._matrix.
Expand Down
11 changes: 11 additions & 0 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import torch
from pytorch3d.transforms import random_rotations
from pytorch3d.transforms.se3 import se3_log_map
from pytorch3d.transforms.so3 import so3_exp_map
from pytorch3d.transforms.transform3d import (
Rotate,
Expand Down Expand Up @@ -161,6 +162,16 @@ def test_init_with_custom_matrix_errors(self):
matrix = torch.randn(*bad_shape).float()
self.assertRaises(ValueError, Transform3d, matrix=matrix)

def test_get_se3(self):
N = 16
random_rotations(N)
tr = Translate(torch.rand((N, 3)))
R = Rotate(random_rotations(N))
transform = Transform3d().compose(R, tr)
se3_log = transform.get_se3_log()
gt_se3_log = se3_log_map(transform.get_matrix())
self.assertClose(se3_log, gt_se3_log)

def test_translate(self):
t = Transform3d().translate(1, 2, 3)
points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]]).view(
Expand Down

0 comments on commit 9a0f9ae

Please sign in to comment.