diff --git a/README.md b/README.md index 8ec95c875..62441a04b 100644 --- a/README.md +++ b/README.md @@ -836,6 +836,121 @@ Note that since our `sample_text.txt` file is very small, this example training will overfit that data in only a few steps and produce unrealistically high accuracy numbers. +Many people have asked how to report the loss during pre-training. Here is +how you do it: + +```shell +python run_pretraining.py \ + --input_file=/tmp/tf_examples.tfrecord \ + --output_dir=/tmp/pretraining_output \ + --do_train=True \ + --do_eval=True \ + --bert_config_file=$BERT_BASE_DIR/bert_config.json \ + --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \ + --train_batch_size=32 \ + --max_seq_length=128 \ + --max_predictions_per_seq=20 \ + --num_train_steps=20 \ + --num_warmup_steps=10 \ + --learning_rate=2e-5 \ + --report_loss +``` + +This will produce the following output during training: + +```shell +Step samples/sec Loss Learning-rate + 100 122.9 9.019 3.9200e-06 + 200 174.9 8.255 7.9200e-06 + 300 174.8 7.962 1.1920e-05 +``` + +Here is how to run the pre-training with FP16 arithmetic on GPUs. Doing this +triples throughput on most GPUs. + +```shell +python run_pretraining.py \ + --input_file=/tmp/tf_examples.tfrecord \ + --output_dir=/tmp/pretraining_output \ + --do_train=True \ + --do_eval=True \ + --bert_config_file=$BERT_BASE_DIR/bert_config.json \ + --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \ + --train_batch_size=32 \ + --max_seq_length=128 \ + --max_predictions_per_seq=20 \ + --num_train_steps=20 \ + --num_warmup_steps=10 \ + --learning_rate=2e-5 \ + --use_fp16 +``` + +Here is how to enable XLA JIT compilation for GPUs. Doing this boosts +throughput by 1.3x for FP32 and 1.7x for FP16 arithmetic. + +```shell + +python run_pretraining.py \ + --input_file=/tmp/tf_examples.tfrecord \ + --output_dir=/tmp/pretraining_output \ + --do_train=True \ + --do_eval=True \ + --bert_config_file=$BERT_BASE_DIR/bert_config.json \ + --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \ + --train_batch_size=32 \ + --max_seq_length=128 \ + --max_predictions_per_seq=20 \ + --num_train_steps=20 \ + --num_warmup_steps=10 \ + --learning_rate=2e-5 \ + --use_xla +``` + +This version of BERT supports pre-training on multiple GPUs. You need Horovod +installed for this. You also need to split your input dataset over multiple +files (at least one per GPU). Assuming you have split your input dataset +over 8 files named tf_examples.part01.tfrecord through +tf_examples.part.08.tfrecord, here is how you run it: + +```shell +mpiexec --allow-run-as-root --bind-to socket -np 8 python run_pretraining.py \ + --input_file=/tmp/tf_examples.part01.tfrecord,/tmp/tf_examples.part02.tfrecord,/tmp/tf_examples.part03.tfrecord,/tmp/tf_examples.part04.tfrecord,/tmp/tf_examples.part05.tfrecord,/tmp/tf_examples.part06.tfrecord,/tmp/tf_examples.part07.tfrecord,/tmp/tf_examples.part08.tfrecord \ + --output_dir=/tmp/pretraining_output \ + --do_train=True \ + --do_eval=True \ + --bert_config_file=$BERT_BASE_DIR/bert_config.json \ + --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \ + --train_batch_size=32 \ + --max_seq_length=128 \ + --max_predictions_per_seq=20 \ + --num_train_steps=20 \ + --num_warmup_steps=10 \ + --learning_rate=2e-5 \ + --horovod +``` + +You can combine --report_loss, --use_fp16, --use_xla and --horovod: + +```shell +mpiexec --allow-run-as-root --bind-to socket -np 8 python run_pretraining.py \ + --input_file=/tmp/tf_examples.part01.tfrecord,/tmp/tf_examples.part02.tfrecord,/tmp/tf_examples.part03.tfrecord,/tmp/tf_examples.part04.tfrecord,/tmp/tf_examples.part05.tfrecord,/tmp/tf_examples.part06.tfrecord,/tmp/tf_examples.part07.tfrecord,/tmp/tf_examples.part08.tfrecord \ + --output_dir=/tmp/pretraining_output \ + --do_train=True \ + --do_eval=True \ + --bert_config_file=$BERT_BASE_DIR/bert_config.json \ + --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \ + --train_batch_size=32 \ + --max_seq_length=128 \ + --max_predictions_per_seq=20 \ + --num_train_steps=20 \ + --num_warmup_steps=10 \ + --learning_rate=2e-5 \ + --report_loss \ + --use_fp16 \ + --use_xla \ + --horovod +``` + ### Pre-training tips and caveats * **If using your own vocabulary, make sure to change `vocab_size` in diff --git a/gpu_environment.py b/gpu_environment.py new file mode 100644 index 000000000..948c3fa44 --- /dev/null +++ b/gpu_environment.py @@ -0,0 +1,36 @@ +# coding=utf-8 +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +import numpy as np + +def float32_variable_storage_getter(getter, name, shape=None, dtype=None, + initializer=None, regularizer=None, + trainable=True, + *args, **kwargs): + """Custom variable getter that forces trainable variables to be stored in + float32 precision and then casts them to the training precision. + """ + storage_dtype = tf.float32 if trainable else dtype + variable = getter(name, shape, dtype=storage_dtype, + initializer=initializer, regularizer=regularizer, + trainable=trainable, + *args, **kwargs) + if trainable and dtype != tf.float32: + variable = tf.cast(variable, dtype) + return variable + +def get_custom_getter(compute_type): + return float32_variable_storage_getter if compute_type == tf.float16 else None diff --git a/modeling.py b/modeling.py index fed525971..45cc67393 100644 --- a/modeling.py +++ b/modeling.py @@ -27,6 +27,7 @@ import six import tensorflow as tf +from gpu_environment import get_custom_getter class BertConfig(object): """Configuration for `BertModel`.""" @@ -135,7 +136,8 @@ def __init__(self, input_mask=None, token_type_ids=None, use_one_hot_embeddings=False, - scope=None): + scope=None, + compute_type=tf.float32): """Constructor for BertModel. Args: @@ -168,7 +170,7 @@ def __init__(self, if token_type_ids is None: token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32) - with tf.variable_scope(scope, default_name="bert"): + with tf.variable_scope(scope, default_name="bert", custom_getter=get_custom_getter(compute_type)): with tf.variable_scope("embeddings"): # Perform embedding lookup on the word ids. (self.embedding_output, self.embedding_table) = embedding_lookup( @@ -203,7 +205,7 @@ def __init__(self, # Run the stacked transformer. # `sequence_output` shape = [batch_size, seq_length, hidden_size]. self.all_encoder_layers = transformer_model( - input_tensor=self.embedding_output, + input_tensor=tf.saturate_cast(self.embedding_output, compute_type), attention_mask=attention_mask, hidden_size=config.hidden_size, num_hidden_layers=config.num_hidden_layers, @@ -215,7 +217,7 @@ def __init__(self, initializer_range=config.initializer_range, do_return_all_layers=True) - self.sequence_output = self.all_encoder_layers[-1] + self.sequence_output = tf.cast(self.all_encoder_layers[-1], tf.float32) # The "pooler" converts the encoded sequence tensor of shape # [batch_size, seq_length, hidden_size] to a tensor of shape # [batch_size, hidden_size]. This is necessary for segment-level @@ -709,7 +711,7 @@ def transpose_for_scores(input_tensor, batch_size, num_attention_heads, # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for # positions we want to attend and -10000.0 for masked positions. - adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0 + adder = (1.0 - tf.cast(attention_mask, attention_scores.dtype)) * -10000.0 # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. diff --git a/optimization.py b/optimization.py index d33dabd91..80ed30d75 100644 --- a/optimization.py +++ b/optimization.py @@ -22,7 +22,7 @@ import tensorflow as tf -def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu): +def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu, hvd=None, use_fp16=False): """Creates an optimizer training op.""" global_step = tf.train.get_or_create_global_step() @@ -66,20 +66,38 @@ def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu): if use_tpu: optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) + else: + if hvd is not None: + from horovod.tensorflow.compression import Compression + optimizer = hvd.DistributedOptimizer(optimizer, sparse_as_dense=True, compression=Compression.fp16) + if use_fp16: + loss_scale_manager = tf.contrib.mixed_precision.ExponentialUpdateLossScaleManager(init_loss_scale=2**32, incr_every_n_steps=1000, decr_every_n_nan_or_inf=2, decr_ratio=0.5) + optimizer = tf.contrib.mixed_precision.LossScaleOptimizer(optimizer, loss_scale_manager) tvars = tf.trainable_variables() - grads = tf.gradients(loss, tvars) + grads_and_vars = optimizer.compute_gradients(loss, tvars) + grads_and_vars = [(g,v) for g,v in grads_and_vars if g is not None] + grads, tvars = list(zip(*grads_and_vars)) + all_are_finite = tf.reduce_all([tf.reduce_all(tf.is_finite(g)) for g in grads]) if use_fp16 else tf.constant(True, dtype=tf.bool) # This is how the model was pre-trained. - (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) + # ensure global norm is a finite number + # to prevent clip_by_global_norm from having a hizzy fit. + (clipped_grads, _) = tf.clip_by_global_norm( + grads, clip_norm=1.0, + use_norm=tf.cond( + all_are_finite, + lambda: tf.global_norm(grads), + lambda: tf.constant(1.0))) train_op = optimizer.apply_gradients( - zip(grads, tvars), global_step=global_step) + list(zip(clipped_grads, tvars)), global_step=global_step) # Normally the global step update is done inside of `apply_gradients`. # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use # a different optimizer, you should probably take this line out. - new_global_step = global_step + 1 + new_global_step = tf.cond(all_are_finite, lambda: global_step+1, lambda: global_step) + new_global_step = tf.identity(new_global_step, name='step_update') train_op = tf.group(train_op, [global_step.assign(new_global_step)]) return train_op @@ -98,7 +116,7 @@ def __init__(self, """Constructs a AdamWeightDecayOptimizer.""" super(AdamWeightDecayOptimizer, self).__init__(False, name) - self.learning_rate = learning_rate + self.learning_rate = tf.identity(learning_rate, name='learning_rate') self.weight_decay_rate = weight_decay_rate self.beta_1 = beta_1 self.beta_2 = beta_2 diff --git a/run_classifier.py b/run_classifier.py index 817b14720..812e92f82 100644 --- a/run_classifier.py +++ b/run_classifier.py @@ -123,6 +123,10 @@ "num_tpu_cores", 8, "Only used if `use_tpu` is True. Total number of TPU cores to use.") +flags.DEFINE_bool("use_fp16", False, "Whether to use fp32 or fp16 arithmetic on GPU.") + +flags.DEFINE_bool("use_xla", False, "Whether to enable XLA JIT compilation.") + class InputExample(object): """A single training/test example for simple sequence classification.""" @@ -580,7 +584,8 @@ def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, input_ids=input_ids, input_mask=input_mask, token_type_ids=segment_ids, - use_one_hot_embeddings=use_one_hot_embeddings) + use_one_hot_embeddings=use_one_hot_embeddings, + compute_type=tf.float16 if FLAGS.use_fp16 else tf.float32) # In the demo, we are doing a simple classification task on the entire # segment. @@ -672,7 +677,8 @@ def tpu_scaffold(): if mode == tf.estimator.ModeKeys.TRAIN: train_op = optimization.create_optimizer( - total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) + total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu, + None, FLAGS.use_fp16) output_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, @@ -824,11 +830,15 @@ def main(_): tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) + config = tf.ConfigProto() + if FLAGS.use_xla: + config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1 is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 run_config = tf.contrib.tpu.RunConfig( cluster=tpu_cluster_resolver, master=FLAGS.master, model_dir=FLAGS.output_dir, + session_config=config, save_checkpoints_steps=FLAGS.save_checkpoints_steps, tpu_config=tf.contrib.tpu.TPUConfig( iterations_per_loop=FLAGS.iterations_per_loop, diff --git a/run_pretraining.py b/run_pretraining.py index b118f62a3..9c7b70e13 100644 --- a/run_pretraining.py +++ b/run_pretraining.py @@ -19,6 +19,7 @@ from __future__ import print_function import os +import time import modeling import optimization import tensorflow as tf @@ -105,10 +106,70 @@ "num_tpu_cores", 8, "Only used if `use_tpu` is True. Total number of TPU cores to use.") +flags.DEFINE_bool("horovod", False, "Whether to use Horovod for multi-gpu runs") + +flags.DEFINE_bool("report_loss", False, "Whether to report total loss during training.") + +flags.DEFINE_bool("use_fp16", False, "Whether to use fp32 or fp16 arithmetic on GPU.") + +flags.DEFINE_bool("use_xla", False, "Whether to enable XLA JIT compilation.") + +# report samples/sec, total loss and learning rate during training +class _LogSessionRunHook(tf.train.SessionRunHook): + def __init__(self, global_batch_size, display_every=10, hvd_rank=-1): + self.global_batch_size = global_batch_size + self.display_every = display_every + self.hvd_rank = hvd_rank + def after_create_session(self, session, coord): + if FLAGS.use_fp16: + print(' Step samples/sec MLM Loss NSP Loss Loss Learning-rate Loss-scaler') + else: + print(' Step samples/sec MLM Loss NSP Loss Loss Learning-rate') + self.elapsed_secs = 0. + self.count = 0 + def before_run(self, run_context): + self.t0 = time.time() + if FLAGS.use_fp16: + return tf.train.SessionRunArgs( + fetches=['step_update:0', 'total_loss:0', + 'learning_rate:0', 'nsp_loss:0', + 'mlm_loss:0', 'loss_scale:0']) + else: + return tf.train.SessionRunArgs( + fetches=['step_update:0', 'total_loss:0', + 'learning_rate:0', 'nsp_loss:0', + 'mlm_loss:0']) + def after_run(self, run_context, run_values): + self.elapsed_secs += time.time() - self.t0 + self.count += 1 + if FLAGS.use_fp16: + global_step, total_loss, lr, nsp_loss, mlm_loss, loss_scaler = run_values.results + else: + global_step, total_loss, lr, nsp_loss, mlm_loss = run_values.results + print_step = global_step + 1 # One-based index for printing. + if print_step == 1 or print_step % self.display_every == 0: + dt = self.elapsed_secs / self.count + img_per_sec = self.global_batch_size / dt + if self.hvd_rank >= 0: + if FLAGS.use_fp16: + print('%2d :: %6i %11.1f %10.4e %10.4e %6.3f %6.4e %6.4e' % + (self.hvd_rank, print_step, img_per_sec, mlm_loss, nsp_loss, total_loss, lr, loss_scaler)) + else: + print('%2d :: %6i %11.1f %10.4e %10.4e %6.3f %6.4e' % + (self.hvd_rank, print_step, img_per_sec, mlm_loss, nsp_loss, total_loss, lr)) + else: + if FLAGS.use_fp16: + print('%6i %11.1f %10.4e %10.4e %6.3f %6.4e %6.4e' % + (print_step, img_per_sec, mlm_loss, nsp_loss, total_loss, lr, loss_scaler)) + else: + print('%6i %11.1f %10.4e %10.4e %6.3f %6.4e' % + (print_step, img_per_sec, mlm_loss, nsp_loss, total_loss, lr)) + self.elapsed_secs = 0. + self.count = 0 def model_fn_builder(bert_config, init_checkpoint, learning_rate, num_train_steps, num_warmup_steps, use_tpu, - use_one_hot_embeddings): + use_one_hot_embeddings, hvd=None): """Returns `model_fn` closure for TPUEstimator.""" def model_fn(features, labels, mode, params): # pylint: disable=unused-argument @@ -134,7 +195,8 @@ def model_fn(features, labels, mode, params): # pylint: disable=unused-argument input_ids=input_ids, input_mask=input_mask, token_type_ids=segment_ids, - use_one_hot_embeddings=use_one_hot_embeddings) + use_one_hot_embeddings=use_one_hot_embeddings, + compute_type=tf.float16 if FLAGS.use_fp16 else tf.float32) (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output( @@ -145,13 +207,16 @@ def model_fn(features, labels, mode, params): # pylint: disable=unused-argument next_sentence_log_probs) = get_next_sentence_output( bert_config, model.get_pooled_output(), next_sentence_labels) + masked_lm_loss = tf.identity(masked_lm_loss, name="mlm_loss") + next_sentence_loss = tf.identity(next_sentence_loss, name="nsp_loss") total_loss = masked_lm_loss + next_sentence_loss + total_loss = tf.identity(total_loss, name='total_loss') tvars = tf.trainable_variables() initialized_variable_names = {} scaffold_fn = None - if init_checkpoint: + if init_checkpoint and (hvd is None or hvd.rank() == 0): (assignment_map, initialized_variable_names ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) if use_tpu: @@ -169,13 +234,14 @@ def tpu_scaffold(): init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" - tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, + tf.logging.info(" %d :: name = %s, shape = %s%s", 0 if hvd is None else hvd.rank(), var.name, var.shape, init_string) output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: train_op = optimization.create_optimizer( - total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) + total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu, + hvd, FLAGS.use_fp16) output_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, @@ -325,7 +391,8 @@ def input_fn_builder(input_files, max_seq_length, max_predictions_per_seq, is_training, - num_cpu_threads=4): + num_cpu_threads=4, + hvd=None): """Creates an `input_fn` closure to be passed to TPUEstimator.""" def input_fn(params): @@ -353,6 +420,7 @@ def input_fn(params): # For eval, we want no shuffling and parallel reading doesn't matter. if is_training: d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files)) + if hvd is not None: d = d.shard(hvd.size(), hvd.rank()) d = d.repeat() d = d.shuffle(buffer_size=len(input_files)) @@ -409,13 +477,51 @@ def main(_): if not FLAGS.do_train and not FLAGS.do_eval: raise ValueError("At least one of `do_train` or `do_eval` must be True.") + if FLAGS.horovod: + import horovod.tensorflow as hvd + hvd.init() + bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) tf.gfile.MakeDirs(FLAGS.output_dir) + def walk_path(location: str, + only_dir: bool = False, + depth: int = None, + extension: str = None): + """Walks through specified remote or local directory. + + Args: + location: local or remote directory to start walk. + only_dir: if True, only directories are yielded, + else only files. + depth: number of subdirectories to recursively walk through. + if unspecified, walk through all subdirectories. + extension: if specified, only files the end with this + extension are returned. + Yields: + local or remote path. + + """ + for level, (root, dirs, file_names) in enumerate( + tf.gfile.Walk(top=location)): + if only_dir: + for dir_name in dirs: + yield os.path.join(root, dir_name) + else: + for file_name in file_names: + if extension and not file_name.endswith(extension): + continue + yield os.path.join(root, file_name) + if depth is not None and depth == level: + return + input_files = [] - for input_pattern in FLAGS.input_file.split(","): - input_files.extend(tf.gfile.Glob(input_pattern)) + if tf.gfile.Exists(FLAGS.input_file): + input_files = list(walk_path(FLAGS.input_file)) + else: + for input_pattern in FLAGS.input_file.split(","): + input_files.extend(tf.gfile.Glob(input_pattern)) tf.logging.info("*** Input Files ***") for input_file in input_files: @@ -426,25 +532,45 @@ def main(_): tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) + config = tf.ConfigProto() + if FLAGS.horovod: + config.gpu_options.visible_device_list = str(hvd.local_rank()) + if FLAGS.use_xla: + config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1 is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 run_config = tf.contrib.tpu.RunConfig( cluster=tpu_cluster_resolver, master=FLAGS.master, model_dir=FLAGS.output_dir, - save_checkpoints_steps=FLAGS.save_checkpoints_steps, + session_config=config, + save_checkpoints_steps=FLAGS.save_checkpoints_steps if not FLAGS.horovod or hvd.rank() == 0 else None, tpu_config=tf.contrib.tpu.TPUConfig( iterations_per_loop=FLAGS.iterations_per_loop, num_shards=FLAGS.num_tpu_cores, - per_host_input_for_training=is_per_host)) + per_host_input_for_training=is_per_host), + # This variable controls how often estimator reports examples/sec. + # Default value is every 100 steps. + # When --report_loss is True, we set to very large value to prevent + # default info reporting from estimator. + # Ideally we should set it to None, but that does not work. + log_step_count_steps=10000 if FLAGS.report_loss else 100) model_fn = model_fn_builder( bert_config=bert_config, init_checkpoint=FLAGS.init_checkpoint, - learning_rate=FLAGS.learning_rate, + learning_rate=FLAGS.learning_rate if not FLAGS.horovod else FLAGS.learning_rate*hvd.size(), num_train_steps=FLAGS.num_train_steps, num_warmup_steps=FLAGS.num_warmup_steps, use_tpu=FLAGS.use_tpu, - use_one_hot_embeddings=FLAGS.use_tpu) + use_one_hot_embeddings=FLAGS.use_tpu, + hvd=None if not FLAGS.horovod else hvd) + + training_hooks = [] + if FLAGS.horovod and hvd.size() > 1: + training_hooks.append(hvd.BroadcastGlobalVariablesHook(0)) + if FLAGS.report_loss: + global_batch_size = FLAGS.train_batch_size if not FLAGS.horovod else FLAGS.train_batch_size*hvd.size() + training_hooks.append(_LogSessionRunHook(global_batch_size,1,-1 if not FLAGS.horovod else hvd.rank())) # If TPU is not available, this will fall back to normal Estimator on CPU # or GPU. @@ -462,10 +588,11 @@ def main(_): input_files=input_files, max_seq_length=FLAGS.max_seq_length, max_predictions_per_seq=FLAGS.max_predictions_per_seq, - is_training=True) - estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps) + is_training=True, + hvd=None if not FLAGS.horovod else hvd) + estimator.train(input_fn=train_input_fn, hooks=training_hooks, max_steps=FLAGS.num_train_steps) - if FLAGS.do_eval: + if FLAGS.do_eval and (not FLAGS.horovod or hvd.rank() == 0): tf.logging.info("***** Running evaluation *****") tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size) @@ -473,7 +600,8 @@ def main(_): input_files=input_files, max_seq_length=FLAGS.max_seq_length, max_predictions_per_seq=FLAGS.max_predictions_per_seq, - is_training=False) + is_training=False, + hvd=None if not FLAGS.horovod else hvd) result = estimator.evaluate( input_fn=eval_input_fn, steps=FLAGS.max_eval_steps) diff --git a/run_squad.py b/run_squad.py index edd4c3ed9..dba036198 100644 --- a/run_squad.py +++ b/run_squad.py @@ -153,6 +153,10 @@ "null_score_diff_threshold", 0.0, "If null_score - best_non_null is greater than the threshold predict null.") +flags.DEFINE_bool("use_fp16", False, "Whether to use fp32 or fp16 arithmetic on GPU.") + +flags.DEFINE_bool("use_xla", False, "Whether to enable XLA JIT compilation.") + class SquadExample(object): """A single training/test example for simple sequence classification. @@ -556,7 +560,8 @@ def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, input_ids=input_ids, input_mask=input_mask, token_type_ids=segment_ids, - use_one_hot_embeddings=use_one_hot_embeddings) + use_one_hot_embeddings=use_one_hot_embeddings, + compute_type=tf.float16 if FLAGS.use_fp16 else tf.float32) final_hidden = model.get_sequence_output() @@ -660,7 +665,8 @@ def compute_loss(logits, positions): total_loss = (start_loss + end_loss) / 2.0 train_op = optimization.create_optimizer( - total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) + total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu, + None, FLAGS.use_fp16) output_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, @@ -1140,11 +1146,15 @@ def main(_): tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) + config = tf.ConfigProto() + if FLAGS.use_xla: + config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1 is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 run_config = tf.contrib.tpu.RunConfig( cluster=tpu_cluster_resolver, master=FLAGS.master, model_dir=FLAGS.output_dir, + session_config=config, save_checkpoints_steps=FLAGS.save_checkpoints_steps, tpu_config=tf.contrib.tpu.TPUConfig( iterations_per_loop=FLAGS.iterations_per_loop,