Skip to content

Make some matrix conversion jittable #898

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions pytorch3d/transforms/rotation_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:

batch_dim = matrix.shape[:-2]
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
matrix.reshape(*batch_dim, 9), dim=-1
matrix.reshape(batch_dim + (9,)), dim=-1
)

q_abs = _sqrt_positive_part(
Expand Down Expand Up @@ -142,14 +142,15 @@ def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:

# We floor here at 0.1 but the exact level is not important; if q_abs is small,
# the candidate won't be picked.
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(q_abs.new_tensor(0.1)))
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))

# if not for numerical problems, quat_candidates[i] should be same (up to a sign),
# forall i; we pick the best-conditioned one (with the largest denominator)

return quat_candidates[
F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : # pyre-ignore[16]
].reshape(*batch_dim, 4)
].reshape(batch_dim + (4,))


def _axis_angle_rotation(axis: str, angle):
Expand Down Expand Up @@ -238,13 +239,14 @@ def _angle_from_tan(
return torch.atan2(data[..., i2], -data[..., i1])


def _index_from_letter(letter: str):
def _index_from_letter(letter: str) -> int:
if letter == "X":
return 0
if letter == "Y":
return 1
if letter == "Z":
return 2
raise ValueError("letter must be either X, Y or Z.")


def matrix_to_euler_angles(matrix, convention: str):
Expand Down Expand Up @@ -573,4 +575,5 @@ def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
Retrieved from http://arxiv.org/abs/1812.07035
"""
return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6)
batch_dim = matrix.size()[:-2]
return matrix[..., :2, :].clone().reshape(batch_dim + (6,))
8 changes: 8 additions & 0 deletions tests/test_rotation_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import itertools
import math
import unittest
from distutils.version import LooseVersion
from typing import Optional, Union

import numpy as np
Expand Down Expand Up @@ -264,6 +265,13 @@ def test_6d(self):
torch.matmul(r, r.permute(0, 2, 1)), torch.eye(3).expand_as(r), atol=1e-6
)

@unittest.skipIf(LooseVersion(torch.__version__) < "1.9", "recent torchscript only")
def test_scriptable(self):
torch.jit.script(matrix_to_axis_angle),
torch.jit.script(matrix_to_euler_angles),
torch.jit.script(matrix_to_quaternion),
torch.jit.script(matrix_to_rotation_6d),

def _assert_quaternions_close(
self,
input: Union[torch.Tensor, np.ndarray],
Expand Down