From 0095f364b6ab72fdab24ddc5cad9da57cef815de Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Wed, 19 Jun 2019 01:33:43 +0300 Subject: [PATCH 1/7] add methods --- geoopt/manifolds/poincare/__init__.py | 167 +++++++++++++++++++++----- geoopt/manifolds/poincare/math.py | 27 +++++ geoopt/utils.py | 10 ++ 3 files changed, 173 insertions(+), 31 deletions(-) diff --git a/geoopt/manifolds/poincare/__init__.py b/geoopt/manifolds/poincare/__init__.py index 36476bfc..6a8c42a9 100644 --- a/geoopt/manifolds/poincare/__init__.py +++ b/geoopt/manifolds/poincare/__init__.py @@ -1,6 +1,6 @@ import torch.nn from . import math -from ...utils import make_tuple +from ...utils import make_tuple, idx2sign from ..base import Manifold __all__ = ["PoincareBall", "PoincareBallExact"] @@ -19,6 +19,7 @@ """ +# noinspection PyMethodOverriding class PoincareBall(Manifold): __doc__ = r"""{} @@ -57,60 +58,164 @@ def _check_point_on_manifold(self, x, *, atol=1e-5, rtol=1e-5): def _check_vector_on_tangent(self, x, u, *, atol=1e-5, rtol=1e-5): return True, None - def dist(self, x, y, *, keepdim=False): - return math.dist(x, y, c=self.c, keepdim=keepdim) + def dist(self, x, y, *, keepdim=False, dim=-1): + return math.dist(x, y, c=self.c, keepdim=keepdim, dim=dim) - def egrad2rgrad(self, x, u): - return math.egrad2rgrad(x, u, c=self.c) + def egrad2rgrad(self, x, u, *, dim=-1): + return math.egrad2rgrad(x, u, c=self.c, dim=dim) - def retr(self, x, u): + def retr(self, x, u, *, dim=-1): # always assume u is scaled properly approx = x + u - return math.project(approx, c=self.c) + return math.project(approx, c=self.c, dim=dim) - def projx(self, x): - return math.project(x, c=self.c) + def projx(self, x, dim=-1): + return math.project(x, c=self.c, dim=dim) def proju(self, x, u): return u - def inner(self, x, u, v=None, *, keepdim=False): + def inner(self, x, u, v=None, *, keepdim=False, dim=-1): if v is None: v = u - return math.inner(x, u, v, c=self.c, keepdim=keepdim) + return math.inner(x, u, v, c=self.c, keepdim=keepdim, dim=dim) - def expmap(self, x, u): - return math.project(math.expmap(x, u, c=self.c), c=self.c) + def norm(self, x, u, *, keepdim=False, dim=-1): + return math.norm(x, u, keepdim=keepdim, dim=dim) - def logmap(self, x, y): - return math.logmap(x, y, c=self.c) + def expmap(self, x, u, *, project=True, dim=-1): + res = math.expmap(x, u, c=self.c, dim=dim) + if project: + return math.project(res, c=self.c, dim=dim) + else: + return res + + def logmap(self, x, y, *, dim=-1): + return math.logmap(x, y, c=self.c, dim=dim) - def transp(self, x, y, v, *more): + def transp(self, x, y, v, *more, dim=-1): if not more: - return math.parallel_transport(x, y, v, c=self.c) + return math.parallel_transport(x, y, v, c=self.c, dim=dim) else: - vecs = torch.stack((v,) + more, dim=0) - transp = math.parallel_transport(x, y, vecs, c=self.c) + try: + vecs = torch.stack((v,) + more, dim=0) + except RuntimeError: + return tuple( + math.parallel_transport(x, y, vec, c=self.c, dim=dim) + for vec in (v, *more) + ) + transp = math.parallel_transport( + x, y, vecs, c=self.c, dim=idx2sign(dim, x.dim()) + ) return transp.unbind(0) - def transp_follow_retr(self, x, u, v, *more): - y = self.retr(x, u) - return self.transp(x, y, v, *more) + def transp_follow_retr(self, x, u, v, *more, dim=-1): + y = self.retr(x, u, dim=dim) + return self.transp(x, y, v, *more, dim=dim) - def transp_follow_expmap(self, x, u, v, *more): - y = self.expmap(x, u) - return self.transp(x, y, v, *more) + def transp_follow_expmap(self, x, u, v, *more, dim=-1, project=True): + y = self.expmap(x, u, dim=dim, project=project) + return self.transp(x, y, v, *more, dim=dim) - def expmap_transp(self, x, u, v, *more): - y = self.expmap(x, u) - vs = self.transp(x, y, v, *more) + def expmap_transp(self, x, u, v, *more, dim=-1, project=True): + y = self.expmap(x, u, dim=dim, project=project) + vs = self.transp(x, y, v, *more, dim=dim) return (y,) + make_tuple(vs) - def retr_transp(self, x, u, v, *more): - y = self.retr(x, u) - vs = self.transp(x, y, v, *more) + def retr_transp(self, x, u, v, *more, dim=-1): + y = self.retr(x, u, dim=dim) + vs = self.transp(x, y, v, *more, dim=dim) return (y,) + make_tuple(vs) + def mobius_add(self, x, y, *, dim=-1, project=True): + res = math.mobius_add(x, y, c=self.c, dim=dim) + if project: + return math.project(res, c=self.c, dim=dim) + else: + return res + + def mobius_sub(self, x, y, *, dim=-1, project=True): + res = math.mobius_sub(x, y, c=self.c, dim=dim) + if project: + return math.project(res, c=self.c, dim=dim) + else: + return res + + def mobius_coadd(self, x, y, *, dim=-1, project=True): + res = math.mobius_coadd(x, y, c=self.c, dim=dim) + if project: + return math.project(res, c=self.c, dim=dim) + else: + return res + + def mobius_cosub(self, x, y, *, dim=-1, project=True): + res = math.mobius_coadd(x, y, c=self.c, dim=dim) + if project: + return math.project(res, c=self.c, dim=dim) + else: + return res + + def mobius_scalar_mul(self, r, x, *, dim=-1, project=True): + res = math.mobius_scalar_mul(r, x, c=self.c, dim=dim) + if project: + return math.project(res, c=self.c, dim=dim) + else: + return res + + def mobius_pointwise_mul(self, w, x, *, dim=-1, project=True): + res = math.mobius_pointwise_mul(w, x, c=self.c, dim=dim) + if project: + return math.project(res, c=self.c, dim=dim) + else: + return res + + def mobius_matvec(self, m, x, *, dim=-1, project=True): + res = math.mobius_matvec(m, x, c=self.c, dim=dim) + if project: + return math.project(res, c=self.c, dim=dim) + else: + return res + + def geodesic(self, t, x, y, *, dim=-1): + return math.geodesic(t, x, y, c=self.c, dim=dim) + + def geodesic_unit(self, t, x, u, *, dim=-1, project=True): + res = math.geodesic_unit(t, x, u, c=self.c, dim=dim) + if project: + return math.project(res, c=self.c, dim=dim) + else: + return res + + def lambda_x(self, x, *, dim=-1, keepdim=False): + return math.lambda_x(x, c=self.c, dim=dim, keepdim=keepdim) + + def dist0(self, x, *, dim=-1, keepdim=False): + return math.dist0(x, c=self.c, dim=dim, keepdim=keepdim) + + def expmap0(self, u, *, dim=-1, project=True): + res = math.expmap0(u, c=self.c, dim=dim) + if project: + return math.project(res, c=self.c, dim=dim) + else: + return res + + def logmap0(self, x, *, dim=-1): + return math.logmap0(x, c=self.c, dim=dim) + + def transp0(self, y, u, *, dim=-1): + return math.parallel_transport0(y, u, c=self.c, dim=dim) + + def transp0back(self, y, u, *, dim=-1): + return math.parallel_transport0back(y, u, c=self.c, dim=dim) + + def gyration(self, x, y, z, *, dim=-1): + return math.gyration(x, y, z, c=self.c, dim=dim) + + def dist2plane(self, x, p, a, *, dim=-1, keepdim=False, signed=False): + return math.dist2plane( + x, p, a, dim=dim, c=self.c, keepdim=keepdim, signed=signed + ) + class PoincareBallExact(PoincareBall): __doc__ = r"""{} diff --git a/geoopt/manifolds/poincare/math.py b/geoopt/manifolds/poincare/math.py index 25d6ac0c..4e650529 100644 --- a/geoopt/manifolds/poincare/math.py +++ b/geoopt/manifolds/poincare/math.py @@ -1273,6 +1273,33 @@ def _parallel_transport0(y, v, c, dim: int = -1): return v * (1 - c * y.pow(2).sum(dim=dim, keepdim=True)).clamp_min(MIN_NORM) +def parallel_transport0back(x, v, *, c=1.0, dim: int = -1): + r""" + Special case parallel transport with last point at zero that + can be computed more efficiently and numerically stable + + Parameters + ---------- + x : tensor + target point + v : tensor + vector to be transported + c : float|tensor + ball negative curvature + dim : int + reduction dimension for operations + + Returns + ------- + tensor + """ + return _parallel_transport0back(x, v, c=c, dim=dim) + + +def _parallel_transport0back(x, v, c, dim: int = -1): + return v / (1 - c * x.pow(2).sum(dim=dim, keepdim=True)).clamp_min(MIN_NORM) + + def egrad2rgrad(x, grad, *, c=1.0, dim=-1): r""" Translate Euclidean gradient to Riemannian gradient on tangent space of :math:`x` diff --git a/geoopt/utils.py b/geoopt/utils.py index 93307cb0..1336efed 100644 --- a/geoopt/utils.py +++ b/geoopt/utils.py @@ -36,3 +36,13 @@ def make_tuple(obj): return (obj,) else: return obj + + +def idx2sign(idx, dim, neg=True): + if neg: + if idx < 0: + return idx + else: + return (idx + 1) % -(dim + 1) + else: + return idx % dim From 289c705eb88383e0adefc0a1130fb20be5f3a3f9 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Wed, 19 Jun 2019 11:24:52 +0300 Subject: [PATCH 2/7] make parallel transport more restrictive --- geoopt/manifolds/poincare/__init__.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/geoopt/manifolds/poincare/__init__.py b/geoopt/manifolds/poincare/__init__.py index 6a8c42a9..895fd049 100644 --- a/geoopt/manifolds/poincare/__init__.py +++ b/geoopt/manifolds/poincare/__init__.py @@ -97,13 +97,7 @@ def transp(self, x, y, v, *more, dim=-1): if not more: return math.parallel_transport(x, y, v, c=self.c, dim=dim) else: - try: - vecs = torch.stack((v,) + more, dim=0) - except RuntimeError: - return tuple( - math.parallel_transport(x, y, vec, c=self.c, dim=dim) - for vec in (v, *more) - ) + vecs = torch.stack((v,) + more, dim=0) transp = math.parallel_transport( x, y, vecs, c=self.c, dim=idx2sign(dim, x.dim()) ) From 7b09e05d7429f1d0865c4e2b7c41b3fa208e18bf Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Wed, 19 Jun 2019 11:58:16 +0300 Subject: [PATCH 3/7] just use tuple --- geoopt/manifolds/poincare/__init__.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/geoopt/manifolds/poincare/__init__.py b/geoopt/manifolds/poincare/__init__.py index 895fd049..c0e6a145 100644 --- a/geoopt/manifolds/poincare/__init__.py +++ b/geoopt/manifolds/poincare/__init__.py @@ -97,11 +97,10 @@ def transp(self, x, y, v, *more, dim=-1): if not more: return math.parallel_transport(x, y, v, c=self.c, dim=dim) else: - vecs = torch.stack((v,) + more, dim=0) - transp = math.parallel_transport( - x, y, vecs, c=self.c, dim=idx2sign(dim, x.dim()) + return tuple( + math.parallel_transport(x, y, vec, c=self.c, dim=dim) + for vec in (v, *more) ) - return transp.unbind(0) def transp_follow_retr(self, x, u, v, *more, dim=-1): y = self.retr(x, u, dim=dim) From 3454302361ab7d80ed1e982dec513ae3dd713dbd Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Wed, 19 Jun 2019 12:29:20 +0300 Subject: [PATCH 4/7] remove unused fn --- geoopt/manifolds/poincare/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/geoopt/manifolds/poincare/__init__.py b/geoopt/manifolds/poincare/__init__.py index c0e6a145..8edda173 100644 --- a/geoopt/manifolds/poincare/__init__.py +++ b/geoopt/manifolds/poincare/__init__.py @@ -1,6 +1,6 @@ import torch.nn from . import math -from ...utils import make_tuple, idx2sign +from ...utils import make_tuple from ..base import Manifold __all__ = ["PoincareBall", "PoincareBallExact"] From 78e224d80acb07931b025dbff87913af0800a195 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Wed, 19 Jun 2019 12:30:10 +0300 Subject: [PATCH 5/7] update changelog --- CHANGELOG.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index bbfe66f4..bc1c01c7 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -18,6 +18,7 @@ New Features * Added expmap implementation (#43) * Added dist, logmap implementation (#44) * Added Poincare Ball model (#45) +* Poincare Ball manifold has now new methods (#78) Maintenance ----------- From 62d5bd41254a10951902c3385547cf4196adad70 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Wed, 19 Jun 2019 13:42:45 +0300 Subject: [PATCH 6/7] remove unused --- geoopt/utils.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/geoopt/utils.py b/geoopt/utils.py index 1336efed..93307cb0 100644 --- a/geoopt/utils.py +++ b/geoopt/utils.py @@ -36,13 +36,3 @@ def make_tuple(obj): return (obj,) else: return obj - - -def idx2sign(idx, dim, neg=True): - if neg: - if idx < 0: - return idx - else: - return (idx + 1) % -(dim + 1) - else: - return idx % dim From a25b81dfcadde66ac96d81f9f9313253975fa3d5 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Thu, 20 Jun 2019 17:09:12 +0300 Subject: [PATCH 7/7] add apply chain --- geoopt/manifolds/poincare/__init__.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/geoopt/manifolds/poincare/__init__.py b/geoopt/manifolds/poincare/__init__.py index 8edda173..2700d6bb 100644 --- a/geoopt/manifolds/poincare/__init__.py +++ b/geoopt/manifolds/poincare/__init__.py @@ -209,6 +209,12 @@ def dist2plane(self, x, p, a, *, dim=-1, keepdim=False, signed=False): x, p, a, dim=dim, c=self.c, keepdim=keepdim, signed=signed ) + def mobius_fn_apply(self, fn, x, *args, dim=-1, **kwargs): + return math.mobius_fn_apply(fn, x, *args, c=self.c, dim=dim, **kwargs) + + def mobius_fn_apply_chain(self, x, *fns, dim=-1): + return math.mobius_fn_apply_chain(x, *fns, c=self.c, dim=dim) + class PoincareBallExact(PoincareBall): __doc__ = r"""{}