Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some clean up of Lie Groups code #449

Merged
merged 5 commits into from
Feb 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/labs/lie_functional/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch


BATCH_SIZES_TO_TEST = [1, 20]
TEST_EPS = 5e-7


Expand Down
6 changes: 3 additions & 3 deletions tests/labs/lie_functional/test_se3.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch

from tests.decorators import run_if_labs
from .common import check_lie_group_function, run_test_op, TEST_EPS
from .common import BATCH_SIZES_TO_TEST, TEST_EPS, check_lie_group_function, run_test_op


@run_if_labs()
Expand All @@ -27,7 +27,7 @@
"left_project",
],
)
@pytest.mark.parametrize("batch_size", [1, 20, 100])
@pytest.mark.parametrize("batch_size", BATCH_SIZES_TO_TEST)
@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
def test_op(op_name, batch_size, dtype):
import theseus.labs.lie_functional.se3 as se3
Expand All @@ -38,7 +38,7 @@ def test_op(op_name, batch_size, dtype):


@run_if_labs()
@pytest.mark.parametrize("batch_size", [1, 20, 100])
@pytest.mark.parametrize("batch_size", BATCH_SIZES_TO_TEST)
@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
def test_vee(batch_size: int, dtype: torch.dtype):
import theseus.labs.lie_functional.se3 as se3
Expand Down
6 changes: 3 additions & 3 deletions tests/labs/lie_functional/test_so3.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch

from tests.decorators import run_if_labs
from .common import TEST_EPS, check_lie_group_function, run_test_op
from .common import BATCH_SIZES_TO_TEST, TEST_EPS, check_lie_group_function, run_test_op


@run_if_labs()
Expand All @@ -28,7 +28,7 @@
"left_project",
],
)
@pytest.mark.parametrize("batch_size", [1, 20, 100])
@pytest.mark.parametrize("batch_size", BATCH_SIZES_TO_TEST)
@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
def test_op(op_name, batch_size, dtype):
import theseus.labs.lie_functional.so3 as so3
Expand All @@ -39,7 +39,7 @@ def test_op(op_name, batch_size, dtype):


@run_if_labs()
@pytest.mark.parametrize("batch_size", [1, 20, 100])
@pytest.mark.parametrize("batch_size", BATCH_SIZES_TO_TEST)
@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
def test_vee(batch_size: int, dtype: torch.dtype):
import theseus.labs.lie_functional.so3 as so3
Expand Down
54 changes: 54 additions & 0 deletions theseus/labs/lie_functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,57 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Callable, Optional, Protocol, Sequence

import torch

import theseus.labs.lie_functional.se3 as _se3_impl
import theseus.labs.lie_functional.so3 as _so3_impl
from .constants import DeviceType
from .lie_group import BinaryOperatorFactory, UnaryOperatorFactory
luisenp marked this conversation as resolved.
Show resolved Hide resolved

_CheckFnType = Callable[[torch.Tensor], None]


class _RandFnType(Protocol):
def __call__(
*size: int,
generator: Optional[torch.Generator] = None,
dtype: Optional[torch.dtype] = None,
device: DeviceType = None,
requires_grad: bool = False,
) -> torch.Tensor:
pass


# Namespace to facilitate type-checking downstream
class LieGroupFns:
def __init__(self, module):
self.exp, self.jexp = UnaryOperatorFactory(module, "exp")
luisenp marked this conversation as resolved.
Show resolved Hide resolved
self.log, self.jlog = UnaryOperatorFactory(module, "log")
self.adj = UnaryOperatorFactory(module, "adjoint")
self.inv, self.jinv = UnaryOperatorFactory(module, "inverse")
self.hat = UnaryOperatorFactory(module, "hat")
self.vee = UnaryOperatorFactory(module, "vee")
self.compose, self.jcompose = BinaryOperatorFactory(module, "compose")
self.lift = UnaryOperatorFactory(module, "lift")
self.project = UnaryOperatorFactory(module, "project")
self.left_act = BinaryOperatorFactory(module, "left_act")
self.left_project = BinaryOperatorFactory(module, "left_project")
self.check_group_tensor: _CheckFnType = module.check_group_tensor
self.check_tangent_vector: _CheckFnType = module.check_tangent_vector
self.check_hat_matrix: _CheckFnType = module.check_hat_matrix
if hasattr(module, "check_unit_quaternion"):
self.check_unit_quaternion: _CheckFnType = module.check_unit_quaternion
if hasattr(module, "check_lift_matrix"):
self.check_lift_matrix: _CheckFnType = module.check_lift_matrix
if hasattr(module, "check_project_matrix"):
self.check_project_matrix: _CheckFnType = module.check_project_matrix
self.check_left_act_matrix: _CheckFnType = module.check_left_act_matrix
self.check_left_project_matrix: _CheckFnType = module.check_left_project_matrix
self.rand: _RandFnType = module.rand
self.randn: _RandFnType = module.randn


se3_fns = LieGroupFns(_se3_impl)
so3_fns = LieGroupFns(_so3_impl)
luisenp marked this conversation as resolved.
Show resolved Hide resolved
138 changes: 88 additions & 50 deletions theseus/labs/lie_functional/lie_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
import abc

from typing import List, Tuple, Optional
from typing import Callable, List, Tuple, Optional, Protocol
from .utils import check_jacobians_list

# There are four functions associated with each Lie group operator xxx.
Expand All @@ -28,8 +28,7 @@ def _jinverse_impl(group: torch.Tensor) -> Tuple[List[torch.Tensor], torch.Tenso

def LeftProjectImplFactory(module):
def _left_project_impl(group: torch.Tensor, matrix: torch.Tensor) -> torch.Tensor:
if not module.check_group_tensor(group):
raise ValueError("Invalid data tensor for SO3.")
module.check_group_tensor(group)
module.check_left_project_matrix(matrix)
group_inverse = module.inverse(group)

Expand All @@ -45,33 +44,57 @@ def forward(cls, ctx, input):
pass


def UnaryOperatorFactory(module, op_name):
# Get autograd.Function wrapper of op and its jacobian
op_autograd_fn = getattr(module, "_" + op_name + "_autograd_fn")
jop_autograd_fn = getattr(module, "_j" + op_name + "_autograd_fn")
class UnaryOperatorOpFnType(Protocol):
def __call__(
self, input: torch.Tensor, jacobians: Optional[List[torch.Tensor]] = None
) -> torch.Tensor:
pass


if jop_autograd_fn is not None:
class UnaryOperatorJOpFnType(Protocol):
def __call__(self, input: torch.Tensor) -> Tuple[List[torch.Tensor], torch.Tensor]:
pass


def _check_jacobians_supported(
jop_autograd_fn: Optional[Callable],
module_name: str,
op_name: str,
is_kwarg: bool = True,
):
if jop_autograd_fn is None:
if is_kwarg:
msg = f"Passing jacobians= is not supported by {module_name}.{op_name}"
else:
msg = f"{module_name}.j{op_name} is not implemented."
raise NotImplementedError(msg)

def op(
input: torch.Tensor,
jacobians: Optional[List[torch.Tensor]] = None,
) -> torch.Tensor:
if jacobians is not None:
check_jacobians_list(jacobians)
jacobians_op = jop_autograd_fn(input)[0]
jacobians.append(jacobians_op[0])
return op_autograd_fn(input)

def jop(input: torch.Tensor) -> Tuple[List[torch.Tensor], torch.Tensor]:
return jop_autograd_fn(input)
def UnaryOperatorFactory(
module, op_name
) -> Tuple[UnaryOperatorOpFnType, UnaryOperatorJOpFnType]:
# Get autograd.Function wrapper of op and its jacobian
op_autograd_fn = getattr(module, "_" + op_name + "_autograd_fn")
jop_autograd_fn = getattr(module, "_j" + op_name + "_autograd_fn")

return op, jop
else:
def op(
input: torch.Tensor,
jacobians: Optional[List[torch.Tensor]] = None,
) -> torch.Tensor:
if jacobians is not None:
_check_jacobians_supported(jop_autograd_fn, module.name, op_name)
check_jacobians_list(jacobians)
jacobians_op = jop_autograd_fn(input)[0]
jacobians.append(jacobians_op[0])
return op_autograd_fn(input)

def op_no_jop(input: torch.Tensor) -> torch.Tensor:
return op_autograd_fn(input)
def jop(input: torch.Tensor) -> Tuple[List[torch.Tensor], torch.Tensor]:
_check_jacobians_supported(
jop_autograd_fn, module.name, op_name, is_kwarg=False
)
return jop_autograd_fn(input)

return op_no_jop
return op, jop


class BinaryOperator(torch.autograd.Function):
Expand All @@ -81,34 +104,49 @@ def forward(cls, ctx, input0, input1):
pass


def BinaryOperatorFactory(module, op_name):
# Get autograd.Function wrapper of op and its jacobian
op_autograd_fn = getattr(module, "_" + op_name + "_autograd_fn")
jop_autograd_fn = getattr(module, "_j" + op_name + "_autograd_fn")

if jop_autograd_fn is not None:
class BinaryOperatorOpFnType(Protocol):
def __call__(
self,
input0: torch.Tensor,
input1: torch.Tensor,
jacobians: Optional[List[torch.Tensor]] = None,
) -> torch.Tensor:
pass

def op(
input0: torch.Tensor,
input1: torch.Tensor,
jacobians: Optional[List[torch.Tensor]] = None,
) -> torch.Tensor:
if jacobians is not None:
check_jacobians_list(jacobians)
jacobians_op = jop_autograd_fn(input0, input1)[0]
for jacobian in jacobians_op:
jacobians.append(jacobian)
return op_autograd_fn(input0, input1)

def jop(
input0: torch.Tensor, input1: torch.Tensor
) -> Tuple[List[torch.Tensor], torch.Tensor]:
return jop_autograd_fn(input0, input1)
class BinaryOperatorJOpFnType(Protocol):
def __call__(
self, input0: torch.Tensor, input1: torch.Tensor
) -> Tuple[List[torch.Tensor], torch.Tensor]:
pass

return op, jop
else:

def op_no_jop(input0: torch.Tensor, input1: torch.Tensor) -> torch.Tensor:
return op_autograd_fn(input0, input1)
def BinaryOperatorFactory(
module, op_name
) -> Tuple[BinaryOperatorOpFnType, BinaryOperatorJOpFnType]:
# Get autograd.Function wrapper of op and its jacobian
op_autograd_fn = getattr(module, "_" + op_name + "_autograd_fn")
jop_autograd_fn = getattr(module, "_j" + op_name + "_autograd_fn")

return op_no_jop
def op(
input0: torch.Tensor,
input1: torch.Tensor,
jacobians: Optional[List[torch.Tensor]] = None,
) -> torch.Tensor:
if jacobians is not None:
_check_jacobians_supported(jop_autograd_fn, module.name, op_name)
check_jacobians_list(jacobians)
jacobians_op = jop_autograd_fn(input0, input1)[0]
for jacobian in jacobians_op:
jacobians.append(jacobian)
return op_autograd_fn(input0, input1)

def jop(
input0: torch.Tensor, input1: torch.Tensor
) -> Tuple[List[torch.Tensor], torch.Tensor]:
_check_jacobians_supported(
jop_autograd_fn, module.name, op_name, is_kwarg=False
)
return jop_autograd_fn(input0, input1)

return op, jop
Loading