diff --git a/CHANGELOG.rst b/CHANGELOG.rst index bbfe66f4..bc1c01c7 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -18,6 +18,7 @@ New Features * Added expmap implementation (#43) * Added dist, logmap implementation (#44) * Added Poincare Ball model (#45) +* Poincare Ball manifold has now new methods (#78) Maintenance ----------- diff --git a/geoopt/manifolds/poincare/__init__.py b/geoopt/manifolds/poincare/__init__.py index 36476bfc..2700d6bb 100644 --- a/geoopt/manifolds/poincare/__init__.py +++ b/geoopt/manifolds/poincare/__init__.py @@ -19,6 +19,7 @@ """ +# noinspection PyMethodOverriding class PoincareBall(Manifold): __doc__ = r"""{} @@ -57,59 +58,162 @@ def _check_point_on_manifold(self, x, *, atol=1e-5, rtol=1e-5): def _check_vector_on_tangent(self, x, u, *, atol=1e-5, rtol=1e-5): return True, None - def dist(self, x, y, *, keepdim=False): - return math.dist(x, y, c=self.c, keepdim=keepdim) + def dist(self, x, y, *, keepdim=False, dim=-1): + return math.dist(x, y, c=self.c, keepdim=keepdim, dim=dim) - def egrad2rgrad(self, x, u): - return math.egrad2rgrad(x, u, c=self.c) + def egrad2rgrad(self, x, u, *, dim=-1): + return math.egrad2rgrad(x, u, c=self.c, dim=dim) - def retr(self, x, u): + def retr(self, x, u, *, dim=-1): # always assume u is scaled properly approx = x + u - return math.project(approx, c=self.c) + return math.project(approx, c=self.c, dim=dim) - def projx(self, x): - return math.project(x, c=self.c) + def projx(self, x, dim=-1): + return math.project(x, c=self.c, dim=dim) def proju(self, x, u): return u - def inner(self, x, u, v=None, *, keepdim=False): + def inner(self, x, u, v=None, *, keepdim=False, dim=-1): if v is None: v = u - return math.inner(x, u, v, c=self.c, keepdim=keepdim) + return math.inner(x, u, v, c=self.c, keepdim=keepdim, dim=dim) - def expmap(self, x, u): - return math.project(math.expmap(x, u, c=self.c), c=self.c) + def norm(self, x, u, *, keepdim=False, dim=-1): + return math.norm(x, u, keepdim=keepdim, dim=dim) - def logmap(self, x, y): - return math.logmap(x, y, c=self.c) + def expmap(self, x, u, *, project=True, dim=-1): + res = math.expmap(x, u, c=self.c, dim=dim) + if project: + return math.project(res, c=self.c, dim=dim) + else: + return res + + def logmap(self, x, y, *, dim=-1): + return math.logmap(x, y, c=self.c, dim=dim) - def transp(self, x, y, v, *more): + def transp(self, x, y, v, *more, dim=-1): if not more: - return math.parallel_transport(x, y, v, c=self.c) + return math.parallel_transport(x, y, v, c=self.c, dim=dim) else: - vecs = torch.stack((v,) + more, dim=0) - transp = math.parallel_transport(x, y, vecs, c=self.c) - return transp.unbind(0) + return tuple( + math.parallel_transport(x, y, vec, c=self.c, dim=dim) + for vec in (v, *more) + ) + + def transp_follow_retr(self, x, u, v, *more, dim=-1): + y = self.retr(x, u, dim=dim) + return self.transp(x, y, v, *more, dim=dim) + + def transp_follow_expmap(self, x, u, v, *more, dim=-1, project=True): + y = self.expmap(x, u, dim=dim, project=project) + return self.transp(x, y, v, *more, dim=dim) + + def expmap_transp(self, x, u, v, *more, dim=-1, project=True): + y = self.expmap(x, u, dim=dim, project=project) + vs = self.transp(x, y, v, *more, dim=dim) + return (y,) + make_tuple(vs) - def transp_follow_retr(self, x, u, v, *more): - y = self.retr(x, u) - return self.transp(x, y, v, *more) + def retr_transp(self, x, u, v, *more, dim=-1): + y = self.retr(x, u, dim=dim) + vs = self.transp(x, y, v, *more, dim=dim) + return (y,) + make_tuple(vs) - def transp_follow_expmap(self, x, u, v, *more): - y = self.expmap(x, u) - return self.transp(x, y, v, *more) + def mobius_add(self, x, y, *, dim=-1, project=True): + res = math.mobius_add(x, y, c=self.c, dim=dim) + if project: + return math.project(res, c=self.c, dim=dim) + else: + return res - def expmap_transp(self, x, u, v, *more): - y = self.expmap(x, u) - vs = self.transp(x, y, v, *more) - return (y,) + make_tuple(vs) + def mobius_sub(self, x, y, *, dim=-1, project=True): + res = math.mobius_sub(x, y, c=self.c, dim=dim) + if project: + return math.project(res, c=self.c, dim=dim) + else: + return res - def retr_transp(self, x, u, v, *more): - y = self.retr(x, u) - vs = self.transp(x, y, v, *more) - return (y,) + make_tuple(vs) + def mobius_coadd(self, x, y, *, dim=-1, project=True): + res = math.mobius_coadd(x, y, c=self.c, dim=dim) + if project: + return math.project(res, c=self.c, dim=dim) + else: + return res + + def mobius_cosub(self, x, y, *, dim=-1, project=True): + res = math.mobius_coadd(x, y, c=self.c, dim=dim) + if project: + return math.project(res, c=self.c, dim=dim) + else: + return res + + def mobius_scalar_mul(self, r, x, *, dim=-1, project=True): + res = math.mobius_scalar_mul(r, x, c=self.c, dim=dim) + if project: + return math.project(res, c=self.c, dim=dim) + else: + return res + + def mobius_pointwise_mul(self, w, x, *, dim=-1, project=True): + res = math.mobius_pointwise_mul(w, x, c=self.c, dim=dim) + if project: + return math.project(res, c=self.c, dim=dim) + else: + return res + + def mobius_matvec(self, m, x, *, dim=-1, project=True): + res = math.mobius_matvec(m, x, c=self.c, dim=dim) + if project: + return math.project(res, c=self.c, dim=dim) + else: + return res + + def geodesic(self, t, x, y, *, dim=-1): + return math.geodesic(t, x, y, c=self.c, dim=dim) + + def geodesic_unit(self, t, x, u, *, dim=-1, project=True): + res = math.geodesic_unit(t, x, u, c=self.c, dim=dim) + if project: + return math.project(res, c=self.c, dim=dim) + else: + return res + + def lambda_x(self, x, *, dim=-1, keepdim=False): + return math.lambda_x(x, c=self.c, dim=dim, keepdim=keepdim) + + def dist0(self, x, *, dim=-1, keepdim=False): + return math.dist0(x, c=self.c, dim=dim, keepdim=keepdim) + + def expmap0(self, u, *, dim=-1, project=True): + res = math.expmap0(u, c=self.c, dim=dim) + if project: + return math.project(res, c=self.c, dim=dim) + else: + return res + + def logmap0(self, x, *, dim=-1): + return math.logmap0(x, c=self.c, dim=dim) + + def transp0(self, y, u, *, dim=-1): + return math.parallel_transport0(y, u, c=self.c, dim=dim) + + def transp0back(self, y, u, *, dim=-1): + return math.parallel_transport0back(y, u, c=self.c, dim=dim) + + def gyration(self, x, y, z, *, dim=-1): + return math.gyration(x, y, z, c=self.c, dim=dim) + + def dist2plane(self, x, p, a, *, dim=-1, keepdim=False, signed=False): + return math.dist2plane( + x, p, a, dim=dim, c=self.c, keepdim=keepdim, signed=signed + ) + + def mobius_fn_apply(self, fn, x, *args, dim=-1, **kwargs): + return math.mobius_fn_apply(fn, x, *args, c=self.c, dim=dim, **kwargs) + + def mobius_fn_apply_chain(self, x, *fns, dim=-1): + return math.mobius_fn_apply_chain(x, *fns, c=self.c, dim=dim) class PoincareBallExact(PoincareBall): diff --git a/geoopt/manifolds/poincare/math.py b/geoopt/manifolds/poincare/math.py index 25d6ac0c..4e650529 100644 --- a/geoopt/manifolds/poincare/math.py +++ b/geoopt/manifolds/poincare/math.py @@ -1273,6 +1273,33 @@ def _parallel_transport0(y, v, c, dim: int = -1): return v * (1 - c * y.pow(2).sum(dim=dim, keepdim=True)).clamp_min(MIN_NORM) +def parallel_transport0back(x, v, *, c=1.0, dim: int = -1): + r""" + Special case parallel transport with last point at zero that + can be computed more efficiently and numerically stable + + Parameters + ---------- + x : tensor + target point + v : tensor + vector to be transported + c : float|tensor + ball negative curvature + dim : int + reduction dimension for operations + + Returns + ------- + tensor + """ + return _parallel_transport0back(x, v, c=c, dim=dim) + + +def _parallel_transport0back(x, v, c, dim: int = -1): + return v / (1 - c * x.pow(2).sum(dim=dim, keepdim=True)).clamp_min(MIN_NORM) + + def egrad2rgrad(x, grad, *, c=1.0, dim=-1): r""" Translate Euclidean gradient to Riemannian gradient on tangent space of :math:`x`