Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TensorFlow sync batch norm elastic compatibility #2100

Open
tgaddair opened this issue Jul 13, 2020 · 15 comments
Open

TensorFlow sync batch norm elastic compatibility #2100

tgaddair opened this issue Jul 13, 2020 · 15 comments

Comments

@tgaddair
Copy link
Collaborator

Following #2075, TensorFlow now supports sync batch norm.

Currently we use size() constant to determine whether to do sync batch norm and how to scale. This works in eager mode, but not graph mode. We should use the newly introduced size_op() instead.

@weiminggao
Copy link

tf 1.15 support BN?

@tgaddair
Copy link
Collaborator Author

Yes, it should be supported in TF 1.15. @romerojosh can you confirm which versions of TF are supported?

@weiminggao
Copy link

Yes, it should be supported in TF 1.15. @romerojosh can you confirm which versions of TF are supported?

The graph mode can also work in tf 1.15?

@weiminggao
Copy link

I test BN in graph mode(tf 1.14) but nor work. In future, will support graph mode?

@tgaddair
Copy link
Collaborator Author

Hey @weiminggao, what's the error you're seeing? We have tests that run with graph mode on TF 1.14 and 1.15 here, so if it should be working.

@weiminggao
Copy link

weiminggao commented Aug 1, 2020

Thans very much, now, the Sync_BN can work well. But when I use it to train, it shows ”Stalled ranks:“.I add UPDATE_OP Dependency.

Which is the right way to use the Sync_BN? @tgaddair

@weiminggao
Copy link

weiminggao commented Aug 3, 2020

My Code:
#####################:model
bn = hvd.SyncBatchNormalization()
cards_feature1 = bn.apply(cards_feature1, training = is_training)
cards_feature1 = tf.nn.relu(cards_feature1)
cards_feature1 = tf.layers.conv1d(inputs = cards_feature1, filters = 256, kernel_size = 3, strides = 1, padding = 'same', use_bias = True)
bn = hvd.SyncBatchNormalization()
cards_feature1 = bn.apply(cards_feature1, training = is_training)
cards_feature1 = tf.nn.relu(cards_feature1)
cards_feature = cards_feature + cards_feature1
######################:train
self.loss = self.cross_entropy# + self.mse
self.opt = tf.train.AdamOptimizer(self.learning_rate * hvd.size())
self.opt = hvd.DistributedOptimizer(self.opt)
global_step = tf.train.get_or_create_global_step()
#print(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
self.train_op = self.opt.minimize(self.loss, global_step)

@tgaddair
Copy link
Collaborator Author

tgaddair commented Aug 3, 2020

@romerojosh can you take a look at the usage of Sync Batch Norm here?

@romerojosh
Copy link
Collaborator

I will take a look and report back.

@romerojosh
Copy link
Collaborator

Considering the tf.layers.batch_normalization documentation, the dependency on UPDATE_OPS is required in order for BN statistics to be updated between training steps. So this appears to be the correct way to use hvd.SyncBatchNormalization as it is a modified form of that layer.

@weiminggao
Copy link

weiminggao commented Aug 4, 2020

Thanks, but it can not work well when train by this way, shows ”Stalled ranks:“. I use tf1.14, can you test it as same as this way?@romerojosh
image

@weiminggao
Copy link

weiminggao commented Aug 4, 2020

And I try to delete UPDATE_OPS. But when train, the same problem happens. @romerojosh

def model(is_value = True):
       ...
        cards_feature = tf.layers.conv1d(inputs = cards_input1, filters = 256, kernel_size = 3, strides = 1, padding = 'same', use_bias = True, activation = tf.nn.relu)
        for i in range(20):
            cards_feature1 = tf.layers.conv1d(inputs = cards_feature, filters = 256, kernel_size = 3, strides = 1, padding = 'same', use_bias = True, activation = tf.nn.relu)
            cards_feature1 = tf.layers.conv1d(inputs = cards_feature1, filters = 256, kernel_size = 3, strides = 1, padding = 'same', use_bias = True)
            cards_feature = cards_feature + cards_feature1
            bn = hvd.SyncBatchNormalization()
            cards_feature = bn.apply(cards_feature, training = is_training)
            cards_feature = tf.nn.relu(cards_feature)
        cards_feature = tf.layers.conv1d(inputs = cards_feature, filters = 1, kernel_size = 1, strides = 1, padding = 'same', use_bias = True, activation = tf.nn.relu)
        _cards_feature = tf.reshape(cards_feature, shape = (-1, 34))
        cards_feature = tf.reshape(tf.tile(_cards_feature, [1, 22]), shape = (-1, 22, 34))
self.cross_entropy = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels = tf.reshape(self.input_dict['label'], shape = (-1, )), logits = self.output)) #交叉熵
self.mse = tf.reduce_mean(tf.square(self.value_predict - self.input_dict['value']))
self.loss = self.cross_entropy# + self.mse
self.opt = tf.train.AdamOptimizer(self.learning_rate * hvd.size())
self.opt = hvd.DistributedOptimizer(self.opt)
global_step = tf.train.get_or_create_global_step()
#with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
self.train_op = self.opt.minimize(self.loss, global_step)

running script:

horovodrun -np 2 -H localhost:2 python train.py

@romerojosh
Copy link
Collaborator

Hi @weiminggao,
I cannot reproduce the stall with either TF 1.14 or TF 1.15 using a variation of the model you provided using dummy input data. This is the script I tested (script.py):

import tensorflow as tf
import horovod.tensorflow as hvd
import numpy as np

# Horovod/TensorFlow setup
hvd.init()
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.gpu_options.visible_device_list = str(hvd.local_rank())

# Dummy model
x_in = tf.placeholder(dtype = tf.float32, shape = [None, 34, 34])
x = tf.layers.conv1d(inputs = x_in, filters = 256, kernel_size = 3, strides = 1, padding = 'same', use_bias = True, activation = tf.nn.relu)
for i in range(20):
  x1 = tf.layers.conv1d(inputs = x, filters = 256, kernel_size = 3, strides = 1, padding = 'same', use_bias = True, activation = tf.nn.relu)
  x1 = tf.layers.conv1d(inputs = x1, filters = 256, kernel_size = 3, strides = 1, padding = 'same', use_bias = True)
  x = x + x1
  #bn = tf.layers.BatchNormalization()
  bn = hvd.SyncBatchNormalization()
  x = bn.apply(x, training=True)
  x = tf.nn.relu(x)
x = tf.layers.conv1d(inputs = x, filters = 1, kernel_size = 1, strides = 1, padding = 'same', use_bias = True, activation = tf.nn.relu)
_x = tf.reshape(x, shape = (-1, 34))
x = tf.reshape(tf.tile(_x, [1, 22]), shape = (-1, 22, 34))

# Optimizer and dummy loss
opt = tf.train.AdamOptimizer(learning_rate=0.001)
opt = hvd.DistributedOptimizer(opt)
loss = tf.reduce_mean(x)

with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
  train_op = opt.minimize(loss)


with tf.Session(config=config) as sess:
  sess.run(tf.global_variables_initializer())

  # Run 100 steps
  for i in range(100):
    print(f"step {i}")
    sess.run(train_op, feed_dict={x_in: np.random.rand(128,34,34)})

Running this script with horovodrun -np 2 python script.py does not stall and all 100 steps are completed with TF 1.14 and TF 1.15.

If you try this script, does it still stall? Also, in your original case with the stall, how long did you wait before cancelling the run? It is possible the stall message is just due to rank 0 taking more time to startup than the other ranks.

@weiminggao
Copy link

Thanks very much, now it can work well.

@fferroni
Copy link

fferroni commented May 18, 2021

@tgaddair Tensorflow now mainly encourages tf.keras APIs. There are tf.keras.layers.BatchNormalization and tf.keras.layers.experimental.SynchBatchNormalization. How does this hvd.SynchBatchNormalization() fit with respect to these? Should we expect batch statistics to be synchronized in Horovod if we use these keras layers? Thank you

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Development

No branches or pull requests

4 participants