Skip to content

Commit

Permalink
add origin() method
Browse files Browse the repository at this point in the history
  • Loading branch information
tao-harald committed Feb 2, 2021
1 parent 43e3e44 commit d6fea4d
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
9 changes: 9 additions & 0 deletions geoopt/manifolds/symmetric_positive_definite.py
Expand Up @@ -271,3 +271,12 @@ def random(self, *size, dtype=None, device=None, **kwargs) -> torch.Tensor:
tens = batch_linalg.sym(tens)
tens = batch_linalg.sym_funcm(tens, torch.exp)
return tens

def origin(
self,
*size: Union[int, Tuple[int]],
dtype=None,
device=None,
seed: Optional[int] = 42
) -> torch.Tensor:
return torch.diag_embed(torch.ones(*size[:-1], dtype=dtype, device=device))
4 changes: 2 additions & 2 deletions tests/test_rsgd.py
Expand Up @@ -54,8 +54,8 @@ def test_rsgd_spd(params):
manifold = geoopt.manifolds.SymmetricPositiveDefinite(3)
torch.manual_seed(42)
with torch.no_grad():
X = geoopt.ManifoldParameter(manifold.random(2,2), manifold=manifold).proj_()
Xstar = manifold.random(2,2)
X = geoopt.ManifoldParameter(manifold.random(2, 2), manifold=manifold).proj_()
Xstar = manifold.random(2, 2)
# Xstar.set_(manifold.projx(Xstar))

def closure():
Expand Down

0 comments on commit d6fea4d

Please sign in to comment.