Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SO3 support #46

Merged
merged 5 commits into from
Feb 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions theseus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .geometry import (
SE2,
SO2,
SO3,
LieGroup,
Manifold,
Point2,
Expand Down
1 change: 1 addition & 0 deletions theseus/geometry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@
from .point_types import Point2, Point3
from .se2 import SE2
from .so2 import SO2
from .so3 import SO3
from .vector import Vector
81 changes: 30 additions & 51 deletions theseus/geometry/se2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
# If data is passed, must be x, y, cos, sin
# If x_y_theta is passed, must be tensor with shape batch_size x 3
class SE2(LieGroup):
SE2_EPS = 5e-7
mhmukadam marked this conversation as resolved.
Show resolved Hide resolved

def __init__(
self,
x_y_theta: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -92,37 +94,25 @@ def update_from_rot_and_trans(self, rotation: SO2, translation: Point2):
# From https://github.com/strasdat/Sophus/blob/master/sophus/se2.hpp#L160
def _log_map_impl(self) -> torch.Tensor:
rotation = self.rotation

theta = rotation.log_map()
half_theta = 0.5 * theta.view(-1)

theta = rotation.log_map().view(-1)
cosine, sine = rotation.to_cos_sin()
cos_minus_one = cosine - 1
halftheta_by_tan_of_halftheta = torch.zeros_like(cos_minus_one)

# Compute halftheta_by_tan_of_halftheta when theta is not near zero
idx_regular_vals = cos_minus_one.abs() > theseus.constants.EPS
halftheta_by_tan_of_halftheta[idx_regular_vals] = (
-(half_theta * sine)[idx_regular_vals] / cos_minus_one[idx_regular_vals]
# Compute the approximations when theta is near to 0
small_theta = theta.abs() < SE2.SE2_EPS
non_zero = torch.ones(1, dtype=self.dtype, device=self.device)
sine_nz = torch.where(small_theta, non_zero, sine)
a = (
mhmukadam marked this conversation as resolved.
Show resolved Hide resolved
0.5
* (1 + cosine)
* torch.where(small_theta, 1 + sine**2 / 6, theta / sine_nz)
)
# Same as above three lines but for small values
idx_small_vals = cos_minus_one.abs() < theseus.constants.EPS
if idx_small_vals.any():
theta_sq_at_idx = theta[idx_small_vals] ** 2
halftheta_by_tan_of_halftheta[idx_small_vals] = (
-theta_sq_at_idx.view(-1) / 12 + 1
)
b = 0.5 * theta

v_inv = torch.empty(self.shape[0], 2, 2).to(
device=self.device, dtype=self.dtype
)
v_inv[:, 0, 0] = halftheta_by_tan_of_halftheta
v_inv[:, 0, 1] = half_theta
v_inv[:, 1, 0] = -half_theta
v_inv[:, 1, 1] = halftheta_by_tan_of_halftheta
tangent_translation = torch.matmul(v_inv, self[:, :2].unsqueeze(-1))
# Compute the translation
ux = a * self[:, 0] + b * self[:, 1]
uy = a * self[:, 1] - b * self[:, 0]

return torch.cat([tangent_translation.view(-1, 2), theta], dim=1)
return torch.stack((ux, uy, theta), dim=1)

# From https://github.com/strasdat/Sophus/blob/master/sophus/se2.hpp#L558
@staticmethod
Expand All @@ -133,32 +123,21 @@ def exp_map(tangent_vector: torch.Tensor) -> LieGroup:

cosine, sine = rotation.to_cos_sin()

sin_theta_by_theta = torch.zeros_like(sine)
one_minus_cos_theta_by_theta = torch.zeros_like(sine)

# Compute above quantities when theta is not near zero
idx_regular_thetas = theta.abs() > theseus.constants.EPS
if idx_regular_thetas.any():
sin_theta_by_theta[idx_regular_thetas] = (
sine[idx_regular_thetas] / theta[idx_regular_thetas]
)
one_minus_cos_theta_by_theta[idx_regular_thetas] = (
-cosine[idx_regular_thetas] + 1
) / theta[idx_regular_thetas]

# Same as above three lines but for small angles
idx_small_thetas = theta.abs() < theseus.constants.EPS
if idx_small_thetas.any():
small_theta = theta[idx_small_thetas]
small_theta_sq = small_theta**2
sin_theta_by_theta[idx_small_thetas] = -small_theta_sq / 6 + 1
one_minus_cos_theta_by_theta[idx_small_thetas] = (
0.5 * small_theta - small_theta / 24 * small_theta_sq
)
# Compute the approximations when theta is near to 0
small_theta = theta.abs() < SE2.SE2_EPS
non_zero = torch.ones(
1, dtype=tangent_vector.dtype, device=tangent_vector.device
)
theta_nz = torch.where(small_theta, non_zero, theta)
a = torch.where(
small_theta, -theta / 2 + theta**3 / 24, (cosine - 1) / theta_nz
)
b = torch.where(small_theta, 1 - theta**2 / 6, sine / theta_nz)

new_x = sin_theta_by_theta * u[:, 0] - one_minus_cos_theta_by_theta * u[:, 1]
new_y = one_minus_cos_theta_by_theta * u[:, 0] + sin_theta_by_theta * u[:, 1]
translation = Point2(data=torch.stack([new_x, new_y], dim=1))
# Compute the translation
x = b * u[:, 0] + a * u[:, 1]
y = b * u[:, 1] - a * u[:, 0]
translation = Point2(data=torch.stack((x, y), dim=1))

se2 = SE2(dtype=tangent_vector.dtype)
se2.update_from_rot_and_trans(rotation, translation)
Expand Down
231 changes: 231 additions & 0 deletions theseus/geometry/so3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional, Union, cast

import torch

import theseus.constants

from .lie_group import LieGroup
from .point_types import Point3


class SO3(LieGroup):
luisenp marked this conversation as resolved.
Show resolved Hide resolved
SO3_EPS = 5e-7

def __init__(
self,
quaternion: Optional[torch.Tensor] = None,
luisenp marked this conversation as resolved.
Show resolved Hide resolved
data: Optional[torch.Tensor] = None,
name: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
):
if quaternion is not None and data is not None:
raise ValueError("Please provide only one of quaternion or data.")
if quaternion is not None:
dtype = quaternion.dtype
if data is not None:
self._SO3_matrix_check(data)
super().__init__(data=data, name=name, dtype=dtype)
if quaternion is not None:
if quaternion.ndim == 1:
quaternion = quaternion.unsqueeze(0)
self.update_from_unit_quaternion(quaternion)

@staticmethod
def _init_data() -> torch.Tensor: # type: ignore
return torch.eye(3, 3).view(1, 3, 3)

def update_from_unit_quaternion(self, quaternion: torch.Tensor):
self.update(self.unit_quaternion_to_matrix(quaternion))

def dof(self) -> int:
return 3

def __repr__(self) -> str:
return f"SO3(data={self.data}, name={self.name})"

def __str__(self) -> str:
with torch.no_grad():
return f"SO3(matrix={self.data}), name={self.name})"

def _adjoint_impl(self) -> torch.Tensor:
return self.data.clone()

@staticmethod
def _SO3_matrix_check(matrix: torch.Tensor):
if matrix.ndim != 3 or matrix.shape[1:] != (3, 3):
raise ValueError("3D rotations can only be 3x3 matrices.")
_check = (
torch.matmul(matrix, matrix.transpose(1, 2))
- torch.eye(3, 3, dtype=matrix.dtype, device=matrix.device)
).abs().max().item() < SO3.SO3_EPS
_check &= (torch.linalg.det(matrix) - 1).abs().max().item() < SO3.SO3_EPS

if not _check:
raise ValueError("Not valid 3D rotations.")

@staticmethod
def _unit_quaternion_check(quaternion: torch.Tensor):
if quaternion.ndim != 2 or quaternion.shape[1] != 4:
raise ValueError("Quaternions can only be 4-D vectors.")

if (torch.linalg.norm(quaternion, dim=1) - 1).abs().max().item() >= SO3.SO3_EPS:
raise ValueError("Not unit quaternions.")

@staticmethod
def exp_map(tangent_vector: torch.Tensor) -> LieGroup:
fantaosha marked this conversation as resolved.
Show resolved Hide resolved
if tangent_vector.ndim != 2 or tangent_vector.shape[1] != 3:
raise ValueError("Invalid input for SO3.exp_map.")
ret = SO3(dtype=tangent_vector.dtype)
theta = torch.linalg.norm(tangent_vector, dim=1, keepdim=True).unsqueeze(1)
theta2 = theta**2
# Compute the approximations when theta ~ 0
small_theta = theta < 0.005
non_zero = torch.ones(
1, dtype=tangent_vector.dtype, device=tangent_vector.device
)
theta_nz = torch.where(small_theta, non_zero, theta)
theta2_nz = torch.where(small_theta, non_zero, theta2)
a = torch.where(small_theta, 8 / (4 + theta2) - 1, theta.cos())
b = torch.where(small_theta, 0.5 * a + 0.5, theta.sin() / theta_nz)
c = torch.where(small_theta, 0.5 * b, (1 - a) / theta2_nz)
ret.data = c * tangent_vector.view(-1, 3, 1) @ tangent_vector.view(-1, 1, 3)
ret[:, 0, 0] += a.view(-1)
ret[:, 1, 1] += a.view(-1)
ret[:, 2, 2] += a.view(-1)
temp = b.view(-1, 1) * tangent_vector
ret[:, 0, 1] -= temp[:, 2]
ret[:, 1, 0] += temp[:, 2]
ret[:, 0, 2] += temp[:, 1]
ret[:, 2, 0] -= temp[:, 1]
ret[:, 1, 2] -= temp[:, 0]
ret[:, 2, 1] += temp[:, 0]
return ret

def _log_map_impl(self) -> torch.Tensor:
ret = torch.zeros(self.shape[0], 3, dtype=self.dtype, device=self.device)
ret[:, 0] = 0.5 * (self[:, 2, 1] - self[:, 1, 2])
ret[:, 1] = 0.5 * (self[:, 0, 2] - self[:, 2, 0])
ret[:, 2] = 0.5 * (self[:, 1, 0] - self[:, 0, 1])
cth = 0.5 * (self[:, 0, 0] + self[:, 1, 1] + self[:, 2, 2] - 1)
sth = ret.norm(dim=1)
theta = torch.atan2(sth, cth)
# theta != pi
not_near_pi = 1 + cth > 1e-7
# Compute the approximation of theta / sin(theta) when theta is near to 0
small_theta = theta[not_near_pi] < 5e-3
non_zero = torch.ones(1, dtype=self.dtype, device=self.device)
sth_nz = torch.where(small_theta, non_zero, sth[not_near_pi])
scale = torch.where(
small_theta, 1 + sth[not_near_pi] ** 2 / 6, theta[not_near_pi] / sth_nz
)
ret[not_near_pi] *= scale.view(-1, 1)
# theta ~ pi
near_pi = ~not_near_pi
ddiag = torch.diagonal(self[near_pi], dim1=1, dim2=2)
# Find the index of major coloumns and diagonals
major = torch.logical_and(
ddiag[:, 1] > ddiag[:, 0], ddiag[:, 1] > ddiag[:, 2]
) + 2 * torch.logical_and(ddiag[:, 2] > ddiag[:, 0], ddiag[:, 2] > ddiag[:, 1])
ret[near_pi] = self[near_pi, major]
ret[near_pi, major] -= cth[near_pi]
ret[near_pi] *= (theta[near_pi] ** 2 / (1 - cth[near_pi])).view(-1, 1)
ret[near_pi] /= ret[near_pi, major].sqrt().view(-1, 1)
return ret

def _compose_impl(self, so3_2: LieGroup) -> "SO3":
raise NotImplementedError

def _inverse_impl(self, get_jacobian: bool = False) -> "SO3":
return SO3(data=self.data.transpose(1, 2).clone())

def to_matrix(self) -> torch.Tensor:
return self.data.clone()

def to_quaternion(self) -> torch.Tensor:
raise NotImplementedError

@staticmethod
def hat(tangent_vector: torch.Tensor) -> torch.Tensor:
_check = tangent_vector.ndim == 3 and tangent_vector.shape[1:] == (3, 1)
fantaosha marked this conversation as resolved.
Show resolved Hide resolved
_check |= tangent_vector.ndim == 2 and tangent_vector.shape[1] == 3
if not _check:
raise ValueError("Invalid vee matrix for SO3.")
matrix = torch.zeros(tangent_vector.shape[0], 3, 3).to(
dtype=tangent_vector.dtype, device=tangent_vector.device
)
matrix[:, 0, 1] = -tangent_vector[:, 2].view(-1)
matrix[:, 0, 2] = tangent_vector[:, 1].view(-1)
matrix[:, 1, 2] = -tangent_vector[:, 0].view(-1)
matrix[:, 1, 0] = tangent_vector[:, 2].view(-1)
matrix[:, 2, 0] = -tangent_vector[:, 1].view(-1)
matrix[:, 2, 1] = tangent_vector[:, 0].view(-1)
return matrix

@staticmethod
def vee(matrix: torch.Tensor) -> torch.Tensor:
_check = matrix.ndim == 3 and matrix.shape[1:] == (3, 3)
_check &= (
matrix.transpose(1, 2) + matrix
).abs().max().item() < theseus.constants.EPS
if not _check:
raise ValueError("Invalid hat matrix for SO3.")
return torch.stack((matrix[:, 2, 1], matrix[:, 0, 2], matrix[:, 1, 0]), dim=1)

def _rotate_shape_check(self, point: Union[Point3, torch.Tensor]):
err_msg = "SO3 can only rotate 3-D vectors."
if isinstance(point, torch.Tensor):
if not point.ndim == 2 or point.shape[1] != 3:
raise ValueError(err_msg)
elif point.dof() != 3:
raise ValueError(err_msg)
if (
point.shape[0] != self.shape[0]
and point.shape[0] != 1
and self.shape[0] != 1
):
raise ValueError(
"Input point batch size is not broadcastable with group batch size."
)

@staticmethod
def unit_quaternion_to_matrix(quaternion: torch.torch.Tensor):
SO3._unit_quaternion_check(quaternion)
q0 = quaternion[:, 0]
q1 = quaternion[:, 1]
q2 = quaternion[:, 2]
q3 = quaternion[:, 3]
q00 = q0 * q0
q01 = q0 * q1
q02 = q0 * q2
q03 = q0 * q3
q11 = q1 * q1
q12 = q1 * q2
q13 = q1 * q3
q22 = q2 * q2
q23 = q2 * q3
q33 = q3 * q3
ret = torch.zeros(quaternion.shape[0], 3, 3).to(
dtype=quaternion.dtype, device=quaternion.device
)
ret[:, 0, 0] = 2 * (q00 + q11) - 1
ret[:, 0, 1] = 2 * (q12 - q03)
ret[:, 0, 2] = 2 * (q13 + q02)
ret[:, 1, 0] = 2 * (q12 + q03)
ret[:, 1, 1] = 2 * (q00 + q22) - 1
ret[:, 1, 2] = 2 * (q23 - q01)
ret[:, 2, 0] = 2 * (q13 - q02)
ret[:, 2, 1] = 2 * (q23 + q01)
ret[:, 2, 2] = 2 * (q00 + q33) - 1
return ret

def _copy_impl(self, new_name: Optional[str] = None) -> "SO3":
return SO3(data=self.data.clone(), name=new_name)

# only added to avoid casting downstream
def copy(self, new_name: Optional[str] = None) -> "SO3":
return cast(SO3, super().copy(new_name=new_name))
Loading