Skip to content

Commit

Permalink
Revert "Drop radam optimizer. (#377)" (#388)
Browse files Browse the repository at this point in the history
This reverts commit 4bb5e4b.
  • Loading branch information
jettify authored Oct 31, 2021
1 parent d2ff5c2 commit ef33fad
Show file tree
Hide file tree
Showing 13 changed files with 238 additions and 0 deletions.
30 changes: 30 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ Supported Optimizers
| `QHM`_ | https://arxiv.org/abs/1810.06801 |
+---------------+--------------------------------------------------------------------------------------------------------------------------------------+
| | |
| `RAdam`_ | https://arxiv.org/abs/1908.03265 |
+---------------+--------------------------------------------------------------------------------------------------------------------------------------+
| | |
| `Ranger`_ | https://medium.com/@lessw/new-deep-learning-optimizer-ranger-synergistic-combination-of-radam-lookahead-for-the-best-of-2dc83f79a48d |
+---------------+--------------------------------------------------------------------------------------------------------------------------------------+
| | |
Expand Down Expand Up @@ -765,6 +768,33 @@ QHM
**Reference Code**: https://github.com/facebookresearch/qhoptim


RAdam
-----

+---------------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------------------+
| .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rastrigin_RAdam.png | .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rosenbrock_RAdam.png |
+---------------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------------------+

.. code:: python
import torch_optimizer as optim
# model = ...
optimizer = optim.RAdam(
m.parameters(),
lr= 1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
)
optimizer.step()
**Paper**: *On the Variance of the Adaptive Learning Rate and Beyond* (2019) [https://arxiv.org/abs/1908.03265]

**Reference Code**: https://github.com/LiyuanLucasLiu/RAdam


Ranger
------

Expand Down
8 changes: 8 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,14 @@ QHM
.. autoclass:: torch_optimizer.QHM
:members:

.. _RAdam:

RAdam
-----

.. autoclass:: torch_optimizer.RAdam
:members:

.. _SGDP:

SGDP
Expand Down
3 changes: 3 additions & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ Supported Optimizers
| :ref:`QHM` | https://arxiv.org/abs/1810.06801 |
+-----------------+-------------------------------------------------------------------------------+
| | |
| :ref:`RAdam` | https://arxiv.org/abs/1908.03265 |
+-----------------+-------------------------------------------------------------------------------+
| | |
| :ref:`Ranger` | https://arxiv.org/abs/1908.00700v2 |
+-----------------+-------------------------------------------------------------------------------+
| | |
Expand Down
Binary file added docs/rastrigin_RAdam.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/rosenbrock_RAdam.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions examples/viz_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def LookaheadYogi(*a, **kw):
(optim.Lamb, -8, -2.9),
(optim.MADGRAD, -8, 0.5),
(optim.NovoGrad, -8, -1.7),
(optim.RAdam, -8, 0.5),
(optim.Yogi, -8, 0.1),
# SGD/Momentum based
(optim.AccSGD, -8, -1.4),
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def _read_version():
'pid',
'qhadam',
'qhm',
'radam',
'sgdw',
'yogi',
'ranger',
Expand Down
1 change: 1 addition & 0 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def build_lookahead(*a, **kw):
{'lr': 2.9, 'betas': (0.9, 0.999), 'grad_averaging': True},
900,
),
(optim.RAdam, {'lr': 0.01, 'betas': (0.9, 0.95), 'eps': 1e-3}, 800),
(optim.SGDW, {'lr': 0.002, 'momentum': 0.91}, 900),
(optim.DiffGrad, {'lr': 0.5}, 500),
(optim.AdaMod, {'lr': 1.0}, 800),
Expand Down
1 change: 1 addition & 0 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def build_lookahead(*a, **kw):
optim.PID,
optim.QHAdam,
optim.QHM,
optim.RAdam,
optim.Ranger,
optim.RangerQH,
optim.RangerVA,
Expand Down
1 change: 1 addition & 0 deletions tests/test_optimizer_with_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def build_lookahead(*a, **kw):
(optim.PID, {'lr': 0.01, 'weight_decay': 1e-3, 'momentum': 0.1}, 200),
(optim.QHAdam, {'lr': 0.1, 'weight_decay': 1e-3}, 200),
(optim.QHM, {'lr': 0.1, 'weight_decay': 1e-5, 'momentum': 0.2}, 200),
(optim.RAdam, {'lr': 1.0, 'weight_decay': 1e-3}, 200),
(optim.Ranger, {'lr': 0.1, 'weight_decay': 1e-3}, 200),
(optim.RangerQH, {'lr': 0.0124, 'weight_decay': 1e-3}, 1100),
(optim.RangerVA, {'lr': 0.2214, 'weight_decay': 1e-3}, 500),
Expand Down
5 changes: 5 additions & 0 deletions tests/test_param_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def assert_sparse_not_supported(optimizer_class, err_msg=None):
optim.DiffGrad,
optim.Lamb,
optim.NovoGrad,
optim.RAdam,
optim.Yogi,
]

Expand All @@ -48,6 +49,7 @@ def test_sparse_not_supported(optimizer_class):
optim.PID,
optim.QHAdam,
optim.QHM,
optim.RAdam,
optim.SGDP,
optim.SGDW,
optim.SWATS,
Expand Down Expand Up @@ -77,6 +79,7 @@ def test_learning_rate(optimizer_class):
optim.MADGRAD,
optim.NovoGrad,
optim.QHAdam,
optim.RAdam,
optim.SGDP,
optim.SWATS,
optim.Yogi,
Expand Down Expand Up @@ -109,6 +112,7 @@ def test_eps_validation(optimizer_class):
optim.PID,
optim.QHAdam,
optim.QHM,
optim.RAdam,
optim.SGDP,
optim.SGDW,
optim.SWATS,
Expand All @@ -135,6 +139,7 @@ def test_weight_decay_validation(optimizer_class):
optim.Lamb,
optim.NovoGrad,
optim.QHAdam,
optim.RAdam,
optim.Yogi,
]

Expand Down
3 changes: 3 additions & 0 deletions torch_optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from .pid import PID
from .qhadam import QHAdam
from .qhm import QHM
from .radam import RAdam
from .sgdp import SGDP
from .sgdw import SGDW
from .shampoo import Shampoo
Expand Down Expand Up @@ -66,6 +67,7 @@
'PID',
'QHAdam',
'QHM',
'RAdam',
'Ranger',
'RangerQH',
'RangerVA',
Expand Down Expand Up @@ -96,6 +98,7 @@
PID,
QHAdam,
QHM,
RAdam,
Ranger,
RangerQH,
RangerVA,
Expand Down
184 changes: 184 additions & 0 deletions torch_optimizer/radam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
import math

import torch
from torch.optim.optimizer import Optimizer

from .types import Betas2, OptFloat, OptLossClosure, Params

__all__ = ('RAdam',)


class RAdam(Optimizer):
r"""Implements RAdam optimization algorithm.
It has been proposed in `On the Variance of the Adaptive Learning
Rate and Beyond`__.
Arguments:
params: iterable of parameters to optimize or dicts defining
parameter groups
lr: learning rate (default: 1e-3)
betas: coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps: term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay: weight decay (L2 penalty) (default: 0)
Example:
>>> import torch_optimizer as optim
>>> optimizer = optim.RAdam(model.parameters(), lr=0.1)
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()
__ https://arxiv.org/abs/1908.03265
Note:
Reference code: https://github.com/LiyuanLucasLiu/RAdam
"""

def __init__(
self,
params: Params,
lr: float = 1e-3,
betas: Betas2 = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0,
) -> None:
if lr <= 0.0:
raise ValueError('Invalid learning rate: {}'.format(lr))
if eps < 0.0:
raise ValueError('Invalid epsilon value: {}'.format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError(
'Invalid beta parameter at index 0: {}'.format(betas[0])
)
if not 0.0 <= betas[1] < 1.0:
raise ValueError(
'Invalid beta parameter at index 1: {}'.format(betas[1])
)
if weight_decay < 0:
raise ValueError(
'Invalid weight_decay value: {}'.format(weight_decay)
)

if (
isinstance(params, (list, tuple))
and len(params) > 0
and isinstance(params[0], dict)
):
for param in params:
if 'betas' in param and (
param['betas'][0] != betas[0]
or param['betas'][1] != betas[1]
):
param['buffer'] = [[None, None, None] for _ in range(10)]

defaults = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
buffer=[[None, None, None] for _ in range(10)],
)
super(RAdam, self).__init__(params, defaults)

def __setstate__(self, state):
super(RAdam, self).__setstate__(state)

def step(self, closure: OptLossClosure = None) -> OptFloat:
r"""Performs a single optimization step.
Arguments:
closure: A closure that reevaluates the model and returns the loss.
"""

loss = None
if closure is not None:
loss = closure()

for group in self.param_groups:
lr = group['lr']
weight_decay = group['weight_decay']
beta1, beta2 = group['betas']
eps = group['eps']

for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data.float()
if grad.is_sparse:
msg = (
'RAdam does not support sparse gradients, '
'please consider SparseAdam instead'
)
raise RuntimeError(msg)

p_data_fp32 = p.data.float()

state = self.state[p]

if len(state) == 0:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(
p_data_fp32, memory_format=torch.preserve_format
)
state['exp_avg_sq'] = torch.zeros_like(
p_data_fp32, memory_format=torch.preserve_format
)
else:
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(
p_data_fp32
)

exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)

state['step'] += 1
buffered = group['buffer'][int(state['step'] % 10)]
if state['step'] == buffered[0]:
N_sma, step_size = buffered[1], buffered[2]
else:
buffered[0] = state['step']
beta2_t = beta2 ** state['step']
N_sma_max = 2 / (1 - beta2) - 1
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (
1 - beta2_t
)
buffered[1] = N_sma

# more conservative since it's an approximated value
if N_sma >= 5:
step_size = (
lr
* math.sqrt(
(1 - beta2_t)
* (N_sma - 4)
/ (N_sma_max - 4)
* (N_sma - 2)
/ N_sma
* N_sma_max
/ (N_sma_max - 2)
)
/ (1 - beta1 ** state['step'])
)
else:
step_size = lr / (1 - beta1 ** state['step'])
buffered[2] = step_size

if weight_decay != 0:
p_data_fp32.add_(p_data_fp32, alpha=-weight_decay * lr)

# more conservative since it's an approximated value
if N_sma >= 5:
denom = exp_avg_sq.sqrt().add_(eps)
p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size)
else:
p_data_fp32.add_(exp_avg, alpha=-step_size)

p.data.copy_(p_data_fp32)

return loss

0 comments on commit ef33fad

Please sign in to comment.