From fdeab42dcf4cb1e4db8c80a6d27fad8229994472 Mon Sep 17 00:00:00 2001 From: Max Kochurov Date: Sat, 30 Nov 2019 16:26:38 +0300 Subject: [PATCH] scalig --- geoopt/manifolds/base.py | 4 +++- geoopt/manifolds/euclidean.py | 2 +- geoopt/manifolds/poincare/__init__.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/geoopt/manifolds/base.py b/geoopt/manifolds/base.py index f7200b47..29a40c13 100644 --- a/geoopt/manifolds/base.py +++ b/geoopt/manifolds/base.py @@ -81,9 +81,11 @@ def mobius_fn_apply( return res """ - def __call__(self, scaling_info: ScalingInfo): + def __call__(self, scaling_info: ScalingInfo, *aliases): def register(fn): self[fn.__name__] = scaling_info + for alias in aliases: + self[alias] = scaling_info return fn return register diff --git a/geoopt/manifolds/euclidean.py b/geoopt/manifolds/euclidean.py index 34471e43..a4dbb17f 100644 --- a/geoopt/manifolds/euclidean.py +++ b/geoopt/manifolds/euclidean.py @@ -107,7 +107,7 @@ def transp(self, x: torch.Tensor, y: torch.Tensor, v: torch.Tensor) -> torch.Ten target_shape = broadcast_shapes(x.shape, y.shape, v.shape) return v.expand(target_shape) - @__scaling__(ScalingInfo(std=-1)) + @__scaling__(ScalingInfo(std=-1), "random") def random_normal( self, *size, mean=0.0, std=1.0, device=None, dtype=None ) -> "geoopt.ManifoldTensor": diff --git a/geoopt/manifolds/poincare/__init__.py b/geoopt/manifolds/poincare/__init__.py index 531b6ea9..bf52c673 100644 --- a/geoopt/manifolds/poincare/__init__.py +++ b/geoopt/manifolds/poincare/__init__.py @@ -286,7 +286,7 @@ def mobius_fn_apply_chain( else: return res - @__scaling__(ScalingInfo(std=-1)) + @__scaling__(ScalingInfo(std=-1), "random") def random_normal( self, *size, mean=0, std=1, dtype=None, device=None ) -> "geoopt.ManifoldTensor":