Skip to content

Commit

Permalink
Update so3 operations for numerical stability
Browse files Browse the repository at this point in the history
Summary: Replace implementations of `so3_exp_map` and `so3_log_map` in so3.py with existing more-stable implementations.

Reviewed By: bottler

Differential Revision: D52513319

fbshipit-source-id: fbfc039643fef284d8baa11bab61651964077afe
  • Loading branch information
Abdelrahman-Khater authored and facebook-github-bot committed Jan 4, 2024
1 parent 3621a36 commit 292acc7
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 54 deletions.
46 changes: 6 additions & 40 deletions pytorch3d/transforms/so3.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Tuple

import torch
from pytorch3d.transforms import rotation_conversions

from ..transforms import acos_linear_extrapolation

Expand Down Expand Up @@ -160,19 +161,10 @@ def _so3_exp_map(
nrms = (log_rot * log_rot).sum(1)
# phis ... rotation angles
rot_angles = torch.clamp(nrms, eps).sqrt()
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
rot_angles_inv = 1.0 / rot_angles
fac1 = rot_angles_inv * rot_angles.sin()
fac2 = rot_angles_inv * rot_angles_inv * (1.0 - rot_angles.cos())
skews = hat(log_rot)
skews_square = torch.bmm(skews, skews)

R = (
fac1[:, None, None] * skews
# pyre-fixme[16]: `float` has no attribute `__getitem__`.
+ fac2[:, None, None] * skews_square
+ torch.eye(3, dtype=log_rot.dtype, device=log_rot.device)[None]
)
R = rotation_conversions.axis_angle_to_matrix(log_rot)

return R, rot_angles, skews, skews_square

Expand All @@ -183,49 +175,23 @@ def so3_log_map(
"""
Convert a batch of 3x3 rotation matrices `R`
to a batch of 3-dimensional matrix logarithms of rotation matrices
The conversion has a singularity around `(R=I)` which is handled
by clamping controlled with the `eps` and `cos_bound` arguments.
The conversion has a singularity around `(R=I)`.
Args:
R: batch of rotation matrices of shape `(minibatch, 3, 3)`.
eps: A float constant handling the conversion singularity.
cos_bound: Clamps the cosine of the rotation angle to
[-1 + cos_bound, 1 - cos_bound] to avoid non-finite outputs/gradients
of the `acos` call when computing `so3_rotation_angle`.
Note that the non-finite outputs/gradients are returned when
the rotation angle is close to 0 or π.
eps: (unused, for backward compatibility)
cos_bound: (unused, for backward compatibility)
Returns:
Batch of logarithms of input rotation matrices
of shape `(minibatch, 3)`.
Raises:
ValueError if `R` is of incorrect shape.
ValueError if `R` has an unexpected trace.
"""

N, dim1, dim2 = R.shape
if dim1 != 3 or dim2 != 3:
raise ValueError("Input has to be a batch of 3x3 Tensors.")

phi = so3_rotation_angle(R, cos_bound=cos_bound, eps=eps)

phi_sin = torch.sin(phi)

# We want to avoid a tiny denominator of phi_factor = phi / (2.0 * phi_sin).
# Hence, for phi_sin.abs() <= 0.5 * eps, we approximate phi_factor with
# 2nd order Taylor expansion: phi_factor = 0.5 + (1.0 / 12) * phi**2
phi_factor = torch.empty_like(phi)
ok_denom = phi_sin.abs() > (0.5 * eps)
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
phi_factor[~ok_denom] = 0.5 + (phi[~ok_denom] ** 2) * (1.0 / 12)
phi_factor[ok_denom] = phi[ok_denom] / (2.0 * phi_sin[ok_denom])

log_rot_hat = phi_factor[:, None, None] * (R - R.permute(0, 2, 1))

log_rot = hat_inv(log_rot_hat)

return log_rot
return rotation_conversions.matrix_to_axis_angle(R)


def hat_inv(h: torch.Tensor) -> torch.Tensor:
Expand Down
14 changes: 0 additions & 14 deletions tests/test_so3.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,20 +97,6 @@ def test_bad_so3_input_value_err(self):
so3_log_map(rot)
self.assertTrue("Input has to be a batch of 3x3 Tensors." in str(err.exception))

# trace of rot definitely bigger than 3 or smaller than -1
rot = torch.cat(
(
torch.rand(size=[5, 3, 3], device=device) + 4.0,
torch.rand(size=[5, 3, 3], device=device) - 3.0,
)
)
with self.assertRaises(ValueError) as err:
so3_log_map(rot)
self.assertTrue(
"A matrix has trace outside valid range [-1-eps,3+eps]."
in str(err.exception)
)

def test_so3_exp_singularity(self, batch_size: int = 100):
"""
Tests whether the `so3_exp_map` is robust to the input vectors
Expand Down

0 comments on commit 292acc7

Please sign in to comment.