Skip to content

Commit

Permalink
add rand and randn to geometry.functional.se3
Browse files Browse the repository at this point in the history
  • Loading branch information
fantaosha committed Jan 1, 2023
1 parent e4f91a9 commit f0d046c
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 3 deletions.
64 changes: 63 additions & 1 deletion theseus/geometry/functional/se3.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.

import torch
from typing import cast, List, Tuple
from typing import cast, List, Tuple, Optional

from . import constants
from . import lie_group, so3
Expand Down Expand Up @@ -261,3 +261,65 @@ def backward(cls, ctx, grad_output):
_jexp_autograd_fn = _jexp_impl

exp, jexp = lie_group.UnaryOperatorFactory(_module, "exp")


# -----------------------------------------------------------------------------
# Rand
# -----------------------------------------------------------------------------
def rand(
*size: int,
generator: Optional[torch.Generator] = None,
dtype: Optional[torch.dtype] = None,
device: constants.DeviceType = None,
requires_grad: bool = False,
) -> torch.Tensor:
if len(size) != 1:
raise ValueError("The size should be 1D.")
rotation = so3.rand(
size[0],
generator=generator,
dtype=dtype,
device=device,
requires_grad=requires_grad,
)
translation = torch.rand(
size[0],
3,
1,
generator=generator,
dtype=dtype,
device=device,
requires_grad=requires_grad,
)
return torch.cat((rotation, translation), dim=2)


# -----------------------------------------------------------------------------
# Rand
# -----------------------------------------------------------------------------
def randn(
*size: int,
generator: Optional[torch.Generator] = None,
dtype: Optional[torch.dtype] = None,
device: constants.DeviceType = None,
requires_grad: bool = False,
) -> torch.Tensor:
if len(size) != 1:
raise ValueError("The size should be 1D.")
rotation = so3.randn(
size[0],
generator=generator,
dtype=dtype,
device=device,
requires_grad=requires_grad,
)
translation = torch.randn(
size[0],
3,
1,
generator=generator,
dtype=dtype,
device=device,
requires_grad=requires_grad,
)
return torch.cat((rotation, translation), dim=2)
2 changes: 0 additions & 2 deletions theseus/geometry/functional/so3.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,6 @@ def check_left_project_matrix(matrix: torch.Tensor):
# -----------------------------------------------------------------------------
# Rand
# -----------------------------------------------------------------------------


def rand(
*size: int,
generator: Optional[torch.Generator] = None,
Expand Down

0 comments on commit f0d046c

Please sign in to comment.