Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve inconsistency with tensor strides and optimizer updates #71

Merged
merged 7 commits into from
May 21, 2019
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@ Deprecations
Bug Fixes
---------
* Make pickle work with ManifoldTensors (#47)
* Resolve inconsistency with tensor strides and optimizer updates (#71)
6 changes: 4 additions & 2 deletions geoopt/optim/radam.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .tracing import create_traced_update
from ..tensor import ManifoldParameter, ManifoldTensor
from ..manifolds import Euclidean
from ..utils import copy_or_set


class RiemannianAdam(OptimMixin, torch.optim.Adam):
Expand Down Expand Up @@ -189,7 +190,8 @@ def perform_step(
new_point, exp_avg_new = manifold.retr_transp(
point, exp_avg, u=direction, t=-step_size
)
point.set_(new_point)
# use copy only for user facing point
copy_or_set(point, new_point)
exp_avg.set_(exp_avg_new)

def stabilize_group(self, group):
Expand All @@ -202,7 +204,7 @@ def stabilize_group(self, group):
continue
manifold = p.manifold
exp_avg = state["exp_avg"]
p.set_(manifold.projx(p))
copy_or_set(p, manifold.projx(p))
exp_avg.set_(manifold.proju(p, exp_avg))

def _sanitize_group(self, group):
Expand Down
9 changes: 5 additions & 4 deletions geoopt/optim/rsgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from ..tensor import ManifoldParameter, ManifoldTensor
from .mixin import OptimMixin
from .tracing import create_traced_update

from ..utils import copy_or_set

__all__ = ["RiemannianSGD"]

Expand Down Expand Up @@ -168,10 +168,11 @@ def perform_step(
point, momentum_buffer, u=grad, t=-lr
)
momentum_buffer.set_(new_momentum_buffer)
point.set_(new_point)
# use copy only for user facing point
copy_or_set(point, new_point)
else:
new_point = manifold.retr(point, grad, -lr)
point.set_(new_point)
copy_or_set(point, new_point)

def stabilize_group(self, group):
with torch.no_grad():
Expand All @@ -180,7 +181,7 @@ def stabilize_group(self, group):
continue
manifold = p.manifold
momentum = group["momentum"]
p.set_(manifold.projx(p))
copy_or_set(p, manifold.projx(p))
if momentum > 0:
param_state = self.state[p]
if not param_state: # due to None grads
Expand Down
4 changes: 2 additions & 2 deletions geoopt/samplers/rhmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from geoopt.tensor import ManifoldParameter, ManifoldTensor
from geoopt.manifolds import Euclidean
from geoopt.samplers.base import Sampler

from ..utils import copy_or_set

__all__ = ["RHMC"]

Expand Down Expand Up @@ -40,7 +40,7 @@ def _step(self, p, r, epsilon):

r.add_(epsilon * egrad2rgrad(p, p.grad))
p_, r_ = retr_transp(p, r, u=r, t=epsilon)
p.set_(p_)
copy_or_set(p, p_)
r.set_(r_)

def step(self, closure):
Expand Down
9 changes: 4 additions & 5 deletions geoopt/samplers/rsgld.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from geoopt.tensor import ManifoldParameter, ManifoldTensor
from geoopt.manifolds import Euclidean
from geoopt.samplers.base import Sampler

from ..utils import copy_or_set

__all__ = ["RSGLD"]

Expand Down Expand Up @@ -50,8 +50,8 @@ def step(self, closure):

n = torch.randn_like(p).mul_(math.sqrt(epsilon))
r = egrad2rgrad(p, 0.5 * epsilon * p.grad + n)

p.set_(retr(p, r, 1.0))
# use copy only for user facing point
copy_or_set(p, retr(p, r, 1.0))
p.grad.zero_()

if not self.burnin:
Expand All @@ -66,5 +66,4 @@ def stabilize(self):
for p in group["params"]:
if not isinstance(p, (ManifoldParameter, ManifoldTensor)):
continue

p.set_(p.manifold.projx(p))
copy_or_set(p, p.manifold.projx(p))
7 changes: 3 additions & 4 deletions geoopt/samplers/sgrhmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from geoopt.tensor import ManifoldParameter, ManifoldTensor
from geoopt.manifolds import Euclidean
from geoopt.samplers.base import Sampler

from ..utils import copy_or_set

__all__ = ["SGRHMC"]

Expand Down Expand Up @@ -76,7 +76,7 @@ def step(self, closure):
v = self.state[p]["v"]

p_, v_ = retr_transp(p, v, u=v, t=1.0)
p.set_(p_)
copy_or_set(p, p_)
v.set_(v_)

n = egrad2rgrad(p, torch.randn_like(v))
Expand All @@ -103,7 +103,6 @@ def stabilize(self):

manifold = p.manifold
v = self.state[p]["v"]

p.set_(manifold.projx(p))
copy_or_set(p, manifold.projx(p))
# proj here is ok
v.set_(manifold.proju(p, v))
6 changes: 3 additions & 3 deletions geoopt/tensor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch.nn
from .manifolds import Euclidean
from .docutils import insert_docs
from .utils import copy_or_set

__all__ = ["ManifoldTensor", "ManifoldParameter"]

Expand All @@ -27,6 +28,7 @@ def __new__(cls, *args, manifold=Euclidean(), requires_grad=False, **kwargs):
instance.manifold = manifold
return instance

@torch.no_grad()
def proj_(self):
"""
Inplace projection to the manifold
Expand All @@ -36,9 +38,7 @@ def proj_(self):
tensor
same instance
"""
with torch.no_grad():
self.set_(self.manifold.projx(self.data))
return self
return copy_or_set(self, self.manifold.projx(self))

@insert_docs(Euclidean.retr.__doc__, r"\s+x : .+\n.+", "")
def retr(self, u, t=1.0, order=None):
Expand Down
24 changes: 24 additions & 0 deletions geoopt/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
__all__ = "copy_or_set"


def copy_or_set(dest, source):
"""
A workaround to respect strides of :code:`dest` when copying :code:`source`
(https://github.com/geoopt/geoopt/issues/70)

Parameters
----------
dest : torch.Tensor
Destination tensor where to store new data
source : torch.Tensor
Source data to put in the new tensor

Returns
-------
dest
torch.Tensor
"""
if dest.stride() != source.stride():
return dest.copy_(source)
else:
return dest.set_(source)
3 changes: 1 addition & 2 deletions tests/test_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def closure():
if (X - Xstar).norm() < 1e-5:
break
optim.step(closure)

assert X.is_contiguous()
np.testing.assert_allclose(X.data, Xstar, atol=1e-5, rtol=1e-5)
optim.load_state_dict(optim.state_dict())
optim.step(closure)
Expand All @@ -49,5 +49,4 @@ def closure():

for _ in range(2000):
optim.step(closure)

np.testing.assert_allclose(start.data, ideal, atol=1e-5, rtol=1e-5)
2 changes: 1 addition & 1 deletion tests/test_rhmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def forward(self):

points = np.asarray(points)
points = points[::20]

assert nd.x.is_contiguous()
np.testing.assert_allclose(mu.numpy(), points.mean(axis=0), atol=1e-1)
np.testing.assert_allclose(sigma.numpy(), points.std(axis=0), atol=1e-1)

Expand Down
3 changes: 2 additions & 1 deletion tests/test_rsgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def closure():
if (X - Xstar).norm() < 1e-5:
break
optim.step(closure)

assert X.is_contiguous()
np.testing.assert_allclose(X.data, Xstar, atol=1e-5)
optim.load_state_dict(optim.state_dict())
optim.step(closure)
Expand All @@ -56,5 +56,6 @@ def test_init_manifold():
opt.zero_grad()
opt.step()
assert not np.allclose(p0.data, p0old.data)
assert p0.is_contiguous()
np.testing.assert_allclose(p1.data, p1old.data)
np.testing.assert_allclose(p0.data, stiefel.projx(p0old.data), atol=1e-4)