Skip to content

Commit

Permalink
Merge 8dcdd5a into 8eb0722
Browse files Browse the repository at this point in the history
  • Loading branch information
ferrine committed Feb 13, 2019
2 parents 8eb0722 + 8dcdd5a commit b081cb1
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 11 deletions.
2 changes: 1 addition & 1 deletion geoopt/linalg/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .batch_linalg import svd, qr, sym, extract_diag, matrix_rank, expm
from .batch_linalg import svd, qr, sym, extract_diag, matrix_rank, expm, block_matrix
10 changes: 9 additions & 1 deletion geoopt/linalg/batch_linalg.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from . import _expm

__all__ = ["svd", "qr", "sym", "extract_diag", "matrix_rank", "expm"]
__all__ = ["svd", "qr", "sym", "extract_diag", "matrix_rank", "expm", "block_matrix"]


@torch.jit.script
Expand Down Expand Up @@ -114,3 +114,11 @@ def expm(x): # pragma: no cover
exp += [e]
result = torch.stack(exp).view(x.shape)
return result


def block_matrix(blocks):
# [[A, B], [C, D]] ->
# [AB]
# [CD]
blocks = tuple(torch.cat(mats, dim=-1) for mats in blocks)
return torch.cat(blocks, dim=-2)
1 change: 1 addition & 0 deletions geoopt/manifolds/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def set_default_order(self, order):
self._retr_funcs = self._retr_funcs.copy()
self._retr_funcs[None] = self._retr_funcs[order]
self._transport_follow_funcs = self._transport_follow_funcs.copy()
self._transport_follow_funcs[None] = self._transport_follow_funcs[order]
self._retr_funcs[None] = self._retr_funcs[order]
self._default_order = order
return self
Expand Down
19 changes: 17 additions & 2 deletions geoopt/manifolds/sphere.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,17 @@ def _projx(self, x):
def _proju(self, x, u):
return u - (x * u).sum(dim=-1, keepdim=True) * x

def _retr(self, x, u, t):
def _expmap(self, x, u, t):
ut = u * t
norm_ut = ut.norm(dim=-1, keepdim=True)
exp = x * torch.cos(norm_ut) + ut * torch.sin(norm_ut) / norm_ut
retr = self._projx(x + ut)
cond = norm_ut < 1e-3
cond = norm_ut > 1e-3
return torch.where(cond, exp, retr)

def _retr(self, x, u, t):
return self._projx(x + u * t)

def _transp_follow(self, x, v, *more, u, t):
y = self._retr(x, u, t)
return self._transp2y(x, v, *more, y=y)
Expand All @@ -71,6 +74,18 @@ def _transp2y(self, x, v, *more, y):
else:
return self._proju(y, v)

def _transp_follow_expmap(self, x, v, *more, u, t):
y = self._expmap(x, u, t)
return self._transp2y(x, v, *more, y=y)

def _expmap_transp(self, x, v, *more, u, t):
y = self._expmap(x, u, t)
vs = self._transp2y(x, v, *more, y=y)
if more:
return (y,) + vs
else:
return y, vs


class SphereSubspaceIntersection(Sphere):
r"""
Expand Down
24 changes: 24 additions & 0 deletions geoopt/manifolds/stiefel.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,27 @@ def _retr(self, x, u, t):
unflip = linalg.batch_linalg.extract_diag(r).sign().add(0.5).sign()
q *= unflip[..., None, :]
return q

def _expmap(self, x, u, t):
u = u * t
xtu = x.transpose(-1, -2) @ u
utu = u.transpose(-1, -2) @ u
eye = torch.zeros_like(utu)
eye[..., torch.arange(utu.shape[-2]), torch.arange(utu.shape[-2])] += 1
logw = linalg.block_matrix([[xtu, -utu], [eye, xtu]])
w = linalg.expm(logw)
z = torch.cat((linalg.expm(-xtu), torch.zeros_like(utu)), dim=-2)
y = torch.cat((x, u), dim=-1) @ w @ z
return y

def _expmap_transp(self, x, v, *more, u, t):
y = self._expmap(x, u, t)
vs = self._transp2y(x, v, *more, y=y)
if more:
return (y,) + vs
else:
return y, vs

def _transp_follow_expmap(self, x, v, *more, u, t):
y = self._expmap(x, u, t)
return self._transp2y(x, v, *more, y=y)
26 changes: 19 additions & 7 deletions tests/test_manifold.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,13 @@
import pymanopt.manifolds


@pytest.fixture("module", params=[1, -1])
def retraction_order(request):
return request.param


@pytest.fixture(
"session",
"module",
params=[
# match implementation of pymanopt for stiefel
functools.partial(geoopt.manifolds.Stiefel, canonical=False),
Expand All @@ -25,8 +30,12 @@
),
],
)
def manifold(request):
return request.param()
def manifold(request, retraction_order):
man = request.param()
try:
return man.set_default_order(retraction_order)
except ValueError:
pytest.skip("not supported retraction order for {}".format(man))


mannopt = {
Expand Down Expand Up @@ -106,16 +115,19 @@ def test_vector_projection_via_assert(unary_case):
unary_case.manifold.assert_check_vector_on_tangent(x, pv)


def test_retraction(unary_case):
def test_retraction(unary_case, retraction_order):
if isinstance(unary_case.manifold, geoopt.manifolds.CanonicalStiefel):
pytest.skip("pymanopt uses euclidean Stiefel")
x = unary_case.x
v = unary_case.v

y = x.retr(v, 1.0)
y_star = unary_case.manopt_manifold.retr(x.numpy(), v.numpy())

np.testing.assert_allclose(y, y_star)
if retraction_order == 1:
y_star = unary_case.manopt_manifold.retr(x.numpy(), v.numpy())
np.testing.assert_allclose(y, y_star)
elif retraction_order == -1:
y_star = unary_case.manopt_manifold.exp(x.numpy(), v.numpy())
np.testing.assert_allclose(y, y_star)


def test_transport(unary_case):
Expand Down

0 comments on commit b081cb1

Please sign in to comment.