Skip to content
This repository has been archived by the owner on Jun 11, 2024. It is now read-only.

Does training resume from last saved checkpoint? #21

Open
KishoreP1 opened this issue Feb 26, 2024 · 1 comment
Open

Does training resume from last saved checkpoint? #21

KishoreP1 opened this issue Feb 26, 2024 · 1 comment

Comments

@KishoreP1
Copy link

When training is interrupted and later resumed, I expect the process to restart from the last saved checkpoint iteration. However, even when specifying the same --checkpoint_dir flag, the training process restarts from iteration 0, disregarding previously completed iterations.

I tried:

  1. Start training with a specified --checkpoint_dir.
  2. Allow the training to proceed past a few iterations (e.g., 12 iterations).
  3. Interrupt the training process.
  4. Resume training with the same --checkpoint_dir flag.

I expected the training to resume from iteration 13, considering the last completed iteration was 12. However, the training restarts from iteration 1, ignoring the checkpoints saved in the specified directory.

Inside run_learner of main_loop.py, the checkpointing and iteration logging logic seems correct. However, I cannot find where the code loads the checkpoint to resume training from the last saved state.

# Start training
for iteration in range(1, num_iterations + 1):
    logging.info(f'Training iteration {iteration}')
    logging.info(f'Starting {learner.agent_name} ...')

    # Update shared iteration count.
    iteration_count.value = iteration

    # Set start training event.
    start_iteration_event.set()
    learner.reset()

    run_learner_loop(learner, data_queue, num_actors, learner_trackers)

    start_iteration_event.clear()
    checkpoint.set_iteration(iteration)
    saved_ckpt = checkpoint.save()

    if saved_ckpt:
        logging.info(f'New checkpoint created at "{saved_ckpt}"')
@michaelnny
Copy link
Owner

Hi, currently the training scripts does not support resume training. As you can see from the code, the --checkpoint_dir argument just specify the path to save model checkpoints, it will not looking for some existing checkpoint to continue training.

You should be able to adapt the code to add the logic to looking for latest model checkpoint if required, here's an example of manually loading checkpoint file in the eval_agent.py module.

    if FLAGS.load_checkpoint_file:
        checkpoint.restore(FLAGS.load_checkpoint_file)

However, keep in mind, the code will only save model state, not optimizer or the agent internal states (number of updates etc.), and also need to correctly handle logging to tensorboard or the csv files.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants