Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Horovod GradientTape performance for Tensorflow #1177

Closed
aj-prime opened this issue Jun 28, 2019 · 12 comments · Fixed by #1193
Closed

Horovod GradientTape performance for Tensorflow #1177

aj-prime opened this issue Jun 28, 2019 · 12 comments · Fixed by #1193
Assignees
Labels

Comments

@aj-prime
Copy link

aj-prime commented Jun 28, 2019

Environment:

  1. Framework: Tensorflow
  2. Framework version: 1.12
  3. Horovod version:0.16.4
  4. MPI version: Mvapich-GDR 2.3.1
  5. CUDA version: 9.2
  6. NCCL version:
  7. Python version: 3.6
  8. OS and version: Red Hat Enterprise Linux Server, 7.5 (Maipo)
  9. GCC version: xl/2019.02.07
  10. Architecture: Power9, V100 GPU

Your question:
I have modified tensorflow_mnist_eager.py script in the examples to print images per sec.
Batch size: 32
#GPUs Perf(images/sec)
1 2216
2 425
4 640

It looks like there is an initial overhead of distributing the DNN training.
Is it the expected behavior?

@alsrgv alsrgv added bug and removed question labels Jun 28, 2019
@alsrgv
Copy link
Member

alsrgv commented Jun 28, 2019

This is not expected. @tgaddair, could you take a look?

@alsrgv
Copy link
Member

alsrgv commented Jun 28, 2019

@aj-prime, by the way, any reason you're not using NCCL?

@aj-prime
Copy link
Author

@alsrgv No specific reason

@alsrgv
Copy link
Member

alsrgv commented Jun 28, 2019

@aj-prime, NCCL will give you better performance on GPU compared to MPI. That said, I do see a slowdown of Eager TF compared to a regular TF even with NCCL in my environment.

@aj-prime
Copy link
Author

@alsrgv I am trying to setup horovod with NCCL. Is the slowdown as severe as MV2-GDR?

@alsrgv
Copy link
Member

alsrgv commented Jun 28, 2019

@aj-prime, I have not tried MV2-GDR, so not sure. What kind of performance are you seeing with Graph TensorFlow MNIST example?

@aj-prime
Copy link
Author

@alsrgv For tensorflow_mnist.py script with batch size 32, I am getting following numbers
1 GPU: 11292
2 GPUs: 12300
4 GPUs: 20080

@alsrgv
Copy link
Member

alsrgv commented Jun 28, 2019

OK, that's much better. We'll look into the eager mode performance.

@tgaddair
Copy link
Collaborator

Hey @aj-prime, can you try running tensorflow_synthetic_benchmark.py with --eager and without? That uses ResNet50, which might provide more interesting data.

@aj-prime
Copy link
Author

hello @tgaddair

I ran tensorflow_synthetic_benchmark.py in the graph and eager mode. Here are the results

Graph Mode
1 GPU: 336.7 +-0.3
2 GPUS: 574.3 +-3.2
4 GPUs: 1090.5 +-13.1

Eager Mode
1 GPU: 257.1 +-0.6
2 GPUs: 249.5 +-11.5
4 GPUs: 447.2 +-56.4

@tgaddair
Copy link
Collaborator

tgaddair commented Jul 3, 2019

Those numbers look a lot better, though still not great (but there's some significant performance penalties to eager execution in TensorFlow at present). One thing we do in the synthetic benchmarks but not the MNIST example is device placement: with tf.device(device):.

Without device placement, allreduce happens on CPU, which can slow things down considerably. Can you try adding with tf.device('GPU'): to your MNIST benchmark and see if that makes any difference?

It could also be simply due to the fact that ResNet50 is a more complex model, so more of the time will be spent in computation vs communication.

I'll see if I can repro on our end.

@alsrgv
Copy link
Member

alsrgv commented Jul 5, 2019

@aj-prime, we figured out why tensorflow_mnist_eager.py had a very bad performance. We recompiled tf.function() that contains allreduce subgraph. This is getting fixed in #1193

There is another issue in eager mode though. In graph mode, TensorFlow can start reducing gradients for layers close to the loss while the rest of gradients are still getting computed. This ensures proper ordering of allreduce operations. In eager mode, allreduce starts after all the gradients are computed, which causes an additional delay & randomized ordering of gradient reductions.

Because of that, it's recommended to wrap the whole training step in @tf.function, like this:

@tf.function
def training_step(images, labels, first_batch):
    with tf.GradientTape() as tape:
        logits = mnist_model(images, training=True)
        loss_value = loss(labels, logits)

    # Horovod: add Horovod Distributed GradientTape.
    tape = hvd.DistributedGradientTape(tape)

    grads = tape.gradient(loss_value, mnist_model.trainable_variables)
    opt.apply_gradients(zip(grads, mnist_model.trainable_variables))

    # Horovod: broadcast initial variable states from rank 0 to all other processes.
    # This is necessary to ensure consistent initialization of all workers when
    # training is started with random weights or restored from a checkpoint.
    if first_batch:
        hvd.broadcast_variables(mnist_model.variables, root_rank=0)
        hvd.broadcast_variables(opt.variables(), root_rank=0)

    return loss_value

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Development

Successfully merging a pull request may close this issue.

3 participants