Skip to content
Branch: master
Find file Copy path
Find file Copy path
4 contributors

Users who have contributed to this file

@sgugger @jph00 @lgvaz @takotab
633 lines (545 sloc) 27.7 KB
# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/13_learner.ipynb (unless otherwise specified).
__all__ = ['CancelFitException', 'CancelEpochException', 'CancelTrainException', 'CancelValidException',
'CancelBatchException', 'Callback', 'TrainEvalCallback', 'GatherPredsCallback', 'event', 'replacing_yield',
'mk_metric', 'save_model', 'load_model', 'Learner', 'VerboseCallback', 'Metric', 'AvgMetric', 'AvgLoss',
'AvgSmoothLoss', 'Recorder', 'FetchPreds', 'load_learner']
# Cell
from .data.all import *
from .optimizer import *
# Cell
_inner_loop = "begin_batch after_pred after_loss after_backward after_step after_cancel_batch after_batch".split()
# Cell
class Callback(GetAttr):
"Basic class handling tweaks of the training loop by changing a `Learner` in various events"
_default,learn,run,run_train,run_valid = 'learn',None,True,True,True
def __repr__(self): return type(self).__name__
def __call__(self, event_name):
"Call `self.{event_name}` if it's defined"
_run = (event_name not in _inner_loop or (self.run_train and getattr(self, 'training', True)) or
(self.run_valid and not getattr(self, 'training', False)))
if and _run: getattr(self, event_name, noop)()
if event_name=='after_fit': #Reset to True at each end of fit
def __setattr__(self, name, value):
if hasattr(self.learn,name):
warn(f"You are setting an attribute ({name}) that also exists in the learner. Please be advised that you're not setting it in the learner but in the callback. Use `self.learn.{name}` if you would like to change it in the learner.")
super().__setattr__(name, value)
def name(self):
"Name of the `Callback`, camel-cased and with '*Callback*' removed"
return class2attr(self, 'Callback')
# Cell
class TrainEvalCallback(Callback):
"`Callback` that tracks the number of iterations done and properly sets training/eval mode"
run_valid = False
def begin_fit(self):
"Set the iter and epoch counters to 0, put the model and the right device"
self.learn.train_iter,self.learn.pct_train = 0,0.
def after_batch(self):
"Update the iter counter (in training mode)"
self.learn.pct_train += 1./(self.n_iter*self.n_epoch)
self.learn.train_iter += 1
def begin_train(self):
"Set the model in training mode"
def begin_validate(self):
"Set the model in validation mode"
# Cell
#TODO: save_targs and save_preds only handle preds/targets that have one tensor, not tuples of tensors.
class GatherPredsCallback(Callback):
"`Callback` that saves the predictions and targets, optionally `with_loss`"
def __init__(self, with_input=False, with_loss=False, save_preds=None, save_targs=None, concat_dim=0):
store_attr(self, "with_input,with_loss,save_preds,save_targs,concat_dim")
def begin_batch(self):
if self.with_input: self.inputs.append((to_detach(self.xb)))
def begin_validate(self):
"Initialize containers"
self.preds,self.targets = [],[]
if self.with_input: self.inputs = []
if self.with_loss: self.losses = []
def after_batch(self):
"Save predictions, targets and potentially losses"
preds,targs = to_detach(self.pred),to_detach(self.yb)
if self.save_preds is None: self.preds.append(preds)
else: (self.save_preds/str(self.iter)).save_array(preds)
if self.save_targs is None: self.targets.append(targs)
else: (self.save_targs/str(self.iter)).save_array(targs[0])
if self.with_loss:
bs = find_bs(self.yb)
loss = self.loss if self.loss.numel() == bs else self.loss.view(bs,-1).mean(1)
def after_fit(self):
"Concatenate all recorded tensors"
if self.with_input: self.inputs = detuplify(to_concat(self.inputs, dim=self.concat_dim))
if not self.save_preds: self.preds = detuplify(to_concat(self.preds, dim=self.concat_dim))
if not self.save_targs: self.targets = detuplify(to_concat(self.targets, dim=self.concat_dim))
if self.with_loss: self.losses = to_concat(self.losses)
def all_tensors(self):
res = [None if self.save_preds else self.preds, None if self.save_targs else self.targets]
if self.with_input: res = [self.inputs] + res
if self.with_loss: res.append(self.losses)
return res
# Cell
_ex_docs = dict(
CancelFitException="Skip the rest of this batch and go to `after_batch`",
CancelEpochException="Skip the rest of the training part of the epoch and go to `after_train`",
CancelTrainException="Skip the rest of the validation part of the epoch and go to `after_validate`",
CancelValidException="Skip the rest of this epoch and go to `after_epoch`",
CancelBatchException="Interrupts training and go to `after_fit`")
for c,d in _ex_docs.items(): mk_class(c,sup=Exception,doc=d)
# Cell
_events = L.split('begin_fit begin_epoch begin_train begin_batch after_pred after_loss \
after_backward after_step after_cancel_batch after_batch after_cancel_train \
after_train begin_validate after_cancel_validate after_validate after_cancel_epoch \
after_epoch after_cancel_fit after_fit')
mk_class('event', **_events.map_dict(),
doc="All possible events as attributes to get tab-completion and typo-proofing")
_before_epoch = [event.begin_fit, event.begin_epoch]
_after_epoch = [event.after_epoch, event.after_fit]
# Cell
_loop = ['Start Fit', 'begin_fit', 'Start Epoch Loop', 'begin_epoch', 'Start Train', 'begin_train',
'Start Batch Loop', 'begin_batch', 'after_pred', 'after_loss', 'after_backward',
'after_step', 'after_cancel_batch', 'after_batch','End Batch Loop','End Train',
'after_cancel_train', 'after_train', 'Start Valid', 'begin_validate','Start Batch Loop',
'**CBs same as train batch**', 'End Batch Loop', 'End Valid', 'after_cancel_validate',
'after_validate', 'End Epoch Loop', 'after_cancel_epoch', 'after_epoch', 'End Fit',
'after_cancel_fit', 'after_fit']
# Cell = 1e-3
defaults.wd = 1e-2
defaults.callbacks = [TrainEvalCallback]
# Cell
def replacing_yield(o, attr, val):
"Context manager to temporarily replace an attribute"
old = getattr(o,attr)
try: yield setattr(o,attr,val)
finally: setattr(o,attr,old)
# Cell
def mk_metric(m):
"Convert `m` to an `AvgMetric`, unless it's already a `Metric`"
return m if isinstance(m, Metric) else AvgMetric(m)
# Cell
def save_model(file, model, opt, with_opt=True):
"Save `model` to `file` along with `opt` (if available, and if `with_opt`)"
if opt is None: with_opt=False
state = get_model(model).state_dict()
if with_opt: state = {'model': state, 'opt':opt.state_dict()}, file)
# Cell
def load_model(file, model, opt, with_opt=None, device=None, strict=True):
"Load `model` from `file` along with `opt` (if available, and if `with_opt`)"
if isinstance(device, int): device = torch.device('cuda', device)
elif device is None: device = 'cpu'
state = torch.load(file, map_location=device)
hasopt = set(state)=={'model', 'opt'}
model_state = state['model'] if hasopt else state
get_model(model).load_state_dict(model_state, strict=strict)
if hasopt and ifnone(with_opt,True):
try: opt.load_state_dict(state['opt'])
if with_opt: warn("Could not load the optimizer state.")
elif with_opt: warn("Saved filed doesn't contain an optimizer state.")
# Cell
def _try_concat(o):
try: return
except: return sum([L(o_[i,:] for i in range_of(o_)) for o_ in o], L())
# Cell
from contextlib import ExitStack
# Cell
class Learner():
def __init__(self, dls, model, loss_func=None, opt_func=Adam,, splitter=trainable_params, cbs=None,
metrics=None, path=None, model_dir='models', wd=defaults.wd, wd_bn_bias=False, train_bn=True,
store_attr(self, "dls,model,opt_func,lr,splitter,model_dir,wd,wd_bn_bias,train_bn,metrics,moms"),self.create_mbar,self.logger,self.opt, = False,True,print,None,L()
if loss_func is None:
loss_func = getattr(dls.train_ds, 'loss_func', None)
assert loss_func is not None, "Could not infer loss function from the data, please pass a loss function."
self.loss_func = loss_func
self.path = path if path is not None else getattr(dls, 'path', Path('.'))
self.add_cbs([(cb() if isinstance(cb, type) else cb) for cb in L(defaults.callbacks)+L(cbs)])
if hasattr(self.model, 'reset'): self.model.reset()
self.epoch,self.n_epoch,self.loss = 0,1,tensor(0.)
def metrics(self): return self._metrics
def metrics(self,v): self._metrics = L(v).map(mk_metric)
def add_cbs(self, cbs): L(cbs).map(self.add_cb)
def remove_cbs(self, cbs): L(cbs).map(self.remove_cb)
def add_cb(self, cb):
old = getattr(self,, None)
assert not old or isinstance(old, type(cb)), f"self.{} already registered"
cb.learn = self
setattr(self,, cb)
return self
def remove_cb(self, cb):
cb.learn = None
if hasattr(self, delattr(self,
if cb in
def added_cbs(self, cbs):
def ordered_cbs(self, cb_func): return [cb for cb in sort_by_run( if hasattr(cb, cb_func)]
def __call__(self, event_name): L(event_name).map(self._call_one)
def _call_one(self, event_name):
assert hasattr(event, event_name)
[cb(event_name) for cb in sort_by_run(]
def _bn_bias_state(self, with_bias): return bn_bias_params(self.model, with_bias).map(self.opt.state)
def create_opt(self):
self.opt = self.opt_func(self.splitter(self.model),
if not self.wd_bn_bias:
for p in self._bn_bias_state(True ): p['do_wd'] = False
if self.train_bn:
for p in self._bn_bias_state(False): p['force_train'] = True
def _split(self, b):
i = getattr(self.dls, 'n_inp', 1 if len(b)==1 else len(b)-1)
self.xb,self.yb = b[:i],b[i:]
def all_batches(self):
self.n_iter = len(self.dl)
for o in enumerate(self.dl): self.one_batch(*o)
def one_batch(self, i, b):
self.iter = i
self._split(b); self('begin_batch')
self.pred = self.model(*self.xb); self('after_pred')
if len(self.yb) == 0: return
self.loss = self.loss_func(self.pred, *self.yb); self('after_loss')
if not return
self.loss.backward(); self('after_backward')
self.opt.step(); self('after_step')
except CancelBatchException: self('after_cancel_batch')
finally: self('after_batch')
def _do_begin_fit(self, n_epoch):
self.n_epoch,self.loss = n_epoch,tensor(0.); self('begin_fit')
def _do_epoch_train(self):
self.dl = self.dls.train; self('begin_train')
except CancelTrainException: self('after_cancel_train')
finally: self('after_train')
def _do_epoch_validate(self, ds_idx=1, dl=None):
if dl is None: dl = self.dls[ds_idx]
names = ['shuffle', 'drop_last']
dl,old,has = change_attrs(dl, names, [False,False])
self.dl = dl; self('begin_validate')
with torch.no_grad(): self.all_batches()
except CancelValidException: self('after_cancel_validate')
dl,*_ = change_attrs(dl, names, old, has); self('after_validate')
def fit(self, n_epoch, lr=None, wd=None, cbs=None, reset_opt=False):
with self.added_cbs(cbs):
if reset_opt or not self.opt: self.create_opt()
self.opt.set_hypers(wd=self.wd if wd is None else wd, if lr is None else lr)
for epoch in range(n_epoch):
self.epoch=epoch; self('begin_epoch')
except CancelEpochException: self('after_cancel_epoch')
finally: self('after_epoch')
except CancelFitException: self('after_cancel_fit')
finally: self('after_fit')
def validate(self, ds_idx=1, dl=None, cbs=None):
if dl is None: dl = self.dls[ds_idx]
with self.added_cbs(cbs), self.no_logging(), self.no_mbar():
self._do_epoch_validate(ds_idx, dl)
return self.recorder.values[-1]
def get_preds(self, ds_idx=1, dl=None, with_input=False, with_decoded=False, with_loss=False, act=None,
inner=False, **kwargs):
if dl is None: dl = self.dls[ds_idx].new(shuffled=False, drop_last=False)
cb = GatherPredsCallback(with_input=with_input, with_loss=with_loss, **kwargs)
#with self.no_logging(), self.added_cbs(cb), self.loss_not_reduced(), self.no_mbar():
ctx_mgrs = [self.no_logging(), self.added_cbs(cb), self.no_mbar()]
if with_loss: ctx_mgrs.append(self.loss_not_reduced())
with ExitStack() as stack:
for mgr in ctx_mgrs: stack.enter_context(mgr)
self(event.begin_epoch if inner else _before_epoch)
self(event.after_epoch if inner else _after_epoch)
if act is None: act = getattr(self.loss_func, 'activation', noop)
res = cb.all_tensors()
pred_i = 1 if with_input else 0
if res[pred_i] is not None:
res[pred_i] = act(res[pred_i])
if with_decoded: res.insert(pred_i+2, getattr(self.loss_func, 'decodes', noop)(res[pred_i]))
return tuple(res)
def predict(self, item, rm_type_tfms=None, with_input=False):
dl = self.dls.test_dl([item], rm_type_tfms=rm_type_tfms)
inp,preds,_,dec_preds = self.get_preds(dl=dl, with_input=True, with_decoded=True)
dec = self.dls.decode_batch((*tuplify(inp),*tuplify(dec_preds)))[0]
i = getattr(self.dls, 'n_inp', -1)
dec_inp,dec_targ = map(detuplify, [dec[:i],dec[i:]])
res = dec_targ,dec_preds[0],preds[0]
if with_input: res = (dec_inp,) + res
return res
def show_results(self, ds_idx=1, dl=None, max_n=9, shuffle=True, **kwargs):
if dl is None: dl = self.dls[ds_idx].new(shuffle=shuffle)
b = dl.one_batch()
_,_,preds = self.get_preds(dl=[b], with_decoded=True)
self.dls.show_results(b, preds, max_n=max_n, **kwargs)
def show_training_loop(self):
indent = 0
for s in _loop:
if s.startswith('Start'): print(f'{" "*indent}{s}'); indent += 2
elif s.startswith('End'): indent -= 2; print(f'{" "*indent}{s}')
else: print(f'{" "*indent} - {s:15}:', self.ordered_cbs(s))
def no_logging(self): return replacing_yield(self, 'logger', noop)
def no_mbar(self): return replacing_yield(self, 'create_mbar', False)
def loss_not_reduced(self):
if hasattr(self.loss_func, 'reduction'): return replacing_yield(self.loss_func, 'reduction', 'none')
else: return replacing_yield(self, 'loss_func', partial(self.loss_func, reduction='none'))
def save(self, file, with_opt=True):
if rank_distrib(): return # don't save if slave proc
file = join_path_file(file, self.path/self.model_dir, ext='.pth')
save_model(file, self.model, getattr(self,'opt',None), with_opt)
def load(self, file, with_opt=None, device=None, strict=True):
if device is None: device = self.dls.device
if self.opt is None: self.create_opt()
file = join_path_file(file, self.path/self.model_dir, ext='.pth')
load_model(file, self.model, self.opt, with_opt=with_opt, device=device, strict=strict)
return self
Learner.x,Learner.y = add_props(lambda i,x: detuplify((x.xb,x.yb)[i]))
# Cell
add_docs(Learner, "Group together a `model`, some `dls` and a `loss_func` to handle training",
add_cbs="Add `cbs` to the list of `Callback` and register `self` as their learner",
add_cb="Add `cb` to the list of `Callback` and register `self` as their learner",
remove_cbs="Remove `cbs` from the list of `Callback` and deregister `self` as their learner",
remove_cb="Add `cb` from the list of `Callback` and deregister `self` as their learner",
added_cbs="Context manage that temporarily adds `cbs`",
ordered_cbs="Return a list of `Callback` for one step `cb_func` in the training loop",
create_opt="Create an optimizer with `lr`",
one_batch="Train or evaluate `self.model` on batch `(xb,yb)`",
all_batches="Train or evaluate `self.model` on all batches of `self.dl`",
fit="Fit `self.model` for `n_epoch` using `cbs`. Optionally `reset_opt`.",
validate="Validate on `dl` with potential new `cbs`.",
get_preds="Get the predictions and targets on the `ds_idx`-th dbunchset or `dl`, optionally `with_input` and `with_loss`",
predict="Return the prediction on `item`, fully decoded, loss function decoded and probabilities",
show_results="Show some predictions on `ds_idx`-th dbunchset or `dl`",
show_training_loop="Show each step in the training loop",
no_logging="Context manager to temporarily remove `logger`",
no_mbar="Context manager to temporarily prevent the master progress bar from being created",
loss_not_reduced="A context manager to evaluate `loss_func` with reduction set to none.",
save="Save model and optimizer state (if `with_opt`) to `self.path/self.model_dir/file`",
load="Load model and optimizer state (if `with_opt`) from `self.path/self.model_dir/file` using `device`"
# Cell
class VerboseCallback(Callback):
"Callback that prints the name of each event called"
def __call__(self, event_name):
# Cell
class Metric():
"Blueprint for defining a metric"
def reset(self): pass
def accumulate(self, learn): pass
def value(self): raise NotImplementedError
def name(self): return class2attr(self, 'Metric')
_docs = dict(
reset="Reset inner state to prepare for new computation",
name="Name of the `Metric`, camel-cased and with Metric removed",
accumulate="Use `learn` to update the state with new results",
value="The value of the metric")
# Cell
def _maybe_reduce(val):
if num_distrib()>1:
val = val.clone()
torch.distributed.all_reduce(val, op=torch.distributed.ReduceOp.SUM)
val /= num_distrib()
return val
# Cell
class AvgMetric(Metric):
"Average the values of `func` taking into account potential different batch sizes"
def __init__(self, func): self.func = func
def reset(self):,self.count = 0.,0
def accumulate(self, learn):
bs = find_bs(learn.yb) += to_detach(self.func(learn.pred, *learn.yb))*bs
self.count += bs
def value(self): return if self.count != 0 else None
def name(self): return self.func.func.__name__ if hasattr(self.func, 'func') else self.func.__name__
# Cell
class AvgLoss(Metric):
"Average the losses taking into account potential different batch sizes"
def reset(self):,self.count = 0.,0
def accumulate(self, learn):
bs = find_bs(learn.yb) += to_detach(learn.loss.mean())*bs
self.count += bs
def value(self): return if self.count != 0 else None
def name(self): return "loss"
# Cell
class AvgSmoothLoss(Metric):
"Smooth average of the losses (exponentially weighted with `beta`)"
def __init__(self, beta=0.98): self.beta = beta
def reset(self): self.count,self.val = 0,tensor(0.)
def accumulate(self, learn):
self.count += 1
self.val = torch.lerp(to_detach(learn.loss.mean(), gather=False), self.val, self.beta)
def value(self): return self.val/(1-self.beta**self.count)
# Cell
from fastprogress.fastprogress import format_time
def _maybe_item(t):
t = t.value
return t.item() if isinstance(t, Tensor) and t.numel()==1 else t
# Cell
class Recorder(Callback):
"Callback that registers statistics (lr, loss and metrics) during training"
run_after = TrainEvalCallback
def __init__(self, add_time=True, train_metrics=False, valid_metrics=True, beta=0.98):
store_attr(self, 'add_time,train_metrics,valid_metrics')
self.loss,self.smooth_loss = AvgLoss(),AvgSmoothLoss(beta=beta)
def begin_fit(self):
"Prepare state for training"
self.lrs,self.iters,self.losses,self.values = [],[],[],[]
names = self.metrics.attrgot('name')
if self.train_metrics and self.valid_metrics:
names = L('loss') + names
names ='train_{}') +'valid_{}')
elif self.valid_metrics: names = L('train_loss', 'valid_loss') + names
else: names = L('train_loss') + names
if self.add_time: names.append('time')
self.metric_names = 'epoch'+names
def after_batch(self):
"Update all metrics and records lr and smooth loss in training"
if len(self.yb) == 0: return
mets = self._train_mets if else self._valid_mets
for met in mets: met.accumulate(self.learn)
if not return
self.learn.smooth_loss = self.smooth_loss.value
def begin_epoch(self):
"Set timer if `self.add_time=True`"
self.cancel_train,self.cancel_valid = False,False
if self.add_time: self.start_epoch = time.time()
self.log = L(getattr(self, 'epoch', 0))
def begin_train (self): self._train_mets[1:].map(Self.reset())
def begin_validate(self):
def after_train (self): self.log +=
def after_validate(self): self.log +=
def after_cancel_train(self): self.cancel_train = True
def after_cancel_validate(self): self.cancel_valid = True
def after_epoch(self):
"Store and log the loss/metric values"
if self.add_time: self.log.append(format_time(time.time() - self.start_epoch))
def _train_mets(self):
if getattr(self, 'cancel_train', False): return L()
return L(self.smooth_loss) + (self.metrics if self.train_metrics else L())
def _valid_mets(self):
if getattr(self, 'cancel_valid', False): return L()
return (L(self.loss) + self.metrics if self.valid_metrics else L())
def plot_loss(self, skip_start=5, with_valid=True):
plt.plot(list(range(skip_start, len(self.losses))), self.losses[skip_start:], label='train')
if with_valid:
idx = (np.array(self.iters)<skip_start).sum()
plt.plot(self.iters[idx:], L(self.values[idx:]).itemgot(1), label='valid')
# Cell
begin_train = "Reset loss and metrics state",
after_train = "Log loss and metric values on the training set (if `self.training_metrics=True`)",
begin_validate = "Reset loss and metrics state",
after_validate = "Log loss and metric values on the validation set",
after_cancel_train = "Ignore training metrics for this epoch",
after_cancel_validate = "Ignore validation metrics for this epoch",
plot_loss = "Plot the losses from `skip_start` and onward")
defaults.callbacks = [TrainEvalCallback, Recorder]
# Cell
class FetchPreds(Callback):
"A callback to fetch predictions during the training loop"
def __init__(self, ds_idx=1, dl=None): store_attr(self, 'ds_idx,dl')
def after_validate(self):
learn,rec = self.learn,self.learn.recorder
self.preds = learn.get_preds(ds_idx=self.ds_idx, dl=self.dl, inner=True)
learn.add_cbs([self, rec])
# Cell
def freeze_to(self:Learner, n):
if self.opt is None: self.create_opt()
def freeze(self:Learner): self.freeze_to(-1)
def unfreeze(self:Learner): self.freeze_to(0)
freeze_to="Freeze parameter groups up to `n`",
freeze="Freeze up to last parameter group",
unfreeze="Unfreeze the entire model")
# Cell
def export(self:Learner, fname='export.pkl'):
"Export the content of `self` without the items and the optimizer state for inference"
if rank_distrib(): return # don't export if slave proc
old_dbunch = self.dls
self.dls = self.dls.new_empty()
state = self.opt.state_dict()
self.opt = None
with warnings.catch_warnings():
#To avoid the warning that come from PyTorch about model not being checked
warnings.simplefilter("ignore"), self.path/fname)
self.dls = old_dbunch
# Cell
def load_learner(fname, cpu=True):
"Load a `Learner` object in `fname`, optionally putting it on the `cpu`"
res = torch.load(fname, map_location='cpu' if cpu else None)
if hasattr(res, 'to_fp32'): res = res.to_fp32()
if cpu: res.dls.cpu()
return res
# Cell
def tta(self:Learner, ds_idx=1, dl=None, n=4, item_tfms=None, batch_tfms=None, beta=0.25, use_max=False):
"Return predictions on the `ds_idx` dataset or `dl` using Test Time Augmentation"
if dl is None: dl = self.dls[ds_idx]
if item_tfms is not None or batch_tfms is not None: dl =, after_batch=batch_tfms)
with dl.dataset.set_split_idx(0), self.no_mbar():
if hasattr(self,'progress'): self.progress.mbar = master_bar(list(range(n)))
aug_preds = []
for i in self.progress.mbar if hasattr(self,'progress') else range(n):
self.epoch = i #To keep track of progress on mbar since the progress callback will use self.epoch
aug_preds.append(self.get_preds(ds_idx, inner=True)[0][None])
aug_preds =
aug_preds = aug_preds.max(0)[0] if use_max else aug_preds.mean(0)
self.epoch = n
with dl.dataset.set_split_idx(1): preds,targs = self.get_preds(ds_idx, inner=True)
if use_max: return torch.stack([preds, aug_preds], 0).max(0)[0],targs
preds = (aug_preds,preds) if beta is None else torch.lerp(aug_preds, preds, beta)
return preds,targs
You can’t perform that action at this time.