Skip to content

Commit

Permalink
remove unsafe copy or set (#180)
Browse files Browse the repository at this point in the history
* remove unsafe copy or set

* fix linters
  • Loading branch information
ferrine committed Jul 1, 2021
1 parent b9e1b0b commit db5ce84
Show file tree
Hide file tree
Showing 10 changed files with 25 additions and 63 deletions.
9 changes: 4 additions & 5 deletions geoopt/optim/radam.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from .mixin import OptimMixin
from ..tensor import ManifoldParameter, ManifoldTensor
from ..utils import copy_or_set_


__all__ = ["RiemannianAdam"]
Expand Down Expand Up @@ -115,8 +114,8 @@ def step(self, closure=None):
point, -step_size * direction, exp_avg
)
# use copy only for user facing point
copy_or_set_(point, new_point)
exp_avg.set_(exp_avg_new)
point.copy_(new_point)
exp_avg.copy_(exp_avg_new)

if (
group["stabilize"] is not None
Expand All @@ -135,5 +134,5 @@ def stabilize_group(self, group):
continue
manifold = p.manifold
exp_avg = state["exp_avg"]
copy_or_set_(p, manifold.projx(p))
exp_avg.set_(manifold.proju(p, exp_avg))
p.copy_(manifold.projx(p))
exp_avg.copy_(manifold.proju(p, exp_avg))
3 changes: 1 addition & 2 deletions geoopt/optim/rlinesearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from .mixin import OptimMixin
from ..tensor import ManifoldParameter, ManifoldTensor
from ..manifolds import Euclidean
from ..utils import copy_or_set_


__all__ = ["RiemannianLineSearch"]
Expand Down Expand Up @@ -543,7 +542,7 @@ def stabilize_group(self, group):
if not state: # due to None grads
continue
manifold = p.manifold
copy_or_set_(p, manifold.projx(p))
p.copy_(manifold.projx(p))


#################################################################################
Expand Down
11 changes: 5 additions & 6 deletions geoopt/optim/rsgd.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch.optim.optimizer
from ..tensor import ManifoldParameter, ManifoldTensor
from .mixin import OptimMixin
from ..utils import copy_or_set_

__all__ = ["RiemannianSGD"]

Expand Down Expand Up @@ -107,12 +106,12 @@ def step(self, closure=None):
new_point, new_momentum_buffer = manifold.retr_transp(
point, -learning_rate * grad, momentum_buffer
)
momentum_buffer.set_(new_momentum_buffer)
momentum_buffer.copy_(new_momentum_buffer)
# use copy only for user facing point
copy_or_set_(point, new_point)
point.copy_(new_point)
else:
new_point = manifold.retr(point, -learning_rate * grad)
copy_or_set_(point, new_point)
point.copy_(new_point)

if (
group["stabilize"] is not None
Expand All @@ -128,11 +127,11 @@ def stabilize_group(self, group):
continue
manifold = p.manifold
momentum = group["momentum"]
copy_or_set_(p, manifold.projx(p))
p.copy_(manifold.projx(p))
if momentum > 0:
param_state = self.state[p]
if not param_state: # due to None grads
continue
if "momentum_buffer" in param_state:
buf = param_state["momentum_buffer"]
buf.set_(manifold.proju(p, buf))
buf.copy_(manifold.proju(p, buf))
5 changes: 2 additions & 3 deletions geoopt/optim/sparse_radam.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from .mixin import OptimMixin, SparseMixin
from ..tensor import ManifoldParameter, ManifoldTensor
from ..utils import copy_or_set_


__all__ = ["SparseRiemannianAdam"]
Expand Down Expand Up @@ -159,5 +158,5 @@ def stabilize_group(self, group):
continue
manifold = p.manifold
exp_avg = state["exp_avg"]
copy_or_set_(p, manifold.projx(p))
exp_avg.set_(manifold.proju(p, exp_avg))
p.copy_(manifold.projx(p))
exp_avg.copy_(manifold.proju(p, exp_avg))
5 changes: 2 additions & 3 deletions geoopt/optim/sparse_rsgd.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch.optim.optimizer
from ..tensor import ManifoldParameter, ManifoldTensor
from .mixin import OptimMixin, SparseMixin
from ..utils import copy_or_set_

__all__ = ["SparseRiemannianSGD"]

Expand Down Expand Up @@ -130,11 +129,11 @@ def stabilize_group(self, group):
continue
manifold = p.manifold
momentum = group["momentum"]
copy_or_set_(p, manifold.projx(p))
p.copy_(manifold.projx(p))
if momentum > 0:
param_state = self.state[p]
if not param_state: # due to None grads
continue
if "momentum_buffer" in param_state:
buf = param_state["momentum_buffer"]
buf.set_(manifold.proju(p, buf))
buf.copy_(manifold.proju(p, buf))
9 changes: 4 additions & 5 deletions geoopt/samplers/rhmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from geoopt.tensor import ManifoldParameter, ManifoldTensor
from geoopt.samplers.base import Sampler
from ..utils import copy_or_set_

__all__ = ["RHMC"]

Expand Down Expand Up @@ -40,8 +39,8 @@ def _step(self, p, r, epsilon):

r.add_(epsilon * egrad2rgrad(p, p.grad))
p_, r_ = retr_transp(p, r * epsilon, r)
copy_or_set_(p, p_)
r.set_(r_)
p.copy_(p_)
r.copy_(r_)

def step(self, closure):
logp = closure()
Expand Down Expand Up @@ -146,8 +145,8 @@ def stabilize_group(self, group):
for p in group["params"]:
if not isinstance(p, (ManifoldParameter, ManifoldTensor)):
continue
copy_or_set_(p, p.manifold.projx(p))
p.copy_(p.manifold.projx(p))
state = self.state[p]
if not state: # due to None grads
continue
copy_or_set_(state["old_p"], p.manifold.projx(state["old_p"]))
state["old_p"].copy_(p.manifold.projx(state["old_p"]))
5 changes: 2 additions & 3 deletions geoopt/samplers/rsgld.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from geoopt.tensor import ManifoldParameter, ManifoldTensor
from geoopt.samplers.base import Sampler
from ..utils import copy_or_set_

__all__ = ["RSGLD"]

Expand Down Expand Up @@ -43,7 +42,7 @@ def step(self, closure):
n = torch.randn_like(p).mul_(math.sqrt(epsilon))
r = egrad2rgrad(p, 0.5 * epsilon * p.grad + n)
# use copy only for user facing point
copy_or_set_(p, retr(p, r))
p.copy_(retr(p, r))
p.grad.zero_()

if not self.burnin:
Expand All @@ -55,4 +54,4 @@ def stabilize_group(self, group):
for p in group["params"]:
if not isinstance(p, (ManifoldParameter, ManifoldTensor)):
continue
copy_or_set_(p, p.manifold.projx(p))
p.copy_(p.manifold.projx(p))
9 changes: 4 additions & 5 deletions geoopt/samplers/sgrhmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from geoopt.tensor import ManifoldParameter, ManifoldTensor
from geoopt.samplers.base import Sampler
from ..utils import copy_or_set_

__all__ = ["SGRHMC"]

Expand Down Expand Up @@ -68,8 +67,8 @@ def step(self, closure):
v = self.state[p]["v"]

p_, v_ = retr_transp(p, v, v)
copy_or_set_(p, p_)
v.set_(v_)
p.copy_(p_)
v.copy_(v_)

n = egrad2rgrad(p, torch.randn_like(v))
v.mul_(1 - alpha).add_(epsilon * p.grad).add_(
Expand All @@ -91,9 +90,9 @@ def stabilize_group(self, group):
continue

manifold = p.manifold
copy_or_set_(p, manifold.projx(p))
p.copy_(manifold.projx(p))
# proj here is ok
state = self.state[p]
if not state:
continue
state["v"].set_(manifold.proju(p, state["v"]))
state["v"].copy_(manifold.proju(p, state["v"]))
3 changes: 1 addition & 2 deletions geoopt/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from .docutils import insert_docs
import functools
from typing import Union, Tuple
from .utils import copy_or_set_

__all__ = ["ManifoldTensor", "ManifoldParameter"]

Expand Down Expand Up @@ -52,7 +51,7 @@ def proj_(self) -> torch.Tensor:
tensor
same instance
"""
return copy_or_set_(self, self.manifold.projx(self))
return self.copy_(self.manifold.projx(self))

@insert_docs(Manifold.retr.__doc__, r"\s+x : .+\n.+", "")
def retr(self, u: torch.Tensor, **kwargs) -> torch.Tensor:
Expand Down
29 changes: 0 additions & 29 deletions geoopt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import geoopt

__all__ = [
"copy_or_set_",
"strip_tuple",
"size2shape",
"make_tuple",
Expand All @@ -24,34 +23,6 @@
]


def copy_or_set_(dest: torch.Tensor, source: torch.Tensor) -> torch.Tensor:
"""
Copy or inplace set from :code:`source` to :code:`dest`.
A workaround to respect strides of :code:`dest` when copying :code:`source`.
The original issue was raised `here <https://github.com/geoopt/geoopt/issues/70>`_
when working with matrix manifolds. Inplace set operation is mode efficient,
but the resulting storage might be incompatible after. To avoid the issue we refer to
the safe option and use :code:`copy_` if strides do not match.
Parameters
----------
dest : torch.Tensor
Destination tensor where to store new data
source : torch.Tensor
Source data to put in the new tensor
Returns
-------
dest
torch.Tensor, modified inplace
"""
if dest.stride() != source.stride():
return dest.copy_(source)
else:
return dest.set_(source)


def strip_tuple(tup: Tuple) -> Union[Tuple, Any]:
if len(tup) == 1:
return tup[0]
Expand Down

0 comments on commit db5ce84

Please sign in to comment.