Skip to content

Commit

Permalink
more efficient workaround
Browse files Browse the repository at this point in the history
  • Loading branch information
ferrine committed May 16, 2019
1 parent c46d278 commit d191594
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 15 deletions.
5 changes: 3 additions & 2 deletions geoopt/optim/radam.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .tracing import create_traced_update
from ..tensor import ManifoldParameter, ManifoldTensor
from ..manifolds import Euclidean
from ..utils import copy_or_set


class RiemannianAdam(OptimMixin, torch.optim.Adam):
Expand Down Expand Up @@ -190,7 +191,7 @@ def perform_step(
point, exp_avg, u=direction, t=-step_size
)
# use copy only for user facing point
point.copy_(new_point)
copy_or_set(point, new_point)
exp_avg.set_(exp_avg_new)

def stabilize_group(self, group):
Expand All @@ -203,7 +204,7 @@ def stabilize_group(self, group):
continue
manifold = p.manifold
exp_avg = state["exp_avg"]
p.copy_(manifold.projx(p))
copy_or_set(p, manifold.projx(p))
exp_avg.set_(manifold.proju(p, exp_avg))

def _sanitize_group(self, group):
Expand Down
8 changes: 4 additions & 4 deletions geoopt/optim/rsgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from ..tensor import ManifoldParameter, ManifoldTensor
from .mixin import OptimMixin
from .tracing import create_traced_update

from ..utils import copy_or_set

__all__ = ["RiemannianSGD"]

Expand Down Expand Up @@ -169,10 +169,10 @@ def perform_step(
)
momentum_buffer.set_(new_momentum_buffer)
# use copy only for user facing point
point.copy_(new_point)
copy_or_set(point, new_point)
else:
new_point = manifold.retr(point, grad, -lr)
point.copy_(new_point)
copy_or_set(point, new_point)

def stabilize_group(self, group):
with torch.no_grad():
Expand All @@ -181,7 +181,7 @@ def stabilize_group(self, group):
continue
manifold = p.manifold
momentum = group["momentum"]
p.copy_(manifold.projx(p))
copy_or_set(p, manifold.projx(p))
if momentum > 0:
param_state = self.state[p]
if not param_state: # due to None grads
Expand Down
4 changes: 2 additions & 2 deletions geoopt/samplers/rhmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from geoopt.tensor import ManifoldParameter, ManifoldTensor
from geoopt.manifolds import Euclidean
from geoopt.samplers.base import Sampler

from ..utils import copy_or_set

__all__ = ["RHMC"]

Expand Down Expand Up @@ -40,7 +40,7 @@ def _step(self, p, r, epsilon):

r.add_(epsilon * egrad2rgrad(p, p.grad))
p_, r_ = retr_transp(p, r, u=r, t=epsilon)
p.copy_(p_)
copy_or_set(p, p_)
r.set_(r_)

def step(self, closure):
Expand Down
5 changes: 2 additions & 3 deletions geoopt/samplers/rsgld.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from geoopt.tensor import ManifoldParameter, ManifoldTensor
from geoopt.manifolds import Euclidean
from geoopt.samplers.base import Sampler

from ..utils import copy_or_set

__all__ = ["RSGLD"]

Expand Down Expand Up @@ -66,5 +66,4 @@ def stabilize(self):
for p in group["params"]:
if not isinstance(p, (ManifoldParameter, ManifoldTensor)):
continue

p.copy_(p.manifold.projx(p))
copy_or_set(p, p.manifold.projx(p))
7 changes: 3 additions & 4 deletions geoopt/samplers/sgrhmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from geoopt.tensor import ManifoldParameter, ManifoldTensor
from geoopt.manifolds import Euclidean
from geoopt.samplers.base import Sampler

from ..utils import copy_or_set

__all__ = ["SGRHMC"]

Expand Down Expand Up @@ -76,7 +76,7 @@ def step(self, closure):
v = self.state[p]["v"]

p_, v_ = retr_transp(p, v, u=v, t=1.0)
p.copy_(p_)
copy_or_set(p, p_)
v.set_(v_)

n = egrad2rgrad(p, torch.randn_like(v))
Expand All @@ -103,7 +103,6 @@ def stabilize(self):

manifold = p.manifold
v = self.state[p]["v"]

p.copy_(manifold.projx(p))
copy_or_set(p, manifold.projx(p))
# proj here is ok
v.set_(manifold.proju(p, v))
23 changes: 23 additions & 0 deletions geoopt/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
__all__ = "copy_or_set"


def copy_or_set(dest, source):
"""
A workaround to respect strides of :code:`dest` when copying :code:`source`
(https://github.com/geoopt/geoopt/issues/70)
Parameters
----------
dest : torch.Tensor
Destination tensor where to store new data
source : torch.Tensor
Source data to put in the new tensor
Returns
-------
dest
torch.Tensor
"""
if dest.stride() != source.stride():
return dest.copy_(source)
else:
return dest.set_(source)

0 comments on commit d191594

Please sign in to comment.