From ee2053fad8cb87a1295301de7f1c4f3ac003ccd4 Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Tue, 24 May 2022 21:21:45 +0000 Subject: [PATCH 1/6] tpu --- examples/bert/bert_finetune_glue.py | 2 + examples/bert/bert_train.py | 79 ++++++++++++++++++----------- 2 files changed, 50 insertions(+), 31 deletions(-) diff --git a/examples/bert/bert_finetune_glue.py b/examples/bert/bert_finetune_glue.py index 598a09fb3b..daccfd9d6f 100644 --- a/examples/bert/bert_finetune_glue.py +++ b/examples/bert/bert_finetune_glue.py @@ -161,12 +161,14 @@ def __init__( kernel_initializer=initializer, name="logits", ) + self._drop_out = tf.keras.layers.Dropout(0.1) def call(self, inputs): outputs = self.bert_model(inputs) # Get the first [CLS] token from each output. outputs = outputs[:, 0, :] outputs = self._pooler_layer(outputs) + outputs = self._drop_out(outputs) return self._logit_layer(outputs) diff --git a/examples/bert/bert_train.py b/examples/bert/bert_train.py index 1029ae077c..5adcce7b13 100644 --- a/examples/bert/bert_train.py +++ b/examples/bert/bert_train.py @@ -16,6 +16,8 @@ import shutil import sys +import os + import tensorflow as tf from absl import app from absl import flags @@ -339,6 +341,13 @@ def __call__(self, step): 0.0, peak_lr * (training - step) / (training - warmup) ), ) + + def get_config(self): + return { + "learning_rate": self.learning_rate, + "num_warmup_steps": self.warmup_steps, + "num_train_steps": self.train_steps, + } def get_config(self): return { @@ -386,6 +395,12 @@ def main(_): model_config = MODEL_CONFIGS[FLAGS.model_size] + resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='local') + tf.config.experimental_connect_to_cluster(resolver) + # This is the TPU initialization code that has to be at the beginning. + tf.tpu.experimental.initialize_tpu_system(resolver) + strategy = tf.distribute.TPUStrategy(resolver) + # Decode and batch data. dataset = tf.data.TFRecordDataset(input_filenames) dataset = dataset.map( @@ -395,33 +410,41 @@ def main(_): dataset = dataset.batch(TRAINING_CONFIG["batch_size"], drop_remainder=True) dataset = dataset.repeat() - # Create a BERT model the input config. - model = BertModel( - vocab_size=len(vocab), - **model_config, - ) - # Make sure model has been called. - model(model.inputs) - model.summary() - - # Allow overriding train steps from the command line for quick testing. - if FLAGS.num_train_steps is not None: - num_train_steps = FLAGS.num_train_steps - else: - num_train_steps = TRAINING_CONFIG["num_train_steps"] - num_warmup_steps = int( - num_train_steps * TRAINING_CONFIG["warmup_percentage"] - ) - epochs = TRAINING_CONFIG["epochs"] - steps_per_epoch = num_train_steps // epochs + with strategy.scope(): + # Create a BERT model the input config. + model = BertModel( + vocab_size=len(vocab), + **model_config, + ) + # Make sure model has been called. + model(model.inputs) + model.summary() + + # Allow overriding train steps from the command line for quick testing. + if FLAGS.num_train_steps is not None: + num_train_steps = FLAGS.num_train_steps + else: + num_train_steps = TRAINING_CONFIG["num_train_steps"] + num_warmup_steps = int( + num_train_steps * TRAINING_CONFIG["warmup_percentage"] + ) + learning_rate_schedule = LinearDecayWithWarmup( + learning_rate=TRAINING_CONFIG["learning_rate"], + num_warmup_steps=num_warmup_steps, + num_train_steps=num_train_steps, + ) + optimizer = keras.optimizers.Adam(learning_rate=learning_rate_schedule) + + pretraining_model = BertPretrainer(model) + pretraining_model.compile( + optimizer=optimizer, + ) - learning_rate_schedule = LinearDecayWithWarmup( - learning_rate=TRAINING_CONFIG["learning_rate"], - num_warmup_steps=num_warmup_steps, - num_train_steps=num_train_steps, - ) - optimizer = keras.optimizers.Adam(learning_rate=learning_rate_schedule) + + epochs = TRAINING_CONFIG["epochs"] + steps_per_epoch = num_train_steps // epochs + callbacks = [] if FLAGS.checkpoint_save_directory is not None: if os.path.exists(FLAGS.checkpoint_save_directory): @@ -440,12 +463,6 @@ def main(_): tf.keras.callbacks.BackupAndRestore(backup_dir=checkpoint_path) ) - # Wrap with pretraining heads and call fit. - pretraining_model = BertPretrainer(model) - pretraining_model.compile( - optimizer=optimizer, - ) - # TODO(mattdangerw): Add TPU strategy support. pretraining_model.fit( dataset, From 7ccf2e047495c2c0bef3406ce5f5070e0632c7a1 Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Wed, 25 May 2022 19:43:34 +0000 Subject: [PATCH 2/6] Add TPU support --- examples/bert/bert_train.py | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/examples/bert/bert_train.py b/examples/bert/bert_train.py index 5adcce7b13..c5f970289e 100644 --- a/examples/bert/bert_train.py +++ b/examples/bert/bert_train.py @@ -55,6 +55,12 @@ "Skip restoring from checkpoint if True", ) +flags.DEFINE_bool( + "use_tpu", + False, + "Use TPU for training if True", +) + flags.DEFINE_string( "model_size", @@ -395,11 +401,19 @@ def main(_): model_config = MODEL_CONFIGS[FLAGS.model_size] - resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='local') - tf.config.experimental_connect_to_cluster(resolver) - # This is the TPU initialization code that has to be at the beginning. - tf.tpu.experimental.initialize_tpu_system(resolver) - strategy = tf.distribute.TPUStrategy(resolver) + if FLAGS.use_tpu: + if not tf.config.list_logical_devices("TPU"): + raise RuntimeError("`use_tpu` is set to True while no TPU is found. " + "Please either set `use_tpu` as False or check if TPU is available.") + + resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='local') + tf.config.experimental_connect_to_cluster(resolver) + # This is the TPU initialization code that has to be at the beginning. + tf.tpu.experimental.initialize_tpu_system(resolver) + strategy = tf.distribute.TPUStrategy(resolver) + else: + # Use default strategy if not using TPU. + strategy = tf.distribute.get_strategy() # Decode and batch data. dataset = tf.data.TFRecordDataset(input_filenames) @@ -440,8 +454,6 @@ def main(_): optimizer=optimizer, ) - - epochs = TRAINING_CONFIG["epochs"] steps_per_epoch = num_train_steps // epochs @@ -463,12 +475,11 @@ def main(_): tf.keras.callbacks.BackupAndRestore(backup_dir=checkpoint_path) ) - # TODO(mattdangerw): Add TPU strategy support. pretraining_model.fit( dataset, epochs=epochs, steps_per_epoch=steps_per_epoch, - callbacks=callbacks, + callbacks=[callbacks], ) print(f"Saving to {FLAGS.saved_model_output}") From 9e11f8ae24bf7a037a476db315b576b7aa68be49 Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Wed, 25 May 2022 13:50:14 -0700 Subject: [PATCH 3/6] Style fix --- examples/bert/bert_train.py | 33 +++++++++++---------------------- 1 file changed, 11 insertions(+), 22 deletions(-) diff --git a/examples/bert/bert_train.py b/examples/bert/bert_train.py index c5f970289e..22ddec43e0 100644 --- a/examples/bert/bert_train.py +++ b/examples/bert/bert_train.py @@ -16,8 +16,6 @@ import shutil import sys -import os - import tensorflow as tf from absl import app from absl import flags @@ -55,12 +53,7 @@ "Skip restoring from checkpoint if True", ) -flags.DEFINE_bool( - "use_tpu", - False, - "Use TPU for training if True", -) - +flags.DEFINE_bool("use_tpu", False, "Use TPU for training if True") flags.DEFINE_string( "model_size", @@ -347,13 +340,6 @@ def __call__(self, step): 0.0, peak_lr * (training - step) / (training - warmup) ), ) - - def get_config(self): - return { - "learning_rate": self.learning_rate, - "num_warmup_steps": self.warmup_steps, - "num_train_steps": self.train_steps, - } def get_config(self): return { @@ -403,16 +389,19 @@ def main(_): if FLAGS.use_tpu: if not tf.config.list_logical_devices("TPU"): - raise RuntimeError("`use_tpu` is set to True while no TPU is found. " - "Please either set `use_tpu` as False or check if TPU is available.") - - resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='local') + raise RuntimeError( + "`use_tpu` is set to True while no TPU is found. Please " + "check if your machine has TPU or set `use_tpu` as False." + ) + # Connect to TPU and create TPU strategy. + resolver = tf.distribute.cluster_resolver.TPUClusterResolver( + tpu="local" + ) tf.config.experimental_connect_to_cluster(resolver) - # This is the TPU initialization code that has to be at the beginning. tf.tpu.experimental.initialize_tpu_system(resolver) strategy = tf.distribute.TPUStrategy(resolver) else: - # Use default strategy if not using TPU. + # Use default strategy if not using TPU. strategy = tf.distribute.get_strategy() # Decode and batch data. @@ -456,7 +445,7 @@ def main(_): epochs = TRAINING_CONFIG["epochs"] steps_per_epoch = num_train_steps // epochs - + callbacks = [] if FLAGS.checkpoint_save_directory is not None: if os.path.exists(FLAGS.checkpoint_save_directory): From df8a22b9f7adb94ee4516d70481db4216b6745b4 Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Wed, 25 May 2022 13:53:39 -0700 Subject: [PATCH 4/6] small fix --- examples/bert/bert_finetune_glue.py | 2 -- examples/bert/bert_train.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/bert/bert_finetune_glue.py b/examples/bert/bert_finetune_glue.py index daccfd9d6f..598a09fb3b 100644 --- a/examples/bert/bert_finetune_glue.py +++ b/examples/bert/bert_finetune_glue.py @@ -161,14 +161,12 @@ def __init__( kernel_initializer=initializer, name="logits", ) - self._drop_out = tf.keras.layers.Dropout(0.1) def call(self, inputs): outputs = self.bert_model(inputs) # Get the first [CLS] token from each output. outputs = outputs[:, 0, :] outputs = self._pooler_layer(outputs) - outputs = self._drop_out(outputs) return self._logit_layer(outputs) diff --git a/examples/bert/bert_train.py b/examples/bert/bert_train.py index 22ddec43e0..b7d28b786d 100644 --- a/examples/bert/bert_train.py +++ b/examples/bert/bert_train.py @@ -468,7 +468,7 @@ def main(_): dataset, epochs=epochs, steps_per_epoch=steps_per_epoch, - callbacks=[callbacks], + callbacks=callbacks, ) print(f"Saving to {FLAGS.saved_model_output}") From 1c684bbbd83b789e6699da7c5d3370b83a3c4fc6 Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Fri, 27 May 2022 13:49:00 -0700 Subject: [PATCH 5/6] remove tpu flag --- examples/bert/bert_train.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/examples/bert/bert_train.py b/examples/bert/bert_train.py index b7d28b786d..5246743bd3 100644 --- a/examples/bert/bert_train.py +++ b/examples/bert/bert_train.py @@ -53,8 +53,6 @@ "Skip restoring from checkpoint if True", ) -flags.DEFINE_bool("use_tpu", False, "Use TPU for training if True") - flags.DEFINE_string( "model_size", "tiny", @@ -387,12 +385,7 @@ def main(_): model_config = MODEL_CONFIGS[FLAGS.model_size] - if FLAGS.use_tpu: - if not tf.config.list_logical_devices("TPU"): - raise RuntimeError( - "`use_tpu` is set to True while no TPU is found. Please " - "check if your machine has TPU or set `use_tpu` as False." - ) + if tf.config.list_logical_devices("TPU"): # Connect to TPU and create TPU strategy. resolver = tf.distribute.cluster_resolver.TPUClusterResolver( tpu="local" From a4441e73b563ce38daff10732b8903e82845d8ec Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Fri, 27 May 2022 14:26:22 -0700 Subject: [PATCH 6/6] nit --- examples/bert/bert_train.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/bert/bert_train.py b/examples/bert/bert_train.py index 5246743bd3..77aca5911e 100644 --- a/examples/bert/bert_train.py +++ b/examples/bert/bert_train.py @@ -387,11 +387,9 @@ def main(_): if tf.config.list_logical_devices("TPU"): # Connect to TPU and create TPU strategy. - resolver = tf.distribute.cluster_resolver.TPUClusterResolver( + resolver = tf.distribute.cluster_resolver.TPUClusterResolver.connect( tpu="local" ) - tf.config.experimental_connect_to_cluster(resolver) - tf.tpu.experimental.initialize_tpu_system(resolver) strategy = tf.distribute.TPUStrategy(resolver) else: # Use default strategy if not using TPU.