Skip to content

Commit

Permalink
Unified requires_grad usage for all random functions.
Browse files Browse the repository at this point in the history
  • Loading branch information
luisenp committed Jan 26, 2023
1 parent 276d11f commit 1634a0f
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 19 deletions.
27 changes: 14 additions & 13 deletions theseus/labs/lie_functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Callable, Optional, Sequence
from typing import Callable, Optional, Protocol, Sequence

import torch

Expand All @@ -12,16 +12,17 @@
from .lie_group import BinaryOperatorFactory, UnaryOperatorFactory

_CheckFnType = Callable[[torch.Tensor], None]
_RadnFnType = Callable[
[
Sequence[int],
Optional[torch.Generator],
Optional[torch.dtype],
DeviceType,
bool,
],
torch.Tensor,
]


class _RandFnType(Protocol):
def __call__(
*size: int,
generator: Optional[torch.Generator] = None,
dtype: Optional[torch.dtype] = None,
device: DeviceType = None,
requires_grad: bool = False,
) -> torch.Tensor:
pass


# Namespace to facilitate type-checking downstream
Expand Down Expand Up @@ -49,8 +50,8 @@ def __init__(self, module):
self.check_project_matrix: _CheckFnType = module.check_project_matrix
self.check_left_act_matrix: _CheckFnType = module.check_left_act_matrix
self.check_left_project_matrix: _CheckFnType = module.check_left_project_matrix
self.rand: _RadnFnType = module.rand
self.randn: _RadnFnType = module.randn
self.rand: _RandFnType = module.rand
self.randn: _RandFnType = module.randn


se3_fns = LieGroupFns(_se3_impl)
Expand Down
2 changes: 0 additions & 2 deletions theseus/labs/lie_functional/se3.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ def randn(
generator=generator,
dtype=dtype,
device=device,
requires_grad=requires_grad,
)
translation = torch.randn(
size[0],
Expand All @@ -118,7 +117,6 @@ def randn(
generator=generator,
dtype=dtype,
device=device,
requires_grad=requires_grad,
)
ret = torch.cat((rotation, translation), dim=2)
ret.requires_grad_(requires_grad)
Expand Down
10 changes: 6 additions & 4 deletions theseus/labs/lie_functional/so3.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ def rand(
generator=generator,
dtype=dtype,
device=device,
requires_grad=requires_grad,
)
u1 = u[0]
u2, u3 = u[1:3] * 2 * constants.PI
Expand All @@ -114,7 +113,9 @@ def rand(
dim=1,
)
assert quaternion.shape == (size[0], 4)
return quaternion_to_rotation(quaternion)
ret = quaternion_to_rotation(quaternion)
ret.requires_grad_(requires_grad)
return ret


# -----------------------------------------------------------------------------
Expand All @@ -129,17 +130,18 @@ def randn(
) -> torch.Tensor:
if len(size) != 1:
raise ValueError("The size should be 1D.")
return exp(
ret = exp(
constants.PI
* torch.randn(
size[0],
3,
generator=generator,
dtype=dtype,
device=device,
requires_grad=requires_grad,
)
)
ret.requires_grad_(requires_grad)
return ret


# -----------------------------------------------------------------------------
Expand Down

0 comments on commit 1634a0f

Please sign in to comment.