Skip to content

Commit

Permalink
adds egrad2rgrad
Browse files Browse the repository at this point in the history
  • Loading branch information
ferrine committed Jan 28, 2019
1 parent 46749e5 commit b6074b2
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 12 deletions.
30 changes: 30 additions & 0 deletions geoopt/manifolds/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ class Manifold(metaclass=abc.ABCMeta):
* ``_retr_transp(x, u, t, *vs)`` desired
Combines ``_transp_many(x, u, t, *vs)`` and ``_retr(x, u, t)``
* ``__eq__(other)`` if needed
Checks if manifolds are the same
* ``_egrad2rgrad(u)`` if needed
Transforms euclidean grad to Riemannian gradient. It is the same as projection in most cases
Notes
-----
Expand Down Expand Up @@ -349,6 +352,24 @@ def proju(self, x, u):
"""
return self._proju(x, u)

def egrad2rgrad(self, x, u):
"""
Embed euclidean gradient into Riemannian manifold
Parameters
----------
x : tensor
point on the manifold
u : tensor
gradient to be projected
Returns
-------
tensor
grad vector in the Riemainnian manifold
"""
return self._egrad2rgrad(x, u)

def projx(self, x):
"""
Project point :math:`x` on the manifold
Expand Down Expand Up @@ -545,6 +566,15 @@ def _projx(self, x):
"""
raise NotImplementedError

def _egrad2rgrad(self, x, u):
"""
Developer Guide
Private implementation for gradient transformation, may do things efficiently in some cases.
Should allow broadcasting.
"""
return self._proju(x, u)

def __repr__(self):
return self.name + " manifold"

Expand Down
2 changes: 1 addition & 1 deletion geoopt/optim/radam.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def perform_step(
amsgrad,
):
grad.add_(weight_decay, point)
grad = manifold.proju(point, grad)
grad = manifold.egrad2rgrad(point, grad)
exp_avg.mul_(betas[0]).add_(1 - betas[0], grad)
exp_avg_sq.mul_(betas[1]).add_(1 - betas[1], manifold.inner(point, grad))
if amsgrad:
Expand Down
2 changes: 1 addition & 1 deletion geoopt/optim/rsgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def perform_step(
use_momentum,
):
grad.add_(weight_decay, point)
grad = manifold.proju(point, grad)
grad = manifold.egrad2rgrad(point, grad)
if use_momentum:
momentum_buffer.mul_(momentum).add_(1 - dampening, grad)
if nesterov:
Expand Down
12 changes: 6 additions & 6 deletions geoopt/samplers/rhmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ def _step(self, p, r, epsilon):
else:
manifold = Euclidean()

proju = manifold.proju
egrad2rgrad = manifold.egrad2rgrad
retr_transp = manifold.retr_transp

r.add_(epsilon * proju(p, p.grad))
r.add_(epsilon * egrad2rgrad(p, p.grad))
p_, r_ = retr_transp(p, r, epsilon, r)
p.set_(p_)
r.set_(r_)
Expand Down Expand Up @@ -69,7 +69,7 @@ def step(self, closure):
else:
manifold = Euclidean()

proju = manifold.proju
egrad2rgrad = manifold.egrad2rgrad
state = self.state[p]

if "r" not in state:
Expand All @@ -79,7 +79,7 @@ def step(self, closure):

r = state["r"]
r.normal_()
r.set_(proju(p, r))
r.set_(egrad2rgrad(p, r))

old_H += 0.5 * (r * r).sum().item()

Expand Down Expand Up @@ -118,10 +118,10 @@ def step(self, closure):
else:
manifold = Euclidean()

proju = manifold.proju
egrad2rgrad = manifold.egrad2rgrad

r = self.state[p]["r"]
r.add_(0.5 * epsilon * proju(p, p.grad))
r.add_(0.5 * epsilon * egrad2rgrad(p, p.grad))
p.grad.zero_()

new_H += 0.5 * (r * r).sum().item()
Expand Down
4 changes: 2 additions & 2 deletions geoopt/samplers/rsgld.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ def step(self, closure):
else:
manifold = Euclidean()

proju, retr = manifold.proju, manifold.retr
egrad2rgrad, retr = manifold.egrad2rgrad, manifold.retr
epsilon = group["epsilon"]

n = torch.randn_like(p).mul_(math.sqrt(epsilon))
r = proju(p, 0.5 * epsilon * p.grad + n)
r = egrad2rgrad(p, 0.5 * epsilon * p.grad + n)

p.set_(retr(p, r, 1.0))
p.grad.zero_()
Expand Down
5 changes: 3 additions & 2 deletions geoopt/samplers/sgrhmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def step(self, closure):
else:
manifold = Euclidean()

proju = manifold.proju
egrad2rgrad = manifold.egrad2rgrad
retr_transp = manifold.retr_transp

epsilon, alpha = group["epsilon"], group["alpha"]
Expand All @@ -79,7 +79,7 @@ def step(self, closure):
p.set_(p_)
v.set_(v_)

n = proju(p, torch.randn_like(v))
n = egrad2rgrad(p, torch.randn_like(v))
v.mul_(1 - alpha).add_(epsilon * p.grad).add_(
math.sqrt(2 * alpha * epsilon) * n
)
Expand All @@ -105,4 +105,5 @@ def stabilize(self):
v = self.state[p]["v"]

p.set_(manifold.projx(p))
# proj here is ok
v.set_(manifold.proju(p, v))

0 comments on commit b6074b2

Please sign in to comment.