From a13e1b3a340a711f50b8f9e2374151a823f07bea Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Fri, 22 Feb 2019 14:03:51 +0300 Subject: [PATCH 1/2] add pickling --- geoopt/tensor.py | 22 ++++++++++++++++++++-- tests/test_utils.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 2 deletions(-) diff --git a/geoopt/tensor.py b/geoopt/tensor.py index 239d29ca..852a34e2 100644 --- a/geoopt/tensor.py +++ b/geoopt/tensor.py @@ -1,6 +1,7 @@ import torch.nn from .manifolds import Euclidean from .docutils import insert_docs +import copyreg __all__ = ["ManifoldTensor", "ManifoldParameter"] @@ -100,6 +101,19 @@ def __repr__(self): self.manifold ) + torch.Tensor.__repr__(self) + # noinspection PyUnresolvedReferences + def __reduce_ex__(self, proto): + proto = ( + self.__class__, + self.storage(), + self.storage_offset(), + self.size(), + self.stride(), + self.requires_grad, + dict(), + ) + return _rebuild_manifold_parameter, proto + (self.manifold,) + class ManifoldParameter(ManifoldTensor, torch.nn.Parameter): """Same as :class:`torch.nn.Parameter` that has information about its manifold. @@ -131,5 +145,9 @@ def __repr__(self): self.manifold ) + torch.Tensor.__repr__(self) - def __reduce_ex__(self, proto): - return ManifoldParameter, (super(ManifoldParameter, self), self.requires_grad) + +def _rebuild_manifold_parameter(cls, *args): + import torch._utils + + tensor = torch._utils._rebuild_tensor_v2(*args[:-1]) + return cls(tensor, manifold=args[-1], requires_grad=args[-3]) diff --git a/tests/test_utils.py b/tests/test_utils.py index 1b7ae9f6..33c3e373 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,6 +2,8 @@ import torch import numpy as np import geoopt +import tempfile +import os @pytest.fixture @@ -43,3 +45,31 @@ def test_expm(A): expm_torch = geoopt.linalg.expm(A) np.testing.assert_allclose(expm_torch.detach(), expm_scipy, rtol=1e-6) expm_torch.sum().backward() # this should work + + +def test_pickle1(): + t = torch.ones(10) + p = geoopt.ManifoldTensor(t, manifold=geoopt.Sphere()) + with tempfile.TemporaryDirectory() as path: + torch.save(p, os.path.join(path, "tens.t7")) + p1 = torch.load(os.path.join(path, "tens.t7")) + assert isinstance(p1, geoopt.ManifoldTensor) + assert p.stride() == p1.stride() + assert p.storage_offset() == p1.storage_offset() + assert p.requires_grad == p1.requires_grad + np.testing.assert_allclose(p.detach(), p1.detach()) + assert p.manifold == p1.manifold + + +def test_pickle2(): + t = torch.ones(10) + p = geoopt.ManifoldParameter(t, manifold=geoopt.Sphere()) + with tempfile.TemporaryDirectory() as path: + torch.save(p, os.path.join(path, "tens.t7")) + p1 = torch.load(os.path.join(path, "tens.t7")) + assert isinstance(p1, geoopt.ManifoldParameter) + assert p.stride() == p1.stride() + assert p.storage_offset() == p1.storage_offset() + assert p.requires_grad == p1.requires_grad + np.testing.assert_allclose(p.detach(), p1.detach()) + assert p.manifold == p1.manifold From a3243ec6c81ac443c96db93fc984d2447300b086 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Fri, 22 Feb 2019 14:18:16 +0300 Subject: [PATCH 2/2] unised import --- geoopt/tensor.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/geoopt/tensor.py b/geoopt/tensor.py index 852a34e2..335b28df 100644 --- a/geoopt/tensor.py +++ b/geoopt/tensor.py @@ -1,8 +1,6 @@ import torch.nn from .manifolds import Euclidean from .docutils import insert_docs -import copyreg - __all__ = ["ManifoldTensor", "ManifoldParameter"]