Skip to content

Commit

Permalink
Add metrics to recorder
Browse files Browse the repository at this point in the history
  • Loading branch information
sgugger committed Oct 11, 2018
1 parent 3e6147e commit 0681146
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 9 deletions.
18 changes: 14 additions & 4 deletions fastai/basic_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,11 @@ def validate(model:Model, dl:DataLoader, loss_fn:OptLossFunc=None,
with torch.no_grad():
val_metrics,nums = [],[]
for xb,yb in progress_bar(dl, parent=pbar, leave=(pbar is not None)):
xb, yb = cb_handler.on_batch_begin(xb, yb, train=False)
val_metrics.append(loss_batch(model, xb, yb, loss_fn, cb_handler=cb_handler, metrics=metrics))
if not is_listy(yb): yb = [yb]
nums.append(yb[0].shape[0])
if cb_handler and cb_handler.on_batch_end(val_metrics[0], train=False): break
if cb_handler and cb_handler.on_batch_end(val_metrics[0]): break
nums = np.array(nums, dtype=np.float32)
if average: return [(to_np(torch.stack(val)) * nums).sum() / nums.sum() for val in zip(*val_metrics)]
else: return val_metrics
Expand Down Expand Up @@ -194,13 +195,15 @@ def on_train_begin(self, pbar:PBar, metrics:MetricFuncList, **kwargs:Any)->None:
"Initialize recording status at beginning of training."
self.pbar = pbar
self.names = ['epoch', 'train loss', 'valid loss'] + [fn.__name__ for fn in metrics]
if hasattr(self, '_added_met_names'): self.names += self._added_met_names
self.pbar.write(' '.join(self.names))
self.losses,self.val_losses,self.lrs,self.moms,self.metrics,self.nb_batches = [],[],[],[],[],[]

def on_batch_begin(self, **kwargs:Any)->None:
def on_batch_begin(self, train, **kwargs:Any)->None:
"Record learning rate and momentum at beginning of batch."
self.lrs.append(self.opt.lr)
self.moms.append(self.opt.mom)
if train:
self.lrs.append(self.opt.lr)
self.moms.append(self.opt.mom)

def on_backward_begin(self, smooth_loss:Tensor, **kwargs:Any)->None:
"Record the loss before any other callback has a chance to modify it."
Expand All @@ -214,6 +217,7 @@ def on_epoch_end(self, epoch:int, num_batch:int, smooth_loss:Tensor,
self.nb_batches.append(num_batch)
if last_metrics is not None:
self.val_losses.append(last_metrics[0])
if hasattr(self, '_added_mets'): last_metrics += self._added_mets
if len(last_metrics) > 1: self.metrics.append(last_metrics[1:])
self.format_stats([epoch, smooth_loss] + last_metrics)
else: self.format_stats([epoch, smooth_loss])
Expand All @@ -227,7 +231,13 @@ def format_stats(self, stats:TensorOrNumList)->None:
t += ' ' * (len(name) - len(t))
str_stats.append(t)
self.pbar.write(' '.join(str_stats))

def add_metrics(self, metrics):
self._added_mets = metrics

def add_metric_names(self, names):
self._added_met_names = names

def plot_lr(self, show_moms=False)->None:
"Plot learning rate, `show_moms` to include momentum."
iterations = range_of(self.lrs)
Expand Down
8 changes: 4 additions & 4 deletions fastai/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,10 @@ def on_epoch_begin(self)->None:
self.state_dict['num_batch'] = 0
self('epoch_begin')

def on_batch_begin(self, xb:Tensor, yb:Tensor)->None:
def on_batch_begin(self, xb:Tensor, yb:Tensor, train:bool=True)->None:
"Handle new batch `xb`,`yb`."
self.state_dict['last_input'], self.state_dict['last_target'] = xb, yb
self.state_dict['train'] = train
for cb in self.callbacks:
a = cb.on_batch_begin(**self.state_dict)
if a is not None: self.state_dict['last_input'], self.state_dict['last_target'] = a
Expand Down Expand Up @@ -224,12 +225,11 @@ def on_step_end(self)->None:
"Handle end of optimization step."
self('step_end')

def on_batch_end(self, loss:Tensor, train:bool=True)->None:
def on_batch_end(self, loss:Tensor)->None:
"Handle end of processing one batch with `loss`."
self.state_dict['last_loss'] = loss
self.state_dict['train'] = train
stop = np.any(self('batch_end'))
if train:
if self.state_dict['train']:
self.state_dict['iteration'] += 1
self.state_dict['num_batch'] += 1
return stop
Expand Down
3 changes: 2 additions & 1 deletion fastai/callbacks/mixup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ class MixUpCallback(Callback):
stack_x:bool=False
stack_y:bool=True

def on_batch_begin(self, last_input, last_target, **kwargs):
def on_batch_begin(self, last_input, last_target, train, **kwargs):
if not train: return
lambd = np.random.beta(self.alpha, self.alpha, last_target.size(0))
lambd = np.concatenate([lambd[:,None], 1-lambd[:,None]], 1).max(1)
lambd = last_input.new(lambd)
Expand Down

0 comments on commit 0681146

Please sign in to comment.