diff --git a/examples/bert/bert_train.py b/examples/bert/bert_train.py index ce831a303d..1029ae077c 100644 --- a/examples/bert/bert_train.py +++ b/examples/bert/bert_train.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +import shutil import sys import tensorflow as tf @@ -39,6 +41,19 @@ "Output directory to save the model to.", ) +flags.DEFINE_string( + "checkpoint_save_directory", + None, + "Output directory to save checkpoints to.", +) + +flags.DEFINE_bool( + "skip_restore", + False, + "Skip restoring from checkpoint if True", +) + + flags.DEFINE_string( "model_size", "tiny", @@ -325,6 +340,13 @@ def __call__(self, step): ), ) + def get_config(self): + return { + "learning_rate": self.learning_rate, + "num_warmup_steps": self.warmup_steps, + "num_train_steps": self.train_steps, + } + def decode_record(record): """Decodes a record to a TensorFlow example.""" @@ -398,15 +420,38 @@ def main(_): num_warmup_steps=num_warmup_steps, num_train_steps=num_train_steps, ) + optimizer = keras.optimizers.Adam(learning_rate=learning_rate_schedule) + + callbacks = [] + if FLAGS.checkpoint_save_directory is not None: + if os.path.exists(FLAGS.checkpoint_save_directory): + if not os.path.isdir(FLAGS.checkpoint_save_directory): + raise ValueError( + "`checkpoint_save_directory` should be a directory, but " + f"{FLAGS.checkpoint_save_directory} is not a directory." + " Please set `checkpoint_save_directory` as a directory." + ) + + elif FLAGS.skip_restore: + # Clear up the directory if users want to skip restoring. + shutil.rmtree(FLAGS.checkpoint_save_directory) + checkpoint_path = FLAGS.checkpoint_save_directory + "/checkpoint" + callbacks.append( + tf.keras.callbacks.BackupAndRestore(backup_dir=checkpoint_path) + ) # Wrap with pretraining heads and call fit. pretraining_model = BertPretrainer(model) pretraining_model.compile( - optimizer=keras.optimizers.Adam(learning_rate=learning_rate_schedule) + optimizer=optimizer, ) + # TODO(mattdangerw): Add TPU strategy support. pretraining_model.fit( - dataset, epochs=epochs, steps_per_epoch=steps_per_epoch + dataset, + epochs=epochs, + steps_per_epoch=steps_per_epoch, + callbacks=callbacks, ) print(f"Saving to {FLAGS.saved_model_output}")