Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Horovod behavior with dataset API #223

Closed
benyoti opened this issue Mar 27, 2018 · 10 comments
Closed

Horovod behavior with dataset API #223

benyoti opened this issue Mar 27, 2018 · 10 comments

Comments

@benyoti
Copy link

benyoti commented Mar 27, 2018

Hi,
I was doing some test with the tensorflow dataset API (tf.data.Dataset) on a single machine with multiple GPUs, but it looks like horovod is sending the same data to each GPU each time the dataset's iterator is called (I am not using MonitoredTrainingSession but a standard tf.Session())
Is this behavior intended? Do you have any idea how to overcome this problem?

@alsrgv
Copy link
Member

alsrgv commented Mar 28, 2018

@benyoti, that's very interesting. Are you manually setting random seed? Are you using shuffling?

@benyoti
Copy link
Author

benyoti commented Mar 28, 2018

I am not doing any shuffling or setting a random seed in order to see what's going on. I wrote a basic script using the Dataset API, AlexNet and incorporated horovod following examples/tensorflow_word2vec.py

If I use one GPU:

$ CUDA_VISIBLE_DEVICES=3 mpirun -np 1 -H localhost:1 -bind-to none -map-by slot -x NCCL_DEBUG=INFO -x LD_LIBRARY_PATH -x PATH -mca pml ob1 -mca btl ^openib python test_multigpus.py
Building graph...
(...)
2018-03-28 18:24:07.886129: I tensorflow/core/common_runtime/gpu/gpu_device.cc:993] Creating TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 11366 MB memory) -> physical GPU (device: 0, name: TITAN Xp, pci bus id: 0000:b3:00.0, compute capability: 6.1)
Batch 1 - gpu 0
[0 0 0 0]
Batch 2 - gpu 0
[1 1 1 1]
Batch 3 - gpu 0
[2 2 2 2]
End of iterator/epoch 1
Batch 4 - gpu 0
[0 0 0 0]
Batch 5 - gpu 0
[1 1 1 1]
Batch 6 - gpu 0
[2 2 2 2]
End of iterator/epoch 2
Done training for 2 epochs, 6 batches.

If I use 2 GPUs:

$ CUDA_VISIBLE_DEVICES=2,3 mpirun -np 2 -H localhost:2 -bind-to none -map-by slot -x NCCL_DEBUG=INFO -x LD_LIBRARY_PATH -x PATH -mca pml ob1 -mca btl ^openib python test_multigpus.py
Building graph...
Building graph...
(...)
2018-03-28 18:24:37.228971: I tensorflow/core/common_runtime/gpu/gpu_device.cc:993] Creating TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 11358 MB memory) -> physical GPU (device: 0, name: TITAN Xp, pci bus id: 0000:65:00.0, compute capability: 6.1)
2018-03-28 18:24:37.238365: I tensorflow/core/common_runtime/gpu/gpu_device.cc:993] Creating TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 11359 MB memory) -> physical GPU (device: 1, name: TITAN Xp, pci bus id: 0000:b3:00.0, compute capability: 6.1)
Batch 1 - gpu 0
[0 0 0 0]
Batch 1 - gpu 1
[0 0 0 0]
Batch 2 - gpu 0
[1 1 1 1]
Batch 2 - gpu 1
[1 1 1 1]
Batch 3 - gpu 0
[2 2 2 2]
Batch 3 - gpu 1
[2 2 2 2]
End of iterator/epoch 1
End of iterator/epoch 1
Batch 4 - gpu 0
[0 0 0 0]
Batch 4 - gpu 1
[0 0 0 0]
Batch 5 - gpu 0
[1 1 1 1]
Batch 5 - gpu 1
[1 1 1 1]
Batch 6 - gpu 0
[2 2 2 2]
Batch 6 - gpu 1
[2 2 2 2]
End of iterator/epoch 2
Done training for 2 epochs, 6 batches

What is shown are the labels input after calling a sess.run(). I have 3 classes, 4 samples per classes (12 samples in total), batch size is 4 and nb. of epoch 2.

Ideally, what you would like when using 2 GPUs would be something like:

Batch 1 - gpu 0
[0 0 0 0]
Batch 1 - gpu 1
[1 1 1 1]
Batch 2 - gpu 0
[2 2 2 2]
End of iterator/epoch 1

This is my training loop if it can help:

with tf.Session(graph=g, config=config) as sess:
    sess.run(init_op)
    sess.run(bcast)
    batch = 1
    for epoch in range(1, N_EPOCHS+1):
        sess.run(train_iter_init_op) # Initialise train iterator
        while True:
            try: 
                img, lbl, _, = sess.run([images, labels, train_op], feed_dict={is_training:True})
                print('Batch {} - gpu {}'.format(batch, hvd.rank()))
                print(lbl)

                for k in range(img.shape[0]):
                    imsave('imgs/img_epoch' + str(epoch) + '_batch' + str(batch)  + '_gpu' + str(hvd.rank()+1) + '_img' + str(k+1) + '.png', img[k,:,:,:])

                time.sleep(1)
                batch += 1
            except tf.errors.OutOfRangeError:
                print("End of iterator/epoch {}".format(epoch))
                break

        epoch += 1
        
    print('Done training for {} epochs, {} batches.'.format(N_EPOCHS, batch-1))

@alsrgv
Copy link
Member

alsrgv commented Mar 28, 2018

Can you add shuffling to your dataset? That should help different processes read different data. It's also help to reduce over fitting.

@ghost
Copy link

ghost commented Mar 31, 2018

@alsrgv @benyoti I have the same problem. In my tests I have found that If according to my specified batch size, one epoch should get over in 200 steps, and I am training on 8 gpus, then instead one epoch gets over in 1600 epochs. This corresponds exactly to the issue pointed.

@alsrgv I do not see any reason for things to improve by shuffling the dataset. This is not a problem of dataset shuffling. Instead it has to do with how get_next() method of tf.data.Iterator is distributed by horovod.
BTW : I added shuffling to my dataset and it did not work. If you'd like to do check, I can share my repository with you
I am trying to look into ways to correct this problem and I hope other contributors look into it as well.

@ghost
Copy link

ghost commented Apr 1, 2018

@alsrgv @benyoti The problem can be solved by using shard() method of tf.data.Dataset class.

@alsrgv
Copy link
Member

alsrgv commented Apr 2, 2018

@calledbymountains, thanks for sharing. Indeed, per https://www.tensorflow.org/api_docs/python/tf/data/Dataset#shard, you can do:

d = d.shard(hvd.size(), hvd.rank())

@benyoti
Copy link
Author

benyoti commented Apr 3, 2018

Great, thanks! I can confirm this is working as intended :)

@benyoti benyoti closed this as completed Apr 3, 2018
@gururao001
Copy link

So will the same behavior hold true for multiple node setting as well? Say there are 2 servers with 2 gpus
each of the servers will shard the dataset based on rank but both the servers will get the same initial dataset to shard right?

server 1
Batch 1 - gpu 0
[0 0 0 0]
Batch 1 - gpu 1
[1 1 1 1]

server 2
Batch 1 - gpu 0
[0 0 0 0]
Batch 1 - gpu 1
[1 1 1 1]

How can I modify this to multiple node setting?
so that server 2 gets
Batch 1 - gpu 0
[3 3 3 3]
Batch 1 - gpu 1
[4 4 4 4]

@alsrgv
Copy link
Member

alsrgv commented Apr 13, 2018

@gururao001, I believe your desired outcome is what dataset API will do for you. hvd.size() will be 4, and hvd.rank() will be [0..3].

@albertz
Copy link

albertz commented May 23, 2020

Note that this pipeline with shard is somewhat inefficient. You load and preprocess all the data in all instances, and then you throw away (N-1)/N of the data.

It would be much more efficient to just load (and preprocess) the data once (e.g. in instance 0, or some extern instance), and then use sth like MultiDeviceIterator (or similar/equivalent). But I'm not sure if there is something equivalent already for Horovod. Is there?

I posted this question also here on StackOverflow.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Development

No branches or pull requests

4 participants