Skip to content

Commit

Permalink
Merge a25b81d into 89df0b6
Browse files Browse the repository at this point in the history
  • Loading branch information
ferrine committed Jun 20, 2019
2 parents 89df0b6 + a25b81d commit 78d36ba
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 33 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Expand Up @@ -18,6 +18,7 @@ New Features
* Added expmap implementation (#43)
* Added dist, logmap implementation (#44)
* Added Poincare Ball model (#45)
* Poincare Ball manifold has now new methods (#78)

Maintenance
-----------
Expand Down
170 changes: 137 additions & 33 deletions geoopt/manifolds/poincare/__init__.py
Expand Up @@ -19,6 +19,7 @@
"""


# noinspection PyMethodOverriding
class PoincareBall(Manifold):
__doc__ = r"""{}
Expand Down Expand Up @@ -57,59 +58,162 @@ def _check_point_on_manifold(self, x, *, atol=1e-5, rtol=1e-5):
def _check_vector_on_tangent(self, x, u, *, atol=1e-5, rtol=1e-5):
return True, None

def dist(self, x, y, *, keepdim=False):
return math.dist(x, y, c=self.c, keepdim=keepdim)
def dist(self, x, y, *, keepdim=False, dim=-1):
return math.dist(x, y, c=self.c, keepdim=keepdim, dim=dim)

def egrad2rgrad(self, x, u):
return math.egrad2rgrad(x, u, c=self.c)
def egrad2rgrad(self, x, u, *, dim=-1):
return math.egrad2rgrad(x, u, c=self.c, dim=dim)

def retr(self, x, u):
def retr(self, x, u, *, dim=-1):
# always assume u is scaled properly
approx = x + u
return math.project(approx, c=self.c)
return math.project(approx, c=self.c, dim=dim)

def projx(self, x):
return math.project(x, c=self.c)
def projx(self, x, dim=-1):
return math.project(x, c=self.c, dim=dim)

def proju(self, x, u):
return u

def inner(self, x, u, v=None, *, keepdim=False):
def inner(self, x, u, v=None, *, keepdim=False, dim=-1):
if v is None:
v = u
return math.inner(x, u, v, c=self.c, keepdim=keepdim)
return math.inner(x, u, v, c=self.c, keepdim=keepdim, dim=dim)

def expmap(self, x, u):
return math.project(math.expmap(x, u, c=self.c), c=self.c)
def norm(self, x, u, *, keepdim=False, dim=-1):
return math.norm(x, u, keepdim=keepdim, dim=dim)

def logmap(self, x, y):
return math.logmap(x, y, c=self.c)
def expmap(self, x, u, *, project=True, dim=-1):
res = math.expmap(x, u, c=self.c, dim=dim)
if project:
return math.project(res, c=self.c, dim=dim)
else:
return res

def logmap(self, x, y, *, dim=-1):
return math.logmap(x, y, c=self.c, dim=dim)

def transp(self, x, y, v, *more):
def transp(self, x, y, v, *more, dim=-1):
if not more:
return math.parallel_transport(x, y, v, c=self.c)
return math.parallel_transport(x, y, v, c=self.c, dim=dim)
else:
vecs = torch.stack((v,) + more, dim=0)
transp = math.parallel_transport(x, y, vecs, c=self.c)
return transp.unbind(0)
return tuple(
math.parallel_transport(x, y, vec, c=self.c, dim=dim)
for vec in (v, *more)
)

def transp_follow_retr(self, x, u, v, *more, dim=-1):
y = self.retr(x, u, dim=dim)
return self.transp(x, y, v, *more, dim=dim)

def transp_follow_expmap(self, x, u, v, *more, dim=-1, project=True):
y = self.expmap(x, u, dim=dim, project=project)
return self.transp(x, y, v, *more, dim=dim)

def expmap_transp(self, x, u, v, *more, dim=-1, project=True):
y = self.expmap(x, u, dim=dim, project=project)
vs = self.transp(x, y, v, *more, dim=dim)
return (y,) + make_tuple(vs)

def transp_follow_retr(self, x, u, v, *more):
y = self.retr(x, u)
return self.transp(x, y, v, *more)
def retr_transp(self, x, u, v, *more, dim=-1):
y = self.retr(x, u, dim=dim)
vs = self.transp(x, y, v, *more, dim=dim)
return (y,) + make_tuple(vs)

def transp_follow_expmap(self, x, u, v, *more):
y = self.expmap(x, u)
return self.transp(x, y, v, *more)
def mobius_add(self, x, y, *, dim=-1, project=True):
res = math.mobius_add(x, y, c=self.c, dim=dim)
if project:
return math.project(res, c=self.c, dim=dim)
else:
return res

def expmap_transp(self, x, u, v, *more):
y = self.expmap(x, u)
vs = self.transp(x, y, v, *more)
return (y,) + make_tuple(vs)
def mobius_sub(self, x, y, *, dim=-1, project=True):
res = math.mobius_sub(x, y, c=self.c, dim=dim)
if project:
return math.project(res, c=self.c, dim=dim)
else:
return res

def retr_transp(self, x, u, v, *more):
y = self.retr(x, u)
vs = self.transp(x, y, v, *more)
return (y,) + make_tuple(vs)
def mobius_coadd(self, x, y, *, dim=-1, project=True):
res = math.mobius_coadd(x, y, c=self.c, dim=dim)
if project:
return math.project(res, c=self.c, dim=dim)
else:
return res

def mobius_cosub(self, x, y, *, dim=-1, project=True):
res = math.mobius_coadd(x, y, c=self.c, dim=dim)
if project:
return math.project(res, c=self.c, dim=dim)
else:
return res

def mobius_scalar_mul(self, r, x, *, dim=-1, project=True):
res = math.mobius_scalar_mul(r, x, c=self.c, dim=dim)
if project:
return math.project(res, c=self.c, dim=dim)
else:
return res

def mobius_pointwise_mul(self, w, x, *, dim=-1, project=True):
res = math.mobius_pointwise_mul(w, x, c=self.c, dim=dim)
if project:
return math.project(res, c=self.c, dim=dim)
else:
return res

def mobius_matvec(self, m, x, *, dim=-1, project=True):
res = math.mobius_matvec(m, x, c=self.c, dim=dim)
if project:
return math.project(res, c=self.c, dim=dim)
else:
return res

def geodesic(self, t, x, y, *, dim=-1):
return math.geodesic(t, x, y, c=self.c, dim=dim)

def geodesic_unit(self, t, x, u, *, dim=-1, project=True):
res = math.geodesic_unit(t, x, u, c=self.c, dim=dim)
if project:
return math.project(res, c=self.c, dim=dim)
else:
return res

def lambda_x(self, x, *, dim=-1, keepdim=False):
return math.lambda_x(x, c=self.c, dim=dim, keepdim=keepdim)

def dist0(self, x, *, dim=-1, keepdim=False):
return math.dist0(x, c=self.c, dim=dim, keepdim=keepdim)

def expmap0(self, u, *, dim=-1, project=True):
res = math.expmap0(u, c=self.c, dim=dim)
if project:
return math.project(res, c=self.c, dim=dim)
else:
return res

def logmap0(self, x, *, dim=-1):
return math.logmap0(x, c=self.c, dim=dim)

def transp0(self, y, u, *, dim=-1):
return math.parallel_transport0(y, u, c=self.c, dim=dim)

def transp0back(self, y, u, *, dim=-1):
return math.parallel_transport0back(y, u, c=self.c, dim=dim)

def gyration(self, x, y, z, *, dim=-1):
return math.gyration(x, y, z, c=self.c, dim=dim)

def dist2plane(self, x, p, a, *, dim=-1, keepdim=False, signed=False):
return math.dist2plane(
x, p, a, dim=dim, c=self.c, keepdim=keepdim, signed=signed
)

def mobius_fn_apply(self, fn, x, *args, dim=-1, **kwargs):
return math.mobius_fn_apply(fn, x, *args, c=self.c, dim=dim, **kwargs)

def mobius_fn_apply_chain(self, x, *fns, dim=-1):
return math.mobius_fn_apply_chain(x, *fns, c=self.c, dim=dim)


class PoincareBallExact(PoincareBall):
Expand Down
27 changes: 27 additions & 0 deletions geoopt/manifolds/poincare/math.py
Expand Up @@ -1273,6 +1273,33 @@ def _parallel_transport0(y, v, c, dim: int = -1):
return v * (1 - c * y.pow(2).sum(dim=dim, keepdim=True)).clamp_min(MIN_NORM)


def parallel_transport0back(x, v, *, c=1.0, dim: int = -1):
r"""
Special case parallel transport with last point at zero that
can be computed more efficiently and numerically stable
Parameters
----------
x : tensor
target point
v : tensor
vector to be transported
c : float|tensor
ball negative curvature
dim : int
reduction dimension for operations
Returns
-------
tensor
"""
return _parallel_transport0back(x, v, c=c, dim=dim)


def _parallel_transport0back(x, v, c, dim: int = -1):
return v / (1 - c * x.pow(2).sum(dim=dim, keepdim=True)).clamp_min(MIN_NORM)


def egrad2rgrad(x, grad, *, c=1.0, dim=-1):
r"""
Translate Euclidean gradient to Riemannian gradient on tangent space of :math:`x`
Expand Down

0 comments on commit 78d36ba

Please sign in to comment.