Skip to content

Commit

Permalink
logmap
Browse files Browse the repository at this point in the history
  • Loading branch information
ferrine committed Feb 13, 2019
1 parent b43d706 commit 45fea9d
Show file tree
Hide file tree
Showing 7 changed files with 205 additions and 21 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ New Features
* Added ``Sphere`` manifold (#25)
* Added ``SphereSubspaceIntersection``, ``SphereSubspaceComplementIntersection`` manifolds (#29)
* Added expmap implementation (#43)
* Added dist, logmap implementation

Maintenance
-----------
Expand Down
15 changes: 15 additions & 0 deletions geoopt/docutils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import re


def insert_docs(doc, pattern=None, repl=None):
def wrapper(fn):
# assume wrapping
if pattern is not None:
if repl is None:
raise RuntimeError("need repl parameter")
fn.__doc__ = re.sub(pattern, repl, doc)
else:
fn.__doc__ = doc
return fn

return wrapper
107 changes: 86 additions & 21 deletions geoopt/manifolds/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def check_vector_on_tangent(self, x, u, explain=False, atol=1e-5, rtol=1e-5):
x : tensor
point on the manifold
u : tensor
vector on the tangent space to ``x``
vector on the tangent space to :math:`x`
atol: float
absolute tolerance as in :func:`numpy.allclose`
rtol: float
Expand Down Expand Up @@ -443,7 +443,7 @@ def assert_check_vector_on_tangent(self, x, u, atol=1e-5, rtol=1e-5):
x : tensor
point on the manifold
u : tensor
vector on the tangent space to ``x``
vector on the tangent space to :math:`x`
atol: float
absolute tolerance as in :func:`numpy.allclose`
rtol: float
Expand All @@ -464,6 +464,24 @@ def assert_check_vector_on_tangent(self, x, u, atol=1e-5, rtol=1e-5):
)
)

def dist(self, x, y):
"""
Compute distance between 2 points on the manifold that is the shortest path along geodesics
Parameters
----------
x : tensor
point on the manifold
y : tensor
point on the manifold
Returns
-------
scalar
distance between two points
"""
return self._dist(x, y)

def retr(self, x, u, t=1.0, order=None):
"""
Perform a retraction from point :math:`x` with
Expand All @@ -474,9 +492,9 @@ def retr(self, x, u, t=1.0, order=None):
x : tensor
point on the manifold
u : tensor
tangent vector at point x
tangent vector at point :math:`x`
t : scalar
time to go with direction u
time to go with direction :math:`u`
order : int
order of retraction approximation, by default uses the simplest that is usually a first order approximation.
Possible choices depend on a concrete manifold and -1 stays for exponential map
Expand All @@ -499,9 +517,9 @@ def expmap(self, x, u, t=1.0):
x : tensor
point on the manifold
u : tensor
tangent vector at point x
tangent vector at point :math:`x`
t : scalar
time to go with direction u
time to go with direction :math:`u`
Returns
-------
Expand All @@ -516,6 +534,29 @@ def expmap(self, x, u, t=1.0):
t = self.broadcast_scalar(t)
return self._retr_funcs[-1](self, x=x, u=u, t=t)

def logmap(self, x, y):
"""
Perform an logarithmic map for a pair of points :math:`x` and :math:`y`.
The result lies in :math:`u \in T_x\mathcal{M}` is such that:
.. math::
y = \operatorname{Exp}_x(\operatorname{Log}_{x}(y))
Parameters
----------
x : tensor
point on the manifold
y : tensor
point on the manifold
Returns
-------
tensor
tangent vector
"""
return self._logmap(x, y)

def expmap_transp(self, x, v, *more, u, t=1.0):
"""
Perform an exponential map from point :math:`x` with
Expand All @@ -526,13 +567,13 @@ def expmap_transp(self, x, v, *more, u, t=1.0):
x : tensor
point on the manifold
v : tensor
tangent vector at point x to be transported
tangent vector at point :math:`x` to be transported
more : tensors
other tangent vectors at point x to be transported
other tangent vectors at point :math:`x` to be transported
u : tensor
tangent vector at point x
tangent vector at point :math:`x`
t : scalar
time to go with direction u
time to go with direction :math:`u`
Returns
-------
Expand Down Expand Up @@ -561,13 +602,13 @@ def transp(self, x, v, *more, u=None, t=1.0, y=None, order=None):
x : tensor
point on the manifold
v : tensor
tangent vector at point x to be transported
tangent vector at point :math:`x` to be transported
more : tensors
other tangent vectors at point x to be transported
other tangent vectors at point :math:`x` to be transported
u : tensor
tangent vector at point x (required if :math:`y` is not provided)
tangent vector at point :math:`x` (required if :math:`y` is not provided)
t : scalar
time to go with direction u
time to go with direction :math:`u`
y : tensor
the target point for vector transport (required if :math:`u` is not provided)
order : int
Expand Down Expand Up @@ -599,9 +640,9 @@ def inner(self, x, u, v=None):
x : tensor
point on the manifold
u : tensor
tangent vector at point x
tangent vector at point :math:`x`
v : tensor (optional)
tangent vector at point x
tangent vector at point :math:`x`
Returns
-------
Expand Down Expand Up @@ -675,14 +716,14 @@ def retr_transp(self, x, v, *more, u, t=1.0, order=None):
----------
x : tensor
point on the manifold
tangent vector at point x
t : scalar
time to go with direction u
v : tensor
tangent vector at point x to be transported (required keyword only argument)
tangent vector at point :math:`x` to be transported
more : tensors
other tangent vector at point x to be transported
other tangent vector at point :math:`x` to be transported
u : tensor
tangent vector at point :math:`x` (required keyword only argument)
t : scalar
time to go with direction :math:`u`
order : int
order of retraction approximation, by default uses the simplest.
Possible choices depend on a concrete manifold and -1 stays for exponential map
Expand Down Expand Up @@ -834,6 +875,30 @@ def _retr(self, x, u, t):
"""
_transp2y = not_implemented

# def _logmap(self, x, y):
"""
Developer Guide
Private implementation for logarithmic map for :math:`x` and :math:`y`. Should allow broadcasting.
"""
_logmap = not_implemented

# def _expmap(self, x, y):
"""
Developer Guide
Private implementation for exponential map for :math:`x` and :math:`y`. Should allow broadcasting.
"""
_expmap = not_implemented

# def _dist(self, x, y):
"""
Developer Guide
Private implementation for computing distance between :math:`x` and :math:`y`. Should allow broadcasting.
"""
_dist = not_implemented

@abc.abstractmethod
def _inner(self, x, u, v):
"""
Expand Down
6 changes: 6 additions & 0 deletions geoopt/manifolds/euclidean.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,9 @@ def _transp2y(self, x, v, *more, y):
return v
else:
return (v,) + more

def _logmap(self, x, y):
return y - x

def _dist(self, x, y):
return (x - y).abs()
11 changes: 11 additions & 0 deletions geoopt/manifolds/sphere.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,17 @@ def _expmap_transp(self, x, v, *more, u, t):
else:
return y, vs

def _logmap(self, x, y):
u = self._proju(x, y - x)
dist = self._dist(x, y).unsqueeze(-1)
# If the two points are "far apart", correct the norm.
cond = dist.gt(1e-6)
return torch.where(cond, u * dist / u.norm(dim=-1, keepdim=True), u)

def _dist(self, x, y):
inner = self._inner(None, x, y).clamp(-1, 1)
return torch.acos(inner)


class SphereSubspaceIntersection(Sphere):
r"""
Expand Down
42 changes: 42 additions & 0 deletions geoopt/tensor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch.nn
from .manifolds import Euclidean
from .docutils import insert_docs


__all__ = ["ManifoldTensor", "ManifoldParameter"]
Expand Down Expand Up @@ -28,31 +29,72 @@ def __new__(cls, *args, manifold=Euclidean(), requires_grad=False, **kwargs):
return instance

def proj_(self):
"""
Inplace projection to the manifold
Returns
-------
tensor
same instance
"""
with torch.no_grad():
self.set_(self.manifold.projx(self.data))
return self

@insert_docs(Euclidean.retr.__doc__, r"\s+x : .+\n.+", "")
def retr(self, u, t=1.0, order=None):
return self.manifold.retr(self, u=u, t=t, order=order)

@insert_docs(Euclidean.expmap.__doc__, r"\s+x : .+\n.+", "")
def expmap(self, u, t=1.0):
return self.manifold.expmap(self, u=u, t=t)

@insert_docs(Euclidean.inner.__doc__, r"\s+x : .+\n.+", "")
def inner(self, u, v=None):
return self.manifold.inner(self, u=u, v=v)

@insert_docs(Euclidean.proju.__doc__, r"\s+x : .+\n.+", "")
def proju(self, u):
return self.manifold.proju(self, u)

@insert_docs(Euclidean.transp.__doc__, r"\s+x : .+\n.+", "")
def transp(self, v, *more, u=None, t=1.0, y=None, order=None):
return self.manifold.transp(self, v, *more, u=u, t=t, y=y, order=order)

@insert_docs(Euclidean.retr_transp.__doc__, r"\s+x : .+\n.+", "")
def retr_transp(self, v, *more, u, t=1.0, order=None):
return self.manifold.retr_transp(self, u, *more, u=v, t=t, order=order)

@insert_docs(Euclidean.expmap_transp.__doc__, r"\s+x : .+\n.+", "")
def expmap_transp(self, v, *more, u, t=1.0):
return self.manifold.expmap_transp(self, u, *more, u=v, t=t)

def dist(self, other, p=2):
"""
Return euclidean or geodesic distance between points on the manifold. Allows broadcasting
Parameters
----------
other : tensor
p : str|int
The norm to use. The default behaviour is not changed and is just euclidean distance.
To compute geodesic distance, :attr:`p` should be set to ``"g"``
Returns
-------
scalar
"""
if p == "g":
return self.manifold.dist(self, other)
else:
return super().dist(other)

@insert_docs(Euclidean.logmap.__doc__, r"\s+x : .+\n.+", "")
def logmap(self, y):
return self.manifold.logmap(self, y)

def __repr__(self):
return "Tensor on {} containing:\n".format(
self.manifold
Expand Down
44 changes: 44 additions & 0 deletions tests/test_manifold.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,47 @@ def test_reversibility(unary_case, t):
else:
assert not np.allclose(X1, X, atol=1e-5)
assert not np.allclose(U1, U, atol=1e-5)


def test_dist(unary_case):
if type(unary_case.manifold)._dist is geoopt.manifolds.base.not_implemented:
pytest.skip("logmap is not implemented for {}".format(unary_case.manifold))
torch.manual_seed(43)
x = torch.randn(*unary_case.shape, dtype=unary_case.x.dtype)
y = torch.randn(*unary_case.shape, dtype=unary_case.x.dtype)
x = unary_case.manifold.projx(x)
y = unary_case.manifold.projx(y)
dhat = unary_case.manopt_manifold.dist(x.numpy(), y.numpy())
d = unary_case.manifold.dist(x, y)
np.testing.assert_allclose(d, dhat)


def test_logmap(unary_case, t):
if type(unary_case.manifold)._logmap is geoopt.manifolds.base.not_implemented:
pytest.skip("logmap is not implemented for {}".format(unary_case.manifold))

x = unary_case.x
v = unary_case.v
y = unary_case.manopt_manifold.exp(x.numpy(), v.numpy() * t)
vman = unary_case.manopt_manifold.log(x.numpy(), y)
vhat = unary_case.manifold.logmap(x, torch.as_tensor(y))
np.testing.assert_allclose(vhat, vman)
ey = unary_case.manifold.expmap(x, vhat)
np.testing.assert_allclose(y, ey)


def test_logmap_many(unary_case, t):
if type(unary_case.manifold)._logmap is geoopt.manifolds.base.not_implemented:
pytest.skip("logmap is not implemented for {}".format(unary_case.manifold))

torch.manual_seed(43)
X = torch.randn(4, *unary_case.shape, dtype=unary_case.x.dtype)
U = torch.randn(4, *unary_case.shape, dtype=unary_case.x.dtype)
X = unary_case.manifold.projx(X)
U = unary_case.manifold.proju(X, U)

Y = unary_case.manifold.expmap(X, U, t=t)
Uh = unary_case.manifold.logmap(X, Y)
Yh = unary_case.manifold.expmap(X, Uh)

np.testing.assert_allclose(Yh, Y)

0 comments on commit 45fea9d

Please sign in to comment.