Skip to content

Commit

Permalink
manifold .to() was not working, fix (#51)
Browse files Browse the repository at this point in the history
* sphere .to() was not working, fix

* fix tests, add testing for pickle parameter+manifold
  • Loading branch information
ferrine committed Mar 3, 2019
1 parent 5fb91db commit 02ce516
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 10 deletions.
8 changes: 7 additions & 1 deletion geoopt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@
from . import linalg

from .tensor import ManifoldParameter, ManifoldTensor
from .manifolds import Stiefel, Euclidean, Sphere
from .manifolds import (
Stiefel,
Euclidean,
Sphere,
SphereSubspaceIntersection,
SphereSubspaceComplementIntersection,
)

__version__ = "0.0.1"
5 changes: 1 addition & 4 deletions geoopt/manifolds/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ def expmap(self, x, u, t=1.0):
return self._retr_funcs[-1](self, x=x, u=u, t=t)

def logmap(self, x, y):
"""
r"""
Perform an logarithmic map for a pair of points :math:`x` and :math:`y`.
The result lies in :math:`u \in T_x\mathcal{M}` is such that:
Expand Down Expand Up @@ -951,6 +951,3 @@ def __repr__(self):
return self.name + "({}) manifold".format(extra)
else:
return self.name + " manifold"

def __eq__(self, other):
return type(self) is type(other)
4 changes: 2 additions & 2 deletions geoopt/manifolds/sphere.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def _check_shape(self, x, name):

def _configure_manifold(self, span):
Q, _ = geoopt.linalg.batch_linalg.qr(span)
self._projector = Q @ Q.transpose(-1, -2)
self.register_buffer("_projector", Q @ Q.transpose(-1, -2))

def _project_on_subspace(self, x):
return x @ self._projector.transpose(-1, -2)
Expand Down Expand Up @@ -176,4 +176,4 @@ def _configure_manifold(self, span):
Q, _ = geoopt.linalg.batch_linalg.qr(span)
P = -Q @ Q.transpose(-1, -2)
P[..., torch.arange(P.shape[-2]), torch.arange(P.shape[-2])] += 1
self._projector = P
self.register_buffer("_projector", P)
39 changes: 36 additions & 3 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
import torch
import torch.nn
import numpy as np
import geoopt
import tempfile
Expand Down Expand Up @@ -58,7 +58,7 @@ def test_pickle1():
assert p.storage_offset() == p1.storage_offset()
assert p.requires_grad == p1.requires_grad
np.testing.assert_allclose(p.detach(), p1.detach())
assert p.manifold == p1.manifold
assert isinstance(p.manifold, type(p1.manifold))


def test_pickle2():
Expand All @@ -72,4 +72,37 @@ def test_pickle2():
assert p.storage_offset() == p1.storage_offset()
assert p.requires_grad == p1.requires_grad
np.testing.assert_allclose(p.detach(), p1.detach())
assert p.manifold == p1.manifold
assert isinstance(p.manifold, type(p1.manifold))


def test_pickle3():
t = torch.ones(10)
span = torch.randn(10, 2)
sub_sphere = geoopt.manifolds.SphereSubspaceIntersection(span)
p = geoopt.ManifoldParameter(t, manifold=sub_sphere)
with tempfile.TemporaryDirectory() as path:
torch.save(p, os.path.join(path, "tens.t7"))
p1 = torch.load(os.path.join(path, "tens.t7"))
assert isinstance(p1, geoopt.ManifoldParameter)
assert p.stride() == p1.stride()
assert p.storage_offset() == p1.storage_offset()
assert p.requires_grad == p1.requires_grad
np.testing.assert_allclose(p.detach(), p1.detach())
assert isinstance(p.manifold, type(p1.manifold))
np.testing.assert_allclose(p.manifold._projector, p1.manifold._projector)


def test_manifold_to_smth():
span = torch.randn(10, 2)
sub_sphere = geoopt.manifolds.SphereSubspaceIntersection(span)
sub_sphere.to(torch.float64)
assert sub_sphere._projector.dtype == torch.float64


def test_manifold_is_submodule():
span = torch.randn(10, 2)
sub_sphere = geoopt.manifolds.SphereSubspaceIntersection(span)
sub_sphere.to(torch.float64)
container = torch.nn.ModuleDict({"sphere": sub_sphere})
container.to(torch.float64)
assert sub_sphere._projector.dtype == torch.float64

0 comments on commit 02ce516

Please sign in to comment.