Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename check functions #516

Merged
merged 2 commits into from
May 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions theseus/labs/lie/functional/lie_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,13 +307,11 @@ def __init__(self, module):
) = UnaryOperatorFactory(module, "quaternion_to_rotation")
self.check_group_tensor: _CheckFnType = module.check_group_tensor
self.check_tangent_vector: _CheckFnType = module.check_tangent_vector
self.check_hat_matrix: _CheckFnType = module.check_hat_matrix
self.check_hat_tensor: _CheckFnType = module.check_hat_tensor
if hasattr(module, "check_unit_quaternion"):
self.check_unit_quaternion: _CheckFnType = module.check_unit_quaternion
if hasattr(module, "check_lift_matrix"):
self.check_lift_matrix: _CheckFnType = module.check_lift_matrix
if hasattr(module, "check_project_matrix"):
self.check_project_matrix: _CheckFnType = module.check_project_matrix
self.check_lift_tensor: _CheckFnType = module.check_lift_tensor
self.check_project_tensor: _CheckFnType = module.check_project_tensor
self.check_left_act_tensor: _CheckFnType = module.check_left_act_tensor
self.check_left_project_tensor: _CheckFnType = module.check_left_project_tensor
self.rand: _RandFnType = module.rand
Expand Down
122 changes: 61 additions & 61 deletions theseus/labs/lie/functional/se3_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def _impl(t_: torch.Tensor):
checks_base(tensor, _impl)


def check_matrix_tensor(tensor: torch.Tensor):
def check_group_shape(tensor: torch.Tensor):
if tensor.shape[-2:] != (3, 4):
raise ValueError(shape_err_msg("SE3 data tensors", "(..., 3, 4)", tensor.shape))

Expand All @@ -46,19 +46,19 @@ def check_tangent_vector(tangent_vector: torch.Tensor):
)


def check_hat_matrix(matrix: torch.Tensor):
def check_hat_tensor(tensor: torch.Tensor):
def _impl(t_: torch.Tensor):
if t_[..., -1].abs().max() > constants._SE3_NEAR_ZERO_EPS[t_.dtype]:
raise ValueError("The last row for hat matrices of SE3 must be zero")
raise ValueError("The last row for hat tensors of SE3 must be zero")

SO3.check_hat_matrix(t_[..., :3, :3])
SO3.check_hat_tensor(t_[..., :3, :3])

if matrix.shape[-2:] != (4, 4):
if tensor.shape[-2:] != (4, 4):
raise ValueError(
shape_err_msg("Hat matrices of SE3", "(..., 4, 4)", matrix.shape)
shape_err_msg("Hat tensors of SE3", "(..., 4, 4)", tensor.shape)
)

checks_base(matrix, _impl)
checks_base(tensor, _impl)


def check_transform_tensor(tensor: torch.Tensor):
Expand All @@ -72,31 +72,31 @@ def check_transform_tensor(tensor: torch.Tensor):
)


def check_lift_matrix(matrix: torch.Tensor):
if not matrix.shape[-1] == 6:
def check_lift_tensor(tensor: torch.Tensor):
if not tensor.shape[-1] == 6:
raise ValueError(
shape_err_msg("Lifted matrices of SE3", "(..., 6)", matrix.shape)
shape_err_msg("Lifted tensors of SE3", "(..., 6)", tensor.shape)
)


def check_project_matrix(matrix: torch.Tensor):
if not matrix.shape[-2:] == (3, 4):
def check_project_tensor(tensor: torch.Tensor):
if not tensor.shape[-2:] == (3, 4):
raise ValueError(
shape_err_msg("Projected tensors of SE3", "(..., 3, 4)", matrix.shape)
shape_err_msg("Projected tensors of SE3", "(..., 3, 4)", tensor.shape)
)


def check_left_act_tensor(matrix: torch.Tensor):
if matrix.shape[-2] != 3:
def check_left_act_tensor(tensor: torch.Tensor):
if tensor.shape[-2] != 3:
raise ValueError(
shape_err_msg("Left acted tensors of SE3", "(..., 3, -1)", matrix.shape)
shape_err_msg("Left acted tensors of SE3", "(..., 3, -1)", tensor.shape)
)


def check_left_project_tensor(matrix: torch.Tensor):
if matrix.shape[-2:] != (3, 4):
def check_left_project_tensor(tensor: torch.Tensor):
if tensor.shape[-2:] != (3, 4):
raise ValueError(
shape_err_msg("Left projected matrices of SE3", "(..., 3, 4)", matrix.shape)
shape_err_msg("Left projected matrices of SE3", "(..., 3, 4)", tensor.shape)
)


Expand Down Expand Up @@ -598,11 +598,11 @@ def _hat_impl(tangent_vector: torch.Tensor) -> torch.Tensor:
check_tangent_vector(tangent_vector)
size = get_tangent_vector_size(tangent_vector)
tangent_vector = tangent_vector.view(*size, 6)
matrix = tangent_vector.new_zeros(*size, 4, 4)
matrix[..., :3, :3] = SO3._hat_impl(tangent_vector[..., 3:])
matrix[..., :3, 3] = tangent_vector[..., :3]
tensor = tangent_vector.new_zeros(*size, 4, 4)
tensor[..., :3, :3] = SO3._hat_impl(tangent_vector[..., 3:])
tensor[..., :3, 3] = tangent_vector[..., :3]

return matrix
return tensor


# NOTE: No jacobian is defined for the hat operator
Expand Down Expand Up @@ -639,16 +639,16 @@ def backward(cls, ctx, grad_output):
# -----------------------------------------------------------------------------
# Vee
# -----------------------------------------------------------------------------
def _vee_impl(matrix: torch.Tensor) -> torch.Tensor:
check_hat_matrix(matrix)
size = matrix.shape[:-2]
ret = matrix.new_zeros(*size, 6)
ret[..., :3] = matrix[..., :3, 3]
def _vee_impl(tensor: torch.Tensor) -> torch.Tensor:
check_hat_tensor(tensor)
size = tensor.shape[:-2]
ret = tensor.new_zeros(*size, 6)
ret[..., :3] = tensor[..., :3, 3]
ret[..., 3:] = 0.5 * torch.stack(
(
matrix[..., 2, 1] - matrix[..., 1, 2],
matrix[..., 0, 2] - matrix[..., 2, 0],
matrix[..., 1, 0] - matrix[..., 0, 1],
tensor[..., 2, 1] - tensor[..., 1, 2],
tensor[..., 0, 2] - tensor[..., 2, 0],
tensor[..., 1, 0] - tensor[..., 0, 1],
),
dim=-1,
)
Expand Down Expand Up @@ -793,11 +793,11 @@ def backward(cls, ctx, grad_output):
# -----------------------------------------------------------------------------
# Lift
# -----------------------------------------------------------------------------
def _lift_impl(matrix: torch.Tensor) -> torch.Tensor:
check_lift_matrix(matrix)
ret = matrix.new_zeros(matrix.shape[:-1] + (3, 4))
ret[..., :3] = SO3._lift_impl(matrix[..., 3:])
ret[..., 3] = matrix[..., :3]
def _lift_impl(tensor: torch.Tensor) -> torch.Tensor:
check_lift_tensor(tensor)
ret = tensor.new_zeros(tensor.shape[:-1] + (3, 4))
ret[..., :3] = SO3._lift_impl(tensor[..., 3:])
ret[..., 3] = tensor[..., :3]

return ret

Expand All @@ -808,9 +808,9 @@ def _lift_impl(matrix: torch.Tensor) -> torch.Tensor:

class Lift(lie_group.UnaryOperator):
@classmethod
def _forward_impl(cls, matrix):
matrix: torch.Tensor = cast(torch.Tensor, matrix)
ret = _lift_impl(matrix)
def _forward_impl(cls, tensor):
tensor: torch.Tensor = cast(torch.Tensor, tensor)
ret = _lift_impl(tensor)
return ret

@classmethod
Expand All @@ -828,16 +828,16 @@ def backward(cls, ctx, grad_output):
# -----------------------------------------------------------------------------
# Project
# -----------------------------------------------------------------------------
def _project_impl(matrix: torch.Tensor) -> torch.Tensor:
check_project_matrix(matrix)
def _project_impl(tensor: torch.Tensor) -> torch.Tensor:
check_project_tensor(tensor)
return torch.stack(
(
matrix[..., 0, 3],
matrix[..., 1, 3],
matrix[..., 2, 3],
matrix[..., 2, 1] - matrix[..., 1, 2],
matrix[..., 0, 2] - matrix[..., 2, 0],
matrix[..., 1, 0] - matrix[..., 0, 1],
tensor[..., 0, 3],
tensor[..., 1, 3],
tensor[..., 2, 3],
tensor[..., 2, 1] - tensor[..., 1, 2],
tensor[..., 0, 2] - tensor[..., 2, 0],
tensor[..., 1, 0] - tensor[..., 0, 1],
),
dim=-1,
)
Expand All @@ -849,9 +849,9 @@ def _project_impl(matrix: torch.Tensor) -> torch.Tensor:

class Project(lie_group.UnaryOperator):
@classmethod
def _forward_impl(cls, matrix):
matrix: torch.Tensor = cast(torch.Tensor, matrix)
ret = _project_impl(matrix)
def _forward_impl(cls, tensor):
tensor: torch.Tensor = cast(torch.Tensor, tensor)
ret = _project_impl(tensor)
return ret

@classmethod
Expand Down Expand Up @@ -981,20 +981,20 @@ def _left_project_autograd_fn(
# -----------------------------------------------------------------------------
# Normalize
# -----------------------------------------------------------------------------
def _normalize_impl(matrix: torch.Tensor) -> torch.Tensor:
check_matrix_tensor(matrix)
rotation = SO3._normalize_impl_helper(matrix[..., :3])[0]
translation = matrix[..., 3:]
def _normalize_impl(tensor: torch.Tensor) -> torch.Tensor:
check_group_shape(tensor)
rotation = SO3._normalize_impl_helper(tensor[..., :3])[0]
translation = tensor[..., 3:]
return torch.cat((rotation, translation), dim=-1)


class Normalize(lie_group.UnaryOperator):
@classmethod
def _forward_impl(cls, matrix):
check_matrix_tensor(matrix)
matrix: torch.Tensor = matrix
rotation, svd_info = SO3._normalize_impl_helper(matrix[..., :3])
translation = matrix[..., 3:]
def _forward_impl(cls, tensor):
check_group_shape(tensor)
tensor: torch.Tensor = tensor
rotation, svd_info = SO3._normalize_impl_helper(tensor[..., :3])
translation = tensor[..., 3:]
output = torch.cat((rotation, translation), dim=-1)
return output, svd_info

Expand All @@ -1017,8 +1017,8 @@ def backward(cls, ctx, grad_output, _):
return grad_input, None


def _normalize_autograd_fn(matrix: torch.Tensor):
return Normalize.apply(matrix)[0]
def _normalize_autograd_fn(tensor: torch.Tensor):
return Normalize.apply(tensor)[0]


_jnormalize_autograd_fn = None
Expand Down
Loading