Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

XLA HorovodAllreduce for tf.function(jit_compile=True) #3053

Merged
merged 28 commits into from Aug 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
1620cc5
Add implementation of XLA HorovodAllreduce.
trentlo Jul 12, 2021
463d17f
Fix a build break due to interface change in TF2.6.
trentlo Jul 13, 2021
aab248d
Implement customized XLA Op registrar as we want to make it an opt-in.
trentlo Jul 14, 2021
2f41921
Ran clang-format.
trentlo Jul 14, 2021
606a6fb
Polish codes.
trentlo Jul 15, 2021
e348e62
Improve cmake for XLA.
trentlo Jul 15, 2021
1dfa8df
Polish comments.
trentlo Jul 15, 2021
043a092
Minor polishing.
trentlo Jul 15, 2021
e5f9680
Don't set alias for the `start` custom-call.
trentlo Jul 17, 2021
7fbcc19
Add a unittest for XLA.
trentlo Jul 17, 2021
5f97707
Add process_id in XLA Ops.
trentlo Jul 20, 2021
22f4e0e
Add test_xla.py
trentlo Jul 20, 2021
129374b
Embedd HOROVOD_ENABLE_XLA_OPS.
trentlo Jul 20, 2021
cb320cb
Ran clang-format.
trentlo Jul 20, 2021
eb52a4b
autopep8 for python formatting.
trentlo Jul 20, 2021
6b7fddc
Add documentation for Horovod XLA Ops.
trentlo Jul 20, 2021
1a3fee7
Format docs/xla.rst.
trentlo Jul 20, 2021
20dac65
Automatically set HOROVOD_ENABLE_ASYNC_COMPLETION for xla ops.
trentlo Jul 20, 2021
f9224f6
Add a link to XLA in summary.rst.
trentlo Jul 22, 2021
fdd2658
Make title line long enough in xla.rst.
trentlo Jul 22, 2021
93f5a1c
Add xla into toctree.
trentlo Jul 22, 2021
1b47a4b
Compile XLA Horovod ops only for TF2.5+
trentlo Jul 23, 2021
a4af503
Setting the default Cycle Time to 0 because the XLA runtime is sensitive
trentlo Jul 23, 2021
e409ded
Skip XLA tests if TF is older than TF2.5.
trentlo Jul 23, 2021
2db7ffe
Don't use tf.function() as decorator.
trentlo Jul 23, 2021
911408f
Remove a redundant test.
trentlo Jul 23, 2021
026d784
xla::CustomCallSchedule requires TF2.6.
trentlo Aug 11, 2021
1ae617d
Do not link _pywrap_tensorflow_internal.so if XLA is not enabled.
trentlo Aug 12, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions README.rst
Expand Up @@ -171,6 +171,7 @@ Supported frameworks
See these pages for Horovod examples and best practices:

- `Horovod with TensorFlow <docs/tensorflow.rst>`_
- `Horovod with XLA in Tensorflow <xla.rst>`_
- `Horovod with Keras <docs/keras.rst>`_
- `Horovod with PyTorch <docs/pytorch.rst>`_
- `Horovod with MXNet <docs/mxnet.rst>`_
Expand Down
8 changes: 7 additions & 1 deletion cmake/Modules/FindTensorflow.cmake
Expand Up @@ -19,7 +19,13 @@ if (LEN EQUAL "4")
list(GET Tensorflow_OUTPUT 0 Tensorflow_VERSION)
list(GET Tensorflow_OUTPUT 1 Tensorflow_INCLUDE_DIRS)
list(GET Tensorflow_OUTPUT 2 Tensorflow_LIBRARIES)
string(REPLACE " " ";" Tensorflow_LIBRARIES "${Tensorflow_LIBRARIES}")
string(REPLACE " " ";" Tensorflow_LIBRARIES_LIST "${Tensorflow_LIBRARIES}")
list(GET Tensorflow_LIBRARIES_LIST 0 Tensorflow_LIB_PATH)
if (Tensorflow_VERSION VERSION_GREATER_EQUAL "2.6")
# XLA implementations are in _pywrap_tensorflow_internal.so
set(Tensorflow_LIBRARIES "${Tensorflow_LIBRARIES} ${Tensorflow_LIB_PATH}/python/ -l:_pywrap_tensorflow_internal.so")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@trentlo @romerojosh
This _pywrap_tensorflow_internal.so is not available on OSX

#3132

endif()
message("Tensorflow_LIBRARIES := ${Tensorflow_LIBRARIES}")
list(GET Tensorflow_OUTPUT 3 Tensorflow_COMPILE_FLAGS)
if("${Tensorflow_COMPILE_FLAGS}" MATCHES "-D_GLIBCXX_USE_CXX11_ABI=1")
set(Tensorflow_CXX11 TRUE)
Expand Down
2 changes: 2 additions & 0 deletions docs/index.rst
Expand Up @@ -101,6 +101,8 @@ Guides

tensorflow

xla

keras

pytorch
Expand Down
1 change: 1 addition & 0 deletions docs/summary.rst
Expand Up @@ -163,6 +163,7 @@ Supported frameworks
See these pages for Horovod examples and best practices:

- `Horovod with TensorFlow <tensorflow.rst>`_
- `Horovod with XLA in Tensorflow <xla.rst>`_
- `Horovod with Keras <keras.rst>`_
- `Horovod with PyTorch <pytorch.rst>`_
- `Horovod with MXNet <mxnet.rst>`_
Expand Down
37 changes: 37 additions & 0 deletions docs/xla.rst
@@ -0,0 +1,37 @@
Horovod with XLA in Tensorflow
===============================

Basic usage
-----------

XLA Horovod ops can be enabled by setting ``HOROVOD_ENABLE_XLA_OPS = 1`` by controlling the registration of the ops to Tensorflow/XLA.

There are two main ways to enable XLA and they could work with Horovod in different ways:

For **Explicit compilation with tf.function(jit_compile=True)**:

.. code-block:: python

os.environ["HOROVOD_ENABLE_XLA_OPS"] = "1"

@tf.function(jit_compile=True)
def compiled_hvd_allreduce(self, dtype, dim):
tensor = self.random_uniform(
[17] * dim, -100, 100, dtype=dtype)
summed = hvd.allreduce(tensor, average=False)
return summed

In this way, all the ops in the ``compiled_hvd_allreduce`` function are lowered into XLA per the compilation requirement. If the XLA Horovod ops are not enabled, XLA will report compilation errors.


For **Auto-clustering**:

Auto-clustering is a convenient way to use XLA by simply setting ``TF_XLA_FLAGS=--tf_xla_auto_jit=2`` and the XLA JIT automatically selects ops in the Tensorflow graph to be lowered into XLA. In this mode, enabling XLA Horovod ops is optional, because the auto-clustering can work even if the Horovod ops are left to be run by Tensorflow (devices) while only parts of the graphs are lowered onto XLA (devices).

List of supported XLA Horovod ops
---------------------------------

The supported op list is:

``HorovodAllreduce``

5 changes: 4 additions & 1 deletion horovod/common/common.h
Expand Up @@ -137,6 +137,7 @@ namespace common {
#define HOROVOD_DISABLE_NVTX_RANGES "HOROVOD_DISABLE_NVTX_RANGES"
#define HOROVOD_ENABLE_ASYNC_COMPLETION "HOROVOD_ENABLE_ASYNC_COMPLETION"
#define HOROVOD_DYNAMIC_PROCESS_SETS "HOROVOD_DYNAMIC_PROCESS_SETS"
#define HOROVOD_ENABLE_XLA_OPS "HOROVOD_ENABLE_XLA_OPS"

// String constant for gloo interface.
#define GLOO_DEFAULT_IFACE ""
Expand All @@ -153,7 +154,7 @@ namespace common {
#define JOIN_TENSOR_NAME "join.noname"

// List of supported frameworks.
enum Framework { TENSORFLOW, PYTORCH, MXNET };
enum Framework { TENSORFLOW, PYTORCH, MXNET, XLA };

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that XLA is implemented only for tensorflow, why XLA declare here as a framework ? So does the design of cpp namespace level. Thx for your explain in advance.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the question. My rationale is explained as below:

  1. XLA is independent of Tensorflow, although they are in the same repo now. There are talks about separating XLA from Tensorflow although I don't know when that will happen. In addition, XLA (as a DL compiler) is used by frontends other than Tensorflow, such as JAX and Pytorch (not in its main repo, I guess), etc.
  2. I certainly can understand that it is inaccurate to say XLA is a framework. However, as XLA is independent of Tensorflow, it does not use Tensorflow constructs and cannot use TFOpContext. It then needs to create a new XLAOpContext with a new "framework" name. I don't see alternatives.
  3. If the name is really confusing to people, I may suggest rename "framework" to "backend".

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your reply, you are right that XLA is a layer independent of tf or even torch. XLA is a layer tf/torch can leverage on in my opinion, that's why I propose the question. I'm not sure am I right with following consideration,

  1. XLA can not be used/run without a framework(tf/torch/mxnet)
  2. XLA will be run in the way tf+XLA (the version you are working on), or torch+XLA, or more.

If so, may be we should keep framework declaration and consider to implement XLA without changing it even by add backend ?

Overall, your work is wonderful, it's ok to keep this version if we cannot figure out a better one.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right that XLA itself is not complete as a framework.

How about let's rename XLA to TF_XLA?
That is, enum Framework { TENSORFLOW, PYTORCH, MXNET, TF_XLA };

This is conceptually clean and simple. The only theoretical drawback I can think of for this approach is that you might later have PYTORCH_XLA, etc. due to combination (of framework and XLA) but I doubt that if this will really become a problem in practice (mainly due to few combinations that really exist in the world). Even if this becomes a problem in practice in the future, we will know better how to deal with it at that time given more data points.

I will make the change accordingly if this makes sense to you. Please let me know.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, we have clarity on the issues but not achieve a solution yet, maybe we should discuss it with the community.

/cc @maxhgerlach @EnricoMi


enum StatusType { OK, UNKNOWN_ERROR, PRECONDITION_ERROR, ABORTED, INVALID_ARGUMENT, IN_PROGRESS };

Expand Down Expand Up @@ -228,6 +229,8 @@ const Status DUPLICATE_NAME_ERROR = Status::InvalidArgument(

class TensorShape {
public:
TensorShape() : shape_() {}
TensorShape(std::vector<int64_t> vec) : shape_(vec) {}
void AddDim(int64_t dim);
void AppendShape(TensorShape& other);

Expand Down
13 changes: 13 additions & 0 deletions horovod/common/operations.cc
Expand Up @@ -494,6 +494,14 @@ void BackgroundThreadLoop(HorovodGlobalState& state) {

// Override the cycle time.
state.parameter_manager.SetCycleTimeMs(1);
bool enable_xla_ops = false;
common::SetBoolFromEnv(HOROVOD_ENABLE_XLA_OPS, enable_xla_ops, true);
if (enable_xla_ops) {
// Setting the default Cycle Time to 0 because the XLA runtime is sensitive
// to latencies.
state.parameter_manager.SetCycleTimeMs(0);
}

auto horovod_cycle_time = std::getenv(HOROVOD_CYCLE_TIME);
if (horovod_cycle_time != nullptr) {
state.parameter_manager.SetCycleTimeMs(
Expand Down Expand Up @@ -563,6 +571,11 @@ void BackgroundThreadLoop(HorovodGlobalState& state) {

// Check if async completion should be enabled
SetBoolFromEnv(HOROVOD_ENABLE_ASYNC_COMPLETION, state.enable_async_completion, true);
if (enable_xla_ops) {
// Enable async completion when XLA ops are enabled. Sine the XLA runtime is
// single-threaded, async completion is essential to reduce host overhead.
state.enable_async_completion = true;
}

// Enable auto-tuning.
auto horovod_autotune = std::getenv(HOROVOD_AUTOTUNE);
Expand Down
1 change: 1 addition & 0 deletions horovod/tensorflow/CMakeLists.txt
Expand Up @@ -59,6 +59,7 @@ set(Tensorflow_CXX11 ${Tensorflow_CXX11} PARENT_SCOPE)

# TF SOURCES
list(APPEND TF_SOURCES "${PROJECT_SOURCE_DIR}/horovod/tensorflow/mpi_ops.cc")
list(APPEND TF_SOURCES "${PROJECT_SOURCE_DIR}/horovod/tensorflow/xla_mpi_ops.cc")

# Create library
set_output_dir()
Expand Down
43 changes: 43 additions & 0 deletions horovod/tensorflow/custom_call_config.fbs
@@ -0,0 +1,43 @@
// Copyright 2021 The TensorFlow Authors. All Rights Reserved.
// Modifications copyright (C) 2017 Uber Technologies, Inc.
// Modifications copyright Microsoft
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================

include "horovod/common/wire/message.fbs";

namespace horovod.xla.wire;

table TensorShape {
dims:[long];
}

table CustomCallConfig {
tensor_name:string;
tensor_type:common.wire.DataType;
input_shapes:[TensorShape];
output_shapes:[TensorShape];

// Prescale and postscale factors
prescale_factor:float;
postscale_factor:float;

// Root rank is necessary for broadcast operation.
root_rank:int;

// Reduce op.
reduce_op:int;

process_set_id:int;
}