Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SyncBatchNormalization layer for TensorFlow. #2075

Merged
merged 2 commits into from
Jul 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions horovod/tensorflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from horovod.tensorflow.mpi_ops import handle_average_backwards_compatibility, check_num_rank_power_of_2
from horovod.tensorflow.util import _executing_eagerly, _make_subgraph, _cache
from horovod.tensorflow.mpi_ops import join
from horovod.tensorflow.sync_batch_norm import SyncBatchNormalization

import tensorflow as tf

Expand Down
57 changes: 57 additions & 0 deletions horovod/tensorflow/sync_batch_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import tensorflow as tf
from horovod.tensorflow.mpi_ops import _allreduce
from horovod.tensorflow.mpi_ops import size, rank
from horovod.tensorflow.mpi_ops import Sum

try:
_BatchNormalization = tf.compat.v1.layers.BatchNormalization
except AttributeError:
_BatchNormalization = tf.layers.BatchNormalization

class SyncBatchNormalization(_BatchNormalization):
""" Synchronous batch normalization. Stats are synchronized across all workers during training. """

def __init__(self, fused=False, **kwargs):
if fused in (True, None):
raise ValueError('SyncBatchNormalization does not support fused=True.')
if not kwargs.get('name', None):
kwargs['name'] = 'sync_batch_normalization'
super(SyncBatchNormalization, self).__init__(fused=fused, **kwargs)

def _moments(self, inputs, reduction_axes, keep_dims):
"""Compute the mean and variance: it overrides the original _moments."""

worker_mean, worker_variance = super(SyncBatchNormalization, self)._moments(
inputs, reduction_axes, keep_dims=keep_dims)

if size() > 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may want to make this work with dynamic worker count in a follow-up PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, sounds good to me. Thanks for pointing that out and opening the issue.

# Compute variance using: Var[X] = E[X^2] - E[X]^2.
worker_square_of_mean = tf.math.square(worker_mean)
worker_mean_of_square = worker_variance + worker_square_of_mean

# Average stats across all workers
group_mean = _allreduce(worker_mean, op=Sum)
group_mean_of_square = _allreduce(worker_mean_of_square, op=Sum)
group_mean /= size()
group_mean_of_square /= size()

group_variance = group_mean_of_square - tf.math.square(group_mean)

return (group_mean, group_variance)
else:
return (worker_mean, worker_variance)
96 changes: 96 additions & 0 deletions test/test_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1356,6 +1356,102 @@ def test_horovod_join_allreduce(self):
self.assertTrue(diff <= threshold,
"hvd.join with hvd.allreduce on GPU produces incorrect results")

def test_horovod_syncbn_gpu(self):
"""Test that the SyncBatchNormalization implementation is correct on GPU."""
# Only do this test if there are GPUs available.
if not tf.test.is_gpu_available(cuda_only=True):
self.skipTest(("No GPUs available"))

hvd.init()
with tf.device("/gpu:%d" % hvd.local_rank()):
x_list = [
tf.convert_to_tensor(np.stack([
np.array([
[r, r + 1],
[r * 2, r * 2 + 1],
[r * 3, r * 3 + 1],
[r * 4, r * 4 + 1]
], dtype=np.float32)
for r in range(hvd.size())
]), np.float32),
tf.convert_to_tensor(np.stack([
np.array([
[r + 1],
[r * 2 + 1],
[r * 3 + 1],
[r * 4 + 1]
], dtype=np.float32)
for r in range(hvd.size())
]), np.float32),
]

for x in x_list:
try:
bn = tf.layers.BatchNormalization(axis=1)
except AttributeError:
bn = tf.compat.v1.layers.BatchNormalization(axis=1)
sync_bn = hvd.SyncBatchNormalization(axis=1)
bn_func = bn.apply(x, training=True)
sync_bn_func = sync_bn.apply(tf.expand_dims(x[hvd.rank()], 0), training=True)

try:
init = tf.global_variables_initializer()
except AttributeError:
init = tf.compat.v1.global_variables_initializer()
self.evaluate(init)
bn_out = self.evaluate(bn_func)
sync_bn_out = self.evaluate(sync_bn_func)

self.assertAllClose(sync_bn_out, np.expand_dims(bn_out[hvd.rank()], 0))
self.assertAllClose(self.evaluate(sync_bn.moving_mean), self.evaluate(bn.moving_mean))
self.assertAllClose(self.evaluate(sync_bn.moving_variance), self.evaluate(bn.moving_variance))

def test_horovod_syncbn_cpu(self):
"""Test that the SyncBatchNormalization implementation is correct on CPU."""

hvd.init()
with tf.device("/cpu:0"):
x_list = [
tf.convert_to_tensor(np.stack([
np.array([
[r, r + 1],
[r * 2, r * 2 + 1],
[r * 3, r * 3 + 1],
[r * 4, r * 4 + 1]
], dtype=np.float32)
for r in range(hvd.size())
]), np.float32),
tf.convert_to_tensor(np.stack([
np.array([
[r + 1],
[r * 2 + 1],
[r * 3 + 1],
[r * 4 + 1]
], dtype=np.float32)
for r in range(hvd.size())
]), np.float32),
]

for x in x_list:
try:
bn = tf.layers.BatchNormalization(axis=1)
except AttributeError:
bn = tf.compat.v1.layers.BatchNormalization(axis=1)
sync_bn = hvd.SyncBatchNormalization(axis=1)
bn_func = bn.apply(x, training=True)
sync_bn_func = sync_bn.apply(tf.expand_dims(x[hvd.rank()], 0), training=True)

try:
init = tf.global_variables_initializer()
except AttributeError:
init = tf.compat.v1.global_variables_initializer()
self.evaluate(init)
bn_out = self.evaluate(bn_func)
sync_bn_out = self.evaluate(sync_bn_func)

self.assertAllClose(sync_bn_out, np.expand_dims(bn_out[hvd.rank()], 0))
self.assertAllClose(self.evaluate(sync_bn.moving_mean), self.evaluate(bn.moving_mean))
self.assertAllClose(self.evaluate(sync_bn.moving_variance), self.evaluate(bn.moving_variance))

from tensorflow.python.framework.test_util import run_all_in_graph_and_eager_modes
run_all_in_graph_and_eager_modes(TensorFlowTests)
Expand Down