Skip to content

Commit

Permalink
Add on_fit_epoch_end callback (ultralytics#5232)
Browse files Browse the repository at this point in the history
* Add `on_fit_epoch_end` callback

* Add results to train

* Update __init__.py
  • Loading branch information
glenn-jocher committed Oct 18, 2021
1 parent 724f355 commit dafbd67
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
4 changes: 3 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,8 +423,10 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
plots=True,
callbacks=callbacks,
compute_loss=compute_loss) # val best model with plots
if is_coco:
callbacks.run('on_fit_epoch_end', list(mloss) + list(results) + lr, epoch, best_fitness, fi)

callbacks.run('on_train_end', last, best, plots, epoch)
callbacks.run('on_train_end', last, best, plots, epoch, results)
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")

torch.cuda.empty_cache()
Expand Down
2 changes: 1 addition & 1 deletion utils/loggers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def on_model_save(self, last, epoch, final_epoch, best_fitness, fi):
if ((epoch + 1) % self.opt.save_period == 0 and not final_epoch) and self.opt.save_period != -1:
self.wandb.log_model(last.parent, self.opt, epoch, fi, best_model=best_fitness == fi)

def on_train_end(self, last, best, plots, epoch):
def on_train_end(self, last, best, plots, epoch, results):
# Callback runs on training end
if plots:
plot_results(file=self.save_dir / 'results.csv') # save results.png
Expand Down

0 comments on commit dafbd67

Please sign in to comment.