Skip to content

Commit

Permalink
Merge pull request #102 from kozistr/feature/stable-weight-decay
Browse files Browse the repository at this point in the history
[Feature] Stable Weight Decay
  • Loading branch information
kozistr committed Feb 4, 2023
2 parents 75a023a + 115906a commit de06f63
Show file tree
Hide file tree
Showing 7 changed files with 182 additions and 8 deletions.
8 changes: 8 additions & 0 deletions docs/optimizer_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -224,3 +224,11 @@ DAdaptSGD

.. autoclass:: pytorch_optimizer.DAdaptSGD
:members:

.. _AdamS:

AdamS
-----

.. autoclass:: pytorch_optimizer.AdamS
:members:
2 changes: 2 additions & 0 deletions pytorch_optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from pytorch_optimizer.optimizer.adabound import AdaBound
from pytorch_optimizer.optimizer.adai import Adai
from pytorch_optimizer.optimizer.adamp import AdamP
from pytorch_optimizer.optimizer.adams import AdamS
from pytorch_optimizer.optimizer.adan import Adan
from pytorch_optimizer.optimizer.adapnm import AdaPNM
from pytorch_optimizer.optimizer.agc import agc
Expand Down Expand Up @@ -88,6 +89,7 @@
DAdaptAdaGrad,
DAdaptAdam,
DAdaptSGD,
AdamS,
]
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}

Expand Down
19 changes: 14 additions & 5 deletions pytorch_optimizer/optimizer/adai.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ class Adai(Optimizer, BaseOptimizer):
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
:param weight_decay: float. weight decay (L2 penalty).
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
:param dampening: float. dampening for momentum. where dampening < 1,
it will show some adaptive-moment behavior.
:param use_stable_weight_decay: bool. perform stable weight decay.
:param dampening: float. dampening for momentum. where dampening < 1, it will show some adaptive-moment behavior.
:param use_gc: bool. use gradient centralization.
:param eps: float. term added to the denominator to improve numerical stability.
"""
Expand All @@ -30,6 +30,7 @@ def __init__(
betas: BETAS = (0.1, 0.99),
weight_decay: float = 0.0,
weight_decouple: bool = False,
use_stable_weight_decay: bool = False,
dampening: float = 1.0,
use_gc: bool = False,
eps: float = 1e-3,
Expand All @@ -38,6 +39,7 @@ def __init__(
self.betas = betas
self.weight_decay = weight_decay
self.weight_decouple = weight_decouple
self.use_stable_weight_decay = use_stable_weight_decay
self.dampening = dampening
self.use_gc = use_gc
self.eps = eps
Expand Down Expand Up @@ -111,7 +113,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:

bias_correction2 = 1.0 - beta2 ** state['step']

if group['weight_decay'] > 0.0:
if not self.use_stable_weight_decay and group['weight_decay'] > 0.0:
if self.weight_decouple:
p.mul_(1.0 - group['lr'] * group['weight_decay'])
else:
Expand All @@ -137,8 +139,13 @@ def step(self, closure: CLOSURE = None) -> LOSS:
grad = p.grad
state = self.state[p]

if self.use_stable_weight_decay and group['weight_decay'] > 0.0:
if self.weight_decouple:
p.mul_(1.0 - group['lr'] * group['weight_decay'])
else:
grad.add_(p, alpha=group['weight_decay'])

exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1_prod = state['beta1_prod']

bias_correction2 = 1.0 - beta2 ** state['step']

Expand All @@ -148,11 +155,13 @@ def step(self, closure: CLOSURE = None) -> LOSS:
).clamp(0.0, 1.0 - group['eps'])
beta3 = (1.0 - beta1).pow(group['dampening'])

beta1_prod = state['beta1_prod']
beta1_prod.mul_(beta1)

bias_correction1 = 1.0 - beta1_prod

exp_avg.mul_(beta1).addcmul_(beta3, grad)
exp_avg_hat = exp_avg / bias_correction1 * beta0_dp
exp_avg_hat = exp_avg.div(bias_correction1).mul(beta0_dp)

p.add_(exp_avg_hat, alpha=-group['lr'])

Expand Down
147 changes: 147 additions & 0 deletions pytorch_optimizer/optimizer/adams.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import math

import torch
from torch.optim.optimizer import Optimizer

from pytorch_optimizer.base.exception import NoSparseGradientError, ZeroParameterSizeError
from pytorch_optimizer.base.optimizer import BaseOptimizer
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS


class AdamS(Optimizer, BaseOptimizer):
r"""Adam with stable weight decay.
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
:param lr: float. learning rate.
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
:param weight_decay: float. weight decay (L2 penalty).
:param amsgrad: bool. whether to use the AMSGrad variant of this algorithm from the paper.
:param adamd_debias_term: bool. Only correct the denominator to avoid inflating step sizes early in training.
:param eps: float. term added to the denominator to improve numerical stability.
"""

def __init__(
self,
params: PARAMETERS,
lr: float = 1e-3,
betas: BETAS = (0.9, 0.999),
weight_decay: float = 1e-4,
amsgrad: bool = False,
adamd_debias_term: bool = False,
eps: float = 1e-8,
):
self.lr = lr
self.betas = betas
self.weight_decay = weight_decay
self.amsgrad = amsgrad
self.adamd_debias_term = adamd_debias_term
self.eps = eps

self.validate_parameters()

defaults: DEFAULTS = {
'lr': lr,
'betas': betas,
'weight_decay': weight_decay,
'eps': eps,
}
super().__init__(params, defaults)

def validate_parameters(self):
self.validate_learning_rate(self.lr)
self.validate_betas(self.betas)
self.validate_weight_decay(self.weight_decay)
self.validate_epsilon(self.eps)

@property
def __str__(self) -> str:
return 'AdamS'

@torch.no_grad()
def reset(self):
for group in self.param_groups:
for p in group['params']:
state = self.state[p]

state['step'] = 0
state['exp_avg'] = torch.zeros_like(p)
state['exp_avg_sq'] = torch.zeros_like(p)

@torch.no_grad()
def step(self, closure: CLOSURE = None) -> LOSS:
loss: LOSS = None
if closure is not None:
with torch.enable_grad():
loss = closure()

param_size: int = 0
exp_avg_sq_hat_sum: float = 0.0

for group in self.param_groups:
beta1, beta2 = group['betas']
for p in group['params']:
if p.grad is None:
continue

grad = p.grad
if grad.is_sparse:
raise NoSparseGradientError(self.__str__)

param_size += p.numel()

state = self.state[p]

if len(state) == 0:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p)
state['exp_avg_sq'] = torch.zeros_like(p)
if self.amsgrad:
state['max_exp_avg_sq'] = torch.zeros_like(p)

state['step'] += 1
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

bias_correction2 = 1.0 - beta2 ** state['step']

exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

if self.amsgrad:
max_exp_avg_sq = state['max_exp_avg_sq']
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
exp_avg_sq_hat = max_exp_avg_sq
else:
exp_avg_sq_hat = exp_avg_sq

exp_avg_sq_hat_sum += exp_avg_sq_hat.sum() / bias_correction2

if param_size == 0:
raise ZeroParameterSizeError()

exp_avg_sq_hat_mean = math.sqrt(exp_avg_sq_hat_sum / param_size) + self.eps

for group in self.param_groups:
beta1, beta2 = group['betas']
for p in group['params']:
if p.grad is None:
continue

state = self.state[p]

if group['weight_decay'] > 0.0:
p.mul_(1.0 - group['lr'] * group['weight_decay'] / exp_avg_sq_hat_mean)

exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

bias_correction1 = 1.0 - beta1 ** state['step']
bias_correction2 = 1.0 - beta2 ** state['step']

exp_avg_sq_hat = state['max_exp_avg_sq'] if self.amsgrad else exp_avg_sq
exp_avg_sq_hat.div_(bias_correction2)

de_nom = exp_avg_sq_hat.sqrt().add(group['eps'])

step_size = group['lr'] if self.adamd_debias_term else group['lr'] / bias_correction1
p.addcdiv_(exp_avg, de_nom, value=-step_size)

return loss
10 changes: 9 additions & 1 deletion tests/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
AdaBound,
Adai,
AdamP,
AdamS,
Adan,
AdaPNM,
DAdaptAdaGrad,
Expand Down Expand Up @@ -55,6 +56,7 @@
'pnm',
'dadaptadam',
'dadaptsgd',
'adams',
]
VALID_OPTIMIZER_NAMES: List[str] = [
'adamp',
Expand All @@ -79,6 +81,7 @@
'dadaptadagrad',
'dadaptadam',
'dadaptsgd',
'adams',
]
INVALID_OPTIMIZER_NAMES: List[str] = [
'asam',
Expand All @@ -105,6 +108,7 @@
'adai',
'shampoo',
'dadaptadam',
'adams',
]

VALID_LR_SCHEDULER_NAMES: List[str] = [
Expand Down Expand Up @@ -135,6 +139,8 @@
(Adai, {'lr': 1e-1, 'weight_decay': 0.0, 'dampening': 0.9}, 150),
(Adai, {'lr': 1e-1, 'weight_decay': 1e-4, 'weight_decouple': False}, 150),
(Adai, {'lr': 1e-1, 'weight_decay': 1e-4, 'weight_decouple': True}, 150),
(Adai, {'lr': 1e-1, 'weight_decay': 1e-4, 'weight_decouple': False, 'use_stable_weight_decay': True}, 150),
(Adai, {'lr': 1e-1, 'weight_decay': 1e-4, 'weight_decouple': True, 'use_stable_weight_decay': True}, 150),
(AdamP, {'lr': 5e-1, 'weight_decay': 1e-3}, 10),
(AdamP, {'lr': 5e-1, 'weight_decay': 1e-3, 'use_gc': True}, 10),
(AdamP, {'lr': 5e-1, 'weight_decay': 1e-3, 'nesterov': True}, 10),
Expand Down Expand Up @@ -188,12 +194,13 @@
(DAdaptAdam, {'lr': 1.0, 'weight_decay': 1e-2, 'weight_decouple': True}, 50),
(DAdaptSGD, {'lr': 1.0, 'weight_decay': 1e-2}, 50),
(DAdaptSGD, {'lr': 1.0, 'momentum': 0.9, 'weight_decay': 1e-3}, 50),
(AdamS, {'lr': 1.0, 'weight_decay': 1e-3}, 50),
(AdamS, {'lr': 1.0, 'weight_decay': 1e-3, 'amsgrad': True}, 50),
]
ADAMD_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [
(build_lookahead, {'lr': 5e-1, 'weight_decay': 1e-3, 'adamd_debias_term': True}, 10),
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'adamd_debias_term': True}, 50),
(AdaBound, {'lr': 5e-1, 'gamma': 0.1, 'weight_decay': 1e-3, 'adamd_debias_term': True}, 100),
(AdaBound, {'lr': 1e-2, 'gamma': 0.1, 'weight_decay': 1e-3, 'amsbound': True, 'adamd_debias_term': True}, 100),
(AdamP, {'lr': 5e-1, 'weight_decay': 1e-3, 'adamd_debias_term': True}, 10),
(DiffGrad, {'lr': 5e-1, 'weight_decay': 1e-3, 'adamd_debias_term': True}, 10),
(DiffRGrad, {'lr': 1e-1, 'weight_decay': 1e-3, 'adamd_debias_term': True}, 100),
Expand All @@ -202,4 +209,5 @@
(Ranger, {'lr': 5e-1, 'weight_decay': 1e-3, 'adamd_debias_term': True}, 100),
(Ranger21, {'lr': 5e-1, 'weight_decay': 1e-3, 'adamd_debias_term': True, 'num_iterations': 200}, 200),
(AdaPNM, {'lr': 3e-1, 'weight_decay': 1e-3, 'adamd_debias_term': True}, 50),
(AdamS, {'lr': 2e1, 'weight_decay': 1e-3, 'adamd_debias_term': True}, 50),
]
2 changes: 1 addition & 1 deletion tests/test_load_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ def test_load_optimizers_invalid(invalid_optimizer_names):


def test_get_supported_optimizers():
assert len(get_supported_optimizers()) == 22
assert len(get_supported_optimizers()) == 23
2 changes: 1 addition & 1 deletion tests/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def test_closure(optimizer):
optimizer = optimizer([param], num_iterations=1) if optimizer_name == 'Ranger21' else optimizer([param])
optimizer.zero_grad()

if optimizer_name in ('Ranger21', 'Adai'):
if optimizer_name in ('Ranger21', 'Adai', 'AdamS'):
with pytest.raises(ZeroParameterSizeError):
optimizer.step(closure=dummy_closure)
else:
Expand Down

0 comments on commit de06f63

Please sign in to comment.