Skip to content
Permalink
Browse files

Revert hvd.DistributedOptimizer to allreduce in compute_gradients(), …

…require graph mode (#1347)

* Revert "Add *args to opt.apply_gradients() (#1316)"

This reverts commit bbda2aa.

Signed-off-by: Alex Sergeev <alsrgv@users.noreply.github.com>

* Revert "Move allreduce to opt.apply_gradients when possible (#1311)"

This reverts commit 2be9d14.

Signed-off-by: Alex Sergeev <alsrgv@users.noreply.github.com>

* Fixes from #1311

Signed-off-by: Alex Sergeev <alsrgv@users.noreply.github.com>

* Force graph execution for hvd.DistributedOptimizer()

Signed-off-by: Alex Sergeev <alsrgv@users.noreply.github.com>

* Move run_eagerly=False to compile()

Signed-off-by: Alex Sergeev <alsrgv@users.noreply.github.com>

* Eager execution is insufficient to distinguish Optimizer usage

Signed-off-by: Alex Sergeev <alsrgv@users.noreply.github.com>

* Cosmetic fix

Signed-off-by: Alex Sergeev <alsrgv@users.noreply.github.com>
  • Loading branch information...
alsrgv committed Aug 25, 2019
1 parent b21f0fb commit 26b55a7890f6923ca58cdb68a765ed0ec436ab0f
@@ -52,9 +52,12 @@
# Horovod: add Horovod DistributedOptimizer.
opt = hvd.DistributedOptimizer(opt)

# Horovod: Specify `experimental_run_tf_function=False` to ensure TensorFlow
# uses hvd.DistributedOptimizer() to compute gradients.
mnist_model.compile(loss=tf.losses.SparseCategoricalCrossentropy(),
optimizer=opt,
metrics=['accuracy'])
metrics=['accuracy'],
experimental_run_tf_function=False)

callbacks = [
# Horovod: broadcast initial variable states from rank 0 to all other processes.
@@ -50,9 +50,6 @@
# Horovod: adjust learning rate based on number of GPUs.
opt = tf.optimizers.Adam(0.001 * hvd.size())

# Horovod: add Horovod Distributed Optimizer.
opt = hvd.DistributedOptimizer(opt)

checkpoint_dir = './checkpoints'
checkpoint = tf.train.Checkpoint(model=mnist_model, optimizer=opt)

@@ -63,6 +60,9 @@ def training_step(images, labels, first_batch):
probs = mnist_model(images, training=True)
loss_value = loss(labels, probs)

# Horovod: add Horovod Distributed GradientTape.
tape = hvd.DistributedGradientTape(tape)

grads = tape.gradient(loss_value, mnist_model.trainable_variables)
opt.apply_gradients(zip(grads, mnist_model.trainable_variables))

@@ -64,22 +64,23 @@
model = getattr(applications, args.model)(weights=None)
opt = tf.optimizers.SGD(0.01)

# Horovod: (optional) compression algorithm.
compression = hvd.Compression.fp16 if args.fp16_allreduce else hvd.Compression.none

# Horovod: add Horovod Distributed Optimizer.
opt = hvd.DistributedOptimizer(opt, compression=compression)

data = tf.random.uniform([args.batch_size, 224, 224, 3])
target = tf.random.uniform([args.batch_size, 1], minval=0, maxval=999, dtype=tf.int64)


@tf.function
def benchmark_step(first_batch):
# Horovod: (optional) compression algorithm.
compression = hvd.Compression.fp16 if args.fp16_allreduce else hvd.Compression.none

# Horovod: use DistributedGradientTape
with tf.GradientTape() as tape:
probs = model(data, training=True)
loss = tf.losses.categorical_crossentropy(target, probs)

# Horovod: add Horovod Distributed GradientTape.
tape = hvd.DistributedGradientTape(tape, compression=compression)

gradients = tape.gradient(loss, model.trainable_variables)
opt.apply_gradients(zip(gradients, model.trainable_variables))

@@ -37,9 +37,6 @@ def main(_):
# Horovod: adjust learning rate based on number of GPUs.
opt = tf.train.AdamOptimizer(0.001 * hvd.size())

# Horovod: add Horovod Distributed Optimizer.
opt = hvd.DistributedOptimizer(opt)

(mnist_images, mnist_labels), _ = \
tf.keras.datasets.mnist.load_data(path='mnist-%d.npz' % hvd.rank())

@@ -61,6 +58,9 @@ def main(_):
logits = mnist_model(images, training=True)
loss_value = tf.losses.sparse_softmax_cross_entropy(labels, logits)

# Horovod: add Horovod Distributed GradientTape.
tape = hvd.DistributedGradientTape(tape)

grads = tape.gradient(loss_value, mnist_model.variables)
opt.apply_gradients(zip(grads, mnist_model.variables),
global_step=tf.train.get_or_create_global_step())
@@ -19,36 +19,17 @@

def create_distributed_optimizer(keras, optimizer, name, device_dense, device_sparse,
compression, sparse_as_dense):
class _DistributedOptimizerWithApplyGradients(keras.optimizers.Optimizer):
class _DistributedOptimizer(keras.optimizers.Optimizer):
def __init__(self, name, device_dense, device_sparse, compression, sparse_as_dense,
config):
if name is None:
name = "Distributed%s" % self.__class__.__base__.__name__
self._allreduce_grads = hvd._make_allreduce_grads_fn(
name, device_dense, device_sparse, compression, sparse_as_dense)
super(self.__class__, self).__init__(**config)

def apply_gradients(self, grads_and_vars, *args, **kwargs):
"""Apply gradients to provided variables.
See Optimizer.apply_gradients() for more info.
In DistributedOptimizer, apply_gradients() is overriden to also
allreduce the gradients before applying them.
"""
if hvd.size() > 1:
grads, vars = zip(*grads_and_vars)
avg_grads = self._allreduce_grads(grads)
grads_and_vars = list(zip(avg_grads, vars))
return super(self.__class__, self).apply_gradients(grads_and_vars, *args, **kwargs)

class _DistributedOptimizerWithGetGradients(keras.optimizers.Optimizer):
def __init__(self, name, device_dense, device_sparse, compression, sparse_as_dense,
config):
if name is None:
name = "Distributed%s" % self.__class__.__base__.__name__
self._allreduce_grads = hvd._make_allreduce_grads_fn(
name, device_dense, device_sparse, compression, sparse_as_dense)
self._name = name
self._device_dense = device_dense
self._device_sparse = device_sparse
self._compression = compression
self._sparse_as_dense = sparse_as_dense
self._get_gradients_used = False
super(self.__class__, self).__init__(**config)

def get_gradients(self, loss, params):
@@ -60,22 +41,41 @@ def get_gradients(self, loss, params):
In DistributedOptimizer, get_gradients() is overriden to also
allreduce the gradients before returning them.
"""
self._get_gradients_used = True
gradients = super(self.__class__, self).get_gradients(loss, params)
if hvd.size() > 1:
return self._allreduce_grads(gradients)
averaged_gradients = []
with tf.name_scope(self._name + "_Allreduce"):
for grad in gradients:
if grad is not None:
if self._sparse_as_dense and \
isinstance(grad, tf.IndexedSlices):
grad = tf.convert_to_tensor(grad)
avg_grad = hvd.allreduce(grad,
device_dense=self._device_dense,
device_sparse=self._device_sparse,
compression=self._compression)
averaged_gradients.append(avg_grad)
else:
averaged_gradients.append(None)
return averaged_gradients
else:
return gradients

def apply_gradients(self, *args, **kwargs):
if not self._get_gradients_used:
raise Exception('`apply_gradients()` was called without a call to '
'`get_gradients()`. If you\'re using TensorFlow 2.0, '
'please specify `experimental_run_tf_function=False` in '
'`compile()`.')
return super(self.__class__, self).apply_gradients(*args, **kwargs)

# We dynamically create a new class that inherits from the optimizer that was passed in.
# The goal is to override get_gradients() method with an allreduce implementation.
# This class will have the same name as the optimizer it's wrapping, so that the saved
# model could be easily restored without Horovod.
if hasattr(optimizer, 'apply_gradients'):
cls = type(optimizer.__class__.__name__, (optimizer.__class__,),
dict(_DistributedOptimizerWithApplyGradients.__dict__))
else:
cls = type(optimizer.__class__.__name__, (optimizer.__class__,),
dict(_DistributedOptimizerWithGetGradients.__dict__))
cls = type(optimizer.__class__.__name__, (optimizer.__class__,),
dict(_DistributedOptimizer.__dict__))
return cls(name, device_dense, device_sparse, compression, sparse_as_dense,
optimizer.get_config())

@@ -33,7 +33,6 @@
from horovod.tensorflow.util import _executing_eagerly, _make_subgraph, _cache

import tensorflow as tf
import warnings


def allreduce(tensor, average=True, device_dense='', device_sparse='',
@@ -134,7 +133,7 @@ def broadcast_global_variables(root_rank):
"""
if _executing_eagerly():
raise RuntimeError(
"Eager Execution is not supported by `hvd.BroadcastGlobalVariablesHook()`. "
"hvd.broadcast_global_variables() does not support eager execution. "
"Please use `hvd.broadcast_variables(<model/optimizer variables>)` instead."
)

@@ -243,23 +242,25 @@ def __init__(self, optimizer, name=None, use_locking=False, device_dense='',
self._allreduce_grads = _make_allreduce_grads_fn(
name, device_dense, device_sparse, compression, sparse_as_dense)

def apply_gradients(self, grads_and_vars, *args, **kwargs):
"""Apply gradients to provided variables.
def compute_gradients(self, *args, **kwargs):
"""Compute gradients of all trainable variables.
See Optimizer.apply_gradients() for more info.
See Optimizer.compute_gradients() for more info.
In DistributedOptimizer, apply_gradients() is overriden to also
allreduce the gradients before applying them.
In DistributedOptimizer, compute_gradients() is overriden to also
allreduce the gradients before returning them.
"""
gradients = self._optimizer.compute_gradients(*args, **kwargs)
if size() > 1:
grads, vars = zip(*grads_and_vars)
grads, vars = zip(*gradients)
avg_grads = self._allreduce_grads(grads)
grads_and_vars = list(zip(avg_grads, vars))
return self._optimizer.apply_gradients(grads_and_vars, *args, **kwargs)
return list(zip(avg_grads, vars))
else:
return gradients

def compute_gradients(self, *args, **kwargs):
def apply_gradients(self, *args, **kwargs):
"""Calls this same method on the underlying optimizer."""
return self._optimizer.compute_gradients(*args, **kwargs)
return self._optimizer.apply_gradients(*args, **kwargs)

def get_slot(self, *args, **kwargs):
"""Calls this same method on the underlying optimizer."""
@@ -364,8 +365,6 @@ def DistributedGradientTape(gradtape, device_dense='', device_sparse='',
performance and memory utilization if the original sparse gradient
has high density. Defaults to false.
"""
warnings.warn('`hvd.DistributedGradientTape()` has been deprecated. '
'Please use `hvd.DistributedOptimizer()` instead.')
cls = type(gradtape.__class__.__name__, (gradtape.__class__,),
dict(_DistributedGradientTape.__dict__))
if hasattr(gradtape, '_watch_accessed_variables'):

0 comments on commit 26b55a7

Please sign in to comment.
You can’t perform that action at this time.