Find file Copy path
67e46d1 Jan 2, 2019
2 contributors

Users who have contributed to this file

@sgugger @jph00
44 lines (38 sloc) 2.06 KB
"Supports 1-Cycle style training"
from ..core import *
from ..callback import *
from ..basic_train import Learner,LearnerCallback
__all__ = ['OneCycleScheduler']
class OneCycleScheduler(LearnerCallback):
"Manage 1-Cycle style training as outlined in Leslie Smith's [paper]("
def __init__(self, learn:Learner, lr_max:float, moms:Floats=(0.95,0.85), div_factor:float=25., pct_start:float=0.3):
self.lr_max,self.div_factor,self.pct_start = lr_max,div_factor,pct_start
if is_listy(self.lr_max): self.lr_max = np.array(self.lr_max)
def steps(self, *steps_cfg:StartOptEnd):
"Build anneal schedule for all of the parameters."
return [Stepper(step, n_iter, func=func)
for (step,(n_iter,func)) in zip(steps_cfg, self.phases)]
def on_train_begin(self, n_epochs:int, **kwargs:Any)->None:
"Initialize our optimization params based on our annealing schedule."
n = len( * n_epochs
a1 = int(n * self.pct_start)
a2 = n-a1
self.phases = ((a1, annealing_cos), (a2, annealing_cos))
low_lr = self.lr_max/self.div_factor
self.lr_scheds = self.steps((low_lr, self.lr_max), (self.lr_max, low_lr/1e4))
self.mom_scheds = self.steps(self.moms, (self.moms[1], self.moms[0]))
self.opt = self.learn.opt, = self.lr_scheds[0].start,self.mom_scheds[0].start
self.idx_s = 0
def on_batch_end(self, train, **kwargs:Any)->None:
"Take one step forward on the annealing schedule for the optim params."
if train:
if self.idx_s >= len(self.lr_scheds): return True = self.lr_scheds[self.idx_s].step() = self.mom_scheds[self.idx_s].step()
# when the current schedule is complete we move onto the next
# schedule. (in 1-cycle there are two schedules)
if self.lr_scheds[self.idx_s].is_done:
self.idx_s += 1