diff --git a/examples/bert/bert_train.py b/examples/bert/bert_train.py index 1029ae077c..77aca5911e 100644 --- a/examples/bert/bert_train.py +++ b/examples/bert/bert_train.py @@ -53,7 +53,6 @@ "Skip restoring from checkpoint if True", ) - flags.DEFINE_string( "model_size", "tiny", @@ -386,6 +385,16 @@ def main(_): model_config = MODEL_CONFIGS[FLAGS.model_size] + if tf.config.list_logical_devices("TPU"): + # Connect to TPU and create TPU strategy. + resolver = tf.distribute.cluster_resolver.TPUClusterResolver.connect( + tpu="local" + ) + strategy = tf.distribute.TPUStrategy(resolver) + else: + # Use default strategy if not using TPU. + strategy = tf.distribute.get_strategy() + # Decode and batch data. dataset = tf.data.TFRecordDataset(input_filenames) dataset = dataset.map( @@ -395,33 +404,39 @@ def main(_): dataset = dataset.batch(TRAINING_CONFIG["batch_size"], drop_remainder=True) dataset = dataset.repeat() - # Create a BERT model the input config. - model = BertModel( - vocab_size=len(vocab), - **model_config, - ) - # Make sure model has been called. - model(model.inputs) - model.summary() + with strategy.scope(): + # Create a BERT model the input config. + model = BertModel( + vocab_size=len(vocab), + **model_config, + ) + # Make sure model has been called. + model(model.inputs) + model.summary() + + # Allow overriding train steps from the command line for quick testing. + if FLAGS.num_train_steps is not None: + num_train_steps = FLAGS.num_train_steps + else: + num_train_steps = TRAINING_CONFIG["num_train_steps"] + num_warmup_steps = int( + num_train_steps * TRAINING_CONFIG["warmup_percentage"] + ) + learning_rate_schedule = LinearDecayWithWarmup( + learning_rate=TRAINING_CONFIG["learning_rate"], + num_warmup_steps=num_warmup_steps, + num_train_steps=num_train_steps, + ) + optimizer = keras.optimizers.Adam(learning_rate=learning_rate_schedule) + + pretraining_model = BertPretrainer(model) + pretraining_model.compile( + optimizer=optimizer, + ) - # Allow overriding train steps from the command line for quick testing. - if FLAGS.num_train_steps is not None: - num_train_steps = FLAGS.num_train_steps - else: - num_train_steps = TRAINING_CONFIG["num_train_steps"] - num_warmup_steps = int( - num_train_steps * TRAINING_CONFIG["warmup_percentage"] - ) epochs = TRAINING_CONFIG["epochs"] steps_per_epoch = num_train_steps // epochs - learning_rate_schedule = LinearDecayWithWarmup( - learning_rate=TRAINING_CONFIG["learning_rate"], - num_warmup_steps=num_warmup_steps, - num_train_steps=num_train_steps, - ) - optimizer = keras.optimizers.Adam(learning_rate=learning_rate_schedule) - callbacks = [] if FLAGS.checkpoint_save_directory is not None: if os.path.exists(FLAGS.checkpoint_save_directory): @@ -440,13 +455,6 @@ def main(_): tf.keras.callbacks.BackupAndRestore(backup_dir=checkpoint_path) ) - # Wrap with pretraining heads and call fit. - pretraining_model = BertPretrainer(model) - pretraining_model.compile( - optimizer=optimizer, - ) - - # TODO(mattdangerw): Add TPU strategy support. pretraining_model.fit( dataset, epochs=epochs,