Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lie Labs API example #476

Merged
merged 4 commits into from
Apr 27, 2023
Merged

Lie Labs API example #476

merged 4 commits into from
Apr 27, 2023

Conversation

luisenp
Copy link
Contributor

@luisenp luisenp commented Mar 1, 2023

Seeking feedback

This PR has a small example illustrating the current features of our new Lie group tensor library for PyTorch. We are looking for feedback on:

  1. API: Is it easy to use? Any confusing points? Any source of friction?
  2. Features: Any features you'd like to see that are not in the planned features below?
  3. Bugs: Any bugs you noticed?

Current features in alpha

  • A LieTensor subclass of torch.Tensor serving as entry point.
  • SE3 and SO3 groups.
  • Implementation of several Lie groups operators, each with custom backward pass:
    • log and exp maps
    • adjoint, vee, hat operators
    • inverse, compose
    • transform_from (convert point from pose to world coordinates)
  • Differentiable jacobians to use in optimization libraries (the jacobians are computed in closed form using torch, but w/o a custom backward pass; we rely on torch autograd for this).
  • Seamless support for optimizing LieTensor parameters using torch optimizers.
  • Overloaded operators for compose (*) and transform_from (@).

Features planned for beta

  • Support for arbitrary number of batch dimensions and broadcasting (current version assumes a single batch dimension for everything).
  • Automatic normalization of operations result (to ensure group properties).
  • SO2, SE2, and Sim group support.
  • vmap support.
  • Storage as matrix or quaternion and operators for conversions between these representations (including axis angle).
  • cpp and CUDA backends if they lead to significant gains in efficiency.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 1, 2023
@luisenp luisenp changed the base branch from main to lep.hacky_lie_example March 1, 2023 22:01
Copy link
Contributor

@fantaosha fantaosha left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments about the behaviors of overriding.

examples/lie_labs.py Outdated Show resolved Hide resolved
examples/lie_labs.py Outdated Show resolved Hide resolved
examples/lie_labs.py Outdated Show resolved Hide resolved
Copy link
Contributor

@exhaustin exhaustin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the clean interface, can't wait to use it!

Some features I would love to see:

  • Conversions to/from quaternions & matrices, even though they might be specific to SE3 and maybe out of scope, but that's a really really common usecase in robotics.
  • I find the Sophus API to be more intuitive in terms of accessing the matrix data, and accessing the individual rotation and translation components. Maybe we can take some inspiration from that?

examples/lie_labs.py Outdated Show resolved Hide resolved
examples/lie_labs.py Outdated Show resolved Hide resolved
examples/lie_labs.py Show resolved Hide resolved
examples/lie_labs.py Outdated Show resolved Hide resolved
examples/lie_labs.py Outdated Show resolved Hide resolved
@luisenp luisenp requested a review from ddetone March 6, 2023 16:02
Base automatically changed from lep.hacky_lie_example to main March 8, 2023 22:11
@luisenp luisenp force-pushed the lep.labs_api_example branch 3 times, most recently from 53f8c0c to f2849fa Compare March 9, 2023 21:25
@luisenp
Copy link
Contributor Author

luisenp commented Mar 9, 2023

Made a new version of the API based on the discussion on operator overload, and also the suggestions for things like SE3.rand and SE3.exp. Feel free to ignore the implementation details, and just look at the example script. If this looks good, I'll probably move everything to a separate PR, merge it, and then rebase this PR so that it has only the script.

Copy link
Contributor

@exhaustin exhaustin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, just my subjective opinion, but the new version looks great to me!

@luisenp luisenp force-pushed the lep.labs_api_example branch 3 times, most recently from 9d0dd10 to 9ac3ab3 Compare March 22, 2023 21:56
@rmurai0610
Copy link
Contributor

The api looks great!
I've tried using the api, and I've noticed few minor details:

import theseus.labs.lie as lie
import theseus as th

g_old = th.SE3.rand(5)
print(g_old[0])

g_new = lie.SE3.rand(5)
print(g_new[0])
tensor([[-0.5758,  0.2524, -0.7777,  0.5286],
        [ 0.1701, -0.8934, -0.4159,  0.8484],
        [-0.7997, -0.3718,  0.4715,  0.3887]])
Traceback (most recent call last):
  File "/Users/riku/phd/manim/test.py", line 8, in <module>
    print(g_new[0])
  File "/Users/riku/phd/manim/theseus/theseus/labs/lie/lie_tensor.py", line 202, in __torch_function__
    return cls._torch_function_impl_lie(func, types, args, kwargs)
  File "/Users/riku/phd/manim/theseus/theseus/labs/lie/lie_tensor.py", line 189, in _torch_function_impl_lie
    raise NotImplementedError(
NotImplementedError: Tried to call a torch function not supported by LieTensor. If trying to operate on the raw tensor data, please use group._t, or run inside the context lie.as_euclidean().

I think it would be useful if indexing was implemented.
The old implementation in theseus returns tensor, but is there a reason why we don't return a LieGroup?


import theseus.labs.lie as lie
import theseus as th

g1 = lie.SE3.rand(5)
print(type(g1.shape))
print(g1.shape)
print(type(g1._t.shape))
print(g1._t.shape)

g2 = th.SE3.rand(5)
print(type(g2.shape))
print(g2.shape)

<class 'tuple'>
(5, 3, 4)
<class 'torch.Size'>
torch.Size([5, 3, 4])
<class 'torch.Size'>
torch.Size([5, 3, 4])

Maybe it would be nice to keep the type consistent for typehint etc.

examples/lie_labs.py Outdated Show resolved Hide resolved
@luisenp luisenp mentioned this pull request Apr 5, 2023
@luisenp luisenp changed the base branch from main to lep.lie_labs_v3 April 6, 2023 16:54
Base automatically changed from lep.lie_labs_v3 to main April 21, 2023 20:14
Copy link
Member

@mhmukadam mhmukadam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with all the feedback now addressed in the other PR.

Thanks @luisenp and @fantaosha for developing everything on Lie groups alpha and thanks to everyone else (on the PR and offline) for the valuable feedback!

@luisenp luisenp merged commit 40aba98 into main Apr 27, 2023
@luisenp luisenp deleted the lep.labs_api_example branch April 27, 2023 15:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants