Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added implementation of TAdam #52

Merged
merged 6 commits into from
Jul 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ conda install -c frgfm pylocron

##### Main features

- Optimizer: [LARS](https://arxiv.org/abs/1708.03888), [Lamb](https://arxiv.org/abs/1904.00962), [RAdam](https://arxiv.org/abs/1908.03265) and customized versions (RaLars)
- Optimizer: [LARS](https://arxiv.org/abs/1708.03888), [Lamb](https://arxiv.org/abs/1904.00962), [RAdam](https://arxiv.org/abs/1908.03265), [TAdam](https://arxiv.org/pdf/2003.00179.pdf) and customized versions (RaLars)
- Optimizer wrapper: [Lookahead](https://arxiv.org/abs/1907.08610), Scout (experimental)
- Scheduler: [OneCycleScheduler](https://arxiv.org/abs/1803.09820)

Expand Down
2 changes: 2 additions & 0 deletions docs/source/optim.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ Implementations of recent parameter optimizer for Pytorch modules.

.. autoclass:: RaLars

.. autoclass:: TAdam


Optimizer wrappers
------------------
Expand Down
1 change: 1 addition & 0 deletions holocron/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .lamb import Lamb
from .radam import RAdam
from .ralars import RaLars
from .tadam import TAdam
from . import wrapper
from . import lr_scheduler

Expand Down
98 changes: 98 additions & 0 deletions holocron/optim/tadam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# -*- coding: utf-8 -*-

'''
Extended version of Adam optimizer with Student-t mean estimation
'''

import torch
from torch.optim.optimizer import Optimizer


class TAdam(Optimizer):
"""Implements the TAdam optimizer from `"TAdam: A Robust Stochastic Gradient Optimizer"
<https://arxiv.org/pdf/2003.00179.pdf>`_.

Args:
params (iterable): iterable of parameters to optimize or dicts defining parameter groups
lr (float, optional): learning rate
betas (Tuple[float, float], optional): coefficients used for running averages (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
dof (int, optional): degrees of freedom
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, dof=None):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, dof=dof)
super().__init__(params, defaults)

@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.

Arguments:
closure (callable, optional): A closure that reevaluates the model and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()

for group in self.param_groups:

# Get group-shared variables
beta1, beta2 = group['betas']

for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('RAdam does not support sparse gradients')

state = self.state[p]

# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
#
state['W_t'] = beta1 / (1 - beta1)
state['d'] = p.data.numel()
state['dof'] = state['d'] if group['dof'] is None else group['dof']

exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

state['step'] += 1

wt = grad.sub(exp_avg).pow_(2).div_(exp_avg_sq.add(group['eps'])).sum()
wt.add_(state['dof']).pow_(-1).mul_(state['dof'] + state['d'])

# Decay the first and second moment running average coefficient
exp_avg.mul_(state['W_t'] / (state['W_t'] + wt)).add_(grad, alpha=wt / (state['W_t'] + wt))
state['W_t'] *= (2 * beta1 - 1) / beta1
state['W_t'] += wt
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

# Bias corrections
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']

# Weight decay
if group['weight_decay'] != 0:
p.data.add_(p.data, alpha=-group['lr'] * group['weight_decay'])

# Adaptive momentum
p.data.addcdiv_(exp_avg / bias_correction1,
(exp_avg_sq / bias_correction2).sqrt().add_(group['eps']), value=-group['lr'])

return loss
2 changes: 1 addition & 1 deletion test/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def _test_wrapper(self, name):
self.assertFalse(torch.equal(_p.data, p_val - lr * _p.grad))


for opt_name in ['Lars', 'Lamb', 'RAdam', 'RaLars']:
for opt_name in ['Lars', 'Lamb', 'RAdam', 'RaLars', 'TAdam']:
def opt_test(self, opt_name=opt_name):
self._test_optimizer(opt_name)

Expand Down