Skip to content

Commit

Permalink
fix random scaling
Browse files Browse the repository at this point in the history
  • Loading branch information
Maxim Kochurov committed Nov 18, 2019
1 parent 2b00320 commit f1b0e69
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 1 deletion.
4 changes: 3 additions & 1 deletion geoopt/manifolds/euclidean.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Union, Tuple, Optional
import torch
from .base import Manifold
from .base import Manifold, ScalingInfo
from ..utils import size2shape, broadcast_shapes
import geoopt

Expand All @@ -19,6 +19,7 @@ class Euclidean(Manifold):
as inner products, etc will respect the :attr:`ndim`.
"""

__scaling__ = Manifold.__scaling__.copy()
name = "Euclidean"
ndim = 0
reversible = True
Expand Down Expand Up @@ -106,6 +107,7 @@ def transp(self, x: torch.Tensor, y: torch.Tensor, v: torch.Tensor) -> torch.Ten
target_shape = broadcast_shapes(x.shape, y.shape, v.shape)
return v.expand(target_shape)

@__scaling__(ScalingInfo(std=-1))
def random_normal(
self, *size, mean=0.0, std=1.0, device=None, dtype=None
) -> "geoopt.ManifoldTensor":
Expand Down
1 change: 1 addition & 0 deletions geoopt/manifolds/poincare/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ def mobius_fn_apply_chain(
else:
return res

@__scaling__(ScalingInfo(std=-1))
def random_normal(
self, *size, mean=0, std=1, dtype=None, device=None
) -> "geoopt.ManifoldTensor":
Expand Down

0 comments on commit f1b0e69

Please sign in to comment.