Skip to content

Commit

Permalink
Support for TensorFlow GradientTape (#670)
Browse files Browse the repository at this point in the history
  • Loading branch information
kuroko1t authored and alsrgv committed Jan 9, 2019
1 parent 4c84663 commit 65bb358
Show file tree
Hide file tree
Showing 5 changed files with 567 additions and 356 deletions.
11 changes: 10 additions & 1 deletion .travis.yml
Expand Up @@ -119,7 +119,7 @@ script:
else
export MPIRUN="mpirun -np 2"
fi
# run unit tests
- docker exec ${CONTAINER} /bin/sh -c "pip install pytest && cd /horovod/test && (echo test_*.py | xargs -n 1 ${MPIRUN} pytest -v)"
Expand All @@ -136,6 +136,15 @@ script:
# run TensorFlow MNIST example
- docker exec ${CONTAINER} /bin/sh -c "${MPIRUN} python /horovod/examples/tensorflow_mnist.py"

# hack TensorFlow Eager MNIST example to be smaller
- docker exec ${CONTAINER} /bin/sh -c "sed -i \"s/dataset.take(20000/dataset.take(100/\" /horovod/examples/tensorflow_mnist_eager.py"

# run TensorFlow Eager MNIST example
- |
if [[ ${TF_PACKAGE} == "tensorflow==1.12.0" ]]; then
docker exec ${CONTAINER} /bin/sh -c "${MPIRUN} python /horovod/examples/tensorflow_mnist_eager.py"
fi
# download Keras MNIST dataset
- docker exec ${CONTAINER} /bin/sh -c "python -c \"from keras.datasets import mnist; mnist.load_data()\""

Expand Down
82 changes: 82 additions & 0 deletions examples/tensorflow_mnist_eager.py
@@ -0,0 +1,82 @@
# Copyright 2017 Uber Technologies, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#!/usr/bin/env python

import tensorflow as tf
import horovod.tensorflow as hvd

def main(_):
# Horovod: initialize Horovod.
hvd.init()

# Horovod: pin GPU to be used to process local rank (one GPU per process)
config = tf.ConfigProto()
config.gpu_options.visible_device_list = str(hvd.local_rank())

tf.enable_eager_execution(config=config)

mnist_model = tf.keras.Sequential([
tf.keras.layers.Conv2D(16, [3, 3], activation='relu'),
tf.keras.layers.Conv2D(16, [3, 3], activation='relu'),
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(10)
])

# Horovod: adjust learning rate based on number of GPUs.
opt = tf.train.RMSPropOptimizer(0.001 * hvd.size())

(mnist_images, mnist_labels), _ = tf.keras.datasets.mnist.load_data()

dataset = tf.data.Dataset.from_tensor_slices(
(tf.cast(mnist_images[..., tf.newaxis] / 255, tf.float32),
tf.cast(mnist_labels, tf.int64))
)
dataset = dataset.shuffle(1000).batch(32)

# Horovod: save checkpoints only on worker 0 to prevent other workers from
checkpoint_dir = './checkpoints'
step_counter = tf.train.get_or_create_global_step()
checkpoint = tf.train.Checkpoint(
model=mnist_model, optimizer=opt, step_counter=step_counter)

# Horovod: adjust number of steps based on number of GPUs.
for (batch, (images, labels)) in enumerate(
dataset.take(20000 // hvd.size())):
with tf.GradientTape() as tape:
logits = mnist_model(images, training=True)
loss_value = tf.losses.sparse_softmax_cross_entropy(labels, logits)

# Horovod: broadcast initial variable states
# from rank 0 to all other processes. This is necessary to ensure consistent
# initialization of all workers when training is started with random weights
# or restored from a checkpoint.
if batch == 0:
hvd.broadcast_variables(0, mnist_model.variables)

# Horovod: add Horovod Distributed GradientTape.
tape = hvd.DistributedGradientTape(tape)

grads = tape.gradient(loss_value, mnist_model.variables)
opt.apply_gradients(zip(grads, mnist_model.variables),
global_step=tf.train.get_or_create_global_step())
if batch % 10 == 0 and hvd.local_rank() == 0:
print('Step #%d\tLoss: %.6f' % (batch, loss_value))

if hvd.rank() == 0:
checkpoint.save(checkpoint_dir)


if __name__ == "__main__":
tf.app.run()
90 changes: 88 additions & 2 deletions horovod/tensorflow/__init__.py
Expand Up @@ -43,7 +43,6 @@

import tensorflow as tf


def allreduce(tensor, average=True, device_dense='', device_sparse='',
compression=Compression.none):
"""Perform an allreduce on a tf.Tensor or tf.IndexedSlices.
Expand Down Expand Up @@ -100,8 +99,18 @@ def broadcast_global_variables(root_rank):
root_rank: rank of the process from which global variables will be broadcasted
to all other processes.
"""
return broadcast_variables(root_rank, tf.global_variables())

def broadcast_variables(root_rank, variables):
"""Broadcasts variables from root rank to all other processes.
Arguments:
root_rank: rank of the process from which global variables will be broadcasted
to all other processes.
variables: variables for broadcast
"""
return tf.group(*[tf.assign(var, broadcast(var, root_rank))
for var in tf.global_variables()])
for var in variables])


class BroadcastGlobalVariablesHook(tf.train.SessionRunHook):
Expand Down Expand Up @@ -237,3 +246,80 @@ def get_slot_names(self, *args, **kwargs):
def variables(self, *args, **kwargs):
"""Calls this same method on the underlying optimizer."""
return self._optimizer.variables(*args, **kwargs)


if hasattr(tf, 'GradientTape'):
class _DistributedGradientTape(tf.GradientTape):

def __init__(self, tape, device_dense, device_sparse,
compression, sparse_as_dense, persistent=False, watch_accessed_variables=True):
if hasattr(tape, '_watch_accessed_variables'):
super(self.__class__, self).__init__(persistent, watch_accessed_variables)
else:
super(self.__class__, self).__init__(persistent)
self._tape = tape
self._persistent = persistent
self._watch_accessed_variables = watch_accessed_variables
self._name = "Distributed"
self._device_dense = device_dense
self._device_sparse = device_sparse
self._compression = compression
self._sparse_as_dense = sparse_as_dense

def allreduce_grads(grads):
with tf.name_scope(self._name + "_Allreduce"):
if self._sparse_as_dense:
grads = [tf.convert_to_tensor(grad)
if grad is not None and isinstance(grad, tf.IndexedSlices)
else grad for grad in grads]
return [allreduce(grad,
device_dense=self._device_dense,
device_sparse=self._device_sparse,
compression=self._compression)
if grad is not None else grad
for grad in grads]

self._allreduce_grads = tf.contrib.eager.defun(allreduce_grads)

def gradient(self, target, sources, output_gradients=None):
gradients = super(self.__class__, self).gradient(target, sources, output_gradients)
if size() > 1:
avg_grads = self._allreduce_grads(gradients)
return avg_grads
else:
return gradients


def DistributedGradientTape(gradtape, device_dense='', device_sparse='',
compression=Compression.none, sparse_as_dense=False):
"""An tape that wraps another tf.GradientTape, using an allreduce to
average gradient values before applying gradients to model weights.
Args:
gradtape:
GradientTape to use for computing gradients and applying updates.
device_dense:
Device to be used for dense tensors. Uses GPU by default
if Horovod was build with HOROVOD_GPU_ALLREDUCE.
device_sparse:
Device to be used for sparse tensors. Uses GPU by default
if Horovod was build with HOROVOD_GPU_ALLGATHER.
compression:
Compression algorithm used during allreduce to reduce the amount
of data sent during the each parameter update step. Defaults to
not using compression.
sparse_as_dense:
Treat all sparse gradients as dense tensors. This can help improve
performance and memory utilization if the original sparse gradient
has high density. Defaults to false.
"""
cls = type(gradtape.__class__.__name__, (gradtape.__class__,),
dict(_DistributedGradientTape.__dict__))
if hasattr(gradtape, '_watch_accessed_variables'):
return cls(gradtape._tape, device_dense, device_sparse,
compression, sparse_as_dense,
gradtape._persistent, gradtape._watch_accessed_variables)
else:
return cls(gradtape._tape, device_dense, device_sparse,
compression, sparse_as_dense,
gradtape._persistent)
2 changes: 1 addition & 1 deletion horovod/tensorflow/util.py
Expand Up @@ -17,7 +17,7 @@
import tensorflow as tf


if LooseVersion(tf.__version__) >= LooseVersion("1.4.0"):
if LooseVersion(tf.__version__) >= LooseVersion("1.9.0"):
from tensorflow.python.eager import context
_has_eager = True
else:
Expand Down

0 comments on commit 65bb358

Please sign in to comment.