Skip to content

Commit

Permalink
scalig
Browse files Browse the repository at this point in the history
  • Loading branch information
ferrine committed Nov 30, 2019
1 parent f1b0e69 commit fdeab42
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 3 deletions.
4 changes: 3 additions & 1 deletion geoopt/manifolds/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,11 @@ def mobius_fn_apply(
return res
"""

def __call__(self, scaling_info: ScalingInfo):
def __call__(self, scaling_info: ScalingInfo, *aliases):
def register(fn):
self[fn.__name__] = scaling_info
for alias in aliases:
self[alias] = scaling_info
return fn

return register
Expand Down
2 changes: 1 addition & 1 deletion geoopt/manifolds/euclidean.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def transp(self, x: torch.Tensor, y: torch.Tensor, v: torch.Tensor) -> torch.Ten
target_shape = broadcast_shapes(x.shape, y.shape, v.shape)
return v.expand(target_shape)

@__scaling__(ScalingInfo(std=-1))
@__scaling__(ScalingInfo(std=-1), "random")
def random_normal(
self, *size, mean=0.0, std=1.0, device=None, dtype=None
) -> "geoopt.ManifoldTensor":
Expand Down
2 changes: 1 addition & 1 deletion geoopt/manifolds/poincare/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def mobius_fn_apply_chain(
else:
return res

@__scaling__(ScalingInfo(std=-1))
@__scaling__(ScalingInfo(std=-1), "random")
def random_normal(
self, *size, mean=0, std=1, dtype=None, device=None
) -> "geoopt.ManifoldTensor":
Expand Down

0 comments on commit fdeab42

Please sign in to comment.