Skip to content

Commit

Permalink
add pickling (#47)
Browse files Browse the repository at this point in the history
* add pickling

* unised import
  • Loading branch information
ferrine committed Feb 22, 2019
1 parent 6b8f94a commit 08a45b7
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 3 deletions.
22 changes: 19 additions & 3 deletions geoopt/tensor.py
Expand Up @@ -2,7 +2,6 @@
from .manifolds import Euclidean
from .docutils import insert_docs


__all__ = ["ManifoldTensor", "ManifoldParameter"]


Expand Down Expand Up @@ -100,6 +99,19 @@ def __repr__(self):
self.manifold
) + torch.Tensor.__repr__(self)

# noinspection PyUnresolvedReferences
def __reduce_ex__(self, proto):
proto = (
self.__class__,
self.storage(),
self.storage_offset(),
self.size(),
self.stride(),
self.requires_grad,
dict(),
)
return _rebuild_manifold_parameter, proto + (self.manifold,)


class ManifoldParameter(ManifoldTensor, torch.nn.Parameter):
"""Same as :class:`torch.nn.Parameter` that has information about its manifold.
Expand Down Expand Up @@ -131,5 +143,9 @@ def __repr__(self):
self.manifold
) + torch.Tensor.__repr__(self)

def __reduce_ex__(self, proto):
return ManifoldParameter, (super(ManifoldParameter, self), self.requires_grad)

def _rebuild_manifold_parameter(cls, *args):
import torch._utils

tensor = torch._utils._rebuild_tensor_v2(*args[:-1])
return cls(tensor, manifold=args[-1], requires_grad=args[-3])
30 changes: 30 additions & 0 deletions tests/test_utils.py
Expand Up @@ -2,6 +2,8 @@
import torch
import numpy as np
import geoopt
import tempfile
import os


@pytest.fixture
Expand Down Expand Up @@ -43,3 +45,31 @@ def test_expm(A):
expm_torch = geoopt.linalg.expm(A)
np.testing.assert_allclose(expm_torch.detach(), expm_scipy, rtol=1e-6)
expm_torch.sum().backward() # this should work


def test_pickle1():
t = torch.ones(10)
p = geoopt.ManifoldTensor(t, manifold=geoopt.Sphere())
with tempfile.TemporaryDirectory() as path:
torch.save(p, os.path.join(path, "tens.t7"))
p1 = torch.load(os.path.join(path, "tens.t7"))
assert isinstance(p1, geoopt.ManifoldTensor)
assert p.stride() == p1.stride()
assert p.storage_offset() == p1.storage_offset()
assert p.requires_grad == p1.requires_grad
np.testing.assert_allclose(p.detach(), p1.detach())
assert p.manifold == p1.manifold


def test_pickle2():
t = torch.ones(10)
p = geoopt.ManifoldParameter(t, manifold=geoopt.Sphere())
with tempfile.TemporaryDirectory() as path:
torch.save(p, os.path.join(path, "tens.t7"))
p1 = torch.load(os.path.join(path, "tens.t7"))
assert isinstance(p1, geoopt.ManifoldParameter)
assert p.stride() == p1.stride()
assert p.storage_offset() == p1.storage_offset()
assert p.requires_grad == p1.requires_grad
np.testing.assert_allclose(p.detach(), p1.detach())
assert p.manifold == p1.manifold

0 comments on commit 08a45b7

Please sign in to comment.