Skip to content

Commit

Permalink
TF 2.x BroadcastGlobalVariablesHook Fix (#1265)
Browse files Browse the repository at this point in the history
* TF BroadcastGlobalVariablesHook Fix

Signed-off-by: DEKHTIARJonathan <jdekhtiar@nvidia.com>

* tf.global_variables() fix and uniformization try/except

Signed-off-by: DEKHTIARJonathan <jdekhtiar@nvidia.com>

* get_default_graph issue fix

Signed-off-by: DEKHTIARJonathan <jdekhtiar@nvidia.com>

* Eager Execution Protection for TF Hook

Signed-off-by: DEKHTIARJonathan <contact@jonathandekhtiar.eu>

* Missing Symbol Fix

Signed-off-by: DEKHTIARJonathan <contact@jonathandekhtiar.eu>

* Missing Symbol Fix - TF 1.6.0

Signed-off-by: DEKHTIARJonathan <contact@jonathandekhtiar.eu>

* Removing deprecated function

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/eager/context.py#L1615
Signed-off-by: DEKHTIARJonathan <jdekhtiar@nvidia.com>

* Changed requested before merge

Signed-off-by: DEKHTIARJonathan <jdekhtiar@nvidia.com>

* Typo Fix

Signed-off-by: DEKHTIARJonathan <jdekhtiar@nvidia.com>

* Unittests added

Signed-off-by: DEKHTIARJonathan <jdekhtiar@nvidia.com>

* Unittest Fix

Signed-off-by: DEKHTIARJonathan <jdekhtiar@nvidia.com>

* Unittest Fix

Signed-off-by: DEKHTIARJonathan <jdekhtiar@nvidia.com>

* Unittest Fix

Signed-off-by: DEKHTIARJonathan <jdekhtiar@nvidia.com>

* Python2 Fix Unittest

Signed-off-by: DEKHTIARJonathan <contact@jonathandekhtiar.eu>

* Unittest Fix

Signed-off-by: DEKHTIARJonathan <jdekhtiar@nvidia.com>

* Requested Changes Applied

Signed-off-by: DEKHTIARJonathan <jdekhtiar@nvidia.com>

* Remove spaces

Signed-off-by: Alex Sergeev <alexander.sergeev@live.com>
  • Loading branch information
Jonathan DEKHTIAR authored and alsrgv committed Aug 13, 2019
1 parent 339f850 commit 6078b46
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 24 deletions.
11 changes: 10 additions & 1 deletion .gitignore
@@ -1,13 +1,22 @@
*.pyc
*.so
.idea
.eggs

.vscode/
.idea/

horovod.egg-info
dist
build
docs/_build
.vscode/

env
venv/

examples/checkpoint*

examples/checkpoints/
horovod/tensorflow/mpi_lib.so
horovod/torch/test_cuda/

64 changes: 43 additions & 21 deletions horovod/tensorflow/__init__.py
Expand Up @@ -19,7 +19,6 @@
from __future__ import division
from __future__ import print_function


from horovod.common.util import check_extension

check_extension('horovod.tensorflow', 'HOROVOD_WITH_TENSORFLOW', __file__, 'mpi_lib')
Expand Down Expand Up @@ -111,7 +110,15 @@ def broadcast_variables(variables, root_rank):
return broadcast_group(variables, root_rank)


if hasattr(tf, 'global_variables'):
try:
_global_variables = tf.global_variables
except AttributeError:
try:
_global_variables = tf.compat.v1.global_variables
except AttributeError:
_global_variables = None

if _global_variables is not None:
def broadcast_global_variables(root_rank):
"""Broadcasts all global variables from root rank to all other processes.
Expand All @@ -121,15 +128,31 @@ def broadcast_global_variables(root_rank):
root_rank: rank of the process from which global variables will be broadcasted
to all other processes.
"""
return broadcast_variables(tf.global_variables(), root_rank)


if hasattr(tf, 'train') and hasattr(tf.train, 'SessionRunHook'):
if hasattr(tf, 'estimator') and hasattr(tf.estimator, 'SessionRunHook'):
_SessionRunHook = tf.estimator.SessionRunHook
else:
if _executing_eagerly():
raise RuntimeError(
"Eager Execution is not supported by `hvd.BroadcastGlobalVariablesHook`\n"
"We recommend using `hvd.DistributedGradientTape` instead"
)

return broadcast_variables(_global_variables(), root_rank)

try:
_get_default_graph = tf.get_default_graph
except AttributeError:
try:
_get_default_graph = tf.compat.v1.get_default_graph
except AttributeError:
_get_default_graph = None

try:
_SessionRunHook = tf.estimator.SessionRunHook
except AttributeError:
try:
_SessionRunHook = tf.train.SessionRunHook
except AttributeError:
_SessionRunHook = None

if _SessionRunHook is not None and _get_default_graph is not None:
class BroadcastGlobalVariablesHook(_SessionRunHook):
"""
SessionRunHook that will broadcast all global variables from root rank
Expand Down Expand Up @@ -158,7 +181,7 @@ def __init__(self, root_rank, device=''):
self.device = device

def begin(self):
if not self.bcast_op or self.bcast_op.graph != tf.get_default_graph():
if not self.bcast_op or self.bcast_op.graph != _get_default_graph():
with tf.device(self.device):
self.bcast_op = broadcast_global_variables(self.root_rank)

Expand Down Expand Up @@ -189,19 +212,18 @@ def allreduce_grads(grads):
return allreduce_grads


if hasattr(tf, 'compat') and hasattr(tf.compat, 'v1') and \
hasattr(tf.compat.v1, 'train') and hasattr(tf.compat.v1.train, 'Optimizer'):
try:
# TensorFlow 2.x
_LegacyOptimizer = tf.compat.v1.train.Optimizer
elif hasattr(tf, 'train') and hasattr(tf.train, 'Optimizer'):
# TensorFlow 1.x
_LegacyOptimizer = tf.train.Optimizer
else:
# Future TensorFlow versions
_LegacyOptimizer = None


if _LegacyOptimizer:
except AttributeError:
try:
# TensorFlow 1.x
_LegacyOptimizer = tf.train.Optimizer
except AttributeError:
# Future TensorFlow versions
_LegacyOptimizer = None

if _LegacyOptimizer is not None:
class _DistributedOptimizer(_LegacyOptimizer):
"""An optimizer that wraps another tf.Optimizer, using an allreduce to
average gradient values before applying gradients to model weights."""
Expand Down
4 changes: 2 additions & 2 deletions horovod/tensorflow/util.py
Expand Up @@ -17,7 +17,7 @@
import tensorflow as tf


if LooseVersion(tf.__version__) >= LooseVersion("1.9.0"):
if LooseVersion(tf.__version__) >= LooseVersion('1.7.0'): # Eager Mode has been introduced in TF 1.7.0
from tensorflow.python.eager import context
_has_eager = True
else:
Expand All @@ -26,7 +26,7 @@

def _executing_eagerly():
"""Returns true if eager execution is supported and enabled."""
return _has_eager and context.in_eager_mode()
return _has_eager and context.executing_eagerly()


def _make_subgraph(f):
Expand Down
19 changes: 19 additions & 0 deletions test/test_tensorflow.py
Expand Up @@ -975,6 +975,25 @@ def test_horovod_broadcast_grad_gpu(self):
"gradient %s differs from expected %s, "
"error: %s" % (grad_out, expected, str(err)))

def test_horovod_broadcast_eager_mode_error(self):
"""Test that tries to broadcast tensorflow global variables
in eager execution mode. This call should raise a RuntimeError."""

if not hvd.util._executing_eagerly():
return

with self.assertRaises(RuntimeError):
hvd.broadcast_global_variables(root_rank=0)

def test_horovod_broadcast_graph_mode(self):
"""Test that tries to broadcast tensorflow global variables
in graph execution mode. This call should not raise any exception."""

if hvd.util._executing_eagerly():
return

hvd.broadcast_global_variables(root_rank=0)

def test_compression_fp16(self):
valid_dtypes = [tf.float16, tf.float32, tf.float64]
invalid_dtypes = [tf.uint8, tf.int8, tf.uint16, tf.int16,
Expand Down

0 comments on commit 6078b46

Please sign in to comment.