Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Poincare Manifold methods #78

Merged
merged 9 commits into from Jun 28, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Expand Up @@ -18,6 +18,7 @@ New Features
* Added expmap implementation (#43)
* Added dist, logmap implementation (#44)
* Added Poincare Ball model (#45)
* Poincare Ball manifold has now new methods (#78)

Maintenance
-----------
Expand Down
178 changes: 145 additions & 33 deletions geoopt/manifolds/poincare/__init__.py
Expand Up @@ -19,6 +19,7 @@
"""


# noinspection PyMethodOverriding
class PoincareBall(Manifold):
__doc__ = r"""{}

Expand Down Expand Up @@ -57,59 +58,170 @@ def _check_point_on_manifold(self, x, *, atol=1e-5, rtol=1e-5):
def _check_vector_on_tangent(self, x, u, *, atol=1e-5, rtol=1e-5):
return True, None

def dist(self, x, y, *, keepdim=False):
return math.dist(x, y, c=self.c, keepdim=keepdim)
def dist(self, x, y, *, keepdim=False, dim=-1):
return math.dist(x, y, c=self.c, keepdim=keepdim, dim=dim)

def egrad2rgrad(self, x, u):
return math.egrad2rgrad(x, u, c=self.c)
def egrad2rgrad(self, x, u, *, dim=-1):
return math.egrad2rgrad(x, u, c=self.c, dim=dim)

def retr(self, x, u):
def retr(self, x, u, *, dim=-1):
# always assume u is scaled properly
approx = x + u
return math.project(approx, c=self.c)
return math.project(approx, c=self.c, dim=dim)

def projx(self, x):
return math.project(x, c=self.c)
def projx(self, x, dim=-1):
return math.project(x, c=self.c, dim=dim)

def proju(self, x, u):
return u

def inner(self, x, u, v=None, *, keepdim=False):
def inner(self, x, u, v=None, *, keepdim=False, dim=-1):
if v is None:
v = u
return math.inner(x, u, v, c=self.c, keepdim=keepdim)
return math.inner(x, u, v, c=self.c, keepdim=keepdim, dim=dim)

def expmap(self, x, u):
return math.project(math.expmap(x, u, c=self.c), c=self.c)
def norm(self, x, u, *, keepdim=False, dim=-1):
return math.norm(x, u, keepdim=keepdim, dim=dim)

def logmap(self, x, y):
return math.logmap(x, y, c=self.c)
def expmap(self, x, u, *, project=True, dim=-1):
res = math.expmap(x, u, c=self.c, dim=dim)
if project:
return math.project(res, c=self.c, dim=dim)
else:
return res

def logmap(self, x, y, *, dim=-1):
return math.logmap(x, y, c=self.c, dim=dim)

def transp(self, x, y, v, *more):
def transp(self, x, y, v, *more, dim=-1):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have concerns whether this method should work with non-matching v, *more. The easy way is to do this in a loop and do not care.

Copy link
Member

@rrkarim rrkarim Jun 19, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or maybe check and raise the warning? (still need to loop tho)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the point in *more anyway and where we will use it? I have completely missed that part.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A flag in parameters may save nerves and add more clarity in what's going on

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

like stack=True

Copy link
Member Author

@ferrine ferrine Jun 19, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The point is to be able to transport multiple vectors with one pass, as this might be much more computationally cheaper. E.g. in Stiefel manifolds this allows performing this operation with one LU decomposition

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is how flag is supposed to be implemented

    def transp(self, x, y, v, *more, dim=-1, stack=True):
        if not more:
            return math.parallel_transport(x, y, v, c=self.c, dim=dim)
        else:
            if stack:
                vecs = torch.stack((v,) + more, dim=0)
                transp = math.parallel_transport(
                    x, y, vecs, c=self.c, dim=idx2sign(dim, x.dim())
                )
                return transp.unbind(0)
            else:
                return tuple(
                    math.parallel_transport(x, y, vec, c=self.c, dim=dim)
                    for vec in (v, *more)
                )

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, tuple approach is not that bad

    def transp(self, x, y, v, *more, dim=-1):
        if not more:
            return math.parallel_transport(x, y, v, c=self.c, dim=dim)
        else:
            return tuple(
                math.parallel_transport(x, y, vec, c=self.c, dim=dim)
                for vec in (v, *more)
            )

if not more:
return math.parallel_transport(x, y, v, c=self.c)
return math.parallel_transport(x, y, v, c=self.c, dim=dim)
else:
vecs = torch.stack((v,) + more, dim=0)
transp = math.parallel_transport(x, y, vecs, c=self.c)
return transp.unbind(0)
return tuple(
math.parallel_transport(x, y, vec, c=self.c, dim=dim)
for vec in (v, *more)
)

def transp_follow_retr(self, x, u, v, *more, dim=-1):
y = self.retr(x, u, dim=dim)
return self.transp(x, y, v, *more, dim=dim)

def transp_follow_expmap(self, x, u, v, *more, dim=-1, project=True):
y = self.expmap(x, u, dim=dim, project=project)
return self.transp(x, y, v, *more, dim=dim)

def expmap_transp(self, x, u, v, *more, dim=-1, project=True):
y = self.expmap(x, u, dim=dim, project=project)
vs = self.transp(x, y, v, *more, dim=dim)
return (y,) + make_tuple(vs)

def transp_follow_retr(self, x, u, v, *more):
y = self.retr(x, u)
return self.transp(x, y, v, *more)
def retr_transp(self, x, u, v, *more, dim=-1):
y = self.retr(x, u, dim=dim)
vs = self.transp(x, y, v, *more, dim=dim)
return (y,) + make_tuple(vs)

def transp_follow_expmap(self, x, u, v, *more):
y = self.expmap(x, u)
return self.transp(x, y, v, *more)
def mobius_add(self, x, y, *, dim=-1, project=True):
res = math.mobius_add(x, y, c=self.c, dim=dim)
if project:
return math.project(res, c=self.c, dim=dim)
else:
return res

def expmap_transp(self, x, u, v, *more):
y = self.expmap(x, u)
vs = self.transp(x, y, v, *more)
return (y,) + make_tuple(vs)
def mobius_sub(self, x, y, *, dim=-1, project=True):
res = math.mobius_sub(x, y, c=self.c, dim=dim)
if project:
return math.project(res, c=self.c, dim=dim)
else:
return res

def retr_transp(self, x, u, v, *more):
y = self.retr(x, u)
vs = self.transp(x, y, v, *more)
return (y,) + make_tuple(vs)
def mobius_coadd(self, x, y, *, dim=-1, project=True):
res = math.mobius_coadd(x, y, c=self.c, dim=dim)
if project:
return math.project(res, c=self.c, dim=dim)
else:
return res

def mobius_cosub(self, x, y, *, dim=-1, project=True):
res = math.mobius_coadd(x, y, c=self.c, dim=dim)
if project:
return math.project(res, c=self.c, dim=dim)
else:
return res

def mobius_scalar_mul(self, r, x, *, dim=-1, project=True):
res = math.mobius_scalar_mul(r, x, c=self.c, dim=dim)
if project:
return math.project(res, c=self.c, dim=dim)
else:
return res

def mobius_pointwise_mul(self, w, x, *, dim=-1, project=True):
res = math.mobius_pointwise_mul(w, x, c=self.c, dim=dim)
if project:
return math.project(res, c=self.c, dim=dim)
else:
return res

def mobius_matvec(self, m, x, *, dim=-1, project=True):
res = math.mobius_matvec(m, x, c=self.c, dim=dim)
if project:
return math.project(res, c=self.c, dim=dim)
else:
return res

def geodesic(self, t, x, y, *, dim=-1):
return math.geodesic(t, x, y, c=self.c, dim=dim)

def geodesic_unit(self, t, x, u, *, dim=-1, project=True):
res = math.geodesic_unit(t, x, u, c=self.c, dim=dim)
if project:
return math.project(res, c=self.c, dim=dim)
else:
return res

def lambda_x(self, x, *, dim=-1, keepdim=False):
return math.lambda_x(x, c=self.c, dim=dim, keepdim=keepdim)

def dist0(self, x, *, dim=-1, keepdim=False):
return math.dist0(x, c=self.c, dim=dim, keepdim=keepdim)

def expmap0(self, u, *, dim=-1, project=True):
res = math.expmap0(u, c=self.c, dim=dim)
if project:
return math.project(res, c=self.c, dim=dim)
else:
return res

def logmap0(self, x, *, dim=-1):
return math.logmap0(x, c=self.c, dim=dim)

def transp0(self, y, u, *, dim=-1):
return math.parallel_transport0(y, u, c=self.c, dim=dim)

def transp0back(self, y, u, *, dim=-1):
return math.parallel_transport0back(y, u, c=self.c, dim=dim)

def gyration(self, x, y, z, *, dim=-1):
return math.gyration(x, y, z, c=self.c, dim=dim)

def dist2plane(self, x, p, a, *, dim=-1, keepdim=False, signed=False):
return math.dist2plane(
x, p, a, dim=dim, c=self.c, keepdim=keepdim, signed=signed
)

def mobius_fn_apply(self, fn, x, *args, dim=-1, project=True, **kwargs):
res = math.mobius_fn_apply(fn, x, *args, c=self.c, dim=dim, **kwargs)
if project:
return math.project(res, c=self.c, dim=dim)
else:
return res

def mobius_fn_apply_chain(self, x, *fns, project=True, dim=-1):
res = math.mobius_fn_apply_chain(x, *fns, c=self.c, dim=dim)
if project:
return math.project(res, c=self.c, dim=dim)
else:
return res


class PoincareBallExact(PoincareBall):
Expand Down
27 changes: 27 additions & 0 deletions geoopt/manifolds/poincare/math.py
Expand Up @@ -1273,6 +1273,33 @@ def _parallel_transport0(y, v, c, dim: int = -1):
return v * (1 - c * y.pow(2).sum(dim=dim, keepdim=True)).clamp_min(MIN_NORM)


def parallel_transport0back(x, v, *, c=1.0, dim: int = -1):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps s/parallel_// for brevity?
And (later) s/\btransp\b/transport/ in PoincareBall class too so that we have just "transport" everywhere

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Parallel transport is a special case of vector transport that is exact one. That's the purpose of having parallel in math. In the manifold we do not provide any different, however but could for sphere

r"""
Special case parallel transport with last point at zero that
can be computed more efficiently and numerically stable

Parameters
----------
x : tensor
target point
v : tensor
vector to be transported
c : float|tensor
ball negative curvature
dim : int
reduction dimension for operations

Returns
-------
tensor
"""
return _parallel_transport0back(x, v, c=c, dim=dim)


def _parallel_transport0back(x, v, c, dim: int = -1):
return v / (1 - c * x.pow(2).sum(dim=dim, keepdim=True)).clamp_min(MIN_NORM)


def egrad2rgrad(x, grad, *, c=1.0, dim=-1):
r"""
Translate Euclidean gradient to Riemannian gradient on tangent space of :math:`x`
Expand Down
2 changes: 2 additions & 0 deletions geoopt/manifolds/sphere.py
Expand Up @@ -104,6 +104,8 @@ def _check_vector_on_tangent(self, x, u, *, atol=1e-5, rtol=1e-5):
return True, None

def inner(self, x, u, v=None, *, keepdim=False):
if v is None:
v = u
return (u * v).sum(-1, keepdim=keepdim)

def projx(self, x):
Expand Down
2 changes: 2 additions & 0 deletions geoopt/manifolds/stiefel.py
Expand Up @@ -182,6 +182,8 @@ def retr_transp(self, x, u, v, *more):
return (y,) + make_tuple(vs)

def inner(self, x, u, v=None, *, keepdim=False):
if v is None:
v = u
return (u * v).sum([-1, -2], keepdim=keepdim)

def retr(self, x, u):
Expand Down