Skip to content

Commit

Permalink
Merge 3d0791d into 7f4f54c
Browse files Browse the repository at this point in the history
  • Loading branch information
ferrine committed Jan 27, 2019
2 parents 7f4f54c + 3d0791d commit ad7ca8d
Show file tree
Hide file tree
Showing 7 changed files with 193 additions and 47 deletions.
2 changes: 1 addition & 1 deletion docs/manifolds.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Manifolds
.. currentmodule:: geoopt.manifolds

.. automodule:: geoopt.manifolds
:members: Euclidean, Stiefel, Sphere
:members: Euclidean, Stiefel, Sphere, SphereSubspaceComplementIntersection, SphereSubspaceIntersection

Extending ``geoopt``
--------------------
Expand Down
3 changes: 3 additions & 0 deletions geoopt/_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import torch

_TORCH_LESS_THAN_ONE = tuple(map(int, torch.__version__.split(".")[:2])) < (1, 0)
6 changes: 5 additions & 1 deletion geoopt/manifolds/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from .base import Manifold
from .euclidean import Euclidean
from .stiefel import Stiefel, EuclideanStiefel, CanonicalStiefel
from .sphere import Sphere
from .sphere import (
Sphere,
SphereSubspaceComplementIntersection,
SphereSubspaceIntersection,
)
89 changes: 88 additions & 1 deletion geoopt/manifolds/sphere.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
import torch

from .base import Manifold
import geoopt.util.linalg

__all__ = [
"Sphere",
"SphereSubspaceIntersection",
"SphereSubspaceComplementIntersection",
]


class Sphere(Manifold):
"""
r"""
Sphere manifold induced by the following constraint
.. math::
Expand Down Expand Up @@ -67,3 +74,83 @@ def _retr_transp(self, x, u, t, v, *more):
y = self._retr(x, u, t)
vs = self._transp_many(x, u, t, v, *more, y=y)
return (y,) + vs


class SphereSubspaceIntersection(Sphere):
r"""
Sphere manifold induced by the following constraint
.. math::
\|x\|=1\\
x \in \mathbb{span}(U)
Parameters
----------
span : matrix
the subspace to intersect with
"""

name = "SphereSubspace"

def __init__(self, span):
self._configure_manifold(span)
if (geoopt.util.linalg.matrix_rank(self._projector) == 1).any():
raise ValueError(
"Manifold only consists of isolated points when "
"subspace is 1-dimensional."
)

def _check_shape(self, x, name):
ok, reason = super()._check_shape(x, name)
if ok:

ok = x.shape[-1] == self._projector.shape[-2]
if not ok:
reason = "The leftmost shape of `span` does not match `x`: {}, {}".format(
x.shape[-1], self._projector.shape[-1]
)
elif x.dim() < (self._projector.dim() - 1):
reason = "`x` should have at least {} dimensions but has {}".format(
self._projector.dim() - 1, x.dim()
)
else:
reason = None
return ok, reason

def _configure_manifold(self, span):
Q, _ = geoopt.util.linalg.qr(span)
self._projector = Q @ Q.transpose(-1, -2)

def _project_on_subspace(self, x):
return x @ self._projector.transpose(-1, -2)

def _proju(self, x, u):
u = super()._proju(x, u)
return self._project_on_subspace(u)

def _projx(self, x):
x = self._project_on_subspace(x)
return super()._projx(x)


class SphereSubspaceComplementIntersection(SphereSubspaceIntersection):
r"""
Sphere manifold induced by the following constraint
.. math::
\|x\|=1\\
x \in \mathbb{span}(U)
Parameters
----------
span : matrix
the subspace to compliment (being orthogonal to)
"""

def _configure_manifold(self, span):
Q, _ = geoopt.util.linalg.qr(span)
P = -Q @ Q.transpose(-1, -2)
P[..., torch.arange(P.shape[-2]), torch.arange(P.shape[-2])] += 1
self._projector = P
3 changes: 2 additions & 1 deletion geoopt/optim/tracing.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import torch.jit
from .._compat import _TORCH_LESS_THAN_ONE


def _compat_trace(fn, args):
if tuple(map(int, torch.__version__.split(".")[:2])) < (1, 0):
if _TORCH_LESS_THAN_ONE:
# torch.jit here does not support inplace ops
return fn
else:
Expand Down
81 changes: 57 additions & 24 deletions geoopt/util/linalg.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,53 @@
import torch
import itertools
import warnings
from .._compat import _TORCH_LESS_THAN_ONE

__all__ = ["svd", "qr", "sym", "extract_diag"]
__all__ = ["svd", "qr", "sym", "extract_diag", "matrix_rank"]


def _warn_lost_grad(x, op):
if torch.is_grad_enabled() and x.requires_grad:
warnings.warn(
"Gradient for operation {0}(...) is lost, please use the pytorch native analog as {0} this "
"is more optimized for internal purposes which do not require gradients but vectorization".format(
op
)
)


def svd(x):
# https://discuss.pytorch.org/t/multidimensional-svd/4366/2
batches = x.shape[:-2]
if batches:
# in most cases we do not require gradients when applying svd (e.g. in projection)
assert not x.requires_grad
n, m = x.shape[-2:]
k = min(n, m)
U, d, V = x.new(*batches, n, k), x.new(*batches, k), x.new(*batches, m, k)
for idx in itertools.product(*map(range, batches)):
U[idx], d[idx], V[idx] = torch.svd(x[idx])
return U, d, V
else:
return torch.svd(x)
_warn_lost_grad(x, "geoopt.utils.linalg.svd")
with torch.no_grad():
batches = x.shape[:-2]
if batches:
# in most cases we do not require gradients when applying svd (e.g. in projection)
n, m = x.shape[-2:]
k = min(n, m)
U, d, V = x.new(*batches, n, k), x.new(*batches, k), x.new(*batches, m, k)
for idx in itertools.product(*map(range, batches)):
U[idx], d[idx], V[idx] = torch.svd(x[idx])
return U, d, V
else:
return torch.svd(x)


def qr(x):
# vectorized version as svd
batches = x.shape[:-2]
if batches:
# in most cases we do not require gradients when applying qr (e.g. in retraction)
assert not x.requires_grad
n, m = x.shape[-2:]
Q, R = x.new(*batches, n, m), x.new(*batches, m, m)
for idx in itertools.product(*map(range, batches)):
Q[idx], R[idx] = torch.qr(x[idx])
return Q, R
else:
return torch.qr(x)
_warn_lost_grad(x, "geoopt.utils.linalg.qr")
with torch.no_grad():
batches = x.shape[:-2]
if batches:
# in most cases we do not require gradients when applying qr (e.g. in retraction)
assert not x.requires_grad
n, m = x.shape[-2:]
Q, R = x.new(*batches, n, m), x.new(*batches, m, m)
for idx in itertools.product(*map(range, batches)):
Q[idx], R[idx] = torch.qr(x[idx])
return Q, R
else:
return torch.qr(x)


def sym(x):
Expand All @@ -43,3 +58,21 @@ def extract_diag(x):
n, m = x.shape[-2:]
k = min(n, m)
return x[..., torch.arange(k), torch.arange(k)]


def matrix_rank(x):
if _TORCH_LESS_THAN_ONE:
import numpy as np

return torch.from_numpy(
np.asarray(np.linalg.matrix_rank(x.detach().cpu().numpy()))
).type_as(x)
with torch.no_grad():
batches = x.shape[:-2]
if batches:
out = x.new(*batches)
for idx in itertools.product(*map(range, batches)):
out[idx] = torch.matrix_rank(x[idx])
return out
else:
return torch.matrix_rank(x)
56 changes: 37 additions & 19 deletions tests/test_manifold.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@
functools.partial(geoopt.manifolds.Stiefel, canonical=True),
geoopt.manifolds.Euclidean,
geoopt.manifolds.Sphere,
functools.partial(
geoopt.manifolds.SphereSubspaceIntersection,
torch.from_numpy(np.random.RandomState(42).randn(10, 3)),
),
functools.partial(
geoopt.manifolds.SphereSubspaceComplementIntersection,
torch.from_numpy(np.random.RandomState(42).randn(10, 3)),
),
],
)
def manifold(request):
Expand All @@ -26,6 +34,14 @@ def manifold(request):
geoopt.manifolds.CanonicalStiefel: pymanopt.manifolds.Stiefel,
geoopt.manifolds.Euclidean: pymanopt.manifolds.Euclidean,
geoopt.manifolds.Sphere: pymanopt.manifolds.Sphere,
geoopt.manifolds.SphereSubspaceIntersection: functools.partial(
pymanopt.manifolds.SphereSubspaceIntersection,
U=np.random.RandomState(42).randn(10, 3),
),
geoopt.manifolds.SphereSubspaceComplementIntersection: functools.partial(
pymanopt.manifolds.SphereSubspaceComplementIntersection,
U=np.random.RandomState(42).randn(10, 3),
),
}

# shapes to verify unary element implementation
Expand All @@ -34,6 +50,8 @@ def manifold(request):
geoopt.manifolds.CanonicalStiefel: (10, 5),
geoopt.manifolds.Euclidean: (1,),
geoopt.manifolds.Sphere: (10,),
geoopt.manifolds.SphereSubspaceIntersection: (10,),
geoopt.manifolds.SphereSubspaceComplementIntersection: (10,),
}

UnaryCase = collections.namedtuple(
Expand All @@ -46,7 +64,7 @@ def unary_case(manifold):
shape = shapes[type(manifold)]
manopt_manifold = mannopt[type(manifold)](*shape)
np.random.seed(42)
rand = manopt_manifold.rand()
rand = manopt_manifold.rand().astype("float64")
x = geoopt.ManifoldTensor(torch.from_numpy(rand), manifold=manifold)
torch.manual_seed(43)
ex = geoopt.ManifoldTensor(torch.randn_like(x), manifold=manifold)
Expand Down Expand Up @@ -117,7 +135,7 @@ def test_transport(unary_case):

def test_broadcast_projx(unary_case):
torch.manual_seed(43)
X = torch.randn(4, *unary_case.shape)
X = torch.randn(4, *unary_case.shape, dtype=unary_case.x.dtype)
pX = unary_case.manifold.projx(X)
unary_case.manifold.assert_check_point_on_manifold(pX)
for px in pX:
Expand All @@ -129,8 +147,8 @@ def test_broadcast_projx(unary_case):

def test_broadcast_proju(unary_case):
torch.manual_seed(43)
X = torch.randn(4, *unary_case.shape)
U = torch.randn(4, *unary_case.shape)
X = torch.randn(4, *unary_case.shape, dtype=unary_case.x.dtype)
U = torch.randn(4, *unary_case.shape, dtype=unary_case.x.dtype)
pX = unary_case.manifold.projx(X)
pU = unary_case.manifold.proju(pX, U)
unary_case.manifold.assert_check_vector_on_tangent(pX, pU)
Expand All @@ -143,8 +161,8 @@ def test_broadcast_proju(unary_case):

def test_broadcast_retr(unary_case):
torch.manual_seed(43)
X = torch.randn(4, *unary_case.shape)
U = torch.randn(4, *unary_case.shape)
X = torch.randn(4, *unary_case.shape, dtype=unary_case.x.dtype)
U = torch.randn(4, *unary_case.shape, dtype=unary_case.x.dtype)
pX = unary_case.manifold.projx(X)
pU = unary_case.manifold.proju(pX, U)
Y = unary_case.manifold.retr(pX, pU, 1.0)
Expand All @@ -158,9 +176,9 @@ def test_broadcast_retr(unary_case):

def test_broadcast_transp(unary_case):
torch.manual_seed(43)
X = torch.randn(4, *unary_case.shape)
U = torch.randn(4, *unary_case.shape)
V = torch.randn(4, *unary_case.shape)
X = torch.randn(4, *unary_case.shape, dtype=unary_case.x.dtype)
U = torch.randn(4, *unary_case.shape, dtype=unary_case.x.dtype)
V = torch.randn(4, *unary_case.shape, dtype=unary_case.x.dtype)
pX = unary_case.manifold.projx(X)
pU = unary_case.manifold.proju(pX, U)
pV = unary_case.manifold.proju(pX, V)
Expand All @@ -176,10 +194,10 @@ def test_broadcast_transp(unary_case):

def test_broadcast_transp_many(unary_case):
torch.manual_seed(43)
X = torch.randn(4, *unary_case.shape)
U = torch.randn(4, *unary_case.shape)
V = torch.randn(4, *unary_case.shape)
F = torch.randn(4, *unary_case.shape)
X = torch.randn(4, *unary_case.shape, dtype=unary_case.x.dtype)
U = torch.randn(4, *unary_case.shape, dtype=unary_case.x.dtype)
V = torch.randn(4, *unary_case.shape, dtype=unary_case.x.dtype)
F = torch.randn(4, *unary_case.shape, dtype=unary_case.x.dtype)
pX = unary_case.manifold.projx(X)
pU = unary_case.manifold.proju(pX, U)
pV = unary_case.manifold.proju(pX, V)
Expand All @@ -199,10 +217,10 @@ def test_broadcast_transp_many(unary_case):

def test_broadcast_retr_transp_many(unary_case):
torch.manual_seed(43)
X = torch.randn(4, *unary_case.shape)
U = torch.randn(4, *unary_case.shape)
V = torch.randn(4, *unary_case.shape)
F = torch.randn(4, *unary_case.shape)
X = torch.randn(4, *unary_case.shape, dtype=unary_case.x.dtype)
U = torch.randn(4, *unary_case.shape, dtype=unary_case.x.dtype)
V = torch.randn(4, *unary_case.shape, dtype=unary_case.x.dtype)
F = torch.randn(4, *unary_case.shape, dtype=unary_case.x.dtype)
pX = unary_case.manifold.projx(X)
pU = unary_case.manifold.proju(pX, U)
pV = unary_case.manifold.proju(pX, V)
Expand All @@ -224,8 +242,8 @@ def test_broadcast_retr_transp_many(unary_case):

def test_reversibility(unary_case):
torch.manual_seed(43)
X = torch.randn(*unary_case.shape)
U = torch.randn(*unary_case.shape)
X = torch.randn(*unary_case.shape, dtype=unary_case.x.dtype)
U = torch.randn(*unary_case.shape, dtype=unary_case.x.dtype)
X = unary_case.manifold.projx(X)
U = unary_case.manifold.proju(X, U)
Z, Q = unary_case.manifold.retr_transp(X, U, 1.0, U)
Expand Down

0 comments on commit ad7ca8d

Please sign in to comment.