Skip to content
This repository has been archived by the owner on Aug 18, 2020. It is now read-only.

Commit

Permalink
Rework on optimizer and native mixed precision
Browse files Browse the repository at this point in the history
  • Loading branch information
sgugger committed Apr 4, 2020
1 parent 90bac7a commit 60d63c3
Show file tree
Hide file tree
Showing 5 changed files with 352 additions and 134 deletions.
4 changes: 4 additions & 0 deletions fastai2/_nbdev.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,10 @@
"MixedPrecision": "18_callback.fp16.ipynb",
"Learner.to_fp16": "18_callback.fp16.ipynb",
"Learner.to_fp32": "18_callback.fp16.ipynb",
"mixed_precision_one_batch": "18_callback.fp16.ipynb",
"NativeMixedPrecision": "18_callback.fp16.ipynb",
"Learner.to_native_fp16": "18_callback.fp16.ipynb",
"Learner.to_native_fp32": "18_callback.fp16.ipynb",
"ShortEpochCallback": "18a_callback.training.ipynb",
"GradientAccumulation": "18a_callback.training.ipynb",
"set_bn_eval": "18a_callback.training.ipynb",
Expand Down
56 changes: 50 additions & 6 deletions fastai2/callback/fp16.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/18_callback.fp16.ipynb (unless otherwise specified).

__all__ = ['get_master', 'to_master_grads', 'to_model_params', 'test_overflow', 'grad_overflow', 'copy_clone',
'ModelToHalf', 'MixedPrecision']
'ModelToHalf', 'MixedPrecision', 'mixed_precision_one_batch', 'NativeMixedPrecision']

# Cell
from ..basics import *
Expand All @@ -14,7 +14,7 @@
from torch.nn.utils import parameters_to_vector

def get_master(opt, flat_master=False):
model_params = [[param for param in pg if param.requires_grad] for pg in opt.param_groups]
model_params = [[param for param in pg if param.requires_grad] for pg in opt.param_lists]
if flat_master:
master_params = []
for pg in model_params:
Expand All @@ -32,7 +32,7 @@ def to_master_grads(model_pgs, master_pgs, flat_master=False):
model_grads_to_master_grads(model_params, master_params, flat_master=flat_master)

# Cell
def to_model_params(model_pgs, master_pgs, flat_master:bool=False)->None:
def to_model_params(model_pgs, master_pgs, flat_master=False)->None:
for (model_params,master_params) in zip(model_pgs,master_pgs):
master_params_to_model_params(model_params, master_params, flat_master=flat_master)

Expand All @@ -54,7 +54,7 @@ def copy_clone(d):

# Cell
def _copy_state(opt, pgs1, pgs2):
opt.param_groups = pgs2
opt.param_lists = pgs2
for pg1,pg2 in zip(pgs1, pgs2):
for p1,p2 in zip(pg1, pg2):
opt.state[p2] = copy_clone(opt.state[p1])
Expand Down Expand Up @@ -84,7 +84,7 @@ def begin_fit(self):
assert self.dls.device.type == 'cuda', "Mixed-precision training requires a GPU, remove the call `to_fp16`"
if self.learn.opt is None: self.learn.create_opt()
self.model_pgs,self.master_pgs = get_master(self.opt, self.flat_master)
self.old_pgs = self.opt.param_groups
self.old_pgs = self.opt.param_lists
#Changes the optimizer so that the optimization step is done in FP32.
_copy_state(self.learn.opt, self.model_pgs, self.master_pgs)
if self.dynamic: self.count = 0
Expand Down Expand Up @@ -121,7 +121,7 @@ def after_step(self):

def after_fit(self):
_copy_state(self.learn.opt, self.master_pgs, self.model_pgs)
self.learn.opt.param_groups = self.old_pgs
self.learn.opt.param_lists = self.old_pgs
delattr(self, "master_pgs")
delattr(self, "model_pgs")
delattr(self, "old_pgs")
Expand All @@ -146,4 +146,48 @@ def to_fp16(self:Learner, **kwargs):
@patch
def to_fp32(self: Learner):
self.remove_cbs([ModelToHalf, MixedPrecision])
return self

# Cell
def mixed_precision_one_batch(self, i, b):

This comment has been minimized.

Copy link
@mcarilli

mcarilli Apr 4, 2020

Great to see this, thanks! Detailed first impression here: #241 (comment)

from torch.cuda.amp import autocast
self.iter = i
try:
self._split(b); self('begin_batch')
with autocast():
self.pred = self.model(*self.xb); self('after_pred')
if len(self.yb) == 0: return
self.loss = self.loss_func(self.pred, *self.yb); self('after_loss')
if not self.training: return
self.scaler.scale(self.loss).backward(); self('after_backward')
self.scaler.step(self.opt); self('after_step')
self.opt.zero_grad()
except CancelBatchException: self('after_cancel_batch')
finally: self('after_batch')

# Cell
class NativeMixedPrecision(Callback):
def __init__(self):
try: from torch.cuda.amp import GradScaler, autocast
except: raise Exception("NativeMixedPrecision requires PyTorch nightlies")
def begin_fit(self):
from torch.cuda.amp import GradScaler
self.old_one_batch = self.learn.one_batch
self.learn.one_batch = partial(mixed_precision_one_batch, self.learn)
self.learn.scaler = GradScaler()

def after_step(self): self.learn.scaler.update()
def after_fit(self):
if getattr(self, 'old_one_batch', None) is not None: self.learn.one_batch = self.old_one_batch

# Cell
@patch
def to_native_fp16(self:Learner):
self.add_cb(NativeMixedPrecision())
return self

# Cell
@patch
def to_native_fp32(self:Learner):
self.remove_cb(NativeMixedPrecision)
return self
90 changes: 50 additions & 40 deletions fastai2/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,23 @@
class _BaseOptimizer():
"Common functionality between `Optimizer` and `OptimWrapper`"
def all_params(self, n=slice(None), with_grad=False):
res = L((p,pg,self.state[p],hyper) for pg,hyper in zip(self.param_groups[n],self.hypers[n]) for p in pg)
res = L((p,pg,self.state[p],hyper) for pg,hyper in zip(self.param_lists[n],self.hypers[n]) for p in pg)
return L(o for o in res if o[0].grad is not None) if with_grad else res

def _set_require_grad(self, rg, p,pg,state,h): p.requires_grad_(rg or state.get('force_train', False))
def freeze_to(self, n):
self.frozen_idx = n if n >= 0 else len(self.param_groups) + n
if self.frozen_idx >= len(self.param_groups):
warn(f"Freezing {self.frozen_idx} groups; model has {len(self.param_groups)}; whole model is frozen.")
self.frozen_idx = n if n >= 0 else len(self.param_lists) + n
if self.frozen_idx >= len(self.param_lists):
warn(f"Freezing {self.frozen_idx} groups; model has {len(self.param_lists)}; whole model is frozen.")
for o in self.all_params(slice(n, None)): self._set_require_grad(True, *o)
for o in self.all_params(slice(None, n)): self._set_require_grad(False, *o)

def freeze(self):
assert(len(self.param_groups)>1)
assert(len(self.param_lists)>1)
self.freeze_to(-1)

def set_freeze(self, n, rg, ignore_force_train=False):
for p in self.param_groups[n]: p.requires_grad_(rg or (state.get('force_train', False) and not ignore_force_train))
for p in self.param_lists[n]: p.requires_grad_(rg or (state.get('force_train', False) and not ignore_force_train))

def unfreeze(self): self.freeze_to(0)
def set_hypers(self, **kwargs): L(kwargs.items()).starmap(self.set_hyper)
Expand All @@ -37,13 +37,22 @@ def _set_hyper(self, k, v):

def set_hyper(self, k, v):
if isinstance(v, slice):
if v.start: v = even_mults(v.start, v.stop, len(self.param_groups))
else: v = [v.stop/10]*(len(self.param_groups)-1) + [v.stop]
if v.start: v = even_mults(v.start, v.stop, len(self.param_lists))
else: v = [v.stop/10]*(len(self.param_lists)-1) + [v.stop]
v = L(v, use_list=None)
if len(v)==1: v = v*len(self.param_groups)
assert len(v) == len(self.hypers), f"Trying to set {len(v)} values for {k} but there are {len(self.param_groups)} parameter groups."
if len(v)==1: v = v*len(self.param_lists)
assert len(v) == len(self.hypers), f"Trying to set {len(v)} values for {k} but there are {len(self.param_lists)} parameter groups."
self._set_hyper(k, v)

@property
def param_groups(self): return [{**{'params': pg}, **hp} for pg,hp in zip(self.param_lists, self.hypers)]
@param_groups.setter
def param_groups(self, v):
for pg,v_ in zip(self.param_lists,v): pg = v_['params']
for hyper,v_ in zip(self.hypers,v):
for k,t in v_.items():
if k != 'params': hyper[k] = t

# Cell
def _update(state, new=None):
if new is None: return state
Expand All @@ -58,8 +67,8 @@ def __init__(self, params, cbs, train_bn=True, **defaults):
params = L(params)
self.cbs,self.state,self.train_bn = L(cbs),defaultdict(dict),train_bn
defaults = merge(*self.cbs.attrgot('defaults'), defaults)
self.param_groups = L(L(p) for p in params) if isinstance(params[0], (L,list)) else L([params])
self.hypers = L({} for _ in range_of(self.param_groups))
self.param_lists = L(L(p) for p in params) if isinstance(params[0], (L,list)) else L([params])
self.hypers = L({} for _ in range_of(self.param_lists))
self.set_hypers(**defaults)
self.frozen_idx = 0

Expand All @@ -82,14 +91,14 @@ def state_dict(self):
return {'state': state, 'hypers': self.hypers}

def load_state_dict(self, sd):
assert len(sd["hypers"]) == len(self.param_groups)
assert len(sd["state"]) == sum([len(pg) for pg in self.param_groups])
assert len(sd["hypers"]) == len(self.param_lists)
assert len(sd["state"]) == sum([len(pg) for pg in self.param_lists])
self.hypers = sd['hypers']
self.state = {p: s for p,s in zip(self.all_params().itemgot(0), sd['state'])}

# Cell
def sgd_step(p, lr, **kwargs):
p.data.add_(-lr, p.grad.data)
p.data.add_(p.grad.data, alpha=-lr)

# Cell
def weight_decay(p, lr, wd, do_wd=True, **kwargs):
Expand All @@ -101,7 +110,7 @@ def weight_decay(p, lr, wd, do_wd=True, **kwargs):
# Cell
def l2_reg(p, lr, wd, do_wd=True, **kwargs):
"L2 regularization as adding `wd*p` to `p.grad`"
if do_wd and wd!=0: p.grad.data.add_(wd, p.data)
if do_wd and wd!=0: p.grad.data.add_(p.data, alpha=wd)

l2_reg.defaults = dict(wd=0.)

Expand All @@ -110,7 +119,7 @@ def average_grad(p, mom, dampening=False, grad_avg=None, **kwargs):
"Keeps track of the avg grads of `p` in `state` with `mom`."
if grad_avg is None: grad_avg = torch.zeros_like(p.grad.data)
damp = 1-mom if dampening else 1.
grad_avg.mul_(mom).add_(damp, p.grad.data)
grad_avg.mul_(mom).add_(p.grad.data, alpha=damp)
return {'grad_avg': grad_avg}

average_grad.defaults = dict(mom=0.9)
Expand All @@ -119,15 +128,15 @@ def average_grad(p, mom, dampening=False, grad_avg=None, **kwargs):
def average_sqr_grad(p, sqr_mom, dampening=True, sqr_avg=None, **kwargs):
if sqr_avg is None: sqr_avg = torch.zeros_like(p.grad.data)
damp = 1-sqr_mom if dampening else 1.
sqr_avg.mul_(sqr_mom).addcmul_(damp, p.grad.data, p.grad.data)
sqr_avg.mul_(sqr_mom).addcmul_(p.grad.data, p.grad.data, value=damp)
return {'sqr_avg': sqr_avg}

average_sqr_grad.defaults = dict(sqr_mom=0.99)

# Cell
def momentum_step(p, lr, grad_avg, **kwargs):
"Step for SGD with momentum with `lr`"
p.data.add_(-lr, grad_avg)
p.data.add_(grad_avg, alpha=-lr)

# Cell
def SGD(params, lr, mom=0., wd=0., decouple_wd=True):
Expand All @@ -141,7 +150,7 @@ def SGD(params, lr, mom=0., wd=0., decouple_wd=True):
def rms_prop_step(p, lr, sqr_avg, eps, grad_avg=None, **kwargs):
"Step for SGD with momentum with `lr`"
denom = sqr_avg.sqrt().add_(eps)
p.data.addcdiv_(-lr, (grad_avg if grad_avg is not None else p.grad), denom)
p.data.addcdiv_((grad_avg if grad_avg is not None else p.grad), denom, value=-lr)

rms_prop_step.defaults = dict(eps=1e-8)

Expand All @@ -167,7 +176,7 @@ def adam_step(p, lr, mom, step, sqr_mom, grad_avg, sqr_avg, eps, **kwargs):
"Step for Adam with `lr` on `p`"
debias1 = debias(mom, 1-mom, step)
debias2 = debias(sqr_mom, 1-sqr_mom, step)
p.data.addcdiv_(-lr / debias1, grad_avg, (sqr_avg/debias2).sqrt() + eps)
p.data.addcdiv_(grad_avg, (sqr_avg/debias2).sqrt() + eps, value = -lr / debias1)
return p

adam_step._defaults = dict(eps=1e-5)
Expand All @@ -191,8 +200,8 @@ def radam_step(p, lr, mom, step, sqr_mom, grad_avg, sqr_avg, eps, beta, **kwargs
denom = (sqr_avg/debias2).sqrt()
if eps: denom += eps
if beta: denom = F.softplus(denom, beta)
p.data.addcdiv_(-lr*v / debias1, grad_avg, denom)
else: p.data.add_(-lr / debias1, grad_avg)
p.data.addcdiv_(grad_avg, denom, value = -lr*v / debias1)
else: p.data.add_(grad_avg, alpha=-lr / debias1)
return p

radam_step._defaults = dict(eps=1e-5)
Expand All @@ -208,8 +217,9 @@ def RAdam(params, lr, mom=0.9, sqr_mom=0.99, eps=1e-5, wd=0., beta=0., decouple_
def qhadam_step(p, lr, mom, sqr_mom, sqr_avg, nu_1, nu_2, step, grad_avg, eps, **kwargs):
debias1 = debias(mom, 1-mom, step)
debias2 = debias(sqr_mom, 1-sqr_mom, step)
p.data.addcdiv_(-lr, ((1-nu_1) * p.grad.data) + (nu_1 * (grad_avg / debias1)),
(((1 - nu_2) * (p.grad.data)**2) + (nu_2 * (sqr_avg / debias2))).sqrt() + eps)
p.data.addcdiv_(((1-nu_1) * p.grad.data) + (nu_1 * (grad_avg / debias1)),
(((1 - nu_2) * (p.grad.data)**2) + (nu_2 * (sqr_avg / debias2))).sqrt() + eps,
value = -lr)
return p

qhadam_step._defaults = dict(eps=1e-8)
Expand All @@ -234,7 +244,7 @@ def larc_layer_lr(p, lr, trust_coeff, wd, eps, clip=True, **kwargs):
# Cell
def larc_step(p, local_lr, grad_avg=None, **kwargs):
"Step for LARC `local_lr` on `p`"
p.data.add_(-local_lr, p.grad.data if grad_avg is None else grad_avg)
p.data.add_(p.grad.data if grad_avg is None else grad_avg, alpha = -local_lr)

# Cell
def Larc(params, lr, mom=0.9, clip=True, trust_coeff=0.02, eps=1e-8, wd=0., decouple_wd=True):
Expand All @@ -253,7 +263,7 @@ def lamb_step(p, lr, mom, step, sqr_mom, grad_avg, sqr_avg, eps, **kwargs):
step = (grad_avg/debias1) / ((sqr_avg/debias2).sqrt()+eps)
r2 = step.pow(2).mean().sqrt()
q = 1 if r1 == 0 or r2 == 0 else min(r1/r2,10)
p.data.add_(-lr * q, step)
p.data.add_(step, alpha = -lr * q)

lamb_step._defaults = dict(eps=1e-6, wd=0.)

Expand All @@ -277,9 +287,9 @@ def step(self):
self.opt.step()
self.count += 1
if self.count%self.k != 0: return
for slow_pg,fast_pg in zip(self.slow_weights,self.param_groups):
for slow_pg,fast_pg in zip(self.slow_weights,self.param_lists):
for slow_p,fast_p in zip(slow_pg,fast_pg):
slow_p.data.add_(self.alpha, fast_p.data-slow_p.data)
slow_p.data.add_(fast_p.data-slow_p.data, alpha=self.alpha)
fast_p.data.copy_(slow_p.data)

def clear_state(self):
Expand All @@ -297,12 +307,12 @@ def load_state_dict(self, sd):
self.opt.load_state_dict(sd)

def _init_state(self): self.count,self.slow_weights = 0,None
def _copy_weights(self): self.slow_weights = L(L(p.clone().detach() for p in pg) for pg in self.param_groups)
def _copy_weights(self): self.slow_weights = L(L(p.clone().detach() for p in pg) for pg in self.param_lists)

@property
def param_groups(self): return self.opt.param_groups
@param_groups.setter
def param_groups(self, v): self.opt.param_groups = v
def param_lists(self): return self.opt.param_lists
@param_lists.setter
def param_lists(self, v): self.opt.param_lists = v

# Cell
@delegates(RAdam)
Expand Down Expand Up @@ -342,17 +352,17 @@ def __init__(self, opt, hp_map=None):
self.state = defaultdict(dict, {})
self.frozen_idx = 0

@property
def param_groups(self): return [pg['params'] for pg in self.opt.param_groups]
@param_groups.setter
def param_groups(self, v):
for pg,v_ in zip(self.opt.param_groups,v): pg['params'] = v_

@property
def hypers(self):
return [{self.fwd_map[k]:v for k,v in detuplify_pg(pg).items() if k != 'params'} for pg in self.opt.param_groups]

def _set_hyper(self, k, v):
for pg,v_ in zip(self.opt.param_groups,v): pg = set_item_pg(pg, self.bwd_map[k], v_)

def clear_state(self): self.opt.state = defaultdict(dict, {})
def clear_state(self): self.opt.state = defaultdict(dict, {})

@property
def param_lists(self): return [pg['params'] for pg in self.opt.param_groups]
@param_lists.setter
def param_lists(self, v):
for pg,v_ in zip(self.opt.param_groups,v): pg['params'] = v_
Loading

0 comments on commit 60d63c3

Please sign in to comment.