Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TF: Add register_local_var to distributed optimizers and gradient aggrega… #3695

Merged
merged 4 commits into from Sep 19, 2022

Conversation

MrAta
Copy link
Contributor

@MrAta MrAta commented Sep 9, 2022

Signed-off-by: Ata FatahiBaarzi afatahibaarzi@linkedin.com

Checklist before submitting

  • Did you read the contributor guide?
  • Did you update the docs?
  • Did you write any tests to validate this change?
  • Did you update the CHANGELOG, if this change affects users?

Description

This is to continue work done in #3628 and #3643 and adds the same functionality to the Distributed Optimizers and Local Gradient Aggregators. This is useful for model parallel use cases that use distributed optimizers and want to skip syncing their "local" gradients.

An example usage is shown in the unit tests included PR, but in short it's like the following:

    ...
    opt = hvd.DistributedOptimizer(opt)

    # Register worker local variables (i.e. local source)
    for var in model.trainable_variables:
      if <var is worker local>:
          opt.register_local_var(var)

    # Compute gradients. Any gradient associated with a var passed to register_local_var will not be modified by Horovod.
    gradients = opt.compute_gradients(loss, model.trainable_variables, tape)

If this change get merged, similar to #3643 we can possibly add a new API called PartialDistributedOptimizer.

@MrAta
Copy link
Contributor Author

MrAta commented Sep 9, 2022

cc @romerojosh @maxhgerlach reopening another one as #3663 was reverted.

@github-actions
Copy link

github-actions bot commented Sep 10, 2022

Unit Test Results

  1 087 files  +  38    1 087 suites  +38   10h 55m 11s ⏱️ - 10m 37s
     816 tests +    3       758 ✔️ +    3       58 💤 ±    0  0 ±0 
21 530 runs  +938  15 146 ✔️ +610  6 384 💤 +328  0 ±0 

Results for commit d4e91b0. ± Comparison against base commit ab97fd1.

♻️ This comment has been updated with latest results.

@github-actions
Copy link

github-actions bot commented Sep 10, 2022

Unit Test Results (with flaky tests)

  1 263 files  +     82    1 263 suites  +82   11h 39m 9s ⏱️ + 3m 40s
     816 tests +       3       758 ✔️ +    3       58 💤 ±    0  0 ±0 
24 978 runs  +1 666  17 064 ✔️ +946  7 914 💤 +720  0 ±0 

Results for commit d4e91b0. ± Comparison against base commit ab97fd1.

♻️ This comment has been updated with latest results.

@maxhgerlach
Copy link
Collaborator

@MrAta, this seems to be failing on GPU consistently (potentially also causing a hang).

Failed GitHub pipeline:
https://github.com/horovod/horovod/runs/8282323717?check_suite_focus=true#step:5:7

Error message on Buildkite which that one leads to:
https://buildkite.com/horovod/horovod/builds/8318#01832608-aa43-415b-8e61-cd74e07c18aa/6-4588

@MrAta
Copy link
Contributor Author

MrAta commented Sep 13, 2022

Hi @maxhgerlach can you please help with finding the test config as well as the stack traces for the failure?
I could don't see any useful info from the log regarding the test failure.

@romerojosh
Copy link
Collaborator

@MrAta The CI reporting on GitHub is a bit odd, but if you click on the "Build and Test GPU (on Builtkite)" entry, you can eventually get to this page which reports the GPU test results: https://buildkite.com/horovod/horovod/builds/8318

The tests here shows some useful stalling information: https://buildkite.com/horovod/horovod/builds/8318#01832614-9be1-4c56-98ea-767272468e6e/6-4795

Looks like maybe one of the ranks in the test isn't broadcasting variables?

@romerojosh
Copy link
Collaborator

@MrAta I tried out the failing tests on a local GPU system and it seems the issue might stem from issues with GPU device assignment. I was getting an error similar to one reported in an old Horovod issue with TF1 (#646). In your case, the issues seem to stem from some of the manual device placement and also the call to tf.test.is_gpu_available(cuda_only=True). At any rate, this device placement code isn't required as the visible device list is already limited to the local_rank indexed GPU in the __init__ method for the TF2KerasTest class (

tf.config.experimental.set_visible_devices(
gpus[hvd.local_rank()], 'GPU')
).

I'm not sure I can push to the master branch of your repository that this PR is sourced from so instead, here is a git diff/patch that I used to get the tests to pass on my system:

diff --git a/test/parallel/test_tensorflow2_keras.py b/test/parallel/test_tensorflow2_keras.py
index 6c4575c..48142f5 100644
--- a/test/parallel/test_tensorflow2_keras.py
+++ b/test/parallel/test_tensorflow2_keras.py
@@ -402,22 +402,12 @@ class Tf2KerasTests(tf.test.TestCase):
                 return var.assign_add(grad)

         backward_passes_per_step = 4
-        local_rank = hvd.local_rank()
-        if tf.test.is_gpu_available(cuda_only=True):
-            with tf.device("/gpu:%d" % local_rank):
-                hvd_optimizer = hvd.DistributedOptimizer(
-                    optimizer=TestingOptimizer("test"),
-                    backward_passes_per_step=backward_passes_per_step,
-                    average_aggregated_gradients=average_aggregated_gradients,
-                    sparse_as_dense=True,
-                )
-        else:
-            hvd_optimizer = hvd.DistributedOptimizer(
-                    optimizer=TestingOptimizer("test"),
-                    backward_passes_per_step=backward_passes_per_step,
-                    average_aggregated_gradients=average_aggregated_gradients,
-                    sparse_as_dense=True,
-                )
+        hvd_optimizer = hvd.DistributedOptimizer(
+            optimizer=TestingOptimizer("test"),
+            backward_passes_per_step=backward_passes_per_step,
+            average_aggregated_gradients=average_aggregated_gradients,
+            sparse_as_dense=True,
+        )

         _ = hvd_optimizer.iterations

@@ -487,15 +477,9 @@ class Tf2KerasTests(tf.test.TestCase):
         def compute_and_apply_gradients_in_tf_function(var_list, **kwargs):
             # Compute and apply gradient updates in tf.function to reproduce
             # how it is done inside `model.fit()`.
-            if tf.test.is_gpu_available(cuda_only=True):
-                with tf.device("/gpu:%d" % local_rank):
-                    grads_and_vars = hvd_optimizer._compute_gradients(
-                        loss, var_list=var_list)
-                    hvd_optimizer.apply_gradients(grads_and_vars, **kwargs)
-            else:
-                grads_and_vars = hvd_optimizer._compute_gradients(
-                    loss, var_list=var_list)
-                hvd_optimizer.apply_gradients(grads_and_vars, **kwargs)
+            grads_and_vars = hvd_optimizer._compute_gradients(
+                loss, var_list=var_list)
+            hvd_optimizer.apply_gradients(grads_and_vars, **kwargs)

         total_num_of_steps = 10
         for idx in range(total_num_of_steps):
@@ -523,7 +507,6 @@ class Tf2KerasTests(tf.test.TestCase):

     def test_distributed_optimizer_with_local_vars(self):
         """ Note: test makes most sense with more than 1 nodes. """
-        hvd.init()
         if hvd.size() == 1:
             self.skipTest("Only one worker available")

@@ -557,20 +540,11 @@ class Tf2KerasTests(tf.test.TestCase):
             local_layers = model.layers[:num_local_layers]
             local_vars = [var for layer in local_layers for var in layer.trainable_weights]

-            local_rank = hvd.local_rank()
-            if tf.test.is_gpu_available(cuda_only=True):
-                with tf.device("/gpu:%d" % local_rank):
-                    opt = hvd.DistributedOptimizer(opt, sparse_as_dense=True)
-                    # register local vars to the opt
-                    for var in local_vars:
-                        opt.register_local_var(var)
-                    gradients_vars_opt = opt._compute_gradients(l, model.trainable_weights, tape=tape)
-            else:
-                opt = hvd.DistributedOptimizer(opt, sparse_as_dense=True)
-                # register local vars to the opt
-                for var in local_vars:
-                    opt.register_local_var(var)
-                gradients_vars_opt = opt._compute_gradients(l, model.trainable_weights, tape=tape)
+            opt = hvd.DistributedOptimizer(opt, sparse_as_dense=True)
+            # register local vars to the opt
+            for var in local_vars:
+                opt.register_local_var(var)
+            gradients_vars_opt = opt._compute_gradients(l, model.trainable_weights, tape=tape)

             var_grad_tape = {var.ref():grad for var,grad in zip(model.trainable_weights, gradients_tape)}
             var_grad_opt = {var.ref():grad for grad,var in gradients_vars_opt}

@MrAta
Copy link
Contributor Author

MrAta commented Sep 14, 2022

Thank you so much @romerojosh! let me apply that patch and push.

@maxhgerlach
Copy link
Collaborator

maxhgerlach commented Sep 14, 2022

Thanks for the support, @romerojosh!

@MrAta, I resolved a minor conflict in CHANGELOG.md via the GitHub UI and that pushed a merge commit to your fork's master branch. Sorry if that causes any confusion... GitHub wouldn't kick off the CI pipelines as long as it saw a merge conflict.

@skyw
Copy link

skyw commented Sep 14, 2022

Does it scale gradient of local var? The reason is all reduce data parallel variable with mean implies gradient is divided by allreduce size. Gradient of local var has to be scaled with the same size. Otherwise gradient of local var and the rest are technically calculated from different loss function, which leads to tiny accuracy loss that is very hard to spot.

Copy link
Collaborator

@maxhgerlach maxhgerlach left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unit test all pass now -- great work!

Josh already reviewed the changes in the earlier version of this PR, so I have very little add; really appreciate the effort adding proper tests for the changes. One thing: #3700 just landed and you should update some of the TF version checks accordingly after rebasing to master.

horovod/_keras/__init__.py Outdated Show resolved Hide resolved
horovod/tensorflow/gradient_aggregation.py Outdated Show resolved Hide resolved
horovod/tensorflow/gradient_aggregation_eager.py Outdated Show resolved Hide resolved
CHANGELOG.md Outdated
@@ -22,6 +22,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
- Added `transformation_edit_fields` and `transformation_removed_fields` param for EstimatorParams. ([#3651](https://github.com/horovod/horovod/pull/3651))
- Added `PartialDistributedGradientTape()` API for model parallel use cases. ([#3643](https://github.com/horovod/horovod/pull/3643))
- Enable use of native `ncclAvg` op for NCCL allreduces. ([#3646](https://github.com/horovod/horovod/pull/3646))
- Added `register_local_var` functionality to distributed optimizers and local gradient aggregators. ([3695](https://github.com/horovod/horovod/pull/3695))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- Added `register_local_var` functionality to distributed optimizers and local gradient aggregators. ([3695](https://github.com/horovod/horovod/pull/3695))
- TensorFlow: Added `register_local_var` functionality to distributed optimizers and local gradient aggregators. ([3695](https://github.com/horovod/horovod/pull/3695))

@maxhgerlach
Copy link
Collaborator

Does it scale gradient of local var? The reason is all reduce data parallel variable with mean implies gradient is divided by allreduce size. Gradient of local var has to be scaled with the same size. Otherwise gradient of local var and the rest are technically calculated from different loss function, which leads to tiny accuracy loss that is very hard to spot.

I don't think that behavior is included with this PR, but I'm also not sure if that's really something that Horovod's DistributedOptimizer should do automatically or rather something that's better controlled explicitly in user code.

@skyw
Copy link

skyw commented Sep 14, 2022

Does it scale gradient of local var? The reason is all reduce data parallel variable with mean implies gradient is divided by allreduce size. Gradient of local var has to be scaled with the same size. Otherwise gradient of local var and the rest are technically calculated from different loss function, which leads to tiny accuracy loss that is very hard to spot.

I don't think that behavior is included with this PR, but I'm also not sure if that's really something that Horovod's DistributedOptimizer should do automatically or rather something that's better controlled explicitly in user code.

I think it is fine to control it in user code with explicit use of gradient tape. But optimizer wraps gradient calculation, communication, and weight update. Combining another level of gradient scale handling seems to defeat the purpose of using distributed optimizer more or less. maybe options to the optimizer?

@MrAta
Copy link
Contributor Author

MrAta commented Sep 14, 2022

Does it scale gradient of local var? The reason is all reduce data parallel variable with mean implies gradient is divided by allreduce size. Gradient of local var has to be scaled with the same size. Otherwise gradient of local var and the rest are technically calculated from different loss function, which leads to tiny accuracy loss that is very hard to spot.

Not sure if I'm understanding your point well or not. But, note that this is for model parallel use cases not data parallel use cases. In model parallel use cases, each rank has exclusively their own local layers (vars, and hence their gradients) which are not shared among ranks. Therefore, averaging the gradients of local layers defeats the purpose of "model parallelism".

The example model in the unittest probably is not the best example to show case, because all ranks deem the first (few) layers as local layers. But, in practice, in more real world use cases (our models at LinkedIn for example), models are multi tower (block), which each tower is local to only one rank.

…tors

Signed-off-by: Ata FatahiBaarzi <afatahibaarzi@linkedin.com>
Signed-off-by: Ata FatahiBaarzi <afatahibaarzi@linkedin.com>
Signed-off-by: Ata FatahiBaarzi <afatahibaarzi@linkedin.com>
@skyw
Copy link

skyw commented Sep 14, 2022

Does it scale gradient of local var? The reason is all reduce data parallel variable with mean implies gradient is divided by allreduce size. Gradient of local var has to be scaled with the same size. Otherwise gradient of local var and the rest are technically calculated from different loss function, which leads to tiny accuracy loss that is very hard to spot.

Not sure if I'm understanding your point well or not. But, note that this is for model parallel use cases not data parallel use cases. In model parallel use cases, each rank has exclusively their own local layers (vars, and hence their gradients) which are not shared among ranks. Therefore, averaging the gradients of local layers defeats the purpose of "model parallelism".

The example model in the unittest probably is not the best example to show case, because all ranks deem the first (few) layers as local layers. But, in practice, in more real world use cases (our models at LinkedIn for example), models are multi tower (block), which each tower is local to only one rank.

I'm talking about hybrid case, where part of the model is in model parallel (e.g. different embedding tables on different rank) and part of the model is in data parallel (e.g. MLP sits on top of embeddings). And there is an alltoall between model and data parallel region. In data parallel region, gradient are averaged by all reduce, technically they are computed w.r.t average global loss. But when it back propagate into model parallel region, activation gradient w.r.t to the global batch size is available in one rank, compute weight gradient from it is w.r.t to sum of global loss not mean, thus the discrepancy. Model will still train but get slightly worse accuracy.
It only happens when back propagated from data parallel region to model parallel region with an alltoall in between

@MrAta
Copy link
Contributor Author

MrAta commented Sep 14, 2022

Does it scale gradient of local var? The reason is all reduce data parallel variable with mean implies gradient is divided by allreduce size. Gradient of local var has to be scaled with the same size. Otherwise gradient of local var and the rest are technically calculated from different loss function, which leads to tiny accuracy loss that is very hard to spot.

Not sure if I'm understanding your point well or not. But, note that this is for model parallel use cases not data parallel use cases. In model parallel use cases, each rank has exclusively their own local layers (vars, and hence their gradients) which are not shared among ranks. Therefore, averaging the gradients of local layers defeats the purpose of "model parallelism".
The example model in the unittest probably is not the best example to show case, because all ranks deem the first (few) layers as local layers. But, in practice, in more real world use cases (our models at LinkedIn for example), models are multi tower (block), which each tower is local to only one rank.

I'm talking about hybrid case, where part of the model is in model parallel (e.g. different embedding tables on different rank) and part of the model is in data parallel (e.g. MLP sits on top of embeddings). And there is an alltoall between model and data parallel region. In data parallel region, gradient are averaged by all reduce, technically they are computed w.r.t average global loss. But when it back propagate into model parallel region, activation gradient w.r.t to the global batch size is available in one rank, compute weight gradient from it is w.r.t to sum of global loss not mean, thus the discrepancy. Model will still train but get slightly worse accuracy. It only happens when back propagated from data parallel region to model parallel region with an alltoall in between

Suppose we have two ranks and each of them have one local embedding layer E1 on rank 1 and E2 on rank 2.
So, the gradients of E1 ONLY exist on rank1 (and undefined for rank2), and gradients of E2 ONLY exist on rank 2 (and undefined on rank 1).
If I understand your point correctly, now, we need to divide the gradients of E1 and E2 by 2 (the world size). How does that help with the issue you're talking about?

@skyw
Copy link

skyw commented Sep 14, 2022

Does it scale gradient of local var? The reason is all reduce data parallel variable with mean implies gradient is divided by allreduce size. Gradient of local var has to be scaled with the same size. Otherwise gradient of local var and the rest are technically calculated from different loss function, which leads to tiny accuracy loss that is very hard to spot.

Not sure if I'm understanding your point well or not. But, note that this is for model parallel use cases not data parallel use cases. In model parallel use cases, each rank has exclusively their own local layers (vars, and hence their gradients) which are not shared among ranks. Therefore, averaging the gradients of local layers defeats the purpose of "model parallelism".
The example model in the unittest probably is not the best example to show case, because all ranks deem the first (few) layers as local layers. But, in practice, in more real world use cases (our models at LinkedIn for example), models are multi tower (block), which each tower is local to only one rank.

I'm talking about hybrid case, where part of the model is in model parallel (e.g. different embedding tables on different rank) and part of the model is in data parallel (e.g. MLP sits on top of embeddings). And there is an alltoall between model and data parallel region. In data parallel region, gradient are averaged by all reduce, technically they are computed w.r.t average global loss. But when it back propagate into model parallel region, activation gradient w.r.t to the global batch size is available in one rank, compute weight gradient from it is w.r.t to sum of global loss not mean, thus the discrepancy. Model will still train but get slightly worse accuracy. It only happens when back propagated from data parallel region to model parallel region with an alltoall in between

Suppose we have two ranks and each of them have one local embedding layer E1 on rank 1 and E2 on rank 2. So, the gradients of E1 ONLY exist on rank1 (and undefined for rank2), and gradients of E2 ONLY exist on rank 2 (and undefined on rank 1). If I understand your point correctly, now, we need to divide the gradients of E1 and E2 by 2 (the world size). How does that help with the issue you're talking about?

After dividing gradients of E1 and E2 by 2, all the gradients (embedding and MLP) are calculated w.r.t to the same loss function, no issue anymore.

Signed-off-by: Ata FatahiBaarzi <afatahibaarzi@linkedin.com>
Copy link
Collaborator

@maxhgerlach maxhgerlach left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

If we feel an extra option to scale gradients is warranted, we can always add it in a follow-up PR.

The latest docs build and the latest ppc64le Jenkins build failed, but both look like flakiness / external conditions (either previous build succeeded for this PR: (1) docs (2) jenkins).

@MrAta
Copy link
Contributor Author

MrAta commented Sep 16, 2022

After dividing gradients of E1 and E2 by 2, all the gradients (embedding and MLP) are calculated w.r.t to the same loss function, no issue anymore.

@skyw I'm not fully convinced with that. Can you please create a discussion thread here so that others can chime in as well?

@maxhgerlach maxhgerlach merged commit 4f723bb into horovod:master Sep 19, 2022
@skyw
Copy link

skyw commented Sep 19, 2022

After dividing gradients of E1 and E2 by 2, all the gradients (embedding and MLP) are calculated w.r.t to the same loss function, no issue anymore.

@skyw I'm not fully convinced with that. Can you please create a discussion thread here so that others can chime in as well?

Will do. I'll probably create a simple reproduce case.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants