Skip to content
Permalink
 
 
Cannot retrieve contributors at this time
213 lines (175 sloc) 7.93 KB
"""
---
title: Optimizers
summary: >
A set of PyTorch implementations/tutorials of popular gradient descent based optimizers.
Currently includes Adam, AMSGrad and RAdam optimizers.
---
# Optimizers
## Optimizer Implementations
* [Adam Optimizer](adam.html)
* [AMSGrad Optimizer](amsgrad.html)
* [Adam Optimizer with warmup](adam_warmup.html)
* [Noam Optimizer](noam.html)
* [Rectified Adam Optimizer](radam.html)
* [AdaBelief Optimizer](ada_belief.html)
This [MNIST example](mnist_experiment.html) uses these optimizers.
## Generic Adaptive Optimizer Base class and Weight Decay
This file defines a common base class for *Adam* and extensions of it.
The base class helps use implement other optimizers with minimal code
because of re-usability.
We also define a special class for L2 weight decay, so that we don't
have to implement it inside each of the optimizers,
and can easily extend to other weight decays like L1 without
changing the optimizers.
Here are some concepts on PyTorch optimizers:
### Parameter groups
PyTorch optimizers group parameters into sets called groups.
Each group can have it's own hyper-parameters like learning rates.
In most common cases there will be only one group.
This is when you initialize your optimizer with,
```python
Optimizer(model.parameters())
```
You can define multiple parameter groups when initializing the optimizer:
```python
Optimizer([{'params': model1.parameters()}, {'params': model2.parameters(), 'lr': 2}])
```
Here we pass a list of groups. Each group is a dictionary with it's parameters under the key 'params'.
You specify any hyper-parameters as well. If the hyper parameters are not defined they will default
to the optimizer level defaults.
You can access (and even change) these groups, and their hyper-parameters with `optimizer.param_groups`.
Most learning rate schedule implementations I've come across do access this and change 'lr'.
### States
Optimizer maintains states (a dictionary) for each parameter (a tensor), in a dictionary `optimizer.state`.
This is where the optimizer maintains things like exponential averages.
"""
from typing import Dict, Tuple, Any
import torch
from torch import nn
from torch.optim.optimizer import Optimizer
class GenericAdaptiveOptimizer(Optimizer):
"""
## Base class for *Adam* and extensions
"""
def __init__(self, params, defaults: Dict[str, Any], lr: float, betas: Tuple[float, float], eps: float):
"""
### Initialize
* `params` is the collection of parameters or set of parameter groups.
* `defaults` a dictionary of default hyper-parameters
* 'lr` is the learning rate, $\alpha$
* `betas` is the tuple $(\beta_1, \beta_2)$
* `eps` is $\epsilon$
"""
# Check the hyper-parameters
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps}")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
# Add the hyper-parameters to the defaults
defaults.update(dict(lr=lr, betas=betas, eps=eps))
# Initialize the PyTorch optimizer.
# This will create parameter groups with the default hyper-parameters
super().__init__(params, defaults)
def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):
"""
### Initialize state for a given parameter tensor
This should be overridden with code to initialize `state` for parameters `param`.
`group` is the parameter group dictionary to which `param` belongs.
"""
pass
def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.Tensor):
"""
### Take optimizer step on a parameter tensor
This should be overridden and take the optimization step on `param` tensor $\theta$,
where `grad` is the gradient for that parameter, $g_t$,
`state` is the optimizer state dictionary for that parameter, and
`group` is the parameter group dictionary `param` belongs to.
"""
pass
@torch.no_grad()
def step(self, closure=None):
"""
### Optimizer step
We have created a template method that does the common stuff every *Adam* based optimizer needs.
"""
# Calculate loss.
#
# 🤔 I'm not sure when you need this. I guess it's if you define a function that
# calculates the loss, does `loss.backward` and return the loss, instead of calling
# it on your own you could pass it to `optimizer.step`. 🤷‍♂️
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
# Iterate through the parameter groups
for group in self.param_groups:
# Iterate through the parameters in the parameter group
for param in group['params']:
# Skip if the parameter has no gradient
if param.grad is None:
continue
# Get the gradient tensor
grad = param.grad.data
# We don't handle sparse gradients
if grad.is_sparse:
raise RuntimeError('GenericAdaptiveOptimizer does not support sparse gradients,'
' please consider SparseAdam instead')
# Get the state for the parameter
state = self.state[param]
# Initialize the state if state is uninitialized
if len(state) == 0:
self.init_state(state, group, param)
# Take the optimization step on the parameter
self.step_param(state, group, grad, param)
# Return the loss, calculated from closure
return loss
class WeightDecay:
"""
## L2 Weight decay
"""
def __init__(self, weight_decay: float = 0., weight_decouple: bool = True, absolute: bool = False):
"""
### Initialize weight decay
* `weight_decay` is the decay coefficient
* `weight_decouple` is a flag indicating whether to add the weight decay to the gradient or directly
decay from the parameter. If added to the gradient it will go through the normal optimizer update.
* `absolute` this flag indicates whether the weight decay coefficient is absolute. This is applicable
when the decay is performed directly on the parameter. If this is false the actual decay is
`weight_decay` * `learning_rate`.
"""
# Check hyper-parameters
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
self.absolute = absolute
self.weight_decouple = weight_decouple
self.weight_decay = weight_decay
def defaults(self):
"""
Return defaults for parameter groups
"""
return dict(weight_decay=self.weight_decay)
def __call__(self, param: torch.nn.Parameter, grad: torch.Tensor, group: Dict[str, any]):
"""
### Perform weight decay and return the gradient
"""
# If we are doing the decay on the parameter directly
if self.weight_decouple:
# If the weight decay coefficient is absolute
if self.absolute:
param.data.mul_(1.0 - group['weight_decay'])
# Otherwise,
else:
param.data.mul_(1.0 - group['lr'] * group['weight_decay'])
# Return the unmodified gradient
return grad
else:
if group['weight_decay'] != 0:
# Add the weight decay to the gradient and return the modified gradient
return grad.add(param.data, alpha=group['weight_decay'])
else:
return grad