Skip to content

Commit

Permalink
Allow train_and_eval to stop and resume.
Browse files Browse the repository at this point in the history
FIx #680
Thanks for the suggestion by @NikZak
  • Loading branch information
mingxingtan committed Aug 16, 2020
1 parent 68fbf15 commit 5ba7223
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions efficientdet/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,17 +327,29 @@ def _eval(steps):
logging.info('Checkpoint %s no longer exists, skipping.', ckpt)

elif FLAGS.mode == 'train_and_eval':
ckpt = tf.train.latest_checkpoint(FLAGS.model_dir)
try:
step = int(os.path.basename(ckpt).split("-")[1])
current_epoch = (
step * FLAGS.train_batch_size // FLAGS.num_examples_per_epoch)
logging.info('found ckpt at step %d (epoch %d)', step, current_epoch)
except (IndexError, TypeError):
logging.info("Folder has no ckpt with valid step.", FLAGS.model_dir)
current_epoch = 0

epochs_per_cycle = 1 # higher number has less graph construction overhead.
for e in range(1, config.num_epochs + 1, epochs_per_cycle):
logging.info('Starting training, epoch: %d.', e)
for e in range(current_epoch + 1, config.num_epochs + 1, epochs_per_cycle):
print('-----------------------------------------------------\n'
'=====> Starting training, epoch: %d.' % e)
_train(e * FLAGS.num_examples_per_epoch // FLAGS.train_batch_size)
logging.info('Starting evaluation, epoch: %d.', e)
print('-----------------------------------------------------\n'
'=====> Starting evaluation, epoch: %d.' % e)
eval_results = _eval(eval_steps)
ckpt = tf.train.latest_checkpoint(FLAGS.model_dir)
utils.archive_ckpt(eval_results, eval_results['AP'], ckpt)

else:
logging.info('Mode not found.')
logging.info('Invalid mode: %s', FLAGS.mode)


if __name__ == '__main__':
Expand Down

0 comments on commit 5ba7223

Please sign in to comment.