Skip to content

Commit

Permalink
feat: Added implementation of TAdam (#52)
Browse files Browse the repository at this point in the history
* feat: Added TAdam optimizer implementation

* test: Added unittest for TAdam

* docs: Added TAdam to documentation

* docs: Updated README

* style: Fixed lint
  • Loading branch information
frgfm committed Jul 1, 2020
1 parent 1171313 commit a901973
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 2 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ conda install -c frgfm pylocron

##### Main features

- Optimizer: [LARS](https://arxiv.org/abs/1708.03888), [Lamb](https://arxiv.org/abs/1904.00962), [RAdam](https://arxiv.org/abs/1908.03265) and customized versions (RaLars)
- Optimizer: [LARS](https://arxiv.org/abs/1708.03888), [Lamb](https://arxiv.org/abs/1904.00962), [RAdam](https://arxiv.org/abs/1908.03265), [TAdam](https://arxiv.org/pdf/2003.00179.pdf) and customized versions (RaLars)
- Optimizer wrapper: [Lookahead](https://arxiv.org/abs/1907.08610), Scout (experimental)
- Scheduler: [OneCycleScheduler](https://arxiv.org/abs/1803.09820)

Expand Down
2 changes: 2 additions & 0 deletions docs/source/optim.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ Implementations of recent parameter optimizer for Pytorch modules.

.. autoclass:: RaLars

.. autoclass:: TAdam


Optimizer wrappers
------------------
Expand Down
1 change: 1 addition & 0 deletions holocron/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .lamb import Lamb
from .radam import RAdam
from .ralars import RaLars
from .tadam import TAdam
from . import wrapper
from . import lr_scheduler

Expand Down
98 changes: 98 additions & 0 deletions holocron/optim/tadam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# -*- coding: utf-8 -*-

'''
Extended version of Adam optimizer with Student-t mean estimation
'''

import torch
from torch.optim.optimizer import Optimizer


class TAdam(Optimizer):
"""Implements the TAdam optimizer from `"TAdam: A Robust Stochastic Gradient Optimizer"
<https://arxiv.org/pdf/2003.00179.pdf>`_.
Args:
params (iterable): iterable of parameters to optimize or dicts defining parameter groups
lr (float, optional): learning rate
betas (Tuple[float, float], optional): coefficients used for running averages (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
dof (int, optional): degrees of freedom
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, dof=None):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, dof=dof)
super().__init__(params, defaults)

@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()

for group in self.param_groups:

# Get group-shared variables
beta1, beta2 = group['betas']

for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('RAdam does not support sparse gradients')

state = self.state[p]

# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
#
state['W_t'] = beta1 / (1 - beta1)
state['d'] = p.data.numel()
state['dof'] = state['d'] if group['dof'] is None else group['dof']

exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

state['step'] += 1

wt = grad.sub(exp_avg).pow_(2).div_(exp_avg_sq.add(group['eps'])).sum()
wt.add_(state['dof']).pow_(-1).mul_(state['dof'] + state['d'])

# Decay the first and second moment running average coefficient
exp_avg.mul_(state['W_t'] / (state['W_t'] + wt)).add_(grad, alpha=wt / (state['W_t'] + wt))
state['W_t'] *= (2 * beta1 - 1) / beta1
state['W_t'] += wt
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

# Bias corrections
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']

# Weight decay
if group['weight_decay'] != 0:
p.data.add_(p.data, alpha=-group['lr'] * group['weight_decay'])

# Adaptive momentum
p.data.addcdiv_(exp_avg / bias_correction1,
(exp_avg_sq / bias_correction2).sqrt().add_(group['eps']), value=-group['lr'])

return loss
2 changes: 1 addition & 1 deletion test/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def _test_wrapper(self, name):
self.assertFalse(torch.equal(_p.data, p_val - lr * _p.grad))


for opt_name in ['Lars', 'Lamb', 'RAdam', 'RaLars']:
for opt_name in ['Lars', 'Lamb', 'RAdam', 'RaLars', 'TAdam']:
def opt_test(self, opt_name=opt_name):
self._test_optimizer(opt_name)

Expand Down

0 comments on commit a901973

Please sign in to comment.