Skip to content

Commit

Permalink
Add CyclicLR implementation
Browse files Browse the repository at this point in the history
Slightly modified version of pull request #2016 from the pytorch repo: pytorch/pytorch#2016
The implementation follows the paper "Cyclical Learning Rates for Training Neural Networks": https://arxiv.org/abs/1506.01186
  • Loading branch information
davidtvs committed Dec 14, 2018
1 parent 8b10698 commit 9dd7425
Show file tree
Hide file tree
Showing 2 changed files with 165 additions and 0 deletions.
1 change: 1 addition & 0 deletions core/__init__.py
Expand Up @@ -4,3 +4,4 @@
from .predict import predict, predict_batch, predict_yield_batch
from .evaluate import evaluate
from .lr_finder import LRFinder
from .cyclic_lr import CyclicLR
164 changes: 164 additions & 0 deletions core/cyclic_lr.py
@@ -0,0 +1,164 @@
import numpy as np
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim import Optimizer


class CyclicLR(_LRScheduler):
"""Sets the learning rate of each parameter group according to
cyclical learning rate policy (CLR). The policy cycles the learning
rate between two boundaries with a constant frequency, as detailed in
the paper `Cyclical Learning Rates for Training Neural Networks`_.
The distance between the two boundaries can be scaled on a per-iteration
or per-cycle basis.
Cyclical learning rate policy changes the learning rate after every batch.
`step` should be called after a batch has been used for training.
To resume training, save `last_batch_iteration` and use it to instantiate `CycleLR`.
This class has three built-in policies, as put forth in the paper:
"triangular":
A basic triangular cycle w/ no amplitude scaling.
"triangular2":
A basic triangular cycle that scales initial amplitude by half each cycle.
"exp_range":
A cycle that scales initial amplitude by gamma**(cycle iterations) at each
cycle iteration.
This implementation is a slightly modified version of pytorch#2016 by thomasjpfan
which is adapted from the github repo `bckenstler/CLR`_.
Args:
optimizer (Optimizer): Wrapped optimizer.
max_lr (float or list): Upper boundaries in the cycle for
each parameter group. Functionally,
it defines the cycle amplitude (max_lr - base_lr).
Where base_lr is the learning rate defined by the
optimizer. The lr at any cycle is the sum of base_lr
and some scaling of the amplitude; therefore
max_lr may not actually be reached depending on
scaling function. Default: 0.006
step_size_up (int): Number of training iterations in the
increasing half of a cycle.
step_size_down (int): Number of training iterations in the
decreasing half of a cycle. If step_size_down is None,
it is set to step_size_up.
mode (str): One of {triangular, triangular2, exp_range}.
Values correspond to policies detailed above.
If scale_fn is not None, this argument is ignored.
Default: 'triangular'
gamma (float): Constant in 'exp_range' scaling function:
gamma**(cycle iterations)
Default: 1.0
scale_fn (function): Custom scaling policy defined by a single
argument lambda function, where
0 <= scale_fn(x) <= 1 for all x >= 0.
mode paramater is ignored
Default: None
scale_mode (str): {'cycle', 'iterations'}.
Defines whether scale_fn is evaluated on
cycle number or cycle iterations (training
iterations since start of cycle).
Default: 'cycle'
last_batch_idx (int): The index of the last batch. Default: -1
Example:
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> scheduler = torch.optim.CyclicLR(optimizer)
>>> data_loader = torch.utils.data.DataLoader(...)
>>> for epoch in range(10):
>>> for batch in data_loader:
>>> scheduler.step()
>>> train_batch(...)
.. _Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186
.. _bckenstler/CLR: https://github.com/bckenstler/CLR
"""

def __init__(
self,
optimizer,
max_lr=6e-3,
step_size_up=2000,
step_size_down=None,
mode="triangular",
gamma=1.0,
scale_fn=None,
scale_mode="cycle",
last_batch_idx=-1,
):

if not isinstance(optimizer, Optimizer):
raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__))
self.optimizer = optimizer

self.max_lrs = self._format_lr("max_lr", optimizer, max_lr)

step_size_down = step_size_down or step_size_up
self.total_size = float(step_size_up + step_size_down)
self.step_ratio = float(step_size_up) / self.total_size

if mode not in ["triangular", "triangular2", "exp_range"] and scale_fn is None:
raise ValueError("mode is invalid and scale_fn is None")

self.mode = mode
self.gamma = gamma

if scale_fn is None:
if self.mode == "triangular":
self.scale_fn = self._triangular_scale_fn
self.scale_mode = "cycle"
elif self.mode == "triangular2":
self.scale_fn = self._triangular2_scale_fn
self.scale_mode = "cycle"
elif self.mode == "exp_range":
self.scale_fn = self._exp_range_scale_fn
self.scale_mode = "iterations"
else:
self.scale_fn = scale_fn
self.scale_mode = scale_mode
super(CyclicLR, self).__init__(optimizer, last_batch_idx)

def _format_lr(self, name, optimizer, lr):
"""Return correctly formatted lr for each param group."""
if isinstance(lr, (list, tuple)):
if len(lr) != len(optimizer.param_groups):
raise ValueError(
"expected {} values for {}, got {}".format(
len(optimizer.param_groups), name, len(lr)
)
)
return np.array(lr)
else:
return lr * np.ones(len(optimizer.param_groups))

def _triangular_scale_fn(self, x):
return 1.0

def _triangular2_scale_fn(self, x):
return 1 / (2.0 ** (x - 1))

def _exp_range_scale_fn(self, x):
return self.gamma ** (x)

def get_lr(self):
"""Calculates the learning rate at batch index. This function treats
`self.last_epoch` as the last batch index.
"""
cycle = np.floor(1 + self.last_epoch / self.total_size)
x = 1 + self.last_epoch / self.total_size - cycle
if x <= self.step_ratio:
scale_factor = x / self.step_ratio
else:
scale_factor = (x - 1) / (self.step_ratio - 1)

lrs = []
for base_lr, max_lr in zip(self.base_lrs, self.max_lrs):
base_height = (max_lr - base_lr) * scale_factor
if self.scale_mode == "cycle":
lr = base_lr + base_height * self.scale_fn(cycle)
else:
lr = base_lr + base_height * self.scale_fn(self.last_epoch)

lrs.append(lr)

return lrs

0 comments on commit 9dd7425

Please sign in to comment.