Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Maxim Kochurov
committed
Jan 28, 2020
1 parent
523ba0f
commit 7d5c477
Showing
8 changed files
with
405 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,4 @@ | ||
from .rsgd import RiemannianSGD | ||
from .radam import RiemannianAdam | ||
from .sparse_radam import SparseRiemannianAdam | ||
from .sparse_rsgd import SparseRiemannianSGD |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
import torch.optim | ||
|
||
from .mixin import OptimMixin, SparseMixin | ||
from ..tensor import ManifoldParameter, ManifoldTensor | ||
from ..utils import copy_or_set_ | ||
|
||
|
||
__all__ = ["SparseRiemannianAdam"] | ||
|
||
|
||
class SparseRiemannianAdam(OptimMixin, SparseMixin, torch.optim.Optimizer): | ||
r""" | ||
Implements lazy version of Adam algorithm suitable for sparse gradients. | ||
In this variant, only moments that show up in the gradient get updated, and | ||
only those portions of the gradient get applied to the parameters. | ||
Parameters | ||
---------- | ||
params : iterable | ||
iterable of parameters to optimize or dicts defining | ||
parameter groups | ||
lr : float (optional) | ||
learning rate (default: 1e-3) | ||
betas : Tuple[float, float] (optional) | ||
coefficients used for computing | ||
running averages of gradient and its square (default: (0.9, 0.999)) | ||
eps : float (optional) | ||
term added to the denominator to improve | ||
numerical stability (default: 1e-8) | ||
amsgrad : bool (optional) | ||
whether to use the AMSGrad variant of this | ||
algorithm from the paper `On the Convergence of Adam and Beyond`_ | ||
(default: False) | ||
Other Parameters | ||
---------------- | ||
stabilize : int | ||
Stabilize parameters if they are off-manifold due to numerical | ||
reasons every ``stabilize`` steps (default: ``None`` -- no stabilize) | ||
.. _On the Convergence of Adam and Beyond: | ||
https://openreview.net/forum?id=ryQu7f-RZ | ||
""" | ||
|
||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, amsgrad=False): | ||
if not 0.0 <= lr: | ||
raise ValueError("Invalid learning rate: {}".format(lr)) | ||
if not 0.0 <= eps: | ||
raise ValueError("Invalid epsilon value: {}".format(eps)) | ||
if not 0.0 <= betas[0] < 1.0: | ||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) | ||
if not 0.0 <= betas[1] < 1.0: | ||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) | ||
defaults = dict(lr=lr, betas=betas, eps=eps, amsgrad=amsgrad) | ||
super(SparseRiemannianAdam, self).__init__(params, defaults) | ||
|
||
def __setstate__(self, state): | ||
super(SparseRiemannianAdam, self).__setstate__(state) | ||
for group in self.param_groups: | ||
group.setdefault("amsgrad", False) | ||
|
||
def step(self, closure=None): | ||
loss = None | ||
if closure is not None: | ||
loss = closure() | ||
with torch.no_grad(): | ||
for group in self.param_groups: | ||
if "step" not in group: | ||
group["step"] = 0 | ||
betas = group["betas"] | ||
eps = group["eps"] | ||
learning_rate = group["lr"] | ||
amsgrad = group["amsgrad"] | ||
for point in group["params"]: | ||
grad = point.grad | ||
if grad is None: | ||
continue | ||
if isinstance(point, (ManifoldParameter, ManifoldTensor)): | ||
manifold = point.manifold | ||
else: | ||
manifold = self._default_manifold | ||
|
||
if not grad.is_sparse: | ||
raise RuntimeError( | ||
"SparseRiemannianAdam does not support sparse gradients, use RiemannianAdam instead" | ||
) | ||
rows = grad.coalesce().indices()[0].unique() | ||
state = self.state[point] | ||
|
||
# State initialization | ||
if len(state) == 0: | ||
state["step"] = 0 | ||
# Exponential moving average of gradient values | ||
state["exp_avg"] = torch.zeros_like(point) | ||
# Exponential moving average of squared gradient values | ||
state["exp_avg_sq"] = torch.zeros_like(point) | ||
if amsgrad: | ||
# Maintains max of all exp. moving avg. of sq. grad. values | ||
state["max_exp_avg_sq"] = torch.zeros_like(point) | ||
|
||
full_point = point | ||
# only nonzero rows are required to make an update | ||
grad = grad.index_select(0, rows).to_dense() | ||
# this takes not view, but copy, we are required to make updates later | ||
point = point[rows] | ||
exp_avg = state["exp_avg"][rows] | ||
exp_avg_sq = state["exp_avg_sq"][rows] | ||
# actual step | ||
grad = manifold.egrad2rgrad(point, grad) | ||
exp_avg.mul_(betas[0]).add_(1 - betas[0], grad) | ||
exp_avg_sq.mul_(betas[1]).add_( | ||
1 - betas[1], manifold.component_inner(point, grad) | ||
) | ||
if amsgrad: | ||
max_exp_avg_sq = state["max_exp_avg_sq"][rows] | ||
# Maintains the maximum of all 2nd moment running avg. till now | ||
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) | ||
# Use the max. for normalizing running avg. of gradient | ||
denom = max_exp_avg_sq.sqrt().add_(eps) | ||
# do not forget to update the state | ||
state["max_exp_avg_sq"][rows] = max_exp_avg_sq | ||
else: | ||
denom = exp_avg_sq.sqrt().add_(eps) | ||
group["step"] += 1 | ||
bias_correction1 = 1 - betas[0] ** group["step"] | ||
bias_correction2 = 1 - betas[1] ** group["step"] | ||
step_size = ( | ||
learning_rate * bias_correction2 ** 0.5 / bias_correction1 | ||
) | ||
|
||
# copy the state, we need it for retraction | ||
# get the direction for ascend | ||
direction = exp_avg / denom | ||
# transport the exponential averaging to the new point | ||
new_point, exp_avg_new = manifold.retr_transp( | ||
point, -step_size * direction, exp_avg | ||
) | ||
# now we update all full tensors | ||
full_point[rows] = new_point | ||
state["exp_avg"][rows] = exp_avg_new | ||
state["exp_avg_sq"][rows] = exp_avg_sq | ||
|
||
group["step"] += 1 | ||
if self._stabilize is not None and group["step"] % self._stabilize == 0: | ||
self.stabilize_group(group) | ||
return loss | ||
|
||
@torch.no_grad() | ||
def stabilize_group(self, group): | ||
for p in group["params"]: | ||
if not isinstance(p, (ManifoldParameter, ManifoldTensor)): | ||
continue | ||
state = self.state[p] | ||
if not state: # due to None grads | ||
continue | ||
manifold = p.manifold | ||
exp_avg = state["exp_avg"] | ||
copy_or_set_(p, manifold.projx(p)) | ||
exp_avg.set_(manifold.proju(p, exp_avg)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
import torch.optim.optimizer | ||
from ..tensor import ManifoldParameter, ManifoldTensor | ||
from .mixin import OptimMixin, SparseMixin | ||
from ..utils import copy_or_set_ | ||
|
||
__all__ = ["SparseRiemannianSGD"] | ||
|
||
|
||
class SparseRiemannianSGD(OptimMixin, SparseMixin, torch.optim.Optimizer): | ||
r""" | ||
Implements lazy version of SGD algorithm suitable for sparse gradients. | ||
In this variant, only moments that show up in the gradient get updated, and | ||
only those portions of the gradient get applied to the parameters. | ||
Parameters | ||
---------- | ||
params : iterable | ||
iterable of parameters to optimize or dicts defining | ||
parameter groups | ||
lr : float | ||
learning rate | ||
momentum : float (optional) | ||
momentum factor (default: 0) | ||
dampening : float (optional) | ||
dampening for momentum (default: 0) | ||
nesterov : bool (optional) | ||
enables Nesterov momentum (default: False) | ||
Other Parameters | ||
---------------- | ||
stabilize : int | ||
Stabilize parameters if they are off-manifold due to numerical | ||
reasons every ``stabilize`` steps (default: ``None`` -- no stabilize) | ||
""" | ||
|
||
def __init__( | ||
self, params, lr, momentum=0, dampening=0, nesterov=False, stabilize=None, | ||
): | ||
if lr < 0.0: | ||
raise ValueError("Invalid learning rate: {}".format(lr)) | ||
if momentum < 0.0: | ||
raise ValueError("Invalid momentum value: {}".format(momentum)) | ||
|
||
defaults = dict( | ||
lr=lr, momentum=momentum, dampening=dampening, nesterov=nesterov, | ||
) | ||
if nesterov and (momentum <= 0 or dampening != 0): | ||
raise ValueError("Nesterov momentum requires a momentum and zero dampening") | ||
super().__init__(params, defaults, stabilize=stabilize) | ||
|
||
def step(self, closure=None): | ||
loss = None | ||
if closure is not None: | ||
loss = closure() | ||
with torch.no_grad(): | ||
for group in self.param_groups: | ||
if "step" not in group: | ||
group["step"] = 0 | ||
momentum = group["momentum"] | ||
dampening = group["dampening"] | ||
nesterov = group["nesterov"] | ||
learning_rate = group["lr"] | ||
for point in group["params"]: | ||
grad = point.grad | ||
if grad is None: | ||
continue | ||
if not grad.is_sparse: | ||
raise RuntimeError( | ||
"SparseRiemannianAdam does not support sparse gradients, use RiemannianAdam instead" | ||
) | ||
# select rows that contain gradient | ||
rows = grad.coalesce().indices()[0].unique() | ||
state = self.state[point] | ||
|
||
# State initialization | ||
if len(state) == 0: | ||
if momentum > 0: | ||
state["momentum_buffer"] = grad.to_dense().clone() | ||
if isinstance(point, (ManifoldParameter, ManifoldTensor)): | ||
manifold = point.manifold | ||
else: | ||
manifold = self._default_manifold | ||
|
||
full_point = point | ||
# only nonzero rows are required to make an update | ||
grad = grad.index_select(0, rows).to_dense() | ||
point = point[rows] | ||
|
||
grad = manifold.egrad2rgrad(point, grad) | ||
if momentum > 0: | ||
momentum_buffer = state["momentum_buffer"][rows] | ||
momentum_buffer.mul_(momentum).add_(1 - dampening, grad) | ||
if nesterov: | ||
grad = grad.add_(momentum, momentum_buffer) | ||
else: | ||
grad = momentum_buffer | ||
# we have all the things projected | ||
new_point, new_momentum_buffer = manifold.retr_transp( | ||
point, -learning_rate * grad, momentum_buffer | ||
) | ||
# use copy only for user facing point | ||
state["momentum_buffer"][rows] = new_momentum_buffer | ||
full_point[rows] = new_point | ||
else: | ||
new_point = manifold.retr(point, -learning_rate * grad) | ||
full_point[rows] = new_point | ||
|
||
group["step"] += 1 | ||
if self._stabilize is not None and group["step"] % self._stabilize == 0: | ||
self.stabilize_group(group) | ||
return loss | ||
|
||
@torch.no_grad() | ||
def stabilize_group(self, group): | ||
for p in group["params"]: | ||
if not isinstance(p, (ManifoldParameter, ManifoldTensor)): | ||
continue | ||
manifold = p.manifold | ||
momentum = group["momentum"] | ||
copy_or_set_(p, manifold.projx(p)) | ||
if momentum > 0: | ||
param_state = self.state[p] | ||
if not param_state: # due to None grads | ||
continue | ||
if "momentum_buffer" in param_state: | ||
buf = param_state["momentum_buffer"] | ||
buf.set_(manifold.proju(p, buf)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import geoopt | ||
import torch | ||
import numpy as np | ||
import pytest | ||
|
||
|
||
@pytest.mark.parametrize("params", [dict(lr=1e-1), dict(lr=1, amsgrad=True)]) | ||
def test_adam_poincare(params): | ||
torch.manual_seed(44) | ||
manifold = geoopt.PoincareBall() | ||
ideal = manifold.random(10, 2) | ||
start = manifold.random(10, 2) | ||
start = geoopt.ManifoldParameter(start, manifold=manifold) | ||
|
||
def closure(): | ||
idx = torch.randint(10, size=(3,)) | ||
start_select = torch.nn.functional.embedding(idx, start, sparse=True) | ||
ideal_select = torch.nn.functional.embedding(idx, ideal, sparse=True) | ||
optim.zero_grad() | ||
loss = manifold.dist2(start_select, ideal_select).sum() | ||
loss.backward() | ||
assert start.grad.is_sparse | ||
return loss.item() | ||
|
||
optim = geoopt.optim.SparseRiemannianAdam([start], **params) | ||
|
||
for _ in range(2000): | ||
optim.step(closure) | ||
np.testing.assert_allclose(start.data, ideal, atol=1e-5, rtol=1e-5) | ||
|
||
|
||
def test_incorrect_init(): | ||
manifold = geoopt.PoincareBall() | ||
param = manifold.random(2, 10, 2).requires_grad_() | ||
with pytest.raises(ValueError) as e: | ||
geoopt.optim.SparseRiemannianAdam([param]) | ||
assert e.match("should be matrix valued") |
Oops, something went wrong.