Skip to content

Commit

Permalink
Fixes for TF2.
Browse files Browse the repository at this point in the history
Signed-off-by: Josh Romero <joshr@nvidia.com>
  • Loading branch information
romerojosh committed Jul 13, 2020
1 parent e3f88a8 commit f8d7e69
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 5 deletions.
7 changes: 6 additions & 1 deletion horovod/tensorflow/sync_batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@
from horovod.tensorflow.mpi_ops import size, rank
from horovod.tensorflow.mpi_ops import Sum

class SyncBatchNormalization(tf.layers.BatchNormalization):
try:
_BatchNormalization = tf.compat.v1.layers.BatchNormalization
except AttributeError:
_BatchNormalization = tf.layers.BatchNormalization

class SyncBatchNormalization(_BatchNormalization):
""" Synchronous batch normalization. Stats are synchronized across all workers during training. """

def __init__(self, fused=False, **kwargs):
Expand Down
20 changes: 16 additions & 4 deletions test/test_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1386,12 +1386,18 @@ def test_horovod_syncbn_gpu(self):
]

for x in x_list:
bn = tf.layers.BatchNormalization(axis=1)
try:
bn = tf.layers.BatchNormalization(axis=1)
except AttributeError:
bn = tf.compat.v1.layers.BatchNormalization(axis=1)
sync_bn = hvd.SyncBatchNormalization(axis=1)
bn_func = bn.apply(x, training=True)
sync_bn_func = sync_bn.apply(tf.expand_dims(x[hvd.rank()], 0), training=True)

init = tf.global_variables_initializer()
try:
init = tf.global_variables_initializer()
except AttributeError:
init = tf.compat.v1.global_variables_initializer()
self.evaluate(init)
bn_out = self.evaluate(bn_func)
sync_bn_out = self.evaluate(sync_bn_func)
Expand Down Expand Up @@ -1427,12 +1433,18 @@ def test_horovod_syncbn_cpu(self):
]

for x in x_list:
bn = tf.layers.BatchNormalization(axis=1)
try:
bn = tf.layers.BatchNormalization(axis=1)
except AttributeError:
bn = tf.compat.v1.layers.BatchNormalization(axis=1)
sync_bn = hvd.SyncBatchNormalization(axis=1)
bn_func = bn.apply(x, training=True)
sync_bn_func = sync_bn.apply(tf.expand_dims(x[hvd.rank()], 0), training=True)

init = tf.global_variables_initializer()
try:
init = tf.global_variables_initializer()
except AttributeError:
init = tf.compat.v1.global_variables_initializer()
self.evaluate(init)
bn_out = self.evaluate(bn_func)
sync_bn_out = self.evaluate(sync_bn_func)
Expand Down

0 comments on commit f8d7e69

Please sign in to comment.