-
Notifications
You must be signed in to change notification settings - Fork 301
Add checkpoints to BERT training #184
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add checkpoints to BERT training #184
Conversation
561928d
to
08fa570
Compare
examples/bert/run_pretraining.py
Outdated
) | ||
optimizer = keras.optimizers.Adam(learning_rate=learning_rate_schedule) | ||
|
||
if FLAGS.restore_from_checkpoint is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need to make this work so it restores after a failure in a automatic fashion.
If the checkpoint save path is set, we should save a checkpoint after each epoch (is this currently saving best checkpoint or one for each epoch?).
If the script is terminated and re-run with the same arguments, we automatically pick up where we left off. And then we should maybe add a skip_restore
flag to avoid this behavior, and always start from scratch, that default to false.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, it is very important that training should be fully autonomous: there will be failures, and we do NOT want to have to rerun the script with different manually-specified arguments every time there is a failure. The restart should be automated, and should resume from the latest saved state (both model wise and data pipeline wise) and that state should be retrievable without modifying any command line argument.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated this PR, please take another look, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The file loading logic looks a little fragile to me. Have you took a look at these:
- https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/BackupAndRestore
- https://www.tensorflow.org/api_docs/python/tf/train/CheckpointManager (lower level)
We should make sure we are modeling the best practices for doing this with the example we have here. Seems like there should be a better flow than this one here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed!
7bc9424
to
e348fd8
Compare
Running just the testing code snippet we have in the README, I get the following, is this expected?
|
examples/bert/run_pretraining.py
Outdated
|
||
callbacks = [] | ||
if FLAGS.checkpoint_save_path is not None: | ||
checkpoint_path = FLAGS.checkpoint_save_path + "/checkpoint_{epoch:2d}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you tested this out? Just looking through the backup and restore code this dir, it looks like this directory get's joined with other directory names and passed directly to tf.train.CheckpointManager
, which I don't think supports this {epoch:2d}
format you are using?
Make sure you actually ls
the directory you are saving too when testing this out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea, this was a legacy code. this style works well with the model checkpoint callback, but not the backupandrestore. I will update the code.
examples/bert/run_pretraining.py
Outdated
) | ||
optimizer = keras.optimizers.Adam(learning_rate=learning_rate_schedule) | ||
|
||
if FLAGS.skip_restore: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we want skip_restore
to just clear the directory without a warning?
Otherwise I don't totally get this. If you point checkpoint_save_path to an empty directory, doesn't this flag do nothing?
Either
- directory is empty, flag does nothing
- directory is not empty, flag causes an error
What is the use case?
If we clear the directory if skip_restore
I understand the use case better--you can re-run the script with the same training args, but make sure you restart from scratch each time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sg, I am clearing out the directory if skip_restored is set to True
examples/bert/run_pretraining.py
Outdated
) | ||
|
||
flags.DEFINE_string( | ||
"checkpoint_save_path", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
checkpoint_directory? to give hint this should be a directory not a filepath?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea, I updated the code to checkpoint_save_dir
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use the whole word directory
to agree with arg naming in create_sentence_split_data
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
examples/bert/run_pretraining.py
Outdated
) | ||
optimizer = keras.optimizers.Adam(learning_rate=learning_rate_schedule) | ||
|
||
if FLAGS.skip_restore and len(os.listdir(FLAGS.checkpoint_save_dir)) > 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This hits some weird undefined behavior if skip_restore=True
and checkpoint_save_dir=None
. Maybe something like
if FLAGS.skip_restore and FLAGS.checkpoint_save_dir is not None:
if os.path.exists(FLAGS.checkpoint_save_dir):
if os.path.isdir(FLAGS.checkpoint_save_dir):
rmdir...
else:
raise error should be directory
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we should check if it is a dir before deleting it, but we probably don't want to error out when it's not a directory? If it's not a directory they can still write to the path. My concern is that we are exposing a strange logic - if you want to skip restoring, you have to make sure that checkpoint_save_directory
either points to nothing or a directory.
Updated the code with directory check, and we can discuss on if error is needed here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure the use case you are trying to protect? There is never a way a user can write to a file.
In the version I suggested, if checkpoint dir is a file, you get an error right away. In the version you pushed, you just get the error after the first epoch from deeper withing tensorflow tensorflow.python.framework.errors_impl.FailedPreconditionError: /home/matt/bert_test_output/myfile.ckpt is not a directory
.
Thinking about it more, it is probably just be nicer to give that error regardless.
if FLAGS.checkpoint_save_dir is not None and os.path.exists(FLAGS.checkpoint_save_dir):
if not os.path.isdir(FLAGS.checkpoint_save_dir):
raise error should be directory
elif FLAGS.skip_restore:
rmdir
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made a mistake - I thought it is okay to have a directory and file having the same name under the same directory, but apparently it is not allowed.
Code updated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm thanks!
a27cd57
to
f744f5a
Compare
get_config
.