Skip to content
Permalink
Browse files

Add TensorFlow Estimator example, update existing examples (#140)

* Add TensorFlow Estimator example, update existing examples

* Update README, bump number of steps for MNIST training

* Adjust MNIST model learning rate

* Add LR scaling
  • Loading branch information...
alsrgv committed Jan 5, 2018
1 parent dfbd64a commit 9bdd70dad6d35fe7fa7c4b0875a27c5cf668ca83
@@ -90,6 +90,9 @@ script:
# run unit tests
- docker exec $CONTAINER /bin/sh -c "$MPIRUN python /horovod/horovod/tensorflow/mpi_ops_test.py"

# hack TensorFlow MNIST example to be smaller
- docker exec $CONTAINER /bin/sh -c "sed -i \"s/last_step=20000/last_step=100/\" /horovod/examples/tensorflow_mnist.py"

# run TensorFlow MNIST example
- docker exec $CONTAINER /bin/sh -c "$MPIRUN python /horovod/examples/tensorflow_mnist.py"

@@ -62,16 +62,19 @@ To use Horovod, make the following additions to your program:
With the typical setup of one GPU per process, this can be set to *local rank*. In that case, the first process on
the server will be allocated the first GPU, second process will be allocated the second GPU and so forth.

3. Wrap optimizer in `hvd.DistributedOptimizer`. The distributed optimizer delegates gradient computation
3. Scale the learning rate by number of workers. Effective batch size in synchronous distributed training is scaled by
the number of workers. An increase in learning rate compensates for the increased batch size.

4. Wrap optimizer in `hvd.DistributedOptimizer`. The distributed optimizer delegates gradient computation
to the original optimizer, averages gradients using *allreduce* or *allgather*, and then applies those averaged
gradients.

4. Add `hvd.BroadcastGlobalVariablesHook(0)` to broadcast initial variable states from rank 0 to all other processes.
5. Add `hvd.BroadcastGlobalVariablesHook(0)` to broadcast initial variable states from rank 0 to all other processes.
This is necessary to ensure consistent initialization of all workers when training is started with random weights or
restored from a checkpoint. Alternatively, if you're not using `MonitoredTrainingSession`, you can simply execute
the `hvd.broadcast_global_variables` op after global variables have been initialized.

5. Modify your code to save checkpoints only on worker 0 to prevent other workers from corrupting them.
6. Modify your code to save checkpoints only on worker 0 to prevent other workers from corrupting them.
This can be accomplished by passing `checkpoint_dir=None` to `tf.train.MonitoredTrainingSession` if
`hvd.rank() != 0`.

@@ -91,7 +94,7 @@ config.gpu_options.visible_device_list = str(hvd.local_rank())
# Build model...
loss = ...
opt = tf.train.AdagradOptimizer(0.01)
opt = tf.train.AdagradOptimizer(0.01 * hvd.size())
# Add Horovod Distributed Optimizer
opt = hvd.DistributedOptimizer(opt)
@@ -152,6 +155,12 @@ See full training [simple](examples/keras_mnist.py) and [advanced](examples/kera
all GPUs on the server, instead of the GPU assigned by the *local rank*. If you have multiple GPUs per server, upgrade
to Keras 2.1.2, or downgrade to Keras 2.0.8.

## Estimator API

Horovod supports Estimator API and regular TensorFlow in similar ways.

See a full training [example](examples/tensorflow_mnist_estimator.py).

## Inference

Learn how to optimize your model for inference and remove Horovod operations from the graph [here](docs/inference.md).
@@ -9,10 +9,10 @@
import tensorflow as tf
import horovod.keras as hvd

# Initialize Horovod.
# Horovod: initialize Horovod.
hvd.init()

# Pin GPU to be used to process local rank (one GPU per process)
# Horovod: pin GPU to be used to process local rank (one GPU per process)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.gpu_options.visible_device_list = str(hvd.local_rank())
@@ -21,7 +21,7 @@
batch_size = 128
num_classes = 10

# Adjust number of epochs based on number of GPUs.
# Horovod: adjust number of epochs based on number of GPUs.
epochs = int(math.ceil(12.0 / hvd.size()))

# Input image dimensions
@@ -63,24 +63,24 @@
model.add(Dropout(0.5))
model.add(Dense(num_classes, activation='softmax'))

# Adjust learning rate based on number of GPUs.
# Horovod: adjust learning rate based on number of GPUs.
opt = keras.optimizers.Adadelta(1.0 * hvd.size())

# Add Horovod Distributed Optimizer.
# Horovod: add Horovod Distributed Optimizer.
opt = hvd.DistributedOptimizer(opt)

model.compile(loss=keras.losses.categorical_crossentropy,
optimizer=opt,
metrics=['accuracy'])

callbacks = [
# Broadcast initial variable states from rank 0 to all other processes.
# Horovod: broadcast initial variable states from rank 0 to all other processes.
# This is necessary to ensure consistent initialization of all workers when
# training is started with random weights or restored from a checkpoint.
hvd.callbacks.BroadcastGlobalVariablesCallback(0),
]

# Save checkpoints only on worker 0 to prevent other workers from corrupting them.
# Horovod: save checkpoints only on worker 0 to prevent other workers from corrupting them.
if hvd.rank() == 0:
callbacks.append(keras.callbacks.ModelCheckpoint('./checkpoint-{epoch}.h5'))

@@ -9,10 +9,10 @@
import tensorflow as tf
import horovod.keras as hvd

# Initialize Horovod.
# Horovod: initialize Horovod.
hvd.init()

# Pin GPU to be used to process local rank (one GPU per process)
# Horovod: pin GPU to be used to process local rank (one GPU per process)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.gpu_options.visible_device_list = str(hvd.local_rank())
@@ -68,29 +68,29 @@
model.add(Dropout(0.5))
model.add(Dense(num_classes, activation='softmax'))

# Adjust learning rate based on number of GPUs.
# Horovod: adjust learning rate based on number of GPUs.
opt = keras.optimizers.Adadelta(lr=1.0 * hvd.size())

# Add Horovod Distributed Optimizer.
# Horovod: add Horovod Distributed Optimizer.
opt = hvd.DistributedOptimizer(opt)

model.compile(loss=keras.losses.categorical_crossentropy,
optimizer=opt,
metrics=['accuracy'])

callbacks = [
# Broadcast initial variable states from rank 0 to all other processes.
# Horovod: broadcast initial variable states from rank 0 to all other processes.
# This is necessary to ensure consistent initialization of all workers when
# training is started with random weights or restored from a checkpoint.
hvd.callbacks.BroadcastGlobalVariablesCallback(0),

# Average metrics among workers at the end of every epoch.
# Horovod: average metrics among workers at the end of every epoch.
#
# Note: This callback must be in the list before the ReduceLROnPlateau,
# TensorBoard or other metrics-based callbacks.
hvd.callbacks.MetricAverageCallback(),

# Using `lr = 1.0 * hvd.size()` from the very beginning leads to worse final
# Horovod: using `lr = 1.0 * hvd.size()` from the very beginning leads to worse final
# accuracy. Scale the learning rate `lr = 1.0` ---> `lr = 1.0 * hvd.size()` during
# the first five epochs. See https://arxiv.org/abs/1706.02677 for details.
hvd.callbacks.LearningRateWarmupCallback(warmup_epochs=5, verbose=1),
@@ -99,7 +99,7 @@
keras.callbacks.ReduceLROnPlateau(patience=10, verbose=1),
]

# Save checkpoints only on worker 0 to prevent other workers from corrupting them.
# Horovod: save checkpoints only on worker 0 to prevent other workers from corrupting them.
if hvd.rank() == 0:
callbacks.append(keras.callbacks.ModelCheckpoint('./checkpoint-{epoch}.h5'))

@@ -108,7 +108,8 @@
height_shift_range=0.08, zoom_range=0.08)
test_gen = ImageDataGenerator()

# Train the model. The training will randomly sample 1 / N batches of training data and
# Train the model.
# Horovod: the training will randomly sample 1 / N batches of training data and
# 3 / N batches of validation data on every worker, where N is the number of workers.
# Over-sampling of validation data helps to increase probability that every validation
# example will be evaluated.
@@ -63,7 +63,7 @@ def conv_model(feature, target, mode):


def main(_):
# Initialize Horovod.
# Horovod: initialize Horovod.
hvd.init()

# Download and load MNIST dataset.
@@ -75,30 +75,36 @@ def main(_):
label = tf.placeholder(tf.float32, [None], name='label')
predict, loss = conv_model(image, label, tf.contrib.learn.ModeKeys.TRAIN)

opt = tf.train.RMSPropOptimizer(0.01)
# Horovod: adjust learning rate based on number of GPUs.
opt = tf.train.RMSPropOptimizer(0.001 * hvd.size())

# Add Horovod Distributed Optimizer.
# Horovod: add Horovod Distributed Optimizer.
opt = hvd.DistributedOptimizer(opt)

global_step = tf.contrib.framework.get_or_create_global_step()
train_op = opt.minimize(loss, global_step=global_step)

# BroadcastGlobalVariablesHook broadcasts initial variable states from rank 0
# to all other processes. This is necessary to ensure consistent initialization
# of all workers when training is started with random weights or restored
# from a checkpoint.
hooks = [hvd.BroadcastGlobalVariablesHook(0),
tf.train.StopAtStepHook(last_step=100),
tf.train.LoggingTensorHook(tensors={'step': global_step, 'loss': loss},
every_n_iter=10),
]

# Pin GPU to be used to process local rank (one GPU per process)
hooks = [
# Horovod: BroadcastGlobalVariablesHook broadcasts initial variable states
# from rank 0 to all other processes. This is necessary to ensure consistent
# initialization of all workers when training is started with random weights
# or restored from a checkpoint.
hvd.BroadcastGlobalVariablesHook(0),

# Horovod: adjust number of steps based on number of GPUs.
tf.train.StopAtStepHook(last_step=20000 // hvd.size()),

tf.train.LoggingTensorHook(tensors={'step': global_step, 'loss': loss},
every_n_iter=10),
]

# Horovod: pin GPU to be used to process local rank (one GPU per process)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.gpu_options.visible_device_list = str(hvd.local_rank())

# Save checkpoints only on worker 0 to prevent other workers from corrupting them.
# Horovod: save checkpoints only on worker 0 to prevent other workers from
# corrupting them.
checkpoint_dir = './checkpoints' if hvd.rank() == 0 else None

# The MonitoredTrainingSession takes care of session initialization,

0 comments on commit 9bdd70d

Please sign in to comment.
You can’t perform that action at this time.