Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,121 @@ Note that since our `sample_text.txt` file is very small, this example training
will overfit that data in only a few steps and produce unrealistically high
accuracy numbers.

Many people have asked how to report the loss during pre-training. Here is
how you do it:

```shell
python run_pretraining.py \
--input_file=/tmp/tf_examples.tfrecord \
--output_dir=/tmp/pretraining_output \
--do_train=True \
--do_eval=True \
--bert_config_file=$BERT_BASE_DIR/bert_config.json \
--init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
--train_batch_size=32 \
--max_seq_length=128 \
--max_predictions_per_seq=20 \
--num_train_steps=20 \
--num_warmup_steps=10 \
--learning_rate=2e-5 \
--report_loss
```

This will produce the following output during training:

```shell
Step samples/sec Loss Learning-rate
100 122.9 9.019 3.9200e-06
200 174.9 8.255 7.9200e-06
300 174.8 7.962 1.1920e-05
```

Here is how to run the pre-training with FP16 arithmetic on GPUs. Doing this
triples throughput on most GPUs.

```shell
python run_pretraining.py \
--input_file=/tmp/tf_examples.tfrecord \
--output_dir=/tmp/pretraining_output \
--do_train=True \
--do_eval=True \
--bert_config_file=$BERT_BASE_DIR/bert_config.json \
--init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
--train_batch_size=32 \
--max_seq_length=128 \
--max_predictions_per_seq=20 \
--num_train_steps=20 \
--num_warmup_steps=10 \
--learning_rate=2e-5 \
--use_fp16
```

Here is how to enable XLA JIT compilation for GPUs. Doing this boosts
throughput by 1.3x for FP32 and 1.7x for FP16 arithmetic.

```shell

python run_pretraining.py \
--input_file=/tmp/tf_examples.tfrecord \
--output_dir=/tmp/pretraining_output \
--do_train=True \
--do_eval=True \
--bert_config_file=$BERT_BASE_DIR/bert_config.json \
--init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
--train_batch_size=32 \
--max_seq_length=128 \
--max_predictions_per_seq=20 \
--num_train_steps=20 \
--num_warmup_steps=10 \
--learning_rate=2e-5 \
--use_xla
```

This version of BERT supports pre-training on multiple GPUs. You need Horovod
installed for this. You also need to split your input dataset over multiple
files (at least one per GPU). Assuming you have split your input dataset
over 8 files named tf_examples.part01.tfrecord through
tf_examples.part.08.tfrecord, here is how you run it:

```shell
mpiexec --allow-run-as-root --bind-to socket -np 8 python run_pretraining.py \
--input_file=/tmp/tf_examples.part01.tfrecord,/tmp/tf_examples.part02.tfrecord,/tmp/tf_examples.part03.tfrecord,/tmp/tf_examples.part04.tfrecord,/tmp/tf_examples.part05.tfrecord,/tmp/tf_examples.part06.tfrecord,/tmp/tf_examples.part07.tfrecord,/tmp/tf_examples.part08.tfrecord \
--output_dir=/tmp/pretraining_output \
--do_train=True \
--do_eval=True \
--bert_config_file=$BERT_BASE_DIR/bert_config.json \
--init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
--train_batch_size=32 \
--max_seq_length=128 \
--max_predictions_per_seq=20 \
--num_train_steps=20 \
--num_warmup_steps=10 \
--learning_rate=2e-5 \
--horovod
```

You can combine --report_loss, --use_fp16, --use_xla and --horovod:

```shell
mpiexec --allow-run-as-root --bind-to socket -np 8 python run_pretraining.py \
--input_file=/tmp/tf_examples.part01.tfrecord,/tmp/tf_examples.part02.tfrecord,/tmp/tf_examples.part03.tfrecord,/tmp/tf_examples.part04.tfrecord,/tmp/tf_examples.part05.tfrecord,/tmp/tf_examples.part06.tfrecord,/tmp/tf_examples.part07.tfrecord,/tmp/tf_examples.part08.tfrecord \
--output_dir=/tmp/pretraining_output \
--do_train=True \
--do_eval=True \
--bert_config_file=$BERT_BASE_DIR/bert_config.json \
--init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
--train_batch_size=32 \
--max_seq_length=128 \
--max_predictions_per_seq=20 \
--num_train_steps=20 \
--num_warmup_steps=10 \
--learning_rate=2e-5 \
--report_loss \
--use_fp16 \
--use_xla \
--horovod
```

### Pre-training tips and caveats

* **If using your own vocabulary, make sure to change `vocab_size` in
Expand Down
36 changes: 36 additions & 0 deletions gpu_environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# coding=utf-8
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf
import numpy as np

def float32_variable_storage_getter(getter, name, shape=None, dtype=None,
initializer=None, regularizer=None,
trainable=True,
*args, **kwargs):
"""Custom variable getter that forces trainable variables to be stored in
float32 precision and then casts them to the training precision.
"""
storage_dtype = tf.float32 if trainable else dtype
variable = getter(name, shape, dtype=storage_dtype,
initializer=initializer, regularizer=regularizer,
trainable=trainable,
*args, **kwargs)
if trainable and dtype != tf.float32:
variable = tf.cast(variable, dtype)
return variable

def get_custom_getter(compute_type):
return float32_variable_storage_getter if compute_type == tf.float16 else None
12 changes: 7 additions & 5 deletions modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import six
import tensorflow as tf

from gpu_environment import get_custom_getter

class BertConfig(object):
"""Configuration for `BertModel`."""
Expand Down Expand Up @@ -135,7 +136,8 @@ def __init__(self,
input_mask=None,
token_type_ids=None,
use_one_hot_embeddings=False,
scope=None):
scope=None,
compute_type=tf.float32):
"""Constructor for BertModel.

Args:
Expand Down Expand Up @@ -168,7 +170,7 @@ def __init__(self,
if token_type_ids is None:
token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32)

with tf.variable_scope(scope, default_name="bert"):
with tf.variable_scope(scope, default_name="bert", custom_getter=get_custom_getter(compute_type)):
with tf.variable_scope("embeddings"):
# Perform embedding lookup on the word ids.
(self.embedding_output, self.embedding_table) = embedding_lookup(
Expand Down Expand Up @@ -203,7 +205,7 @@ def __init__(self,
# Run the stacked transformer.
# `sequence_output` shape = [batch_size, seq_length, hidden_size].
self.all_encoder_layers = transformer_model(
input_tensor=self.embedding_output,
input_tensor=tf.saturate_cast(self.embedding_output, compute_type),
attention_mask=attention_mask,
hidden_size=config.hidden_size,
num_hidden_layers=config.num_hidden_layers,
Expand All @@ -215,7 +217,7 @@ def __init__(self,
initializer_range=config.initializer_range,
do_return_all_layers=True)

self.sequence_output = self.all_encoder_layers[-1]
self.sequence_output = tf.cast(self.all_encoder_layers[-1], tf.float32)
# The "pooler" converts the encoded sequence tensor of shape
# [batch_size, seq_length, hidden_size] to a tensor of shape
# [batch_size, hidden_size]. This is necessary for segment-level
Expand Down Expand Up @@ -709,7 +711,7 @@ def transpose_for_scores(input_tensor, batch_size, num_attention_heads,
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0
adder = (1.0 - tf.cast(attention_mask, attention_scores.dtype)) * -10000.0

# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
Expand Down
30 changes: 24 additions & 6 deletions optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import tensorflow as tf


def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu):
def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu, hvd=None, use_fp16=False):
"""Creates an optimizer training op."""
global_step = tf.train.get_or_create_global_step()

Expand Down Expand Up @@ -66,20 +66,38 @@ def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu):

if use_tpu:
optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
else:
if hvd is not None:
from horovod.tensorflow.compression import Compression
optimizer = hvd.DistributedOptimizer(optimizer, sparse_as_dense=True, compression=Compression.fp16)
if use_fp16:
loss_scale_manager = tf.contrib.mixed_precision.ExponentialUpdateLossScaleManager(init_loss_scale=2**32, incr_every_n_steps=1000, decr_every_n_nan_or_inf=2, decr_ratio=0.5)
optimizer = tf.contrib.mixed_precision.LossScaleOptimizer(optimizer, loss_scale_manager)

tvars = tf.trainable_variables()
grads = tf.gradients(loss, tvars)
grads_and_vars = optimizer.compute_gradients(loss, tvars)
grads_and_vars = [(g,v) for g,v in grads_and_vars if g is not None]
grads, tvars = list(zip(*grads_and_vars))
all_are_finite = tf.reduce_all([tf.reduce_all(tf.is_finite(g)) for g in grads]) if use_fp16 else tf.constant(True, dtype=tf.bool)

# This is how the model was pre-trained.
(grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
# ensure global norm is a finite number
# to prevent clip_by_global_norm from having a hizzy fit.
(clipped_grads, _) = tf.clip_by_global_norm(
grads, clip_norm=1.0,
use_norm=tf.cond(
all_are_finite,
lambda: tf.global_norm(grads),
lambda: tf.constant(1.0)))

train_op = optimizer.apply_gradients(
zip(grads, tvars), global_step=global_step)
list(zip(clipped_grads, tvars)), global_step=global_step)

# Normally the global step update is done inside of `apply_gradients`.
# However, `AdamWeightDecayOptimizer` doesn't do this. But if you use
# a different optimizer, you should probably take this line out.
new_global_step = global_step + 1
new_global_step = tf.cond(all_are_finite, lambda: global_step+1, lambda: global_step)
new_global_step = tf.identity(new_global_step, name='step_update')
train_op = tf.group(train_op, [global_step.assign(new_global_step)])
return train_op

Expand All @@ -98,7 +116,7 @@ def __init__(self,
"""Constructs a AdamWeightDecayOptimizer."""
super(AdamWeightDecayOptimizer, self).__init__(False, name)

self.learning_rate = learning_rate
self.learning_rate = tf.identity(learning_rate, name='learning_rate')
self.weight_decay_rate = weight_decay_rate
self.beta_1 = beta_1
self.beta_2 = beta_2
Expand Down
14 changes: 12 additions & 2 deletions run_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@
"num_tpu_cores", 8,
"Only used if `use_tpu` is True. Total number of TPU cores to use.")

flags.DEFINE_bool("use_fp16", False, "Whether to use fp32 or fp16 arithmetic on GPU.")

flags.DEFINE_bool("use_xla", False, "Whether to enable XLA JIT compilation.")


class InputExample(object):
"""A single training/test example for simple sequence classification."""
Expand Down Expand Up @@ -580,7 +584,8 @@ def create_model(bert_config, is_training, input_ids, input_mask, segment_ids,
input_ids=input_ids,
input_mask=input_mask,
token_type_ids=segment_ids,
use_one_hot_embeddings=use_one_hot_embeddings)
use_one_hot_embeddings=use_one_hot_embeddings,
compute_type=tf.float16 if FLAGS.use_fp16 else tf.float32)

# In the demo, we are doing a simple classification task on the entire
# segment.
Expand Down Expand Up @@ -672,7 +677,8 @@ def tpu_scaffold():
if mode == tf.estimator.ModeKeys.TRAIN:

train_op = optimization.create_optimizer(
total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu)
total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu,
None, FLAGS.use_fp16)

output_spec = tf.contrib.tpu.TPUEstimatorSpec(
mode=mode,
Expand Down Expand Up @@ -824,11 +830,15 @@ def main(_):
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

config = tf.ConfigProto()
if FLAGS.use_xla:
config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
run_config = tf.contrib.tpu.RunConfig(
cluster=tpu_cluster_resolver,
master=FLAGS.master,
model_dir=FLAGS.output_dir,
session_config=config,
save_checkpoints_steps=FLAGS.save_checkpoints_steps,
tpu_config=tf.contrib.tpu.TPUConfig(
iterations_per_loop=FLAGS.iterations_per_loop,
Expand Down
Loading