New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
XLA HorovodAllreduce for tf.function(jit_compile=True) #3053
Changes from all commits
1620cc5
463d17f
aab248d
2f41921
606a6fb
e348e62
1dfa8df
043a092
e5f9680
7fbcc19
5f97707
22f4e0e
129374b
cb320cb
eb52a4b
6b7fddc
1a3fee7
20dac65
f9224f6
fdd2658
93f5a1c
1b47a4b
a4af503
e409ded
2db7ffe
911408f
026d784
1ae617d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -101,6 +101,8 @@ Guides | |
|
||
tensorflow | ||
|
||
xla | ||
|
||
keras | ||
|
||
pytorch | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
Horovod with XLA in Tensorflow | ||
=============================== | ||
|
||
Basic usage | ||
----------- | ||
|
||
XLA Horovod ops can be enabled by setting ``HOROVOD_ENABLE_XLA_OPS = 1`` by controlling the registration of the ops to Tensorflow/XLA. | ||
|
||
There are two main ways to enable XLA and they could work with Horovod in different ways: | ||
|
||
For **Explicit compilation with tf.function(jit_compile=True)**: | ||
|
||
.. code-block:: python | ||
|
||
os.environ["HOROVOD_ENABLE_XLA_OPS"] = "1" | ||
|
||
@tf.function(jit_compile=True) | ||
def compiled_hvd_allreduce(self, dtype, dim): | ||
tensor = self.random_uniform( | ||
[17] * dim, -100, 100, dtype=dtype) | ||
summed = hvd.allreduce(tensor, average=False) | ||
return summed | ||
|
||
In this way, all the ops in the ``compiled_hvd_allreduce`` function are lowered into XLA per the compilation requirement. If the XLA Horovod ops are not enabled, XLA will report compilation errors. | ||
|
||
|
||
For **Auto-clustering**: | ||
|
||
Auto-clustering is a convenient way to use XLA by simply setting ``TF_XLA_FLAGS=--tf_xla_auto_jit=2`` and the XLA JIT automatically selects ops in the Tensorflow graph to be lowered into XLA. In this mode, enabling XLA Horovod ops is optional, because the auto-clustering can work even if the Horovod ops are left to be run by Tensorflow (devices) while only parts of the graphs are lowered onto XLA (devices). | ||
|
||
List of supported XLA Horovod ops | ||
--------------------------------- | ||
|
||
The supported op list is: | ||
|
||
``HorovodAllreduce`` | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -137,6 +137,7 @@ namespace common { | |
#define HOROVOD_DISABLE_NVTX_RANGES "HOROVOD_DISABLE_NVTX_RANGES" | ||
#define HOROVOD_ENABLE_ASYNC_COMPLETION "HOROVOD_ENABLE_ASYNC_COMPLETION" | ||
#define HOROVOD_DYNAMIC_PROCESS_SETS "HOROVOD_DYNAMIC_PROCESS_SETS" | ||
#define HOROVOD_ENABLE_XLA_OPS "HOROVOD_ENABLE_XLA_OPS" | ||
|
||
// String constant for gloo interface. | ||
#define GLOO_DEFAULT_IFACE "" | ||
|
@@ -153,7 +154,7 @@ namespace common { | |
#define JOIN_TENSOR_NAME "join.noname" | ||
|
||
// List of supported frameworks. | ||
enum Framework { TENSORFLOW, PYTORCH, MXNET }; | ||
enum Framework { TENSORFLOW, PYTORCH, MXNET, XLA }; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems that XLA is implemented only for tensorflow, why XLA declare here as a framework ? So does the design of cpp namespace level. Thx for your explain in advance. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the question. My rationale is explained as below:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for your reply, you are right that XLA is a layer independent of tf or even torch. XLA is a layer tf/torch can leverage on in my opinion, that's why I propose the question. I'm not sure am I right with following consideration,
If so, may be we should keep framework declaration and consider to implement XLA without changing it even by add backend ? Overall, your work is wonderful, it's ok to keep this version if we cannot figure out a better one. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You are right that XLA itself is not complete as a framework. How about let's rename This is conceptually clean and simple. The only theoretical drawback I can think of for this approach is that you might later have PYTORCH_XLA, etc. due to combination (of framework and XLA) but I doubt that if this will really become a problem in practice (mainly due to few combinations that really exist in the world). Even if this becomes a problem in practice in the future, we will know better how to deal with it at that time given more data points. I will make the change accordingly if this makes sense to you. Please let me know. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great, we have clarity on the issues but not achieve a solution yet, maybe we should discuss it with the community. |
||
|
||
enum StatusType { OK, UNKNOWN_ERROR, PRECONDITION_ERROR, ABORTED, INVALID_ARGUMENT, IN_PROGRESS }; | ||
|
||
|
@@ -228,6 +229,8 @@ const Status DUPLICATE_NAME_ERROR = Status::InvalidArgument( | |
|
||
class TensorShape { | ||
public: | ||
TensorShape() : shape_() {} | ||
TensorShape(std::vector<int64_t> vec) : shape_(vec) {} | ||
void AddDim(int64_t dim); | ||
void AppendShape(TensorShape& other); | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
// Copyright 2021 The TensorFlow Authors. All Rights Reserved. | ||
// Modifications copyright (C) 2017 Uber Technologies, Inc. | ||
// Modifications copyright Microsoft | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
// ============================================================================= | ||
|
||
include "horovod/common/wire/message.fbs"; | ||
|
||
namespace horovod.xla.wire; | ||
|
||
table TensorShape { | ||
dims:[long]; | ||
} | ||
|
||
table CustomCallConfig { | ||
tensor_name:string; | ||
tensor_type:common.wire.DataType; | ||
input_shapes:[TensorShape]; | ||
output_shapes:[TensorShape]; | ||
|
||
// Prescale and postscale factors | ||
prescale_factor:float; | ||
postscale_factor:float; | ||
|
||
// Root rank is necessary for broadcast operation. | ||
root_rank:int; | ||
|
||
// Reduce op. | ||
reduce_op:int; | ||
|
||
process_set_id:int; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@trentlo @romerojosh
This
_pywrap_tensorflow_internal.so
is not available on OSX#3132