New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
XLA HorovodAllreduce for tf.function(jit_compile=True) #3053
XLA HorovodAllreduce for tf.function(jit_compile=True) #3053
Conversation
@romerojosh FYI. |
a57296c
to
41d9872
Compare
The build failed because the XLA header files are missing in the container?
Could someone give me some pointers about how I could fix it? Thanks! |
From the build log it looks like only I guess what headers are put into tensorflow_core would be controlled by https://github.com/tensorflow/tensorflow/blob/r2.1/tensorflow/core/BUILD, but I also get the impression that the tensorflow_core package disappeared with version 2.2. When I build locally with TF 2.4, So maybe it would suffice to limit building the XLA support to TF 2.2+? |
Thanks for taking a look at it. Right, I built locally with TF2.5 and it includes only BTW, I committed some changes into XLA to make its schedule work better with the Horovod changes in this PR for performance. So, this PR depends on TF2.5+ anyway. Let me see if I can find a way to skip building the new codes if TF version is too old. |
@@ -153,7 +154,7 @@ namespace common { | |||
#define JOIN_TENSOR_NAME "join.noname" | |||
|
|||
// List of supported frameworks. | |||
enum Framework { TENSORFLOW, PYTORCH, MXNET }; | |||
enum Framework { TENSORFLOW, PYTORCH, MXNET, XLA }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that XLA is implemented only for tensorflow, why XLA declare here as a framework ? So does the design of cpp namespace level. Thx for your explain in advance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the question. My rationale is explained as below:
- XLA is independent of Tensorflow, although they are in the same repo now. There are talks about separating XLA from Tensorflow although I don't know when that will happen. In addition, XLA (as a DL compiler) is used by frontends other than Tensorflow, such as JAX and Pytorch (not in its main repo, I guess), etc.
- I certainly can understand that it is inaccurate to say XLA is a framework. However, as XLA is independent of Tensorflow, it does not use Tensorflow constructs and cannot use TFOpContext. It then needs to create a new XLAOpContext with a new "framework" name. I don't see alternatives.
- If the name is really confusing to people, I may suggest rename "framework" to "backend".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your reply, you are right that XLA is a layer independent of tf or even torch. XLA is a layer tf/torch can leverage on in my opinion, that's why I propose the question. I'm not sure am I right with following consideration,
- XLA can not be used/run without a framework(tf/torch/mxnet)
- XLA will be run in the way tf+XLA (the version you are working on), or torch+XLA, or more.
If so, may be we should keep framework declaration and consider to implement XLA without changing it even by add backend ?
Overall, your work is wonderful, it's ok to keep this version if we cannot figure out a better one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right that XLA itself is not complete as a framework.
How about let's rename XLA
to TF_XLA
?
That is, enum Framework { TENSORFLOW, PYTORCH, MXNET, TF_XLA };
This is conceptually clean and simple. The only theoretical drawback I can think of for this approach is that you might later have PYTORCH_XLA, etc. due to combination (of framework and XLA) but I doubt that if this will really become a problem in practice (mainly due to few combinations that really exist in the world). Even if this becomes a problem in practice in the future, we will know better how to deal with it at that time given more data points.
I will make the change accordingly if this makes sense to you. Please let me know.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, we have clarity on the issues but not achieve a solution yet, maybe we should discuss it with the community.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Responded. Let me know if this addresses your comments. Thanks!
@@ -153,7 +154,7 @@ namespace common { | |||
#define JOIN_TENSOR_NAME "join.noname" | |||
|
|||
// List of supported frameworks. | |||
enum Framework { TENSORFLOW, PYTORCH, MXNET }; | |||
enum Framework { TENSORFLOW, PYTORCH, MXNET, XLA }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the question. My rationale is explained as below:
- XLA is independent of Tensorflow, although they are in the same repo now. There are talks about separating XLA from Tensorflow although I don't know when that will happen. In addition, XLA (as a DL compiler) is used by frontends other than Tensorflow, such as JAX and Pytorch (not in its main repo, I guess), etc.
- I certainly can understand that it is inaccurate to say XLA is a framework. However, as XLA is independent of Tensorflow, it does not use Tensorflow constructs and cannot use TFOpContext. It then needs to create a new XLAOpContext with a new "framework" name. I don't see alternatives.
- If the name is really confusing to people, I may suggest rename "framework" to "backend".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Responded. I will act according to your response.
@@ -153,7 +154,7 @@ namespace common { | |||
#define JOIN_TENSOR_NAME "join.noname" | |||
|
|||
// List of supported frameworks. | |||
enum Framework { TENSORFLOW, PYTORCH, MXNET }; | |||
enum Framework { TENSORFLOW, PYTORCH, MXNET, XLA }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right that XLA itself is not complete as a framework.
How about let's rename XLA
to TF_XLA
?
That is, enum Framework { TENSORFLOW, PYTORCH, MXNET, TF_XLA };
This is conceptually clean and simple. The only theoretical drawback I can think of for this approach is that you might later have PYTORCH_XLA, etc. due to combination (of framework and XLA) but I doubt that if this will really become a problem in practice (mainly due to few combinations that really exist in the world). Even if this becomes a problem in practice in the future, we will know better how to deal with it at that time given more data points.
I will make the change accordingly if this makes sense to you. Please let me know.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution @trentlo! These XLA updates look quite nice, and it will be good to see the implementation extended to other operations in the future.
Given that the feature is opt-in via HOROVOD_ENABLE_XLA_OPS
, seems safe to merge this initial implementation in to me.
Thanks for taking a look at it, @romerojosh. I do have a plan to add some more Horovod ops for XLA. :-) |
Unit Test Results 385 files ±0 385 suites ±0 4h 45m 0s ⏱️ ±0s For more details on these failures, see this check. Results for commit f4d519c. ± Comparison against base commit f4d519c. ♻️ This comment has been updated with latest results. |
72df0ae
to
0b7e04d
Compare
There was 1 build failure because Some details: It did not make into TF2.5, i.e., no such construct. I was confused by NV's TF source code because I probably back ported the feature into NV's TF2.5. |
@romerojosh, looks like I need your help to hit the CI once again. There was another linking failure due to missing A guard is introduced in cmake to fix it as in 1a031b3. |
1a031b3
to
1ae617d
Compare
It depends on TF2.6 because of the new CustomCallSchedule to give scheduling hints to HLOs, which is essentail to performance when lowering HorovodAllreduce into XLA. Signed-off-by: Trent Lo <trentl@nvidia.com>
Signed-off-by: Trent Lo <trentl@nvidia.com>
Signed-off-by: Trent Lo <trentl@nvidia.com>
Signed-off-by: Trent Lo <trentl@nvidia.com>
Signed-off-by: Trent Lo <trentl@nvidia.com>
Signed-off-by: Trent Lo <trentl@nvidia.com>
Signed-off-by: Trent Lo <trentl@nvidia.com>
Signed-off-by: Trent Lo <trentl@nvidia.com>
XLA may have problem dealing with it. Signed-off-by: Trent Lo <trentl@nvidia.com>
Signed-off-by: Trent Lo <trentl@nvidia.com>
Signed-off-by: Trent Lo <trentl@nvidia.com>
Signed-off-by: Trent Lo <trentl@nvidia.com>
Signed-off-by: Trent Lo <trentl@nvidia.com>
Signed-off-by: Trent Lo <trentl@nvidia.com>
Signed-off-by: Trent Lo <trentl@nvidia.com>
Signed-off-by: Trent Lo <trentl@nvidia.com>
Signed-off-by: Trent Lo <trentl@nvidia.com>
to latencies. Signed-off-by: Trent Lo <trentl@nvidia.com>
Signed-off-by: Trent Lo <trentl@nvidia.com>
Signed-off-by: Trent Lo <trentl@nvidia.com>
Signed-off-by: Trent Lo <trentl@nvidia.com>
Also, prefix with EARLIEST and LATEST with SCHEDULE_. Signed-off-by: Trent Lo <trentl@nvidia.com>
Signed-off-by: Trent Lo <trentl@nvidia.com>
Failing test is unrelated. This PR looks good to me. Thanks for the contribution @trentlo! |
Unit Test Results (with flaky tests) 449 files ±0 449 suites ±0 6h 36m 53s ⏱️ ±0s For more details on these failures, see this check. Results for commit f4d519c. ± Comparison against base commit f4d519c. ♻️ This comment has been updated with latest results. |
@cheshire FYI. XLA |
list(GET Tensorflow_LIBRARIES_LIST 0 Tensorflow_LIB_PATH) | ||
if (Tensorflow_VERSION VERSION_GREATER_EQUAL "2.6") | ||
# XLA implementations are in _pywrap_tensorflow_internal.so | ||
set(Tensorflow_LIBRARIES "${Tensorflow_LIBRARIES} ${Tensorflow_LIB_PATH}/python/ -l:_pywrap_tensorflow_internal.so") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@trentlo @romerojosh
This _pywrap_tensorflow_internal.so
is not available on OSX
The missing symbols are resolved by linking against _pywrap_tensorflow_internal.so, which was introduced to Horovod with PR horovod#3053. Signed-off-by: Max H. Gerlach <git@maxgerlach.de>
The missing symbols are resolved by linking against _pywrap_tensorflow_internal.so, which was introduced to Horovod with PR horovod#3053. Signed-off-by: Max H. Gerlach <git@maxgerlach.de>
The missing symbols are resolved by linking against _pywrap_tensorflow_internal.so, which was introduced to Horovod with PR horovod#3053. Signed-off-by: Max H. Gerlach <git@maxgerlach.de>
Great work! @trentlo is there a minimum TF version required? |
TF2.6+ is required. |
The missing symbols are resolved by linking against _pywrap_tensorflow_internal.so, which was introduced to Horovod with PR horovod#3053. Signed-off-by: Max H. Gerlach <git@maxgerlach.de>
The missing symbols are resolved by linking against _pywrap_tensorflow_internal.so, which was introduced to Horovod with PR horovod#3053. Signed-off-by: Max H. Gerlach <git@maxgerlach.de>
The missing symbols are resolved by linking against _pywrap_tensorflow_internal.so, which was introduced to Horovod with PR horovod#3053. Signed-off-by: Max H. Gerlach <git@maxgerlach.de>
* Update comment in FindTensorflow.cmake Signed-off-by: Max H. Gerlach <git@maxgerlach.de> * Add in-place broadcast_() and broadcast_variables() for TF Signed-off-by: Max H. Gerlach <git@maxgerlach.de> * Include source files from TF in build to avoid missing symbol errors Signed-off-by: Max H. Gerlach <git@maxgerlach.de> * Limit build and test to TF 2.6+ Signed-off-by: Max H. Gerlach <git@maxgerlach.de> * Remove source files copied from TensorFlow The missing symbols are resolved by linking against _pywrap_tensorflow_internal.so, which was introduced to Horovod with PR #3053. Signed-off-by: Max H. Gerlach <git@maxgerlach.de> * Fix possible type attribute values for HorovodBroadcastInplace Signed-off-by: Max H. Gerlach <git@maxgerlach.de> * Add reference variables to test Signed-off-by: Max H. Gerlach <git@maxgerlach.de> * Update comments, doc strings, changelog Signed-off-by: Max H. Gerlach <git@maxgerlach.de>
…horovod#3173) Spark/Lightning: fix the usage of checkpoint callback (horovod#3186) Signed-off-by: Chongxiao Cao <chongxiaoc@uber.com> Fix Cometlogger experiment key lost issue (horovod#3184) * test Signed-off-by: Peng Zhang <pengz@uber.com> * test Signed-off-by: Peng Zhang <pengz@uber.com> * fix_logger Signed-off-by: Peng Zhang <pengz@uber.com> * fix_logger Signed-off-by: Peng Zhang <pengz@uber.com> * recreate_loger Signed-off-by: Peng Zhang <pengz@uber.com> * fix_var Signed-off-by: Peng Zhang <pengz@uber.com> * test Signed-off-by: Peng Zhang <pengz@uber.com> * test Signed-off-by: Peng Zhang <pengz@uber.com> Updated torch c++ to use new aten api (horovod#3175) Spark/Keras: remove bare Keras support (horovod#3191) Make fork PRs publish test change stats (horovod#3185) Signed-off-by: Enrico Minack <github@enrico.minack.dev> Support for nccl on cuda 11.4 (horovod#3182) Signed-off-by: Evan Brossard <evanb@maka-ars.com> Fix MPICH support (horovod#3148) * fix MPICH implementation * enable tests for MPICH and Intel MPI Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu> Increase build timeout to 40m on Buildkite (horovod#3192) Signed-off-by: Enrico Minack <github@enrico.minack.dev> Change CMake syntax to be compatible with old versions of CMake (horovod#3196) Signed-off-by: Max H. Gerlach <git@maxgerlach.de> Reinit every torch test (horovod#3194) Add barrier call to torch module to support easy synchronization for process sets (horovod#3139) * Added barrier call to torch module Signed-off-by: TJ <tix@uber.com> Bump version to 0.23.0 (horovod#3200) Signed-off-by: Travis Addair <tgaddair@gmail.com> Co-authored-by: Max H. Gerlach <git@maxgerlach.de> Increase Parallel PyTest timeout to 10m (horovod#3198) * Increase MPI and Gloo Parallel PyTest timeout to 10m Signed-off-by: Enrico Minack <github@enrico.minack.dev> Spark/Lightning: don't overwrite model with checkpoint by default (horovod#3201) Lightning estimator saves model by default if there is no specified checkpoint callback. However, model is not overwritten with checkpoint file in that case. Signed-off-by: Chongxiao Cao <chongxiaoc@uber.com> Spark/Lightning: fix checkpoint callback dirpath typo (horovod#3204) Signed-off-by: Chongxiao Cao <chongxiaoc@uber.com> Rework events in CI workflows (horovod#3202) Signed-off-by: Enrico Minack <github@enrico.minack.dev> Allow for concurrent schedule and master build, document concurrency (horovod#3206) Signed-off-by: Enrico Minack <github@enrico.minack.dev> Ray: fix RayExecutor to fail when num_workers=0 and num_hosts=None (horovod#3210) Signed-off-by: Travis Addair <tgaddair@gmail.com> add_history_in_lightning_estimator (horovod#3214) Signed-off-by: Peng Zhang <pengz@uber.com> Allow buildkite building merge commits on forks (horovod#3215) Signed-off-by: Enrico Minack <github@enrico.minack.dev> Fix json output in ci-results.yaml (horovod#3217) Spark/Lightning: fix history metrics for estimator serialization (horovod#3216) Save metrics inside the checkpoint dict , which will be load with map_location=torch.device('cpu') Signed-off-by: Peng Zhang <pengz@uber.com> patch python source files on macCI (horovod#3220) * patch python source files on macCI * Trigger build and test CI Signed-off-by: TJ <tix@uber.com> Co-authored-by: Enrico Minack <github@enrico.minack.dev> Updated examples of torch and tf to include mixed precision training (horovod#3222) * Added mixed precision example for pytorch * added mixed precision for keras Signed-off-by: TJ <tix@uber.com> Job buildkite-heads accesses ci-workflow outputs, add it to the needs (horovod#3225) Signed-off-by: Enrico Minack <github@enrico.minack.dev> Fixes race condition for ray scale up down tests (horovod#3205) Ensure that at least one host from the previous set of hosts have been registered. Without this, the discovery script will "discover" the new set of hosts before the current set can register. This would result in a race condition. Consider a discovery schedule: ``` discovery_schedule = [ (10, ['host-1:2']), (30, ['host-1:2', 'host-2:1', 'host-3:1']), (None, ['host-2:1']), ] ``` The initial set is: ['host-1:2']. Before this is registered in the driver, the discovery script discovers the set: ['host-1:2', 'host-2:1', 'host-3:1'], and adds ['host-2:1', 'host-3:1']. However, since ['host-1:2'] has not registered, there is no coordinator to notify the workers. When host-1 and host-3 are removed, driver.resume will call _activate_workers, which will update the host assignments. It has a check to see if the intersection between the previous and current set of hosts. It finds that the previous set is ['host-1:2'], and the current set is ['host-2:1'], since there was no notification for the added and removed hosts. This ensures that the previous set of hosts can register before the current set is discovered. Signed-off-by: Abin Shahab <ashahab@linkedin.com> Removed a case of the default mutable argument pitfall (horovod#3227) Signed-off-by: Naelson Douglas <naelson17@gmail.com> Updates to TSC members (horovod#3234) Signed-off-by: Travis Addair <tgaddair@gmail.com> Add in-place broadcast for TensorFlow (horovod#3128) * Update comment in FindTensorflow.cmake Signed-off-by: Max H. Gerlach <git@maxgerlach.de> * Add in-place broadcast_() and broadcast_variables() for TF Signed-off-by: Max H. Gerlach <git@maxgerlach.de> * Include source files from TF in build to avoid missing symbol errors Signed-off-by: Max H. Gerlach <git@maxgerlach.de> * Limit build and test to TF 2.6+ Signed-off-by: Max H. Gerlach <git@maxgerlach.de> * Remove source files copied from TensorFlow The missing symbols are resolved by linking against _pywrap_tensorflow_internal.so, which was introduced to Horovod with PR horovod#3053. Signed-off-by: Max H. Gerlach <git@maxgerlach.de> * Fix possible type attribute values for HorovodBroadcastInplace Signed-off-by: Max H. Gerlach <git@maxgerlach.de> * Add reference variables to test Signed-off-by: Max H. Gerlach <git@maxgerlach.de> * Update comments, doc strings, changelog Signed-off-by: Max H. Gerlach <git@maxgerlach.de> [Elastic Horovod] Fix the bug for ElasticSampler and hvd.elastic.state (horovod#3144) Co-authored-by: gethinhu <gethinhu@tencent.com> a better way to handle nccl error under elastic scenario (horovod#3112) Signed-off-by: guoze.lin <guozelin@tencent.com> check torch version for mixed precision example (horovod#3238) Lightning: set limit_train_batches and limit_val_batches (horovod#3237) Tell Lightning trainer that how many batches a single epoch needs. Signed-off-by: Chongxiao Cao <chongxiaoc@uber.com> Spark/Lightning: reduce memory footprint of async dataloader (horovod#3239) Limit async data loader queue size. Signed-off-by: Peng Zhang <pengz@uber.com> Change default fusion threshold from 64MB to 128MB in docs (horovod#3241) fix the example of pytorch_lightning_mnist.py (horovod#3245) - remove unused arg parameters - fix model test issue on GPU Signed-off-by: Chongxiao Cao <chongxiaoc@uber.com> CI: use latest pytorch_lightning with torchhead (horovod#3243) test_gradient_aggregation with real gradient instead of a constant (horovod#3176) This fixes issue horovod#2664 by performing gradient aggregation with a real gradient instead of a constant. PR: horovod#2647 shifts the gradient allreduce when the gradient is computed (both through the DistributedOptimizer or through the DistributedGradientTape). Which means that this unittest, by design in TF2.4, doesn't call allreduce in _aggregate_gradients(). Since this unittest provide a gradient as constant (without effectively computing it), the gradient will never be allreduced. The current change ensure that instead of a constant a real gradient is computed from a loss-function. Note: The current loss-function intentionally evaluates to zero. A future PR should convert it to a real loss function(e.g. MeanSquaredError) and compute gradients from that to test gradient aggregation. Signed-off-by: Abin Shahab <ashahab@linkedin.com>
* Update comment in FindTensorflow.cmake Signed-off-by: Max H. Gerlach <git@maxgerlach.de> * Add in-place broadcast_() and broadcast_variables() for TF Signed-off-by: Max H. Gerlach <git@maxgerlach.de> * Include source files from TF in build to avoid missing symbol errors Signed-off-by: Max H. Gerlach <git@maxgerlach.de> * Limit build and test to TF 2.6+ Signed-off-by: Max H. Gerlach <git@maxgerlach.de> * Remove source files copied from TensorFlow The missing symbols are resolved by linking against _pywrap_tensorflow_internal.so, which was introduced to Horovod with PR horovod#3053. Signed-off-by: Max H. Gerlach <git@maxgerlach.de> * Fix possible type attribute values for HorovodBroadcastInplace Signed-off-by: Max H. Gerlach <git@maxgerlach.de> * Add reference variables to test Signed-off-by: Max H. Gerlach <git@maxgerlach.de> * Update comments, doc strings, changelog Signed-off-by: Max H. Gerlach <git@maxgerlach.de> Signed-off-by: weihanmines <weihan13@amd.com>
- Fixes issue when start_epoch != 0 Signed-off-by: Dinesh Ramasamy <89654805+iitmdinesh@users.noreply.github.com> Signed-off-by: weihanmines <weihan13@amd.com> fix torch op handles lazy release which may cause oom in elastic scenario (horovod#3110) * fix torch op handles lazy release which may cause oom in elastic scenario Signed-off-by: guoze.lin <guozelin@tencent.com> * Update mpi_ops.py Co-authored-by: guoze.lin <guozelin@tencent.com> Co-authored-by: Travis Addair <tgaddair@gmail.com> Signed-off-by: weihanmines <weihan13@amd.com> Added support for extraction of storage options from url. (horovod#3137) * Added support for extraction of storage options from url. Signed-off-by: Manjur Ansari <maansar@microsoft.com> * mock fsspec.utils Signed-off-by: Manjur Ansari <maansar@microsoft.com> * Added missing comma Co-authored-by: Travis Addair <tgaddair@gmail.com> Signed-off-by: weihanmines <weihan13@amd.com> Make RayExecutor use the current placement group if one exists (horovod#3134) Signed-off-by: weihanmines <weihan13@amd.com> Fix the mapping btw pyspark and numpy (horovod#3146) Signed-off-by: Haoyang Chen <haoyang@uber.com> Signed-off-by: weihanmines <weihan13@amd.com> Add tests for Keras callbacks: MetricAverageCallback, LearningRateScheduleCallback and LearningRateWarmupCallback (horovod#3102) There were no tests for MetricAverageCallback, LearningRateScheduleCallback and LearningRateWarmupCallback from hvd as noted in horovod#2659. This PR adds testing to verify the callback works. Signed-off-by: Moses Lee <14leeyuchieh@gmail.com> Co-authored-by: Moses Lee <molee@molee-ld4.linkedin.biz> Signed-off-by: weihanmines <weihan13@amd.com> Split gpu tests in head and non-head versions (horovod#3155) Signed-off-by: Enrico Minack <github@enrico.minack.dev> Signed-off-by: weihanmines <weihan13@amd.com> Allow caller to customize the Tensorboard callback (horovod#3153) * Keras Estimator: Allow user to pass in TensorBoard callback Signed-off-by: Rich Porter <rich.porter@uber.com> * Remove callback from other processes on the same machine Signed-off-by: Rich Porter <rich.porter@uber.com> * Allow other ranks to profile as well. Doesn't seem to conflict Signed-off-by: Rich Porter <rich.porter@uber.com> Signed-off-by: weihanmines <weihan13@amd.com> test_torch.py: add explicit join() for testing duplicated name errors (horovod#3159) For torch nightly >=10.0, we need to add an explict join() call to avoid hanging when testing duplicated name errors. Signed-off-by: Chongxiao Cao <chongxiaoc@uber.com> Signed-off-by: weihanmines <weihan13@amd.com> Disable TF2.6.0 XLA support on OSX (horovod#3133) Related to issue#3132 Signed-off-by: Chongxiao Cao <chongxiaoc@uber.com> Signed-off-by: weihanmines <weihan13@amd.com> Fix linking _pywrap_tensorflow_internal.so and re-enable XLA on macOS (horovod#3173) Signed-off-by: weihanmines <weihan13@amd.com> Spark/Lightning: fix the usage of checkpoint callback (horovod#3186) Signed-off-by: Chongxiao Cao <chongxiaoc@uber.com> Signed-off-by: weihanmines <weihan13@amd.com> Fix Cometlogger experiment key lost issue (horovod#3184) * test Signed-off-by: Peng Zhang <pengz@uber.com> * test Signed-off-by: Peng Zhang <pengz@uber.com> * fix_logger Signed-off-by: Peng Zhang <pengz@uber.com> * fix_logger Signed-off-by: Peng Zhang <pengz@uber.com> * recreate_loger Signed-off-by: Peng Zhang <pengz@uber.com> * fix_var Signed-off-by: Peng Zhang <pengz@uber.com> * test Signed-off-by: Peng Zhang <pengz@uber.com> * test Signed-off-by: Peng Zhang <pengz@uber.com> Signed-off-by: weihanmines <weihan13@amd.com> Updated torch c++ to use new aten api (horovod#3175) Signed-off-by: weihanmines <weihan13@amd.com> Spark/Keras: remove bare Keras support (horovod#3191) Signed-off-by: weihanmines <weihan13@amd.com> Make fork PRs publish test change stats (horovod#3185) Signed-off-by: Enrico Minack <github@enrico.minack.dev> Signed-off-by: weihanmines <weihan13@amd.com> Support for nccl on cuda 11.4 (horovod#3182) Signed-off-by: Evan Brossard <evanb@maka-ars.com> Signed-off-by: weihanmines <weihan13@amd.com> Fix MPICH support (horovod#3148) * fix MPICH implementation * enable tests for MPICH and Intel MPI Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu> Signed-off-by: weihanmines <weihan13@amd.com> Increase build timeout to 40m on Buildkite (horovod#3192) Signed-off-by: Enrico Minack <github@enrico.minack.dev> Signed-off-by: weihanmines <weihan13@amd.com> Change CMake syntax to be compatible with old versions of CMake (horovod#3196) Signed-off-by: Max H. Gerlach <git@maxgerlach.de> Signed-off-by: weihanmines <weihan13@amd.com> Reinit every torch test (horovod#3194) Signed-off-by: weihanmines <weihan13@amd.com> Add barrier call to torch module to support easy synchronization for process sets (horovod#3139) * Added barrier call to torch module Signed-off-by: TJ <tix@uber.com> Signed-off-by: weihanmines <weihan13@amd.com> Bump version to 0.23.0 (horovod#3200) Signed-off-by: Travis Addair <tgaddair@gmail.com> Co-authored-by: Max H. Gerlach <git@maxgerlach.de> Signed-off-by: weihanmines <weihan13@amd.com> Increase Parallel PyTest timeout to 10m (horovod#3198) * Increase MPI and Gloo Parallel PyTest timeout to 10m Signed-off-by: Enrico Minack <github@enrico.minack.dev> Signed-off-by: weihanmines <weihan13@amd.com> Spark/Lightning: don't overwrite model with checkpoint by default (horovod#3201) Lightning estimator saves model by default if there is no specified checkpoint callback. However, model is not overwritten with checkpoint file in that case. Signed-off-by: Chongxiao Cao <chongxiaoc@uber.com> Signed-off-by: weihanmines <weihan13@amd.com> Spark/Lightning: fix checkpoint callback dirpath typo (horovod#3204) Signed-off-by: Chongxiao Cao <chongxiaoc@uber.com> Signed-off-by: weihanmines <weihan13@amd.com> Rework events in CI workflows (horovod#3202) Signed-off-by: Enrico Minack <github@enrico.minack.dev> Signed-off-by: weihanmines <weihan13@amd.com> Allow for concurrent schedule and master build, document concurrency (horovod#3206) Signed-off-by: Enrico Minack <github@enrico.minack.dev> Signed-off-by: weihanmines <weihan13@amd.com> Ray: fix RayExecutor to fail when num_workers=0 and num_hosts=None (horovod#3210) Signed-off-by: Travis Addair <tgaddair@gmail.com> Signed-off-by: weihanmines <weihan13@amd.com> add_history_in_lightning_estimator (horovod#3214) Signed-off-by: Peng Zhang <pengz@uber.com> Signed-off-by: weihanmines <weihan13@amd.com> Allow buildkite building merge commits on forks (horovod#3215) Signed-off-by: Enrico Minack <github@enrico.minack.dev> Signed-off-by: weihanmines <weihan13@amd.com> Fix json output in ci-results.yaml (horovod#3217) Signed-off-by: weihanmines <weihan13@amd.com> Spark/Lightning: fix history metrics for estimator serialization (horovod#3216) Save metrics inside the checkpoint dict , which will be load with map_location=torch.device('cpu') Signed-off-by: Peng Zhang <pengz@uber.com> Signed-off-by: weihanmines <weihan13@amd.com> patch python source files on macCI (horovod#3220) * patch python source files on macCI * Trigger build and test CI Signed-off-by: TJ <tix@uber.com> Co-authored-by: Enrico Minack <github@enrico.minack.dev> Signed-off-by: weihanmines <weihan13@amd.com> Updated examples of torch and tf to include mixed precision training (horovod#3222) * Added mixed precision example for pytorch * added mixed precision for keras Signed-off-by: TJ <tix@uber.com> Signed-off-by: weihanmines <weihan13@amd.com> Job buildkite-heads accesses ci-workflow outputs, add it to the needs (horovod#3225) Signed-off-by: Enrico Minack <github@enrico.minack.dev> Signed-off-by: weihanmines <weihan13@amd.com> Fixes race condition for ray scale up down tests (horovod#3205) Ensure that at least one host from the previous set of hosts have been registered. Without this, the discovery script will "discover" the new set of hosts before the current set can register. This would result in a race condition. Consider a discovery schedule: ``` discovery_schedule = [ (10, ['host-1:2']), (30, ['host-1:2', 'host-2:1', 'host-3:1']), (None, ['host-2:1']), ] ``` The initial set is: ['host-1:2']. Before this is registered in the driver, the discovery script discovers the set: ['host-1:2', 'host-2:1', 'host-3:1'], and adds ['host-2:1', 'host-3:1']. However, since ['host-1:2'] has not registered, there is no coordinator to notify the workers. When host-1 and host-3 are removed, driver.resume will call _activate_workers, which will update the host assignments. It has a check to see if the intersection between the previous and current set of hosts. It finds that the previous set is ['host-1:2'], and the current set is ['host-2:1'], since there was no notification for the added and removed hosts. This ensures that the previous set of hosts can register before the current set is discovered. Signed-off-by: Abin Shahab <ashahab@linkedin.com> Signed-off-by: weihanmines <weihan13@amd.com> Removed a case of the default mutable argument pitfall (horovod#3227) Signed-off-by: Naelson Douglas <naelson17@gmail.com> Signed-off-by: weihanmines <weihan13@amd.com> Updates to TSC members (horovod#3234) Signed-off-by: Travis Addair <tgaddair@gmail.com> Signed-off-by: weihanmines <weihan13@amd.com> Add in-place broadcast for TensorFlow (horovod#3128) * Update comment in FindTensorflow.cmake Signed-off-by: Max H. Gerlach <git@maxgerlach.de> * Add in-place broadcast_() and broadcast_variables() for TF Signed-off-by: Max H. Gerlach <git@maxgerlach.de> * Include source files from TF in build to avoid missing symbol errors Signed-off-by: Max H. Gerlach <git@maxgerlach.de> * Limit build and test to TF 2.6+ Signed-off-by: Max H. Gerlach <git@maxgerlach.de> * Remove source files copied from TensorFlow The missing symbols are resolved by linking against _pywrap_tensorflow_internal.so, which was introduced to Horovod with PR horovod#3053. Signed-off-by: Max H. Gerlach <git@maxgerlach.de> * Fix possible type attribute values for HorovodBroadcastInplace Signed-off-by: Max H. Gerlach <git@maxgerlach.de> * Add reference variables to test Signed-off-by: Max H. Gerlach <git@maxgerlach.de> * Update comments, doc strings, changelog Signed-off-by: Max H. Gerlach <git@maxgerlach.de> Signed-off-by: weihanmines <weihan13@amd.com> [Elastic Horovod] Fix the bug for ElasticSampler and hvd.elastic.state (horovod#3144) Co-authored-by: gethinhu <gethinhu@tencent.com> Signed-off-by: weihanmines <weihan13@amd.com> a better way to handle nccl error under elastic scenario (horovod#3112) Signed-off-by: guoze.lin <guozelin@tencent.com> Signed-off-by: weihanmines <weihan13@amd.com> check torch version for mixed precision example (horovod#3238) Signed-off-by: weihanmines <weihan13@amd.com> Lightning: set limit_train_batches and limit_val_batches (horovod#3237) Tell Lightning trainer that how many batches a single epoch needs. Signed-off-by: Chongxiao Cao <chongxiaoc@uber.com> Signed-off-by: weihanmines <weihan13@amd.com> Spark/Lightning: reduce memory footprint of async dataloader (horovod#3239) Limit async data loader queue size. Signed-off-by: Peng Zhang <pengz@uber.com> Signed-off-by: weihanmines <weihan13@amd.com> Change default fusion threshold from 64MB to 128MB in docs (horovod#3241) Signed-off-by: weihanmines <weihan13@amd.com> fix the example of pytorch_lightning_mnist.py (horovod#3245) - remove unused arg parameters - fix model test issue on GPU Signed-off-by: Chongxiao Cao <chongxiaoc@uber.com> Signed-off-by: weihanmines <weihan13@amd.com> CI: use latest pytorch_lightning with torchhead (horovod#3243) Signed-off-by: weihanmines <weihan13@amd.com> test_gradient_aggregation with real gradient instead of a constant (horovod#3176) This fixes issue horovod#2664 by performing gradient aggregation with a real gradient instead of a constant. PR: horovod#2647 shifts the gradient allreduce when the gradient is computed (both through the DistributedOptimizer or through the DistributedGradientTape). Which means that this unittest, by design in TF2.4, doesn't call allreduce in _aggregate_gradients(). Since this unittest provide a gradient as constant (without effectively computing it), the gradient will never be allreduced. The current change ensure that instead of a constant a real gradient is computed from a loss-function. Note: The current loss-function intentionally evaluates to zero. A future PR should convert it to a real loss function(e.g. MeanSquaredError) and compute gradients from that to test gradient aggregation. Signed-off-by: Abin Shahab <ashahab@linkedin.com> Signed-off-by: weihanmines <weihan13@amd.com> Remove MetricAverageCallback warning on tf >= 2.5 (horovod#3050) Signed-off-by: Henrique Mendonça <henrique.mendonca@cscs.ch> Signed-off-by: weihanmines <weihan13@amd.com> Fix Horovod pyarrow IndexError: list index out of range (horovod#3255) Signed-off-by: Weichen Xu <weichen.xu@databricks.com> Signed-off-by: weihanmines <weihan13@amd.com> Fixing up current CI test failures. (horovod#3259) Signed-off-by: Josh Romero <joshr@nvidia.com> Co-authored-by: Travis Addair <tgaddair@gmail.com> Co-authored-by: Enrico Minack <github@enrico.minack.dev> Signed-off-by: weihanmines <weihan13@amd.com> Revert "Fix Horovod pyarrow IndexError: list index out of range (horovod#3255)" (horovod#3265) This reverts commit 3efc229. Signed-off-by: Travis Addair <tgaddair@gmail.com> Signed-off-by: weihanmines <weihan13@amd.com> Debugging for lightning data loader and fix for simple profiler. (horovod#3253) add debugging flag for lightning data loader , make async data loader queue size configurable Signed-off-by: weihanmines <weihan13@amd.com> Call process_set._setup in init() to point to the correct native lib path (horovod#3258) * call setup for common process_set in remote trainers moved _setup call to init() Signed-off-by: TJ <tix@uber.com> Signed-off-by: weihanmines <weihan13@amd.com> Add support for MXNet async dependency engine. (horovod#3242) Signed-off-by: Josh Romero <joshr@nvidia.com> Signed-off-by: weihanmines <weihan13@amd.com>
Checklist before submitting
Description
For fixing #2590, although only HorovodAllreduce is implemented in this PR but the reported issue might need other Horovod ops. More ops can be added later based on the same code structure and facilities built.
Mainly, this PR implements the XLA HorovodAllreduce op, which provides a way to lower the op into XLA. The main idea is to lower the op into XLA (HLO)
custom-calls
that enqueue requests to the Horovod runtime and process the response.This new XLA op is disabled by default now mainly because of its dependency on
ASYNC_COMPLETION
, which is essential to mitigating host latencies of the XLA runtime. AnHOROVOD_ENABLE_XLA_OPS
environment variable is added to control the on/off of the XLA op. We will make the op(s) on by default after ASYNC_COMPLETION is on by default or after we implement a way to enable ASYNC_COMPLETION on demand.See
xla.rst
for an example oftf.function(jit_compile=True)
with lowering HorovodAllreduce into XLA.Bump the c++ version to c++14 as it is needed to compile XLA.
Review process to land