diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000000..17de768cbe --- /dev/null +++ b/.clang-format @@ -0,0 +1,7 @@ +--- +BasedOnStyle: LLVM +--- +Language: Cpp +# Force pointers to the type for C++. +DerivePointerAlignment: false +PointerAlignment: Left diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000..503569b57a --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +*.pyc +.idea +horovod.egg-info +dist +build diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000..f70e1e0f41 --- /dev/null +++ b/LICENSE @@ -0,0 +1,249 @@ + Horovod + Copyright 2017 Uber Technologies, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Horovod includes: + + FlatBuffers + Copyright (c) 2014 Google Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + baidu-research/tensorflow-allreduce + Copyright (c) 2015, The TensorFlow Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000000..e24a4c600f --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,2 @@ +recursive-include * *.h *.cc +include README.md diff --git a/README.md b/README.md new file mode 100644 index 0000000000..2f003bf6f8 --- /dev/null +++ b/README.md @@ -0,0 +1,289 @@ +# Horovod + +Horovod is a distributed training framework for TensorFlow. The goal of Horovod is to make distributed Deep Learning +fast and easy to use. + +# Install + +To install Horovod: + +1. Install [Open MPI](https://www.open-mpi.org/). + +2. Install the `horovod` pip package. + +```bash +$ pip install horovod +``` + +This basic installation is good for laptops and for getting to know Horovod. +If you're installing Horovod on a server with GPUs, read the [Horovod on GPU](#gpu) section. + +# Concepts + +Horovod core principles are based on [MPI](http://mpi-forum.org/) concepts such as *size*, *rank*, +*local rank*, *allreduce*, *allgather* and *broadcast*. These are best explained by example. Say we launched +a training script on 4 servers, each having 4 GPUs. If we launched one copy of the script per GPU: + +1. *Size* would be the number of processes, in this case 16. + +2. *Rank* would be the unique process ID from 0 to 15 (*size* - 1). + +3. *Local rank* would be the unique process ID within the server from 0 to 3. + +4. *Allreduce* is an operation that aggregates data among multiple processes and distributes + results back to them. *Allreduce* is used to average dense tensors. Here's an illustration from the + [MPI Tutorial](http://mpitutorial.com/tutorials/mpi-reduce-and-allreduce/): + + ![Allreduce Illustration](http://mpitutorial.com/tutorials/mpi-reduce-and-allreduce/mpi_allreduce_1.png) + +5. *Allgather* is an operation that gathers data from all processes on every process. *Allgather* is used to collect + values of sparse tensors. Here's an illustration from the [MPI Tutorial](http://mpitutorial.com/tutorials/mpi-scatter-gather-and-allgather/): + + ![Allgather Illustration](http://mpitutorial.com/tutorials/mpi-scatter-gather-and-allgather/allgather.png) + +6. *Broadcast* is an operation that broadcasts data from one process, identified by root rank, onto every other process. + Here's an illustration from the [MPI Tutorial](http://mpitutorial.com/tutorials/mpi-broadcast-and-collective-communication/): + + ![Broadcast Illustration](http://mpitutorial.com/tutorials/mpi-broadcast-and-collective-communication/broadcast_pattern.png) + +# Usage + +To use Horovod, make the following additions to your program: + +1. Run `hvd.init()`. + +2. Pin a server GPU to be used by this process using `config.gpu_options.visible_device_list`. + With the typical setup of one GPU per process, this can be set to *local rank*. + +3. Wrap optimizer in `hvd.DistributedOptimizer`. The distributed optimizer delegates gradient computation + to the original optimizer, averages gradients using *allreduce* or *allgather*, and then applies those averaged + gradients. + +4. Add `hvd.BroadcastGlobalVariablesHook(0)` to broadcast initial variable states from rank 0 to all other + processes. Alternatively, if you're not using `MonitoredTrainingSession`, you can simply execute the + `hvd.broadcast_global_variables` op after global variables have been initialized. + +Example (full MNIST training example is available [here](examples/tensorflow_mnist.py)): + +```python +import tensorflow as tf +import horovod.tensorflow as hvd + + +# Initialize Horovod +hvd.init() + +# Pin GPU to be used to process local rank (one GPU per process) +config = tf.ConfigProto() +config.gpu_options.visible_device_list = str(hvd.local_rank()) + +# Build model... +loss = ... +opt = tf.train.AdagradOptimizer(0.01) + +# Add Horovod Distributed Optimizer +opt = hvd.DistributedOptimizer(opt) + +# Add hook to broadcast variables from rank 0 to all other processes during +# initialization. +hooks = [hvd.BroadcastGlobalVariablesHook(0)] + +# Make training operation +train_op = opt.minimize(loss) + +# The MonitoredTrainingSession takes care of session initialization, +# restoring from a checkpoint, saving to a checkpoint, and closing when done +# or an error occurs. +with tf.train.MonitoredTrainingSession(checkpoint_dir="/tmp/train_logs", + config=config, + hooks=hooks) as mon_sess: + while not mon_sess.should_stop(): + # Perform synchronous training. + mon_sess.run(train_op) +``` + +To run on a machine with 4 GPUs: + +```bash +$ mpirun -np 4 python train.py +``` + +## Horovod on GPU + +To use Horovod on GPU, read the options below and see which one applies to you best. + +### Have GPUs? + +In most situations, using NCCL 2 will significantly improve performance over the CPU version. NCCL 2 provides the *allreduce* +operation optimized for NVIDIA GPUs and a variety of networking devices, such as InfiniBand or RoCE. + +1. Install [NCCL 2](https://developer.nvidia.com/nccl). + +2. Install [Open MPI](https://www.open-mpi.org/). + +3. Install the `horovod` pip package. + +```bash +$ HOROVOD_GPU_ALLREDUCE=NCCL pip install horovod +``` + +**Note**: Some networks with a high computation to communication ratio benefit from doing allreduce on CPU, even if a +GPU version is available. Inception V3 is an example of such network. To force allreduce to happen on CPU, pass +`device_dense='/cpu:0'` to `hvd.DistributedOptimizer`: + +```python +opt = hvd.DistributedOptimizer(opt, device_dense='/cpu:0') +``` + +### Advanced: Have GPUs and networking with GPUDirect? + +[GPUDirect](https://developer.nvidia.com/gpudirect) allows GPUs to transfer memory among each other without CPU +involvement, which significantly reduces latency and load on CPU. NCCL 2 is able to use GPUDirect automatically for +*allreduce* operation if it detects it. + +Additionally, Horovod uses *allgather* and *broadcast* operations from MPI. They are used for averaging sparse tensors +that are typically used for embeddings, and for broadcasting initial state. To speed these operations up with GPUDirect, +make sure your MPI implementation supports CUDA and add `HOROVOD_GPU_ALLGATHER=MPI HOROVOD_GPU_BROADCAST=MPI` to the pip +command. + +1. Install [NCCL 2](https://developer.nvidia.com/nccl). + +2. Install [Open MPI](https://www.open-mpi.org/). + +3. Install the `horovod` pip package. + +```bash +$ HOROVOD_GPU_ALLREDUCE=NCCL HOROVOD_GPU_ALLGATHER=MPI HOROVOD_GPU_BROADCAST=MPI pip install horovod +``` + +**Note**: Allgather allocates an output tensor which is proportionate to the number of processes participating in the +training. If you find yourself running out of GPU memory, you can force allreduce to happen on CPU by passing +`device_sparse='/cpu:0'` to `hvd.DistributedOptimizer`: + +```python +opt = hvd.DistributedOptimizer(opt, device_sparse='/cpu:0') +``` + +### Advanced: Have MPI optimized for your network? + +If you happen to have network hardware not supported by NCCL 2 or your MPI vendor's implementation on GPU is faster, +you can also use the pure MPI version of *allreduce*, *allgather* and *broadcast* on GPU. + +1. Make sure your MPI implementation is installed. + +2. Install the `horovod` pip package. + +```bash +$ HOROVOD_GPU_ALLREDUCE=MPI HOROVOD_GPU_ALLGATHER=MPI HOROVOD_GPU_BROADCAST=MPI pip install horovod +``` + +## Inference + +What about inference? Inference may be done outside of the Python script that was used to train the model. If you do this, it +will not have references to the Horovod library. + +To run inference on a checkpoint generated by the Horovod-enabled training script you should optimize the graph and only +keep operations necessary for a forward pass through network. The [Optimize for Inference](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/optimize_for_inference.py) +script from the TensorFlow repository will do that for you. + +If you want to convert your checkpoint to [Frozen Graph](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py), +you should do so after doing the optimization described above, otherwise the [Freeze Graph](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py) +script will fail to load Horovod op: + +``` +ValueError: No op named HorovodAllreduce in defined operations. +``` + +## Troubleshooting + +### Import TensorFlow failed during installation + +1. Is TensorFlow installed? + +If you see the error message below, it means that TensorFlow is not installed. Please install TensorFlow before installing +Horovod. + +``` +error: import tensorflow failed, is it installed? + +Traceback (most recent call last): + File "/tmp/pip-OfE_YX-build/setup.py", line 29, in fully_define_extension + import tensorflow as tf +ImportError: No module named tensorflow +``` + +2. Are the CUDA libraries available? + +If you see the error message below, it means that TensorFlow cannot be loaded. If you're installing Horovod into a container +on a machine without GPUs, you may use CUDA stub drivers to work around the issue. + +``` +error: import tensorflow failed, is it installed? + +Traceback (most recent call last): + File "/tmp/pip-41aCq9-build/setup.py", line 29, in fully_define_extension + import tensorflow as tf + File "/usr/local/lib/python2.7/dist-packages/tensorflow/__init__.py", line 24, in + from tensorflow.python import * + File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/__init__.py", line 49, in + from tensorflow.python import pywrap_tensorflow + File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/pywrap_tensorflow.py", line 52, in + raise ImportError(msg) +ImportError: Traceback (most recent call last): + File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/pywrap_tensorflow.py", line 41, in + from tensorflow.python.pywrap_tensorflow_internal import * + File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/pywrap_tensorflow_internal.py", line 28, in + _pywrap_tensorflow_internal = swig_import_helper() + File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/pywrap_tensorflow_internal.py", line 24, in swig_import_helper + _mod = imp.load_module('_pywrap_tensorflow_internal', fp, pathname, description) +ImportError: libcuda.so.1: cannot open shared object file: No such file or directory +``` + +To use CUDA stub drivers: + +```bash +# temporary add stub drivers to ld.so.cache +$ ldconfig /usr/local/cuda/lib64/stubs + +# install Horovod, add other HOROVOD_* environment variables as necessary +$ pip install horovod + +# revert to standard libraries +$ ldconfig +``` + +### Running out of memory + +If you notice that your program is running out of GPU memory and multiple processes +are being placed on the same GPU, it's likely that your program (or its dependencies) +create a `tf.Session` that does not use the `config` that pins specific GPU. + +If possible, track down the part of program that uses these additional `tf.Session`s and pass +the same configuration. + +Alternatively, you can place following snippet in the beginning of your program to ask TensorFlow +to minimize the amount of memory it will pre-allocate on each GPU: + +```python +small_cfg = tf.ConfigProto() +small_cfg.gpu_options.allow_growth = True +with tf.Session(config=small_cfg): + pass +``` + +As a last resort, you can **replace** setting `config.gpu_options.visible_device_list` +with different code: + +```python +# Pin GPU to be used +import os +os.environ['CUDA_VISIBLE_DEVICES'] = str(hvd.local_rank()) +``` + +**Note**: Setting `CUDA_VISIBLE_DEVICES` is incompatible with `config.gpu_options.visible_device_list`. + +Setting `CUDA_VISIBLE_DEVICES` has additional disadvantage for GPU version - CUDA will not be able to use IPC, which +will likely cause NCCL and MPI and to fail. In order to disable IPC in NCCL and MPI and allow it to fallback to shared +memory, use: +* `export NCCL_P2P_DISABLE=1` for NCCL. +* `--mca btl_smcuda_use_cuda_ipc 0` flag for OpenMPI and similar flags for other vendors. diff --git a/examples/tensorflow_mnist.py b/examples/tensorflow_mnist.py new file mode 100644 index 0000000000..63795bf16a --- /dev/null +++ b/examples/tensorflow_mnist.py @@ -0,0 +1,110 @@ +# Copyright 2017 Uber Technologies, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +#!/usr/bin/env python + +import tensorflow as tf +import horovod.tensorflow as hvd +layers = tf.contrib.layers +learn = tf.contrib.learn + +tf.logging.set_verbosity(tf.logging.INFO) + + +def conv_model(feature, target, mode): + """2-layer convolution model.""" + # Convert the target to a one-hot tensor of shape (batch_size, 10) and + # with a on-value of 1 for each one-hot vector of length 10. + target = tf.one_hot(tf.cast(target, tf.int32), 10, 1, 0) + + # Reshape feature to 4d tensor with 2nd and 3rd dimensions being + # image width and height final dimension being the number of color channels. + feature = tf.reshape(feature, [-1, 28, 28, 1]) + + # First conv layer will compute 32 features for each 5x5 patch + with tf.variable_scope('conv_layer1'): + h_conv1 = layers.conv2d( + feature, 32, kernel_size=[5, 5], activation_fn=tf.nn.relu) + h_pool1 = tf.nn.max_pool( + h_conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') + + # Second conv layer will compute 64 features for each 5x5 patch. + with tf.variable_scope('conv_layer2'): + h_conv2 = layers.conv2d( + h_pool1, 64, kernel_size=[5, 5], activation_fn=tf.nn.relu) + h_pool2 = tf.nn.max_pool( + h_conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') + # reshape tensor into a batch of vectors + h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64]) + + # Densely connected layer with 1024 neurons. + h_fc1 = layers.dropout( + layers.fully_connected( + h_pool2_flat, 1024, activation_fn=tf.nn.relu), + keep_prob=0.5, + is_training=mode == tf.contrib.learn.ModeKeys.TRAIN) + + # Compute logits (1 per class) and compute loss. + logits = layers.fully_connected(h_fc1, 10, activation_fn=None) + loss = tf.losses.softmax_cross_entropy(target, logits) + + return tf.argmax(logits, 1), loss + + +def main(_): + # Initialize Horovod. + hvd.init() + + # Download and load MNIST dataset. + mnist = learn.datasets.mnist.read_data_sets('MNIST-data-%d' % hvd.rank()) + + # Build model... + with tf.name_scope('input'): + image = tf.placeholder(tf.float32, [None, 784], name='image') + label = tf.placeholder(tf.float32, [None], name='label') + predict, loss = conv_model(image, label, tf.contrib.learn.ModeKeys.TRAIN) + + opt = tf.train.RMSPropOptimizer(0.01) + + # Add Horovod Distributed Optimizer. + opt = hvd.DistributedOptimizer(opt) + + global_step = tf.contrib.framework.get_or_create_global_step() + train_op = opt.minimize(loss, global_step=global_step) + + # BroadcastGlobalVariablesHook broadcasts variables from rank 0 to all other + # processes during initialization. + hooks = [hvd.BroadcastGlobalVariablesHook(0), + tf.train.StopAtStepHook(last_step=100), + tf.train.LoggingTensorHook(tensors={'step': global_step, 'loss': loss}, + every_n_iter=10), + ] + + # Pin GPU to be used to process local rank (one GPU per process) + config = tf.ConfigProto() + config.gpu_options.allow_growth = True + config.gpu_options.visible_device_list = str(hvd.local_rank()) + + # The MonitoredTrainingSession takes care of session initialization, + # restoring from a checkpoint, saving to a checkpoint, and closing when done + # or an error occurs. + with tf.train.SingularMonitoredSession(hooks=hooks, config=config) as mon_sess: + while not mon_sess.should_stop(): + # Run a training step synchronously. + image_, label_ = mnist.train.next_batch(100) + mon_sess.run(train_op, feed_dict={image: image_, label: label_}) + + +if __name__ == "__main__": + tf.app.run() diff --git a/examples/tensorflow_word2vec.py b/examples/tensorflow_word2vec.py new file mode 100644 index 0000000000..bd4f2f9efa --- /dev/null +++ b/examples/tensorflow_word2vec.py @@ -0,0 +1,245 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Modifications copyright (C) 2017 Uber Technologies, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Basic word2vec example.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import math +import os +import random +import zipfile + +import numpy as np +from six.moves import urllib +from six.moves import xrange # pylint: disable=redefined-builtin +import tensorflow as tf +import horovod.tensorflow as hvd + +# Initialize Horovod. +hvd.init() + + +# Step 1: Download the data. +url = 'http://mattmahoney.net/dc/text8.zip' + + +def maybe_download(filename, expected_bytes): + """Download a file if not present, and make sure it's the right size.""" + if not os.path.exists(filename): + filename, _ = urllib.request.urlretrieve(url, filename) + statinfo = os.stat(filename) + if statinfo.st_size == expected_bytes: + print('Found and verified', filename) + else: + print(statinfo.st_size) + raise Exception( + 'Failed to verify ' + url + '. Can you get to it with a browser?') + return filename + +filename = maybe_download('text8-%d.zip' % hvd.rank(), 31344016) + + +# Read the data into a list of strings. +def read_data(filename): + """Extract the first file enclosed in a zip file as a list of words.""" + with zipfile.ZipFile(filename) as f: + data = tf.compat.as_str(f.read(f.namelist()[0])).split() + return data + +vocabulary = read_data(filename) +print('Data size', len(vocabulary)) + +# Step 2: Build the dictionary and replace rare words with UNK token. +vocabulary_size = 50000 + + +def build_dataset(words, n_words): + """Process raw inputs into a dataset.""" + count = [['UNK', -1]] + count.extend(collections.Counter(words).most_common(n_words - 1)) + dictionary = dict() + for word, _ in count: + dictionary[word] = len(dictionary) + data = list() + unk_count = 0 + for word in words: + if word in dictionary: + index = dictionary[word] + else: + index = 0 # dictionary['UNK'] + unk_count += 1 + data.append(index) + count[0][1] = unk_count + reversed_dictionary = dict(zip(dictionary.values(), dictionary.keys())) + return data, count, dictionary, reversed_dictionary + +data, count, dictionary, reverse_dictionary = build_dataset(vocabulary, + vocabulary_size) +del vocabulary # Hint to reduce memory. +print('Most common words (+UNK)', count[:5]) +print('Sample data', data[:10], [reverse_dictionary[i] for i in data[:10]]) + + +# Step 3: Function to generate a training batch for the skip-gram model. +def generate_batch(batch_size, num_skips, skip_window): + assert num_skips <= 2 * skip_window + # Adjust batch_size to match num_skips + batch_size = batch_size // num_skips * num_skips + span = 2 * skip_window + 1 # [ skip_window target skip_window ] + # Backtrack a little bit to avoid skipping words in the end of a batch + data_index = random.randint(0, len(data) - span - 1) + batch = np.ndarray(shape=(batch_size), dtype=np.int32) + labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32) + buffer = collections.deque(maxlen=span) + for _ in range(span): + buffer.append(data[data_index]) + data_index = (data_index + 1) % len(data) + for i in range(batch_size // num_skips): + target = skip_window # target label at the center of the buffer + targets_to_avoid = [skip_window] + for j in range(num_skips): + while target in targets_to_avoid: + target = random.randint(0, span - 1) + targets_to_avoid.append(target) + batch[i * num_skips + j] = buffer[skip_window] + labels[i * num_skips + j, 0] = buffer[target] + buffer.append(data[data_index]) + data_index = (data_index + 1) % len(data) + return batch, labels + +batch, labels = generate_batch(batch_size=8, num_skips=2, skip_window=1) +for i in range(8): + print(batch[i], reverse_dictionary[batch[i]], + '->', labels[i, 0], reverse_dictionary[labels[i, 0]]) + +# Step 4: Build and train a skip-gram model. + +max_batch_size = 128 +embedding_size = 128 # Dimension of the embedding vector. +skip_window = 1 # How many words to consider left and right. +num_skips = 2 # How many times to reuse an input to generate a label. + +# We pick a random validation set to sample nearest neighbors. Here we limit the +# validation samples to the words that have a low numeric ID, which by +# construction are also the most frequent. +valid_size = 16 # Random set of words to evaluate similarity on. +valid_window = 100 # Only pick dev samples in the head of the distribution. +valid_examples = np.random.choice(valid_window, valid_size, replace=False) +num_sampled = 64 # Number of negative examples to sample. + +graph = tf.Graph() + +with graph.as_default(): + + # Input data. + train_inputs = tf.placeholder(tf.int32, shape=[None]) + train_labels = tf.placeholder(tf.int32, shape=[None, 1]) + valid_dataset = tf.constant(valid_examples, dtype=tf.int32) + + # Look up embeddings for inputs. + embeddings = tf.Variable( + tf.random_uniform([vocabulary_size, embedding_size], -1.0, 1.0)) + embed = tf.nn.embedding_lookup(embeddings, train_inputs) + + # Construct the variables for the NCE loss + nce_weights = tf.Variable( + tf.truncated_normal([vocabulary_size, embedding_size], + stddev=1.0 / math.sqrt(embedding_size))) + nce_biases = tf.Variable(tf.zeros([vocabulary_size])) + + # Compute the average NCE loss for the batch. + # tf.nce_loss automatically draws a new sample of the negative labels each + # time we evaluate the loss. + loss = tf.reduce_mean( + tf.nn.nce_loss(weights=nce_weights, + biases=nce_biases, + labels=train_labels, + inputs=embed, + num_sampled=num_sampled, + num_classes=vocabulary_size)) + + # Construct the SGD optimizer using a learning rate of 1.0. + optimizer = tf.train.GradientDescentOptimizer(1.0) + + # Add Horovod Distributed Optimizer. + optimizer = hvd.DistributedOptimizer(optimizer) + + train_op = optimizer.minimize(loss) + + # Compute the cosine similarity between minibatch examples and all embeddings. + norm = tf.sqrt(tf.reduce_sum(tf.square(embeddings), 1, keep_dims=True)) + normalized_embeddings = embeddings / norm + valid_embeddings = tf.nn.embedding_lookup( + normalized_embeddings, valid_dataset) + similarity = tf.matmul( + valid_embeddings, normalized_embeddings, transpose_b=True) + + # Add variable initializer. + init = tf.global_variables_initializer() + + # Broadcast variables from rank 0 to all other processes. + bcast = hvd.broadcast_global_variables(0) + +# Step 5: Begin training. +num_steps = 100001 + +# Pin GPU to be used to process local rank (one GPU per process) +config = tf.ConfigProto() +config.gpu_options.allow_growth = True +config.gpu_options.visible_device_list = str(hvd.local_rank()) + +with tf.Session(graph=graph, config=config) as session: + # We must initialize all variables before we use them. + init.run() + bcast.run() + print('Initialized') + + average_loss = 0 + for step in xrange(num_steps): + # simulate various sentence length by randomization + batch_size = random.randint(max_batch_size // 2, max_batch_size) + batch_inputs, batch_labels = generate_batch( + batch_size, num_skips, skip_window) + feed_dict = {train_inputs: batch_inputs, train_labels: batch_labels} + + # We perform one update step by evaluating the optimizer op (including it + # in the list of returned values for session.run() + _, loss_val = session.run([train_op, loss], feed_dict=feed_dict) + average_loss += loss_val + + if step % 2000 == 0: + if step > 0: + average_loss /= 2000 + # The average loss is an estimate of the loss over the last 2000 batches. + print('Average loss at step ', step, ': ', average_loss) + average_loss = 0 + + # Note that this is expensive (~20% slowdown if computed every 500 steps) + if step % 10000 == 0: + sim = similarity.eval() + for i in xrange(valid_size): + valid_word = reverse_dictionary[valid_examples[i]] + top_k = 8 # number of nearest neighbors + nearest = (-sim[i, :]).argsort()[1:top_k + 1] + log_str = 'Nearest to %s:' % valid_word + for k in xrange(top_k): + close_word = reverse_dictionary[nearest[k]] + log_str = '%s %s,' % (log_str, close_word) + print(log_str) + final_embeddings = normalized_embeddings.eval() diff --git a/horovod/__init__.py b/horovod/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/horovod/tensorflow/__init__.py b/horovod/tensorflow/__init__.py new file mode 100644 index 0000000000..f859e01b95 --- /dev/null +++ b/horovod/tensorflow/__init__.py @@ -0,0 +1,214 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# Modifications copyright (C) 2017 Uber Technologies, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: disable=g-short-docstring-punctuation +"""## Communicating Between Processes with MPI + +TensorFlow natively provides inter-device communication through send and +receive ops and inter-node communication through Distributed TensorFlow, based +on the same send and receive abstractions. On HPC clusters where Infiniband or +other high-speed node interconnects are available, these can end up being +insufficient for synchronous data-parallel training (without asynchronous +gradient descent). This module implements a variety of MPI ops which can take +advantage of hardware-specific MPI libraries for efficient communication. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from horovod.tensorflow.mpi_ops import size +from horovod.tensorflow.mpi_ops import rank +from horovod.tensorflow.mpi_ops import local_rank +from horovod.tensorflow.mpi_ops import allgather +from horovod.tensorflow.mpi_ops import broadcast +from horovod.tensorflow.mpi_ops import _allreduce +from horovod.tensorflow.mpi_ops import init + + +def allreduce(tensor, average=True, device_dense='', device_sparse=''): + """Perform an allreduce on a tf.Tensor or tf.IndexedSlices. + + Arguments: + tensor: tf.Tensor, tf.Variable, or tf.IndexedSlices to reduce. + The shape of the input must be identical across all ranks. + average: If True, computes the average over all ranks. + Otherwise, computes the sum over all ranks. + device_dense: Device to be used for dense tensors. Uses GPU by default + if Horovod was build with HOROVOD_GPU_ALLREDUCE. + device_sparse: Device to be used for sparse tensors. Uses GPU by default + if Horovod was build with HOROVOD_GPU_ALLGATHER. + + This function performs a bandwidth-optimal ring allreduce on the input + tensor. If the input is an tf.IndexedSlices, the function instead does an + allgather on the values and the indices, effectively doing an allreduce on + the represented tensor. + """ + if isinstance(tensor, tf.IndexedSlices): + with tf.device(device_sparse): + # For IndexedSlices, do two allgathers intead of an allreduce. + horovod_size = tf.cast(size(), tensor.values.dtype) + values = allgather(tensor.values) + indices = allgather(tensor.indices) + + # To make this operation into an average, divide all gathered values by + # the Horovod size. + new_values = tf.div(values, horovod_size) if average else values + return tf.IndexedSlices(new_values, indices, + dense_shape=tensor.dense_shape) + else: + with tf.device(device_dense): + horovod_size = tf.cast(size(), tensor.dtype) + summed_tensor = _allreduce(tensor) + new_tensor = (tf.div(summed_tensor, horovod_size) + if average else summed_tensor) + return new_tensor + + +def broadcast_global_variables(root_rank): + """Broadcasts all global variables from root rank to all other processes. + + Arguments: + root_rank: rank of the process from which global variables will be broadcasted + to all other processes. + """ + return tf.group(*[tf.assign(var, broadcast(var, root_rank)) + for var in tf.global_variables()]) + + +class BroadcastGlobalVariablesHook(tf.train.SessionRunHook): + """SessionRunHook that will broadcast all global variables from root rank + to all other processes during initialization.""" + + def __init__(self, root_rank, device=''): + """Construct a new BroadcastGlobalVariablesHook that will broadcast all + global variables from root rank to all other processes during initialization. + + Args: + root_rank: + Rank that will send data, other ranks will receive data. + device: + Device to be used for broadcasting. Uses GPU by default + if Horovod was build with HOROVOD_GPU_BROADCAST. + """ + self.root_rank = root_rank + self.bcast_op = None + self.device = device + + def begin(self): + if not self.bcast_op: + with tf.device(self.device): + self.bcast_op = broadcast_global_variables(self.root_rank) + + def after_create_session(self, session, coord): + session.run(self.bcast_op) + + +class DistributedOptimizer(tf.train.Optimizer): + """An optimizer that wraps another tf.Optimizer, using an allreduce to + average gradient values before applying gradients to model weights.""" + + def __init__(self, optimizer, name=None, use_locking=False, device_dense='', + device_sparse=''): + """Construct a new DistributedOptimizer, which uses another optimizer + under the hood for computing single-process gradient values and + applying gradient updates after the gradient values have been averaged + across all the Horovod ranks. + + Args: + optimizer: + Optimizer to use for computing gradients and applying updates. + name: + Optional name prefix for the operations created when applying + gradients. Defaults to "Distributed" followed by the provided + optimizer type. + use_locking: + Whether to use locking when updating variables. + See Optimizer.__init__ for more info. + device_dense: + Device to be used for dense tensors. Uses GPU by default + if Horovod was build with HOROVOD_GPU_ALLREDUCE. + device_sparse: + Device to be used for sparse tensors. Uses GPU by default + if Horovod was build with HOROVOD_GPU_ALLGATHER. + """ + if name is None: + name = "Distributed{}".format(type(optimizer).__name__) + + self._optimizer = optimizer + self._device_dense = device_dense + self._device_sparse = device_sparse + super(DistributedOptimizer, self).__init__( + name=name, use_locking=use_locking) + + def compute_gradients(self, *args, **kwargs): + """Compute gradients of all trainable variables. + + See Optimizer.compute_gradients() for more info. + + In DistributedOptimizer, compute_gradients() is overriden to also + allreduce the gradients before returning them. + """ + gradients = (super(DistributedOptimizer, self) + .compute_gradients(*args, **kwargs)) + if size() > 1: + with tf.name_scope(self._name + "_Allreduce"): + return [(allreduce(gradient, device_dense=self._device_dense, + device_sparse=self._device_sparse), var) + for (gradient, var) in gradients] + else: + return gradients + + def _apply_dense(self, *args, **kwargs): + """Calls this same method on the underlying optimizer.""" + return self._optimizer._apply_dense(*args, **kwargs) + + def _resource_apply_dense(self, *args, **kwargs): + """Calls this same method on the underlying optimizer.""" + return self._optimizer._resource_apply_dense(*args, **kwargs) + + def _resource_apply_sparse_duplicate_indices(self, *args, **kwargs): + """Calls this same method on the underlying optimizer.""" + return self._optimizer._resource_apply_sparse_duplicate_indices(*args, **kwargs) + + def _resource_apply_sparse(self, *args, **kwargs): + """Calls this same method on the underlying optimizer.""" + return self._optimizer._resource_apply_sparse(*args, **kwargs) + + def _apply_sparse_duplicate_indices(self, *args, **kwargs): + """Calls this same method on the underlying optimizer.""" + return self._optimizer._apply_sparse_duplicate_indices(*args, **kwargs) + + def _apply_sparse(self, *args, **kwargs): + """Calls this same method on the underlying optimizer.""" + return self._optimizer._apply_sparse(*args, **kwargs) + + def _prepare(self, *args, **kwargs): + """Calls this same method on the underlying optimizer.""" + return self._optimizer._prepare(*args, **kwargs) + + def _create_slots(self, *args, **kwargs): + """Calls this same method on the underlying optimizer.""" + return self._optimizer._create_slots(*args, **kwargs) + + def _valid_dtypes(self, *args, **kwargs): + """Calls this same method on the underlying optimizer.""" + return self._optimizer._valid_dtypes(*args, **kwargs) + + def _finish(self, *args, **kwargs): + """Calls this same method on the underlying optimizer.""" + return self._optimizer._finish(*args, **kwargs) diff --git a/horovod/tensorflow/hash_vector.h b/horovod/tensorflow/hash_vector.h new file mode 100644 index 0000000000..b2d9f30565 --- /dev/null +++ b/horovod/tensorflow/hash_vector.h @@ -0,0 +1,38 @@ +// Copyright 2017 Uber Technologies, Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef HOROVOD_HASH_VECTOR_H +#define HOROVOD_HASH_VECTOR_H + +#include + +namespace std { + +template struct hash> { + typedef std::vector argument_type; + typedef std::size_t result_type; + + result_type operator()(argument_type const& in) const { + size_t size = in.size(); + size_t seed = 0; + for (size_t i = 0; i < size; i++) + seed ^= std::hash()(in[i]) + 0x9e3779b9 + (seed << 6) + (seed >> 2); + return seed; + } +}; + +} // namespace std + +#endif //HOROVOD_HASH_VECTOR_H diff --git a/horovod/tensorflow/mpi_message.cc b/horovod/tensorflow/mpi_message.cc new file mode 100644 index 0000000000..1a8d4d12b6 --- /dev/null +++ b/horovod/tensorflow/mpi_message.cc @@ -0,0 +1,246 @@ +// Copyright 2016 The TensorFlow Authors. All Rights Reserved. +// Modifications copyright (C) 2017 Uber Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mpi_message.h" +#include "wire/mpi_message_generated.h" +#include + +namespace horovod { +namespace tensorflow { + +const std::string& MPIDataType_Name(MPIDataType value) { + switch (value) { + case TF_MPI_UINT8: + static const std::string uint8("uint8"); + return uint8; + case TF_MPI_INT8: + static const std::string int8("int8"); + return int8; + case TF_MPI_UINT16: + static const std::string uint16("uint16"); + return uint16; + case TF_MPI_INT16: + static const std::string int16("int16"); + return int16; + case TF_MPI_INT32: + static const std::string int32("int32"); + return int32; + case TF_MPI_INT64: + static const std::string int64("int64"); + return int64; + case TF_MPI_FLOAT32: + static const std::string float32("float32"); + return float32; + case TF_MPI_FLOAT64: + static const std::string float64("float64"); + return float64; + default: + static const std::string unknown(""); + return unknown; + } +} + +const std::string& MPIRequest::RequestType_Name(RequestType value) { + switch (value) { + case RequestType::ALLREDUCE: + static const std::string allreduce("ALLREDUCE"); + return allreduce; + case RequestType::ALLGATHER: + static const std::string allgather("ALLGATHER"); + return allgather; + case RequestType::BROADCAST: + static const std::string broadcast("BROADCAST"); + return broadcast; + default: + static const std::string unknown(""); + return unknown; + } +} + +int32_t MPIRequest::request_rank() const { return request_rank_; } + +void MPIRequest::set_request_rank(int32_t value) { request_rank_ = value; } + +MPIRequest::RequestType MPIRequest::request_type() const { + return request_type_; +} + +void MPIRequest::set_request_type(RequestType value) { request_type_ = value; } + +MPIDataType MPIRequest::tensor_type() const { return tensor_type_; } + +void MPIRequest::set_tensor_type(MPIDataType value) { tensor_type_ = value; } + +const std::string& MPIRequest::tensor_name() const { return tensor_name_; } + +void MPIRequest::set_tensor_name(const std::string& value) { + tensor_name_ = value; +} + +int32_t MPIRequest::root_rank() const { return root_rank_; } + +void MPIRequest::set_root_rank(int32_t value) { root_rank_ = value; } + +int32_t MPIRequest::device() const { return device_; } + +void MPIRequest::set_device(int32_t value) { device_ = value; } + +const std::vector& MPIRequest::tensor_shape() const { + return tensor_shape_; +} + +void MPIRequest::set_tensor_shape(const std::vector& value) { + tensor_shape_ = value; +} + +void MPIRequest::add_tensor_shape(int64_t value) { + tensor_shape_.push_back(value); +} + +void MPIRequest::ParseFromString(MPIRequest& request, + const std::string& input) { + auto obj = flatbuffers::GetRoot((uint8_t*)input.c_str()); + request.set_request_rank(obj->request_rank()); + request.set_request_type((MPIRequest::RequestType)obj->request_type()); + request.set_tensor_type((MPIDataType)obj->tensor_type()); + request.set_tensor_name(obj->tensor_name()->str()); + request.set_root_rank(obj->root_rank()); + request.set_device(obj->device()); + request.set_tensor_shape(std::vector(obj->tensor_shape()->begin(), + obj->tensor_shape()->end())); +} + +void MPIRequest::SerializeToString(MPIRequest& request, std::string& output) { + flatbuffers::FlatBufferBuilder builder(1024); + wire::MPIRequestBuilder request_builder(builder); + request_builder.add_request_rank(request.request_rank()); + request_builder.add_request_type( + (wire::MPIRequestType)request.request_type()); + request_builder.add_tensor_type((wire::MPIDataType)request.tensor_type()); + request_builder.add_tensor_name(builder.CreateString(request.tensor_name())); + request_builder.add_root_rank(request.root_rank()); + request_builder.add_device(request.device()); + request_builder.add_tensor_shape( + builder.CreateVector(request.tensor_shape())); + auto obj = request_builder.Finish(); + builder.Finish(obj); + + uint8_t* buf = builder.GetBufferPointer(); + auto size = builder.GetSize(); + output = std::string((char*)buf, size); +} + +const std::string& MPIResponse::ResponseType_Name(ResponseType value) { + switch (value) { + case ResponseType::ALLREDUCE: + static const std::string allreduce("ALLREDUCE"); + return allreduce; + case ResponseType::ALLGATHER: + static const std::string allgather("ALLGATHER"); + return allgather; + case ResponseType::BROADCAST: + static const std::string broadcast("BROADCAST"); + return broadcast; + case ResponseType::ERROR: + static const std::string error("ERROR"); + return error; + case ResponseType::DONE: + static const std::string done("DONE"); + return done; + case ResponseType::SHUTDOWN: + static const std::string shutdown("SHUTDOWN"); + return shutdown; + default: + static const std::string unknown(""); + return unknown; + } +} + +MPIResponse::ResponseType MPIResponse::response_type() const { + return response_type_; +} + +void MPIResponse::set_response_type(ResponseType value) { + response_type_ = value; +} + +const std::string& MPIResponse::tensor_name() const { return tensor_name_; } + +void MPIResponse::set_tensor_name(const std::string& value) { + tensor_name_ = value; +} + +const std::string& MPIResponse::error_message() const { return error_message_; } + +void MPIResponse::set_error_message(const std::string& value) { + error_message_ = value; +} + +const std::vector& MPIResponse::devices() const { return devices_; } + +void MPIResponse::set_devices(const std::vector& value) { + devices_ = value; +} + +void MPIResponse::add_devices(int32_t value) { devices_.push_back(value); } + +const std::vector& MPIResponse::tensor_sizes() const { + return tensor_sizes_; +} + +void MPIResponse::set_tensor_sizes(const std::vector& value) { + tensor_sizes_ = value; +} + +void MPIResponse::add_tensor_sizes(int64_t value) { + tensor_sizes_.push_back(value); +} + +void MPIResponse::ParseFromString(MPIResponse& response, + const std::string& input) { + auto obj = flatbuffers::GetRoot((uint8_t*)input.c_str()); + response.set_response_type((MPIResponse::ResponseType)obj->response_type()); + response.set_tensor_name(obj->tensor_name()->str()); + response.set_error_message(obj->error_message()->str()); + response.set_devices( + std::vector(obj->devices()->begin(), obj->devices()->end())); + response.set_tensor_sizes(std::vector(obj->tensor_sizes()->begin(), + obj->tensor_sizes()->end())); +} + +void MPIResponse::SerializeToString(MPIResponse& response, + std::string& output) { + flatbuffers::FlatBufferBuilder builder(1024); + wire::MPIResponseBuilder response_builder(builder); + response_builder.add_response_type( + (wire::MPIResponseType)response.response_type()); + response_builder.add_tensor_name( + builder.CreateString(response.tensor_name())); + response_builder.add_error_message( + builder.CreateString(response.error_message())); + response_builder.add_devices(builder.CreateVector(response.devices())); + response_builder.add_tensor_sizes( + builder.CreateVector(response.tensor_sizes())); + auto obj = response_builder.Finish(); + builder.Finish(obj); + + uint8_t* buf = builder.GetBufferPointer(); + auto size = builder.GetSize(); + output = std::string((char*)buf, size); +} + +} // namespace tensorflow +} // namespace horovod diff --git a/horovod/tensorflow/mpi_message.h b/horovod/tensorflow/mpi_message.h new file mode 100644 index 0000000000..6da559219a --- /dev/null +++ b/horovod/tensorflow/mpi_message.h @@ -0,0 +1,142 @@ +// Copyright 2016 The TensorFlow Authors. All Rights Reserved. +// Modifications copyright (C) 2017 Uber Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef HOROVOD_MPI_MESSAGE_H +#define HOROVOD_MPI_MESSAGE_H + +#include +#include + +namespace horovod { +namespace tensorflow { + +enum MPIDataType { + TF_MPI_UINT8 = 0, + TF_MPI_INT8 = 1, + TF_MPI_UINT16 = 2, + TF_MPI_INT16 = 3, + TF_MPI_INT32 = 4, + TF_MPI_INT64 = 5, + TF_MPI_FLOAT32 = 6, + TF_MPI_FLOAT64 = 7 +}; + +const std::string& MPIDataType_Name(MPIDataType value); + +// An MPIRequest is a message sent from a rank greater than zero to the +// coordinator (rank zero), informing the coordinator of an operation that +// the rank wants to do and the tensor that it wants to apply the operation to. +class MPIRequest { +public: + enum RequestType { ALLREDUCE = 0, ALLGATHER = 1, BROADCAST = 2 }; + + static const std::string& RequestType_Name(RequestType value); + + // The request rank is necessary to create a consistent ordering of results, + // for example in the allgather where the order of outputs should be sorted + // by rank. + int32_t request_rank() const; + void set_request_rank(int32_t value); + + RequestType request_type() const; + void set_request_type(RequestType value); + + MPIDataType tensor_type() const; + void set_tensor_type(MPIDataType value); + + const std::string& tensor_name() const; + void set_tensor_name(const std::string& value); + + int32_t root_rank() const; + void set_root_rank(int32_t value); + + int32_t device() const; + void set_device(int32_t value); + + const std::vector& tensor_shape() const; + void set_tensor_shape(const std::vector& value); + void add_tensor_shape(int64_t value); + + static void ParseFromString(MPIRequest& request, const std::string& input); + static void SerializeToString(MPIRequest& request, std::string& output); + +private: + int32_t request_rank_; + RequestType request_type_; + MPIDataType tensor_type_; + int32_t root_rank_; + int32_t device_; + std::string tensor_name_; + std::vector tensor_shape_; +}; + +// An MPIResponse is a message sent from the coordinator (rank zero) to a rank +// greater than zero, informing the rank of an operation should be performed +// now. If the operation requested would result in an error (for example, due +// to a type or shape mismatch), then the MPIResponse can contain an error and +// an error message instead. Finally, an MPIResponse can be a DONE message (if +// there are no more tensors to reduce on this tick of the background loop) or +// SHUTDOWN if all MPI processes should shut down. +class MPIResponse { +public: + enum ResponseType { + ALLREDUCE = 0, + ALLGATHER = 1, + BROADCAST = 2, + ERROR = 3, + DONE = 4, + SHUTDOWN = 5 + }; + + static const std::string& ResponseType_Name(ResponseType value); + + ResponseType response_type() const; + void set_response_type(ResponseType value); + + // Empty if the type is DONE or SHUTDOWN. + const std::string& tensor_name() const; + void set_tensor_name(const std::string& value); + + // Empty unless response_type is ERROR. + const std::string& error_message() const; + void set_error_message(const std::string& value); + + const std::vector& devices() const; + void set_devices(const std::vector& value); + void add_devices(int32_t value); + + // Empty unless response_type is ALLGATHER. + // These tensor sizes are the dimension zero sizes of all the input matrices, + // indexed by the rank. + const std::vector& tensor_sizes() const; + void set_tensor_sizes(const std::vector& value); + void add_tensor_sizes(int64_t value); + + static void ParseFromString(MPIResponse& response, const std::string& input); + static void SerializeToString(MPIResponse& response, std::string& output); + +private: + ResponseType response_type_; + std::string tensor_name_; + std::string error_message_; + std::vector devices_; + std::vector tensor_sizes_; +}; + +} // namespace tensorflow +} // namespace horovod + +#endif // HOROVOD_MPI_MESSAGE_H diff --git a/horovod/tensorflow/mpi_ops.cc b/horovod/tensorflow/mpi_ops.cc new file mode 100644 index 0000000000..27ee4b94b9 --- /dev/null +++ b/horovod/tensorflow/mpi_ops.cc @@ -0,0 +1,1451 @@ +// Copyright 2016 The TensorFlow Authors. All Rights Reserved. +// Modifications copyright (C) 2017 Uber Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include +#include +#include + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/shape_inference.h" + +#define EIGEN_USE_THREADS + +#if HAVE_CUDA +#include "tensorflow/stream_executor/stream.h" +#include +#endif + +#if HAVE_NCCL +#include +#endif + +#define OMPI_SKIP_MPICXX +#include "mpi.h" +#include "mpi_message.h" +#include "hash_vector.h" + +/* + * Allreduce, Allgather and Broadcast Ops for TensorFlow. + * + * TensorFlow natively provides inter-device communication through send and + * receive ops and inter-node communication through Distributed TensorFlow, + * based on the same send and receive abstractions. These end up being + * insufficient for synchronous data-parallel training on HPC clusters where + * Infiniband or other high-speed interconnects are available. This module + * implements MPI ops for allgather, allreduce and broadcast, which do + * optimized gathers, reductions and broadcasts and can take advantage of + * hardware-optimized communication libraries through the MPI implementation. + * + * The primary logic of the allreduce, allgather and broadcast are in MPI and + * NCCL implementations. The background thread which facilitates MPI operations + * is run in BackgroundThreadLoop(). The provided ops are: + * – HorovodAllreduce: + * Perform an allreduce on a Tensor, returning the sum + * across all MPI processes in the global communicator. + * – HorovodAllgather: + * Perform an allgather on a Tensor, returning the concatenation of + * the tensor on the first dimension across all MPI processes in the + * global communicator. + * - HorovodBroadcast: + * Perform a broadcast on a Tensor, broadcasting Tensor + * value from root rank to all other ranks. + * + * Additionally, this library provides C APIs to initialize Horovod and query + * rank, local rank and world size. These are used in Python directly through + * ctypes. + */ + +using namespace tensorflow; + +namespace horovod { +namespace tensorflow { + +namespace { + +// Device ID used for CPU. +#define CPU_DEVICE_ID -1 + +// Use void pointer for ready event if CUDA is not present to avoid linking +// error. +#if HAVE_CUDA +#define GPU_EVENT_IF_CUDA perftools::gputools::Event* +#else +#define GPU_EVENT_IF_CUDA void* +#endif + +// A callback to call after the MPI communication completes. Since the +// allreduce and allgather ops are asynchronous, this callback is what resumes +// computation after the reduction is completed. +typedef std::function StatusCallback; + +// Table storing Tensors to be reduced, keyed by unique name. +// This table contains everything necessary to do the reduction. +typedef struct { + // Operation context. + OpKernelContext* context; + // Input tensor. + Tensor tensor; + // Pre-allocated output tensor. + Tensor* output; + // Root rank for broadcast operation. + int root_rank; + // Event indicating that data is ready. + GPU_EVENT_IF_CUDA ready_event; + // GPU to do reduction on, or CPU_DEVICE_ID in case of CPU. + int device; + // A callback to call with the status. + StatusCallback callback; +} TensorTableEntry; +typedef std::unordered_map TensorTable; + +// Table for storing Tensor metadata on rank zero. This is used for error +// checking, stall checking and size calculations, as well as determining +// when a reduction is ready to be done (when all nodes are ready to do it). +typedef std::unordered_map< + std::string, + std::tuple, std::chrono::system_clock::time_point>> + MessageTable; + +// The global state required for the MPI ops. +// +// MPI is a library that stores a lot of global per-program state and often +// requires running on a single thread. As a result, we have to have a single +// background thread responsible for all MPI operations, and communicate with +// that background thread through global state. +struct HorovodGlobalState { + // An atomic boolean which is set to true when background thread is started. + // This ensures that only one background thread is spawned. + std::atomic_flag initialize_flag = ATOMIC_FLAG_INIT; + + // A mutex that needs to be used whenever MPI operations are done. + std::mutex mutex; + + // Tensors waiting to be allreduced or allgathered. + TensorTable tensor_table; + + // Queue of MPI requests waiting to be sent to the coordinator node. + std::queue message_queue; + + // Background thread running MPI communication. + std::thread background_thread; + + // Whether the background thread should shutdown. + bool shut_down = false; + + // Only exists on the coordinator node (rank zero). Maintains a count of + // how many nodes are ready to allreduce every tensor (keyed by tensor + // name) and time point when tensor started allreduce op. + std::unique_ptr message_table; + + // Time point when coordinator last checked for stalled tensors. + std::chrono::system_clock::time_point last_stall_check; + + // Whether MPI_Init has been completed on the background thread. + bool initialization_done = false; + + // The MPI rank, local rank, and size. + int rank = 0; + int local_rank = 0; + int size = 1; + +// The CUDA stream used for data transfers and within-allreduce operations. +// A naive implementation would use the TensorFlow StreamExecutor CUDA +// stream. However, the allreduce and allgather require doing memory copies +// and kernel executions (for accumulation of values on the GPU). However, +// the subsequent operations must wait for those operations to complete, +// otherwise MPI (which uses its own stream internally) will begin the data +// transfers before the CUDA calls are complete. In order to wait for those +// CUDA operations, if we were using the TensorFlow stream, we would have to +// synchronize that stream; however, other TensorFlow threads may be +// submitting more work to that stream, so synchronizing on it can cause the +// allreduce to be delayed, waiting for compute totally unrelated to it in +// other parts of the graph. Overlaying memory transfers and compute during +// backpropagation is crucial for good performance, so we cannot use the +// TensorFlow stream, and must use our own stream. +#if HAVE_NCCL + std::unordered_map streams; + std::unordered_map, ncclComm_t> nccl_comms; +#endif + + ~HorovodGlobalState() { + // Make sure that the destructor of the background thread is safe to + // call. If a thread is still joinable (not detached or complete) its + // destructor cannot be called. + if (background_thread.joinable()) { + shut_down = true; + background_thread.join(); + } + } +}; + +// All the Horovod state that must be stored globally per-process. +static HorovodGlobalState horovod_global; + +// For clarify in argument lists. +#define RANK_ZERO 0 + +// A tag used for all coordinator messaging. +#define TAG_NOTIFY 1 + +// Stall-check warning time +#define STALL_WARNING_TIME std::chrono::seconds(60) + +// Store the MPIRequest for a name, and return whether the total count of +// MPIRequests for that tensor is now equal to the MPI size (and thus we are +// ready to reduce the tensor). +bool IncrementTensorCount(std::unique_ptr& message_table, + MPIRequest msg, int mpi_size) { + auto name = msg.tensor_name(); + auto table_iter = message_table->find(name); + if (table_iter == message_table->end()) { + std::vector messages = {msg}; + auto now = std::chrono::system_clock::now(); + message_table->emplace(name, std::make_tuple(std::move(messages), now)); + table_iter = message_table->find(name); + } else { + std::vector& messages = std::get<0>(table_iter->second); + messages.push_back(msg); + } + + std::vector& messages = std::get<0>(table_iter->second); + int count = (int)messages.size(); + return count == mpi_size; +} + +// Once a tensor is ready to be reduced, the coordinator sends an MPIResponse +// instructing all ranks to start the reduction to all ranks. The MPIResponse +// also contains error messages in case the submitted MPIRequests were not +// valid (for example, contained mismatched shapes or types). +// +// Constructing the MPIResponse, thus, requires a whole lot of error checking. +MPIResponse ConstructMPIResponse(std::unique_ptr& message_table, + std::string name) { + bool error = false; + auto it = message_table->find(name); + assert(it != message_table->end()); + + std::vector& requests = std::get<0>(it->second); + assert(requests.size() > 0); + + std::ostringstream error_message_stream; + + // Check that all data types of tensors being reduced, gathered or broadcasted + // are identical. + auto data_type = requests[0].tensor_type(); + for (unsigned int i = 1; i < requests.size(); i++) { + auto request_type = requests[i].tensor_type(); + if (data_type != request_type) { + error = true; + error_message_stream << "Mismatched data types: One rank had type " + << MPIDataType_Name(data_type) + << ", but another rank had type " + << MPIDataType_Name(request_type) << "."; + break; + } + } + + // Check that all requested operations are the same + auto message_type = requests[0].request_type(); + for (unsigned int i = 1; i < requests.size(); i++) { + if (error) { + break; + } + + auto request_type = requests[i].request_type(); + if (message_type != request_type) { + error = true; + error_message_stream << "Mismatched MPI operations: One rank did an " + << MPIRequest::RequestType_Name(message_type) + << ", but another rank did an " + << MPIRequest::RequestType_Name(request_type) << "."; + break; + } + } + + // If we are doing an allreduce or broadcast, check that all tensor shapes are + // identical. + if (message_type == MPIRequest::ALLREDUCE || + message_type == MPIRequest::BROADCAST) { + TensorShape tensor_shape; + for (auto it = requests[0].tensor_shape().begin(); + it != requests[0].tensor_shape().end(); it++) { + tensor_shape.AddDim(*it); + } + for (unsigned int i = 1; i < requests.size(); i++) { + if (error) { + break; + } + + TensorShape request_shape; + for (auto it = requests[i].tensor_shape().begin(); + it != requests[i].tensor_shape().end(); it++) { + request_shape.AddDim(*it); + } + if (tensor_shape != request_shape) { + error = true; + error_message_stream + << "Mismatched " << MPIRequest::RequestType_Name(message_type) + << " tensor shapes: One rank sent a tensor of shape " + << tensor_shape.DebugString() + << ", but another rank sent a tensor of shape " + << request_shape.DebugString() << "."; + break; + } + } + } + + // If we are doing an allgather, make sure all but the first dimension are + // the same. The first dimension may be different and the output tensor is + // the sum of the first dimension. Collect the sizes by rank. + std::vector tensor_sizes(requests.size()); + if (message_type == MPIRequest::ALLGATHER) { + TensorShape tensor_shape; + for (auto it = requests[0].tensor_shape().begin(); + it != requests[0].tensor_shape().end(); it++) { + tensor_shape.AddDim(*it); + } + + if (tensor_shape.dims() == 0) { + error = true; + error_message_stream << "Rank zero tried to " + << MPIRequest::RequestType_Name(message_type) + << " a rank-zero tensor."; + } else { + tensor_sizes[requests[0].request_rank()] = + size_t(tensor_shape.dim_size(0)); + } + + for (unsigned int i = 1; i < requests.size(); i++) { + if (error) { + break; + } + + TensorShape request_shape; + for (auto it = requests[i].tensor_shape().begin(); + it != requests[i].tensor_shape().end(); it++) { + request_shape.AddDim(*it); + } + if (tensor_shape.dims() != request_shape.dims()) { + error = true; + error_message_stream + << "Mismatched " << MPIRequest::RequestType_Name(message_type) + << " tensor shapes: One rank sent a tensor of rank " + << tensor_shape.dims() + << ", but another rank sent a tensor of rank " + << request_shape.dims() << "."; + break; + } + + bool dim_mismatch = false; + for (int dim = 1; dim < tensor_shape.dims(); dim++) { + if (tensor_shape.dim_size(dim) != request_shape.dim_size(dim)) { + error = true; + error_message_stream + << "Mismatched " << MPIRequest::RequestType_Name(message_type) + << " tensor shapes: One rank sent a tensor with dimension " << dim + << " equal to " << tensor_shape.dim_size(dim) + << ", but another rank sent a tensor with dimension " << dim + << " equal to " << request_shape.dim_size(dim) << "."; + dim_mismatch = true; + break; + } + } + if (dim_mismatch) { + break; + } + + tensor_sizes[requests[i].request_rank()] = + size_t(request_shape.dim_size(0)); + } + } + + // If we are doing a broadcast, check that all root ranks are identical. + if (message_type == MPIRequest::BROADCAST) { + int first_root_rank = requests[0].root_rank(); + for (unsigned int i = 1; i < requests.size(); i++) { + if (error) { + break; + } + + int this_root_rank = requests[i].root_rank(); + if (first_root_rank != this_root_rank) { + error = true; + error_message_stream + << "Mismatched " << MPIRequest::RequestType_Name(message_type) + << " root ranks: One rank specified root rank " << first_root_rank + << ", but another rank specified root rank " << this_root_rank + << "."; + break; + } + } + } + + bool first_device_is_cpu = requests[0].device() == CPU_DEVICE_ID; + for (unsigned int i = 1; i < requests.size(); i++) { + if (error) { + break; + } + + bool this_device_is_cpu = requests[i].device() == CPU_DEVICE_ID; + if (first_device_is_cpu != this_device_is_cpu) { + error = true; + error_message_stream + << "Mismatched " << MPIRequest::RequestType_Name(message_type) + << " CPU/GPU device selection: One rank specified device " + << (first_device_is_cpu ? "CPU" : "GPU") + << ", but another rank specified device " + << (this_device_is_cpu ? "CPU" : "GPU") << "."; + break; + } + } + std::vector devices(requests.size()); + for (auto it = requests.begin(); it != requests.end(); it++) { + devices[it->request_rank()] = it->device(); + } + + MPIResponse response; + response.set_tensor_name(name); + if (error) { + std::string error_message = error_message_stream.str(); + response.set_response_type(MPIResponse::ERROR); + response.set_error_message(error_message); + } else if (message_type == MPIRequest::ALLGATHER) { + response.set_response_type(MPIResponse::ALLGATHER); + for (auto dim : tensor_sizes) { + response.add_tensor_sizes(dim); + } + } else if (message_type == MPIRequest::ALLREDUCE) { + response.set_response_type(MPIResponse::ALLREDUCE); + } else if (message_type == MPIRequest::BROADCAST) { + response.set_response_type(MPIResponse::BROADCAST); + } + response.set_devices(devices); + + // Clear all queued up requests for this name. They are now taken care of + // by the constructed MPI response. + message_table->erase(it); + + return response; +} + +Status GetMPIDataType(const Tensor tensor, MPI_Datatype* dtype) { + switch (tensor.dtype()) { + case DT_UINT8: + *dtype = MPI_UINT8_T; + return Status::OK(); + case DT_INT8: + *dtype = MPI_INT8_T; + return Status::OK(); + case DT_UINT16: + *dtype = MPI_UINT16_T; + return Status::OK(); + case DT_INT16: + *dtype = MPI_INT16_T; + return Status::OK(); + case DT_INT32: + *dtype = MPI_INT32_T; + return Status::OK(); + case DT_INT64: + *dtype = MPI_INT64_T; + return Status::OK(); + case DT_FLOAT: + *dtype = MPI_FLOAT; + return Status::OK(); + case DT_DOUBLE: + *dtype = MPI_DOUBLE; + return Status::OK(); + default: + // This is not reachable normally since we specify acceptable + // data types in Op definition. + return errors::Internal("Invalid tensor type."); + } +} + +#if HAVE_NCCL +Status GetNCCLDataType(const Tensor tensor, ncclDataType_t* dtype) { + switch (tensor.dtype()) { + case DT_INT32: + *dtype = ncclInt32; + return Status::OK(); + case DT_INT64: + *dtype = ncclInt64; + return Status::OK(); + case DT_FLOAT: + *dtype = ncclFloat32; + return Status::OK(); + case DT_DOUBLE: + *dtype = ncclFloat64; + return Status::OK(); + default: + // This is not reachable normally since we specify acceptable + // data types in Op definition. + return errors::Internal("Invalid tensor type."); + } +} +#endif + +#define MPI_CHECK(e, op_name, op) \ + { \ + auto mpi_result = (op); \ + if (mpi_result != MPI_SUCCESS) { \ + e.callback( \ + errors::Unknown(op_name, " failed, see MPI output for details.")); \ + return; \ + } \ + } + +#define CUDA_CHECK(e, op_name, op) \ + { \ + auto cuda_result = (op); \ + if (cuda_result != cudaSuccess) { \ + e.callback(errors::Unknown( \ + op_name, " failed: ", cudaGetErrorString(cuda_result))); \ + return; \ + } \ + } + +#define NCCL_CHECK(e, op_name, op) \ + { \ + auto nccl_result = (op); \ + if (nccl_result != ncclSuccess) { \ + e.callback(errors::Unknown( \ + op_name, " failed: ", ncclGetErrorString(nccl_result))); \ + return; \ + } \ + } + +// Process an MPIResponse by doing a reduction, a gather, a broadcast, or +// raising an error. +void PerformOperation(TensorTable& tensor_table, MPIResponse response) { + TensorTableEntry e; + { + // Lock on the tensor table. + std::lock_guard guard(horovod_global.mutex); + + // We should never fail at finding this key in the tensor table. + auto name = response.tensor_name(); + auto iter = tensor_table.find(name); + assert(iter != tensor_table.end()); + + assert(response.response_type() == MPIResponse::ALLREDUCE || + response.response_type() == MPIResponse::ALLGATHER || + response.response_type() == MPIResponse::BROADCAST || + response.response_type() == MPIResponse::ERROR); + + e = iter->second; + + // Clear the tensor table of this tensor and its callbacks; the rest of + // this function takes care of it. + tensor_table.erase(iter); + } + +#if HAVE_CUDA + // On GPU data readiness is signalled by ready_event. + if (e.ready_event != nullptr) { + while (e.ready_event->PollForStatus() == + perftools::gputools::Event::Status::kPending) { + std::this_thread::sleep_for(std::chrono::nanoseconds(100)); + } + } +#endif + + Status status; + if (response.response_type() == MPIResponse::ALLGATHER) { + // Copy tensor sizes from the MPI response into a vector of size_t + // and compute total size. This is size of first dimension. + std::vector tensor_sizes; + size_t total_dimension_size = 0; + for (auto it = response.tensor_sizes().begin(); + it != response.tensor_sizes().end(); it++) { + tensor_sizes.push_back(size_t(*it)); + total_dimension_size += size_t(*it); + } + + // Every tensor participating in Allgather operation may have different + // first dimension size, but the rest of dimensions are same for all + // tensors. Here we get shape of tensor sliced by first dimension. + TensorShape single_slice_shape; + for (int i = 1; i < e.tensor.shape().dims(); i++) { + single_slice_shape.AddDim(e.tensor.dim_size(i)); + } + + // Allgather output will have shape of: + // (sum of first dimension of every tensor) x (tensor slice shape). + TensorShape output_shape; + output_shape.AddDim((int64)total_dimension_size); + output_shape.AppendShape(single_slice_shape); + + status = e.context->allocate_output(0, output_shape, &e.output); + if (!status.ok()) { + e.callback(status); + return; + } + +#if HAVE_CUDA + // On GPU allocation is asynchronous, we need to wait for it to complete. + auto device_context = e.context->op_device_context(); + if (device_context != nullptr) { + device_context->stream()->BlockHostUntilDone(); + } +#endif + + MPI_Datatype dtype; + status = GetMPIDataType(e.tensor, &dtype); + if (!status.ok()) { + e.callback(status); + return; + } + + // Tensors may have different first dimension, so we need to use + // MPI_Allgatherv API that supports gathering arrays of different length. + int* recvcounts = new int[tensor_sizes.size()]; + int* displcmnts = new int[tensor_sizes.size()]; + for (size_t i = 0; i < tensor_sizes.size(); i++) { + recvcounts[i] = + (int)(single_slice_shape.num_elements() * tensor_sizes[i]); + if (i == 0) { + displcmnts[i] = 0; + } else { + displcmnts[i] = recvcounts[i - 1] + displcmnts[i - 1]; + } + } + auto result = MPI_Allgatherv((const void*)e.tensor.tensor_data().data(), + (int)e.tensor.NumElements(), dtype, + (void*)e.output->tensor_data().data(), + recvcounts, displcmnts, dtype, MPI_COMM_WORLD); + delete[] recvcounts; + delete[] displcmnts; + MPI_CHECK(e, "MPI_Allgatherv", result) + + e.callback(Status::OK()); + + } else if (response.response_type() == MPIResponse::ALLREDUCE) { +#if HOROVOD_GPU_ALLREDUCE == 'N' // 'N' stands for NCCL + bool on_gpu = e.device != CPU_DEVICE_ID; + if (on_gpu) { + CUDA_CHECK(e, "cudaSetDevice", cudaSetDevice(e.device)) + + // Ensure stream is in the map before executing reduction. + cudaStream_t& stream = horovod_global.streams[e.device]; + if (stream == nullptr) { + CUDA_CHECK(e, "cudaStreamCreate", cudaStreamCreate(&stream)) + } + + // Ensure NCCL communicator is in the map before executing reduction. + ncclComm_t& nccl_comm = horovod_global.nccl_comms[response.devices()]; + if (nccl_comm == nullptr) { + ncclUniqueId nccl_id; + if (horovod_global.rank == 0) { + NCCL_CHECK(e, "ncclGetUniqueId", ncclGetUniqueId(&nccl_id)) + } + + MPI_CHECK(e, "MPI_Bcast", + MPI_Bcast((void*)&nccl_id, sizeof(nccl_id), MPI_BYTE, 0, + MPI_COMM_WORLD)); + + NCCL_CHECK(e, "ncclCommInitRank", + ncclCommInitRank(&nccl_comm, horovod_global.size, nccl_id, + horovod_global.rank)) + + // TODO: Rohit (NVIDIA): figure out why we need this sleep + std::this_thread::sleep_for(std::chrono::seconds(1)); + } + + ncclDataType_t dtype; + status = GetNCCLDataType(e.tensor, &dtype); + if (!status.ok()) { + e.callback(status); + return; + } + + NCCL_CHECK(e, "ncclAllReduce", + ncclAllReduce((const void*)e.tensor.tensor_data().data(), + (void*)e.output->tensor_data().data(), + (size_t)e.tensor.NumElements(), dtype, ncclSum, + nccl_comm, stream)) + + // Use completion marker via event because it's faster than + // cudaStreamSynchronize(). + cudaEvent_t event; + CUDA_CHECK(e, "cudaEventCreateWithFlags", + cudaEventCreateWithFlags(&event, cudaEventBlockingSync | + cudaEventDisableTiming)) + CUDA_CHECK(e, "cudaEventRecord", + cudaEventRecord(event, horovod_global.streams[e.device])) + + // TODO: use thread pool or single thread for callbacks + std::thread finalizer_thread([e, event] { + CUDA_CHECK(e, "cudaSetDevice", cudaSetDevice(e.device)) + CUDA_CHECK(e, "cudaEventSynchronize", cudaEventSynchronize(event)) + e.callback(Status::OK()); + cudaEventDestroy(event); + }); + finalizer_thread.detach(); + return; + } +#endif + + MPI_Datatype dtype; + status = GetMPIDataType(e.tensor, &dtype); + if (!status.ok()) { + e.callback(status); + return; + } + + MPI_CHECK(e, "MPI_Allreduce", + MPI_Allreduce((const void*)e.tensor.tensor_data().data(), + (void*)e.output->tensor_data().data(), + (int)e.tensor.NumElements(), dtype, MPI_SUM, + MPI_COMM_WORLD)) + + e.callback(Status::OK()); + + } else if (response.response_type() == MPIResponse::BROADCAST) { + MPI_Datatype dtype; + status = GetMPIDataType(e.tensor, &dtype); + if (!status.ok()) { + e.callback(status); + return; + } + + // On root rank, MPI_Bcast sends data, on other ranks it receives data. + void* data_ptr; + if (horovod_global.rank == e.root_rank) { + data_ptr = (void*)e.tensor.tensor_data().data(); + } else { + data_ptr = (void*)e.output->tensor_data().data(); + } + + MPI_CHECK(e, "MPI_Bcast", + MPI_Bcast(data_ptr, (int)e.tensor.NumElements(), dtype, + e.root_rank, MPI_COMM_WORLD)) + + e.callback(Status::OK()); + + } else if (response.response_type() == MPIResponse::ERROR) { + status = errors::FailedPrecondition(response.error_message()); + e.callback(status); + } +} + +// Report Tensors that were submitted to be reduced, gathered or broadcasted by +// some ranks but not others and are waiting for long time to get processed. +void CheckForStalledTensors(HorovodGlobalState& state) { + bool preamble = false; + auto now = std::chrono::system_clock::now(); + for (auto it = state.message_table->begin(); it != state.message_table->end(); + it++) { + auto tensor_name = it->first; + std::vector& messages = std::get<0>(it->second); + std::chrono::system_clock::time_point start_at = std::get<1>(it->second); + + if (now - start_at > STALL_WARNING_TIME) { + if (!preamble) { + std::cerr << "WARNING: One or more tensors were submitted to be " + "reduced, gathered or broadcasted by subset of ranks and " + "are waiting for remainder of ranks for more than " + << std::chrono::duration_cast( + STALL_WARNING_TIME) + .count() + << " seconds. "; + std::cerr << "This may indicate that different ranks are trying to " + "submit different tensors or that only subset of ranks is " + "submitting tensors, which will cause deadlock. "; + std::cerr << "Stalled ops: "; + preamble = true; + } else { + std::cerr << ", "; + } + std::cerr << tensor_name; + std::cerr << " [ready ranks:"; + for (auto msg_iter = messages.begin(); msg_iter != messages.end(); + msg_iter++) { + if (msg_iter == messages.begin()) { + std::cerr << " "; + } else { + std::cerr << ", "; + } + std::cerr << msg_iter->request_rank(); + } + std::cerr << "]"; + } + } + if (preamble) { + std::cerr << std::endl; + } +} + +// The MPI background thread loop coordinates all the MPI processes and the +// tensor reductions. The design of the communicator mechanism is limited by a +// few considerations: +// +// 1. Some MPI implementations require all MPI calls to happen from a +// single thread. Since TensorFlow may use several threads for graph +// processing, this means we must have our own dedicated thread for dealing +// with MPI. +// 2. We want to gracefully handle errors, when MPI processes do not +// properly agree upon what should happen (such as mismatched types or +// shapes). To do so requires the MPI processes to know about the shapes +// and types of the relevant tensors on the other processes. +// 3. The MPI reductions and gathers should be able to happen in parallel +// with other ongoing operations. This means that they cannot be blocking +// ops, but rather must be async ops, the execution of which happens on a +// separate thread. +// 4. We cannot guarantee that all the MPI processes reduce their tensors +// in the same order, so we cannot dispatch one thread per tensor, +// otherwise we may end up dispatching many blocked threads and never make +// progress if we have a thread pool limit. +// +// The coordinator currently follows a master-worker paradigm. Rank zero acts +// as the master (the "coordinator"), whereas all other ranks are simply +// workers. Each rank runs its own background thread which progresses in ticks. +// In each tick, the following actions happen: +// +// a) The workers send an MPIRequest to the coordinator, indicating what +// they would like to do (which tensor they would like to gather and +// reduce, as well as their shape and type). They repeat this for every +// tensor that they would like to operate on. +// +// b) The workers send an empty "DONE" message to the coordinator to +// indicate that there are no more tensors they wish to operate on. +// +// c) The coordinator receives the MPIRequests from the workers, as well +// as from its own TensorFlow ops, and stores them in a request table. The +// coordinator continues to receive MPIRequest messages until it has +// received MPI_SIZE number of empty "DONE" messages. +// +// d) The coordinator finds all tensors that are ready to be reduced, +// gathered, or all operations that result in an error. For each of those, +// it sends an MPIResponse to all the workers. When no more MPIResponses +// are available, it sends a "DONE" response to the workers. If the process +// is being shutdown, it instead sends a "SHUTDOWN" response. +// +// e) The workers listen for MPIResponse messages, processing each one by +// doing the required reduce or gather, until they receive a "DONE" +// response from the coordinator. At that point, the tick ends. +// If instead of "DONE" they receive "SHUTDOWN", they exit their background +// loop. +void BackgroundThreadLoop(HorovodGlobalState& state) { + // Initialize MPI. This must happen on the background thread, since not all + // MPI implementations support being called from multiple threads. + MPI_Init(NULL, NULL); + + // Get MPI rank to determine if we are rank zero. + int rank; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + bool is_coordinator = rank == 0; + + // Get MPI size to determine how many tensors to wait for before reducing. + int size; + MPI_Comm_size(MPI_COMM_WORLD, &size); + + // Determine local rank by querying the local communicator. + MPI_Comm local_comm; + MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, + &local_comm); + int local_rank; + MPI_Comm_rank(local_comm, &local_rank); + + state.rank = rank; + state.local_rank = local_rank; + state.size = size; + state.initialization_done = true; + + // Initialize the tensor count table. No tensors are available yet. + if (is_coordinator) { + state.message_table = std::unique_ptr(new MessageTable()); + } + + // The coordinator sends a SHUTDOWN message to trigger shutdown. + bool should_shut_down = false; + do { + // This delay determines thread frequency and MPI message latency + std::this_thread::sleep_for(std::chrono::milliseconds(5)); + + // Copy the data structures from global state under this lock. + // However, don't keep the lock for the rest of the loop, so that + // enqueued stream callbacks can continue. + std::queue message_queue; + { + std::lock_guard guard(state.mutex); + while (!state.message_queue.empty()) { + MPIRequest message = state.message_queue.front(); + state.message_queue.pop(); + message_queue.push(message); + } + } + + // Collect all tensors that are ready to be reduced. Record them in the + // tensor count table (rank zero) or send them to rank zero to be + // recorded (everyone else). + std::vector ready_to_reduce; + while (!message_queue.empty()) { + // Pop the first available message message + MPIRequest message = message_queue.front(); + message_queue.pop(); + + if (is_coordinator) { + bool reduce = IncrementTensorCount(state.message_table, message, size); + if (reduce) { + ready_to_reduce.push_back(message.tensor_name()); + } + } else { + std::string encoded_message; + MPIRequest::SerializeToString(message, encoded_message); + MPI_Send(encoded_message.c_str(), (int)encoded_message.length() + 1, + MPI_BYTE, RANK_ZERO, TAG_NOTIFY, MPI_COMM_WORLD); + } + } + + // Rank zero has put all its own tensors in the tensor count table. + // Now, it should count all the tensors that are coming from other + // ranks at this tick. It should keep getting tensors until it gets a + // DONE message from all the other ranks. + if (is_coordinator) { + // Count of DONE messages. Keep receiving messages until the number + // of messages is equal to the number of processes. Initialize to + // one since the coordinator is effectively done. + int completed_ranks = 1; + while (completed_ranks != size) { + MPI_Status status; + MPI_Probe(MPI_ANY_SOURCE, TAG_NOTIFY, MPI_COMM_WORLD, &status); + + // Find number of characters in message (including zero byte). + int source_rank = status.MPI_SOURCE; + int msg_length; + MPI_Get_count(&status, MPI_BYTE, &msg_length); + + // If the length is zero, this is a DONE message. + if (msg_length == 0) { + completed_ranks++; + MPI_Recv(NULL, 0, MPI_BYTE, source_rank, TAG_NOTIFY, MPI_COMM_WORLD, + &status); + continue; + } + + // Get tensor name from MPI into an std::string. + char* buffer = new char[msg_length]; + MPI_Recv(buffer, msg_length, MPI_BYTE, source_rank, TAG_NOTIFY, + MPI_COMM_WORLD, &status); + std::string received_data(buffer, (size_t)msg_length); + delete[] buffer; + + MPIRequest received_message; + MPIRequest::ParseFromString(received_message, received_data); + auto received_name = received_message.tensor_name(); + + bool reduce = + IncrementTensorCount(state.message_table, received_message, size); + if (reduce) { + ready_to_reduce.push_back(received_name); + } + } + + // At this point, rank zero should have a fully updated tensor count + // table and should know all the tensors that need to be reduced or + // gathered, and everyone else should have sent all their information + // to rank zero. We can now do reductions and gathers; rank zero will + // choose which ones and in what order, and will notify the other ranks + // before doing each reduction. + for (size_t i = 0; i < ready_to_reduce.size(); i++) { + // Notify all nodes which tensor we'd like to reduce at this step. + auto name = ready_to_reduce[i]; + MPIResponse response = ConstructMPIResponse(state.message_table, name); + + std::string encoded_response; + MPIResponse::SerializeToString(response, encoded_response); + for (int r = 1; r < size; r++) { + MPI_Send(encoded_response.c_str(), (int)encoded_response.length() + 1, + MPI_BYTE, r, TAG_NOTIFY, MPI_COMM_WORLD); + } + + // Perform the collective operation. All nodes should end up performing + // the same operation. + PerformOperation(state.tensor_table, response); + } + + // Notify all nodes that we are done with the reductions for this tick. + MPIResponse done_response; + should_shut_down = state.shut_down; + done_response.set_response_type(should_shut_down ? MPIResponse::SHUTDOWN + : MPIResponse::DONE); + std::string encoded_response; + MPIResponse::SerializeToString(done_response, encoded_response); + for (int r = 1; r < size; r++) { + MPI_Send(encoded_response.c_str(), (int)encoded_response.length() + 1, + MPI_BYTE, r, TAG_NOTIFY, MPI_COMM_WORLD); + } + + // Check for stalled tensors. + if (std::chrono::system_clock::now() - state.last_stall_check > + STALL_WARNING_TIME) { + CheckForStalledTensors(state); + state.last_stall_check = std::chrono::system_clock::now(); + } + } else { + // Notify the coordinator that this node is done sending messages. + // A DONE message is encoded as a zero-length message. + MPI_Send(NULL, 0, MPI_BYTE, RANK_ZERO, TAG_NOTIFY, MPI_COMM_WORLD); + + // Receive names for tensors to reduce from rank zero. + // Once we receive a empty DONE message, stop waiting for more names. + while (true) { + MPI_Status status; + MPI_Probe(0, TAG_NOTIFY, MPI_COMM_WORLD, &status); + + // Find number of characters in message (including zero byte). + int msg_length; + MPI_Get_count(&status, MPI_BYTE, &msg_length); + + // Get tensor name from MPI into an std::string. + char* buffer = new char[msg_length]; + MPI_Recv(buffer, msg_length, MPI_BYTE, 0, TAG_NOTIFY, MPI_COMM_WORLD, + &status); + std::string received_message(buffer, (size_t)msg_length); + delete[] buffer; + + MPIResponse response; + MPIResponse::ParseFromString(response, received_message); + if (response.response_type() == MPIResponse::DONE) { + // No more messages this tick + break; + } else if (response.response_type() == MPIResponse::SHUTDOWN) { + // No more messages this tick, and the background thread should shut + // down + should_shut_down = true; + break; + } else { + // Process the current message + PerformOperation(state.tensor_table, response); + } + } + } + } while (!should_shut_down); + + // TODO: init.cu:645 WARN Cuda failure 'driver shutting down' + //#if HAVE_NCCL + // for (auto it = horovod_global.streams.begin(); + // it != horovod_global.streams.end(); it++) { + // cudaStreamSynchronize(it->second); + // } + // for (auto it = horovod_global.nccl_comms.begin(); + // it != horovod_global.nccl_comms.end(); it++) { + // ncclCommDestroy(it->second); + // } + //#endif + MPI_Finalize(); +} + +// Start Horovod background thread. Ensure that this is +// only done once no matter how many times this function is called. +void InitializeHorovodOnce() { + // Ensure background thread is only started once. + if (!horovod_global.initialize_flag.test_and_set()) { + horovod_global.background_thread = + std::thread(BackgroundThreadLoop, std::ref(horovod_global)); + } + + // Wait to ensure that the background thread has finished initializing MPI. + while (!horovod_global.initialization_done) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } +} + +// Check that Horovod is initialized. +Status CheckInitialized() { + if (!horovod_global.initialization_done) { + return errors::FailedPrecondition( + "Horovod has not been initialized; use horovod.tensorflow.init()."); + } + return Status::OK(); +} + +// C interface to initialize Horovod. +extern "C" void horovod_tensorflow_init() { InitializeHorovodOnce(); } + +// C interface to get index of current Horovod process. +// Returns -1 if Horovod is not initialized. +extern "C" int horovod_tensorflow_rank() { + if (!horovod_global.initialization_done) { + return -1; + } + return horovod_global.rank; +} + +// C interface to get index of current Horovod process in the node it is on.. +// Returns -1 if Horovod is not initialized. +extern "C" int horovod_tensorflow_local_rank() { + if (!horovod_global.initialization_done) { + return -1; + } + return horovod_global.local_rank; +} + +// C interface to return number of Horovod processes. +// Returns -1 if Horovod is not initialized. +extern "C" int horovod_tensorflow_size() { + if (!horovod_global.initialization_done) { + return -1; + } + return horovod_global.size; +} + +// Convert a TensorFlow DataType to our MPIDataType. +Status DataTypeToMPIType(DataType tf_dtype, MPIDataType* mpi_dtype) { + switch (tf_dtype) { + case DT_UINT8: + *mpi_dtype = TF_MPI_UINT8; + return Status::OK(); + case DT_INT8: + *mpi_dtype = TF_MPI_INT8; + return Status::OK(); + case DT_UINT16: + *mpi_dtype = TF_MPI_UINT16; + return Status::OK(); + case DT_INT16: + *mpi_dtype = TF_MPI_INT16; + return Status::OK(); + case DT_INT32: + *mpi_dtype = TF_MPI_INT32; + return Status::OK(); + case DT_INT64: + *mpi_dtype = TF_MPI_INT64; + return Status::OK(); + case DT_FLOAT: + *mpi_dtype = TF_MPI_FLOAT32; + return Status::OK(); + case DT_DOUBLE: + *mpi_dtype = TF_MPI_FLOAT64; + return Status::OK(); + default: + return errors::Internal("Invalid tensor type."); + } +} + +// MPI must be initialized and the background thread must be running before +// this function is called. +void EnqueueTensorAllreduce(OpKernelContext* context, const Tensor& tensor, + Tensor* output, GPU_EVENT_IF_CUDA ready_event, + const std::string name, const int device, + StatusCallback callback) { + MPIDataType dtype; + Status status = DataTypeToMPIType(tensor.dtype(), &dtype); + if (!status.ok()) { + callback(status); + return; + } + + int rank; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + + MPIRequest message; + message.set_request_rank(rank); + message.set_tensor_name(name); + message.set_tensor_type(dtype); + message.set_device(device); + message.set_request_type(MPIRequest::ALLREDUCE); + for (int i = 0; i < tensor.shape().dims(); i++) { + message.add_tensor_shape(tensor.shape().dim_size(i)); + } + + TensorTableEntry e; + e.context = context; + e.tensor = tensor; + e.output = output; + e.ready_event = ready_event; + e.device = device; + e.callback = callback; + + std::lock_guard guard(horovod_global.mutex); + horovod_global.tensor_table.emplace(name, std::move(e)); + horovod_global.message_queue.push(message); +} + +// MPI must be initialized and the background thread must be running before +// this function is called. +void EnqueueTensorAllgather(OpKernelContext* context, const Tensor& tensor, + GPU_EVENT_IF_CUDA ready_event, + const std::string name, const int device, + StatusCallback callback) { + MPIDataType dtype; + Status status = DataTypeToMPIType(tensor.dtype(), &dtype); + if (!status.ok()) { + callback(status); + return; + } + + int rank; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + + MPIRequest message; + message.set_request_rank(rank); + message.set_tensor_name(name); + message.set_tensor_type(dtype); + message.set_device(device); + message.set_request_type(MPIRequest::ALLGATHER); + for (int i = 0; i < tensor.shape().dims(); i++) { + message.add_tensor_shape(tensor.shape().dim_size(i)); + } + + TensorTableEntry e; + e.context = context; + e.tensor = tensor; + e.ready_event = ready_event; + e.device = device; + e.callback = callback; + + std::lock_guard guard(horovod_global.mutex); + horovod_global.tensor_table.emplace(name, std::move(e)); + horovod_global.message_queue.push(message); +} + +// MPI must be initialized and the background thread must be running before +// this function is called. +void EnqueueTensorBroadcast(OpKernelContext* context, const Tensor& tensor, + Tensor* output, int root_rank, + GPU_EVENT_IF_CUDA ready_event, + const std::string name, const int device, + StatusCallback callback) { + MPIDataType dtype; + Status status = DataTypeToMPIType(tensor.dtype(), &dtype); + if (!status.ok()) { + callback(status); + return; + } + + int rank; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + + MPIRequest message; + message.set_request_rank(rank); + message.set_tensor_name(name); + message.set_tensor_type(dtype); + message.set_root_rank(root_rank); + message.set_device(device); + message.set_request_type(MPIRequest::BROADCAST); + for (int i = 0; i < tensor.shape().dims(); i++) { + message.add_tensor_shape(tensor.shape().dim_size(i)); + } + + TensorTableEntry e; + e.context = context; + e.tensor = tensor; + e.output = output; + e.root_rank = root_rank; + e.ready_event = ready_event; + e.device = device; + e.callback = callback; + + std::lock_guard guard(horovod_global.mutex); + horovod_global.tensor_table.emplace(name, std::move(e)); + horovod_global.message_queue.push(message); +} + +int GetDeviceID(OpKernelContext* context) { + int device = CPU_DEVICE_ID; + if (context->device() != nullptr && + context->device()->tensorflow_gpu_device_info() != nullptr) { + device = context->device()->tensorflow_gpu_device_info()->gpu_id; + } + return device; +} + +// On GPU this event will signal that data is ready, and tensors are +// allocated. +GPU_EVENT_IF_CUDA RecordReadyEvent(OpKernelContext* context) { +#if HAVE_CUDA + auto device_context = context->op_device_context(); + if (device_context != nullptr) { + auto executor = device_context->stream()->parent(); + GPU_EVENT_IF_CUDA ready_event = new perftools::gputools::Event(executor); + ready_event->Init(); + device_context->stream()->ThenRecordEvent(ready_event); + return ready_event; + } +#endif + return nullptr; +} + +} // namespace tensorflow + +class HorovodAllreduceOp : public AsyncOpKernel { +public: + explicit HorovodAllreduceOp(OpKernelConstruction* context) + : AsyncOpKernel(context) {} + + void ComputeAsync(OpKernelContext* context, DoneCallback done) override { + OP_REQUIRES_OK(context, CheckInitialized()); + + auto node_name = name(); + auto device = GetDeviceID(context); + auto tensor = context->input(0); + Tensor* output; + OP_REQUIRES_OK(context, + context->allocate_output(0, tensor.shape(), &output)); + GPU_EVENT_IF_CUDA ready_event = RecordReadyEvent(context); + EnqueueTensorAllreduce(context, tensor, output, ready_event, node_name, + device, [context, done](const Status& status) { + context->SetStatus(status); + done(); + }); + } +}; + +REGISTER_KERNEL_BUILDER(Name("HorovodAllreduce").Device(DEVICE_CPU), + HorovodAllreduceOp); +#if HOROVOD_GPU_ALLREDUCE +REGISTER_KERNEL_BUILDER(Name("HorovodAllreduce").Device(DEVICE_GPU), + HorovodAllreduceOp); +#endif + +REGISTER_OP("HorovodAllreduce") + .Attr("T: {int32, int64, float32, float64}") + .Input("tensor: T") + .Output("sum: T") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->input(0)); + return Status::OK(); + }) + .Doc(R"doc( +Perform an MPI Allreduce on a tensor. All other processes that do a reduction +on a tensor with the same name must have the same dimension for that tensor. +Tensors are reduced with other tensors that have the same node name for the +allreduce. + +Arguments + tensor: A tensor to reduce. + +Output + sum: A tensor with the same shape as `tensor`, summed across all MPI processes. +)doc"); + +class HorovodAllgatherOp : public AsyncOpKernel { +public: + explicit HorovodAllgatherOp(OpKernelConstruction* context) + : AsyncOpKernel(context) {} + + void ComputeAsync(OpKernelContext* context, DoneCallback done) override { + OP_REQUIRES_OK(context, CheckInitialized()); + + auto node_name = name(); + auto device = GetDeviceID(context); + auto tensor = context->input(0); + // We cannot pre-allocate output for allgather, since shape of result + // is only known after all ranks make a request. + GPU_EVENT_IF_CUDA ready_event = RecordReadyEvent(context); + EnqueueTensorAllgather(context, tensor, ready_event, node_name, device, + [context, done](const Status& status) { + context->SetStatus(status); + done(); + }); + } +}; // namespace tensorflow + +REGISTER_KERNEL_BUILDER(Name("HorovodAllgather").Device(DEVICE_CPU), + HorovodAllgatherOp); +#if HOROVOD_GPU_ALLGATHER +REGISTER_KERNEL_BUILDER(Name("HorovodAllgather").Device(DEVICE_GPU), + HorovodAllgatherOp); +#endif + +REGISTER_OP("HorovodAllgather") + .Attr("T: {uint8, int8, uint16, int16, int32, int64, float32, float64}") + .Input("tensor: T") + .Output("output: T") + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle output; + TF_RETURN_IF_ERROR( + c->ReplaceDim(c->input(0), 0, c->UnknownDim(), &output)); + c->set_output(0, output); + return Status::OK(); + }) + .Doc(R"doc( +Perform an MPI Allgather on a tensor. All other processes that do a gather on a +tensor with the same name must have the same rank for that tensor, and have the +same dimension on all but the first dimension. + +Arguments + tensor: A tensor to gather. + +Output + gathered: A tensor with the same shape as `tensor` except for the first dimension. +)doc"); + +class HorovodBroadcastOp : public AsyncOpKernel { +public: + explicit HorovodBroadcastOp(OpKernelConstruction* context) + : AsyncOpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("root_rank", &root_rank_)); + } + + void ComputeAsync(OpKernelContext* context, DoneCallback done) override { + OP_REQUIRES_OK(context, CheckInitialized()); + + auto node_name = name(); + auto device = GetDeviceID(context); + auto tensor = context->input(0); + Tensor* output = nullptr; + if (horovod_global.rank == root_rank_) { + context->set_output(0, tensor); + } else { + OP_REQUIRES_OK(context, + context->allocate_output(0, tensor.shape(), &output)); + } + GPU_EVENT_IF_CUDA ready_event = RecordReadyEvent(context); + EnqueueTensorBroadcast(context, tensor, output, root_rank_, ready_event, + node_name, device, + [context, done](const Status& status) { + context->SetStatus(status); + done(); + }); + } + +private: + int root_rank_; +}; + +REGISTER_KERNEL_BUILDER(Name("HorovodBroadcast").Device(DEVICE_CPU), + HorovodBroadcastOp); +#if HOROVOD_GPU_BROADCAST +REGISTER_KERNEL_BUILDER(Name("HorovodBroadcast").Device(DEVICE_GPU), + HorovodBroadcastOp); +#endif + +REGISTER_OP("HorovodBroadcast") + .Attr("T: {uint8, int8, uint16, int16, int32, int64, float32, float64}") + .Attr("root_rank: int") + .Input("tensor: T") + .Output("output: T") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->input(0)); + return Status::OK(); + }) + .Doc(R"doc( +Perform an MPI Broadcast on a tensor. All other processes that do a broadcast +on a tensor with the same name must have the same dimension for that tensor. + +Arguments + tensor: A tensor to broadcast. + root_rank: Rank that will send data, other ranks will receive data. + +Output + output: A tensor with the same shape as `tensor` and same value as + `tensor` on root rank. +)doc"); + +} // namespace tensorflow +} // namespace horovod diff --git a/horovod/tensorflow/mpi_ops.py b/horovod/tensorflow/mpi_ops.py new file mode 100644 index 0000000000..8db0c9e387 --- /dev/null +++ b/horovod/tensorflow/mpi_ops.py @@ -0,0 +1,190 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# Modifications copyright (C) 2017 Uber Technologies, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Inter-process communication using MPI.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import ctypes +import re +import sysconfig +from tensorflow.python.framework import load_library +from tensorflow.python.framework import ops +from tensorflow.python.platform import resource_loader + + +def _get_ext_suffix(): + """Determine library extension for various versions of Python.""" + ext_suffix = sysconfig.get_config_var('EXT_SUFFIX') + if ext_suffix: + return ext_suffix + + ext_suffix = sysconfig.get_config_var('SO') + if ext_suffix: + return ext_suffix + + return '.so' + + +def _load_library(name, op_list=None): + """Loads a .so file containing the specified operators. + + Args: + name: The name of the .so file to load. + op_list: A list of names of operators that the library should have. If None + then the .so file's contents will not be verified. + + Raises: + NameError if one of the required ops is missing. + NotFoundError if were not able to load .so file. + """ + filename = resource_loader.get_path_to_datafile(name) + library = load_library.load_op_library(filename) + for expected_op in (op_list or []): + for lib_op in library.OP_LIST.op: + if lib_op.name == expected_op: + break + else: + raise NameError( + 'Could not find operator %s in dynamic library %s' % + (expected_op, name)) + return library + + +def _load_ctypes_dll(name): + filename = resource_loader.get_path_to_datafile(name) + return ctypes.CDLL(filename, mode=ctypes.RTLD_GLOBAL) + + +MPI_LIB = _load_library('mpi_lib' + _get_ext_suffix(), + ['HorovodAllgather', 'HorovodAllreduce']) + + +MPI_LIB_CTYPES = _load_ctypes_dll('mpi_lib' + _get_ext_suffix()) + + +def init(): + """A function which initializes Horovod. + """ + return MPI_LIB_CTYPES.horovod_tensorflow_init() + + +def size(): + """A function which returns the number of Horovod processes. + + Returns: + An integer scalar containing the number of Horovod processes. + """ + size = MPI_LIB_CTYPES.horovod_tensorflow_size() + if size == -1: + raise ValueError( + 'Horovod has not been initialized; use horovod.tensorflow.init().') + return size + + +def rank(): + """A function which returns the Horovod rank of the calling process. + + Returns: + An integer scalar with the Horovod rank of the calling process. + """ + rank = MPI_LIB_CTYPES.horovod_tensorflow_rank() + if rank == -1: + raise ValueError( + 'Horovod has not been initialized; use horovod.tensorflow.init().') + return rank + + +def local_rank(): + """A function which returns the local Horovod rank of the calling process, within the + node that it is running on. For example, if there are seven processes running + on a node, their local ranks will be zero through six, inclusive. + + Returns: + An integer scalar with the local Horovod rank of the calling process. + """ + local_rank = MPI_LIB_CTYPES.horovod_tensorflow_local_rank() + if local_rank == -1: + raise ValueError( + 'Horovod has not been initialized; use horovod.tensorflow.init().') + return local_rank + + +def _normalize_name(name): + """Normalizes operation name to TensorFlow rules.""" + return re.sub('[^a-zA-Z0-9_]', '_', name) + + +def _allreduce(tensor, name=None): + """An op which sums an input tensor over all the Horovod processes. + + The reduction operation is keyed by the name of the op. The tensor type and + shape must be the same on all Horovod processes for a given name. The reduction + will not start until all processes are ready to send and receive the tensor. + + Returns: + A tensor of the same shape and type as `tensor`, summed across all + processes. + """ + if name is None: + name = 'HorovodAllreduce_%s' % _normalize_name(tensor.name) + return MPI_LIB.horovod_allreduce(tensor, name=name) + + +ops.NotDifferentiable('HorovodAllreduce') + + +def allgather(tensor, name=None): + """An op which concatenates the input tensor with the same input tensor on + all other Horovod processes. + + The concatenation is done on the first dimension, so the input tensors on the + different processes must have the same rank and shape, except for the first + dimension, which is allowed to be different. + + Returns: + A tensor of the same type as `tensor`, concatenated on dimension zero + across all processes. The shape is identical to the input shape, except for + the first dimension, which may be greater and is the sum of all first + dimensions of the tensors in different Horovod processes. + """ + if name is None: + name = 'HorovodAllgather_%s' % _normalize_name(tensor.name) + return MPI_LIB.horovod_allgather(tensor, name=name) + + +ops.NotDifferentiable('HorovodAllgather') + + +def broadcast(tensor, root_rank, name=None): + """An op which broadcasts the input tensor on root rank to the same input tensor + on all other Horovod processes. + + The broadcast operation is keyed by the name of the op. The tensor type and + shape must be the same on all Horovod processes for a given name. The broadcast + will not start until all processes are ready to send and receive the tensor. + + Returns: + A tensor of the same shape and type as `tensor`, with the value broadcasted + from root rank. + """ + if name is None: + name = 'HorovodBroadcast_%s' % _normalize_name(tensor.name) + return MPI_LIB.horovod_broadcast(tensor, name=name, root_rank=root_rank) + + +ops.NotDifferentiable('HorovodBroadcast') diff --git a/horovod/tensorflow/mpi_ops_test.py b/horovod/tensorflow/mpi_ops_test.py new file mode 100644 index 0000000000..f12397ce6f --- /dev/null +++ b/horovod/tensorflow/mpi_ops_test.py @@ -0,0 +1,462 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# Modifications copyright (C) 2017 Uber Technologies, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +"""Tests for horovod.tensorflow.mpi_ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os.path +import itertools + +import tensorflow as tf + +import horovod.tensorflow as hvd + + +def mpi_env_rank_and_size(): + """Get MPI rank and size from environment variables and return them as a + tuple of integers. + + Most MPI implementations have an `mpirun` or `mpiexec` command that will + run an MPI executable and set up all communication necessary between the + different processors. As part of that set up, they will set environment + variables that contain the rank and size of the MPI_COMM_WORLD + communicator. We can read those environment variables from Python in order + to ensure that `hvd.rank()` and `hvd.size()` return the expected values. + + Since MPI is just a standard, not an implementation, implementations + typically choose their own environment variable names. This function tries + to support several different implementation, but really it only needs to + support whatever implementation we want to use for the TensorFlow test + suite. + + If this is not running under MPI, then defaults of rank zero and size one + are returned. (This is appropriate because when you call MPI_Init in an + application not started with mpirun, it will create a new independent + communicator with only one process in it.) + """ + rank_env = "PMI_RANK OMPI_COMM_WORLD_RANK".split() + size_env = "PMI_SIZE OMPI_COMM_WORLD_SIZE".split() + + for rank_var, size_var in zip(rank_env, size_env): + rank = os.environ.get(rank_var) + size = os.environ.get(size_var) + if rank is not None and size is not None: + return int(rank), int(size) + + # Default to rank zero and size one if there are no environment variables + return 0, 1 + + +class MPITests(tf.test.TestCase): + """ + Tests for ops in horovod.tensorflow. + """ + + def test_horovod_rank(self): + """Test that the rank returned by hvd.rank() is correct.""" + true_rank, _ = mpi_env_rank_and_size() + hvd.init() + rank = hvd.rank() + self.assertEqual(true_rank, rank) + + def test_horovod_size(self): + """Test that the size returned by hvd.size() is correct.""" + _, true_size = mpi_env_rank_and_size() + hvd.init() + size = hvd.size() + self.assertEqual(true_size, size) + + def test_horovod_allreduce_cpu(self): + """Test on CPU that the allreduce correctly sums 1D, 2D, 3D tensors.""" + hvd.init() + size = hvd.size() + with self.test_session() as session: + dtypes = [tf.int32, tf.int64, tf.float32, tf.float64] + dims = [1, 2, 3] + for dtype, dim in itertools.product(dtypes, dims): + with tf.device("/cpu:0"): + tf.set_random_seed(1234) + tensor = tf.random_uniform([17] * dim, -100, 100, dtype=dtype) + summed = hvd.allreduce(tensor, average=False) + multiplied = tensor * size + max_difference = tf.reduce_max(tf.abs(summed - multiplied)) + + # Threshold for floating point equality depends on number of + # ranks, since we're comparing against precise multiplication. + if size <= 3: + threshold = 0 + elif size < 10: + threshold = 1e-4 + elif size < 15: + threshold = 5e-4 + else: + break + + diff = session.run(max_difference) + self.assertTrue(diff <= threshold, + "hvd.allreduce produces incorrect results") + + def test_horovod_allreduce_gpu(self): + """Test that the allreduce works on GPUs. + + This test will crash badly if used with an MPI implementation that does + not support GPU memory transfers directly, as it will call MPI_Send on + a GPU data pointer.""" + # Only do this test if there are GPUs available. + if not tf.test.is_gpu_available(cuda_only=True): + return + + hvd.init() + local_rank = hvd.local_rank() + size = hvd.size() + + one_gpu = tf.GPUOptions(visible_device_list=str(local_rank)) + gpu_config = tf.ConfigProto(gpu_options=one_gpu) + with self.test_session(config=gpu_config) as session: + dtypes = [tf.int32, tf.int64, tf.float32, tf.float64] + dims = [1, 2, 3] + for dtype, dim in itertools.product(dtypes, dims): + with tf.device("/gpu:0"): + tf.set_random_seed(1234) + tensor = tf.random_uniform( + [17] * dim, -100, 100, dtype=dtype) + summed = hvd.allreduce(tensor, average=False) + multiplied = tensor * size + max_difference = tf.reduce_max(tf.abs(summed - multiplied)) + + # Threshold for floating point equality depends on number of + # ranks, since we're comparing against precise multiplication. + if size <= 3: + threshold = 0 + elif size < 10: + threshold = 1e-4 + elif size < 15: + threshold = 5e-4 + else: + return + + diff = session.run(max_difference) + self.assertTrue(diff <= threshold, + "hvd.allreduce on GPU produces incorrect results") + + def test_horovod_allreduce_multi_gpu(self): + """Test that the allreduce works on multiple GPUs. + + This test will crash badly if used with an MPI implementation that does + not support GPU memory transfers directly, as it will call MPI_Send on + a GPU data pointer.""" + # Only do this test if there are GPUs available. + if not tf.test.is_gpu_available(cuda_only=True): + return + + hvd.init() + local_rank = hvd.local_rank() + size = hvd.size() + + iter = 0 + two_gpus = tf.GPUOptions(visible_device_list=('%d,%d' % (local_rank * 2, local_rank * 2 + 1))) + gpu_config = tf.ConfigProto(gpu_options=two_gpus) + with self.test_session(config=gpu_config) as session: + dtypes = [tf.int32, tf.int64, tf.float32, tf.float64] + dims = [1, 2, 3] + for dtype, dim in itertools.product(dtypes, dims): + iter += 1 + with tf.device("/gpu:%d" % ((iter + local_rank) % 2)): + tf.set_random_seed(1234) + tensor = tf.random_uniform( + [17] * dim, -100, 100, dtype=dtype) + summed = hvd.allreduce(tensor, average=False) + multiplied = tensor * size + max_difference = tf.reduce_max(tf.abs(summed - multiplied)) + + # Threshold for floating point equality depends on number of + # ranks, since we're comparing against precise multiplication. + if size <= 3: + threshold = 0 + elif size < 10: + threshold = 1e-4 + elif size < 15: + threshold = 5e-4 + else: + return + + diff = session.run(max_difference) + self.assertTrue(diff <= threshold, + "hvd.allreduce on GPU produces incorrect results") + + def test_horovod_allreduce_error(self): + """Test that the allreduce raises an error if different ranks try to + send tensors of different rank or dimension.""" + hvd.init() + rank = hvd.rank() + size = hvd.size() + + # This test does not apply if there is only one worker. + if size == 1: + return + + with self.test_session() as session: + # Same rank, different dimension + tf.set_random_seed(1234) + dims = [17 + rank] * 3 + tensor = tf.random_uniform(dims, -1.0, 1.0) + with self.assertRaises(tf.errors.FailedPreconditionError): + session.run(hvd.allreduce(tensor)) + + # Same number of elements, different rank + tf.set_random_seed(1234) + if rank == 0: + dims = [17, 23 * 57] + else: + dims = [17, 23, 57] + tensor = tf.random_uniform(dims, -1.0, 1.0) + with self.assertRaises(tf.errors.FailedPreconditionError): + session.run(hvd.allreduce(tensor)) + + def test_horovod_allreduce_type_error(self): + """Test that the allreduce raises an error if different ranks try to + send tensors of different type.""" + hvd.init() + rank = hvd.rank() + size = hvd.size() + + # This test does not apply if there is only one worker. + if size == 1: + return + + with self.test_session() as session: + # Same rank, different dimension + dims = [17] * 3 + tensor = tf.ones(dims, + dtype=tf.int32 if rank % 2 == 0 else tf.float32) + with self.assertRaises(tf.errors.FailedPreconditionError): + session.run(hvd.allreduce(tensor)) + + def test_horovod_allreduce_cpu_gpu_error(self): + """Test that the allreduce raises an error if different ranks try to + perform reduction on CPU and GPU.""" + # Only do this test if there are GPUs available. + if not tf.test.is_gpu_available(cuda_only=True): + return + + hvd.init() + local_rank = hvd.local_rank() + size = hvd.size() + + # This test does not apply if there is only one worker. + if size == 1: + return + + device = "/gpu:0" if local_rank % 2 == 0 else "/cpu:0" + one_gpu = tf.GPUOptions(visible_device_list=str(local_rank)) + gpu_config = tf.ConfigProto(gpu_options=one_gpu) + with self.test_session(config=gpu_config) as session: + with tf.device(device): + # Same rank, different dimension + dims = [17] * 3 + tensor = tf.ones(dims, dtype=tf.int32) + with self.assertRaises(tf.errors.FailedPreconditionError): + session.run(hvd.allreduce(tensor)) + + def test_horovod_allgather(self): + """Test that the allgather correctly gathers 1D, 2D, 3D tensors.""" + hvd.init() + rank = hvd.rank() + size = hvd.size() + + with self.test_session() as session: + dtypes = [tf.uint8, tf.int8, tf.uint16, tf.int16, + tf.int32, tf.int64, tf.float32, tf.float64] + dims = [1, 2, 3] + for dtype, dim in itertools.product(dtypes, dims): + tensor = tf.ones([17] * dim, dtype=dtype) * rank + gathered = hvd.allgather(tensor) + + gathered_tensor = session.run(gathered) + self.assertEqual(list(gathered_tensor.shape), + [17 * size] + [17] * (dim - 1)) + + for i in range(size): + rank_tensor = tf.slice(gathered_tensor, + [i * 17] + [0] * (dim - 1), + [17] + [-1] * (dim - 1)) + self.assertEqual(list(rank_tensor.shape), [17] * dim) + # tf.equal() does not support tf.uint16 as of TensorFlow 1.2, + # so need to cast rank_tensor to tf.int32. + self.assertTrue( + session.run(tf.reduce_all( + tf.equal(tf.cast(rank_tensor, tf.int32), i))), + "hvd.allgather produces incorrect gathered tensor") + + def test_horovod_allgather_variable_size(self): + """Test that the allgather correctly gathers 1D, 2D, 3D tensors, + even if those tensors have different sizes along the first dim.""" + hvd.init() + rank = hvd.rank() + size = hvd.size() + + with self.test_session() as session: + dtypes = [tf.uint8, tf.int8, tf.uint16, tf.int16, + tf.int32, tf.int64, tf.float32, tf.float64] + dims = [1, 2, 3] + for dtype, dim in itertools.product(dtypes, dims): + # Support tests up to MPI Size of 35 + if size > 35: + break + + tensor_sizes = [17, 32, 81, 12, 15, 23, 22] * 5 + tensor_sizes = tensor_sizes[:size] + + tensor = tf.ones([tensor_sizes[rank]] + [17] * (dim - 1), + dtype=dtype) * rank + gathered = hvd.allgather(tensor) + + gathered_tensor = session.run(gathered) + expected_size = sum(tensor_sizes) + self.assertEqual(list(gathered_tensor.shape), + [expected_size] + [17] * (dim - 1)) + + for i in range(size): + rank_size = [tensor_sizes[i]] + [17] * (dim - 1) + rank_tensor = tf.slice( + gathered, [sum(tensor_sizes[:i])] + [0] * (dim - 1), + rank_size) + self.assertEqual(list(rank_tensor.shape), rank_size) + # tf.equal() does not support tf.uint16 as of TensorFlow 1.2, + # so need to cast rank_tensor to tf.int32. + self.assertTrue( + session.run(tf.reduce_all( + tf.equal(tf.cast(rank_tensor, tf.int32), i))), + "hvd.allgather produces incorrect gathered tensor") + + def test_horovod_allgather_error(self): + """Test that the allgather returns an error if any dimension besides + the first is different among the tensors being gathered.""" + hvd.init() + rank = hvd.rank() + size = hvd.size() + + # This test does not apply if there is only one worker. + if size == 1: + return + + with self.test_session() as session: + tensor_size = [17] * 3 + tensor_size[1] = 10 * (rank + 1) + tensor = tf.ones(tensor_size, dtype=tf.float32) * rank + with self.assertRaises(tf.errors.FailedPreconditionError): + session.run(hvd.allgather(tensor)) + + def test_horovod_allgather_type_error(self): + """Test that the allgather returns an error if the types being gathered + differ among the processes""" + hvd.init() + rank = hvd.rank() + size = hvd.size() + + # This test does not apply if there is only one worker. + if size == 1: + return + + with self.test_session() as session: + tensor_size = [17] * 3 + dtype = tf.int32 if rank % 2 == 0 else tf.float32 + tensor = tf.ones(tensor_size, dtype=dtype) * rank + with self.assertRaises(tf.errors.FailedPreconditionError): + session.run(hvd.allgather(tensor)) + + def test_horovod_broadcast(self): + """Test that the broadcast correctly broadcasts 1D, 2D, 3D tensors.""" + hvd.init() + rank = hvd.rank() + size = hvd.size() + + # This test does not apply if there is only one worker. + if size == 1: + return + + with self.test_session() as session: + dtypes = [tf.uint8, tf.int8, tf.uint16, tf.int16, + tf.int32, tf.int64, tf.float32, tf.float64] + dims = [1, 2, 3] + root_ranks = list(range(size)) + for dtype, dim, root_rank in itertools.product(dtypes, dims, root_ranks): + tensor = tf.ones([17] * dim, dtype=dtype) * rank + root_tensor = tf.ones([17] * dim, dtype=dtype) * root_rank + broadcasted_tensor = hvd.broadcast(tensor, root_rank) + self.assertTrue( + session.run(tf.reduce_all(tf.equal( + tf.cast(root_tensor, tf.int32), tf.cast(broadcasted_tensor, tf.int32)))), + "hvd.broadcast produces incorrect broadcasted tensor") + + def test_horovod_broadcast_error(self): + """Test that the broadcast returns an error if any dimension besides + the first is different among the tensors being broadcasted.""" + hvd.init() + rank = hvd.rank() + size = hvd.size() + + # This test does not apply if there is only one worker. + if size == 1: + return + + with self.test_session() as session: + tensor_size = [17] * 3 + tensor_size[1] = 10 * (rank + 1) + tensor = tf.ones(tensor_size, dtype=tf.float32) * rank + with self.assertRaises(tf.errors.FailedPreconditionError): + session.run(hvd.broadcast(tensor, 0)) + + def test_horovod_broadcast_type_error(self): + """Test that the broadcast returns an error if the types being broadcasted + differ among the processes""" + hvd.init() + rank = hvd.rank() + size = hvd.size() + + # This test does not apply if there is only one worker. + if size == 1: + return + + with self.test_session() as session: + tensor_size = [17] * 3 + dtype = tf.int32 if rank % 2 == 0 else tf.float32 + tensor = tf.ones(tensor_size, dtype=dtype) * rank + with self.assertRaises(tf.errors.FailedPreconditionError): + session.run(hvd.broadcast(tensor, 0)) + + def test_horovod_broadcast_rank_error(self): + """Test that the broadcast returns an error if different ranks + specify different root rank.""" + hvd.init() + rank = hvd.rank() + size = hvd.size() + + # This test does not apply if there is only one worker. + if size == 1: + return + + with self.test_session() as session: + tensor = tf.ones([17] * 3, dtype=tf.float32) + with self.assertRaises(tf.errors.FailedPreconditionError): + session.run(hvd.broadcast(tensor, rank)) + + +if __name__ == '__main__': + tf.test.main() diff --git a/horovod/tensorflow/wire/flatbuffers/flatbuffers.h b/horovod/tensorflow/wire/flatbuffers/flatbuffers.h new file mode 100644 index 0000000000..a13921595f --- /dev/null +++ b/horovod/tensorflow/wire/flatbuffers/flatbuffers.h @@ -0,0 +1,1852 @@ +/* + * Copyright 2014 Google Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef FLATBUFFERS_H_ +#define FLATBUFFERS_H_ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef _STLPORT_VERSION + #define FLATBUFFERS_CPP98_STL +#endif +#ifndef FLATBUFFERS_CPP98_STL + #include +#endif + +/// @cond FLATBUFFERS_INTERNAL +#if __cplusplus <= 199711L && \ + (!defined(_MSC_VER) || _MSC_VER < 1600) && \ + (!defined(__GNUC__) || \ + (__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__ < 40400)) + #error A C++11 compatible compiler with support for the auto typing is \ + required for FlatBuffers. + #error __cplusplus _MSC_VER __GNUC__ __GNUC_MINOR__ __GNUC_PATCHLEVEL__ +#endif + +#if !defined(__clang__) && \ + defined(__GNUC__) && \ + (__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__ < 40600) + // Backwards compatability for g++ 4.4, and 4.5 which don't have the nullptr + // and constexpr keywords. Note the __clang__ check is needed, because clang + // presents itself as an older GNUC compiler. + #ifndef nullptr_t + const class nullptr_t { + public: + template inline operator T*() const { return 0; } + private: + void operator&() const; + } nullptr = {}; + #endif + #ifndef constexpr + #define constexpr const + #endif +#endif + +// The wire format uses a little endian encoding (since that's efficient for +// the common platforms). +#if !defined(FLATBUFFERS_LITTLEENDIAN) + #if defined(__GNUC__) || defined(__clang__) + #ifdef __BIG_ENDIAN__ + #define FLATBUFFERS_LITTLEENDIAN 0 + #else + #define FLATBUFFERS_LITTLEENDIAN 1 + #endif // __BIG_ENDIAN__ + #elif defined(_MSC_VER) + #if defined(_M_PPC) + #define FLATBUFFERS_LITTLEENDIAN 0 + #else + #define FLATBUFFERS_LITTLEENDIAN 1 + #endif + #else + #error Unable to determine endianness, define FLATBUFFERS_LITTLEENDIAN. + #endif +#endif // !defined(FLATBUFFERS_LITTLEENDIAN) + +#define FLATBUFFERS_VERSION_MAJOR 1 +#define FLATBUFFERS_VERSION_MINOR 6 +#define FLATBUFFERS_VERSION_REVISION 0 +#define FLATBUFFERS_STRING_EXPAND(X) #X +#define FLATBUFFERS_STRING(X) FLATBUFFERS_STRING_EXPAND(X) + +#if (!defined(_MSC_VER) || _MSC_VER > 1600) && \ + (!defined(__GNUC__) || (__GNUC__ * 100 + __GNUC_MINOR__ >= 407)) + #define FLATBUFFERS_FINAL_CLASS final +#else + #define FLATBUFFERS_FINAL_CLASS +#endif + +#if (!defined(_MSC_VER) || _MSC_VER >= 1900) && \ + (!defined(__GNUC__) || (__GNUC__ * 100 + __GNUC_MINOR__ >= 406)) + #define FLATBUFFERS_CONSTEXPR constexpr +#else + #define FLATBUFFERS_CONSTEXPR +#endif + +/// @endcond + +/// @file +namespace flatbuffers { + +/// @cond FLATBUFFERS_INTERNAL +// Our default offset / size type, 32bit on purpose on 64bit systems. +// Also, using a consistent offset type maintains compatibility of serialized +// offset values between 32bit and 64bit systems. +typedef uint32_t uoffset_t; + +// Signed offsets for references that can go in both directions. +typedef int32_t soffset_t; + +// Offset/index used in v-tables, can be changed to uint8_t in +// format forks to save a bit of space if desired. +typedef uint16_t voffset_t; + +typedef uintmax_t largest_scalar_t; + +// In 32bits, this evaluates to 2GB - 1 +#define FLATBUFFERS_MAX_BUFFER_SIZE ((1ULL << (sizeof(soffset_t) * 8 - 1)) - 1) + +// We support aligning the contents of buffers up to this size. +#define FLATBUFFERS_MAX_ALIGNMENT 16 + +#ifndef FLATBUFFERS_CPP98_STL +// Pointer to relinquished memory. +typedef std::unique_ptr> + unique_ptr_t; +#endif + +// Wrapper for uoffset_t to allow safe template specialization. +template struct Offset { + uoffset_t o; + Offset() : o(0) {} + Offset(uoffset_t _o) : o(_o) {} + Offset Union() const { return Offset(o); } +}; + +inline void EndianCheck() { + int endiantest = 1; + // If this fails, see FLATBUFFERS_LITTLEENDIAN above. + assert(*reinterpret_cast(&endiantest) == FLATBUFFERS_LITTLEENDIAN); + (void)endiantest; +} + +template T EndianSwap(T t) { + #if defined(_MSC_VER) + #define FLATBUFFERS_BYTESWAP16 _byteswap_ushort + #define FLATBUFFERS_BYTESWAP32 _byteswap_ulong + #define FLATBUFFERS_BYTESWAP64 _byteswap_uint64 + #else + #if defined(__GNUC__) && __GNUC__ * 100 + __GNUC_MINOR__ < 408 + // __builtin_bswap16 was missing prior to GCC 4.8. + #define FLATBUFFERS_BYTESWAP16(x) \ + static_cast(__builtin_bswap32(static_cast(x) << 16)) + #else + #define FLATBUFFERS_BYTESWAP16 __builtin_bswap16 + #endif + #define FLATBUFFERS_BYTESWAP32 __builtin_bswap32 + #define FLATBUFFERS_BYTESWAP64 __builtin_bswap64 + #endif + if (sizeof(T) == 1) { // Compile-time if-then's. + return t; + } else if (sizeof(T) == 2) { + auto r = FLATBUFFERS_BYTESWAP16(*reinterpret_cast(&t)); + return *reinterpret_cast(&r); + } else if (sizeof(T) == 4) { + auto r = FLATBUFFERS_BYTESWAP32(*reinterpret_cast(&t)); + return *reinterpret_cast(&r); + } else if (sizeof(T) == 8) { + auto r = FLATBUFFERS_BYTESWAP64(*reinterpret_cast(&t)); + return *reinterpret_cast(&r); + } else { + assert(0); + } +} + +template T EndianScalar(T t) { + #if FLATBUFFERS_LITTLEENDIAN + return t; + #else + return EndianSwap(t); + #endif +} + +template T ReadScalar(const void *p) { + return EndianScalar(*reinterpret_cast(p)); +} + +template void WriteScalar(void *p, T t) { + *reinterpret_cast(p) = EndianScalar(t); +} + +template size_t AlignOf() { + #ifdef _MSC_VER + return __alignof(T); + #else + #ifndef alignof + return __alignof__(T); + #else + return alignof(T); + #endif + #endif +} + +// When we read serialized data from memory, in the case of most scalars, +// we want to just read T, but in the case of Offset, we want to actually +// perform the indirection and return a pointer. +// The template specialization below does just that. +// It is wrapped in a struct since function templates can't overload on the +// return type like this. +// The typedef is for the convenience of callers of this function +// (avoiding the need for a trailing return decltype) +template struct IndirectHelper { + typedef T return_type; + typedef T mutable_return_type; + static const size_t element_stride = sizeof(T); + static return_type Read(const uint8_t *p, uoffset_t i) { + return EndianScalar((reinterpret_cast(p))[i]); + } +}; +template struct IndirectHelper> { + typedef const T *return_type; + typedef T *mutable_return_type; + static const size_t element_stride = sizeof(uoffset_t); + static return_type Read(const uint8_t *p, uoffset_t i) { + p += i * sizeof(uoffset_t); + return reinterpret_cast(p + ReadScalar(p)); + } +}; +template struct IndirectHelper { + typedef const T *return_type; + typedef T *mutable_return_type; + static const size_t element_stride = sizeof(T); + static return_type Read(const uint8_t *p, uoffset_t i) { + return reinterpret_cast(p + i * sizeof(T)); + } +}; + +// An STL compatible iterator implementation for Vector below, effectively +// calling Get() for every element. +template +struct VectorIterator + : public std::iterator { + + typedef std::iterator super_type; + +public: + VectorIterator(const uint8_t *data, uoffset_t i) : + data_(data + IndirectHelper::element_stride * i) {} + VectorIterator(const VectorIterator &other) : data_(other.data_) {} + #ifndef FLATBUFFERS_CPP98_STL + VectorIterator(VectorIterator &&other) : data_(std::move(other.data_)) {} + #endif + + VectorIterator &operator=(const VectorIterator &other) { + data_ = other.data_; + return *this; + } + + VectorIterator &operator=(VectorIterator &&other) { + data_ = other.data_; + return *this; + } + + bool operator==(const VectorIterator &other) const { + return data_ == other.data_; + } + + bool operator!=(const VectorIterator &other) const { + return data_ != other.data_; + } + + ptrdiff_t operator-(const VectorIterator &other) const { + return (data_ - other.data_) / IndirectHelper::element_stride; + } + + typename super_type::value_type operator *() const { + return IndirectHelper::Read(data_, 0); + } + + typename super_type::value_type operator->() const { + return IndirectHelper::Read(data_, 0); + } + + VectorIterator &operator++() { + data_ += IndirectHelper::element_stride; + return *this; + } + + VectorIterator operator++(int) { + VectorIterator temp(data_, 0); + data_ += IndirectHelper::element_stride; + return temp; + } + + VectorIterator operator+(const uoffset_t &offset) { + return VectorIterator(data_ + offset * IndirectHelper::element_stride, 0); + } + + VectorIterator& operator+=(const uoffset_t &offset) { + data_ += offset * IndirectHelper::element_stride; + return *this; + } + + VectorIterator &operator--() { + data_ -= IndirectHelper::element_stride; + return *this; + } + + VectorIterator operator--(int) { + VectorIterator temp(data_, 0); + data_ -= IndirectHelper::element_stride; + return temp; + } + + VectorIterator operator-(const uoffset_t &offset) { + return VectorIterator(data_ - offset * IndirectHelper::element_stride, 0); + } + + VectorIterator& operator-=(const uoffset_t &offset) { + data_ -= offset * IndirectHelper::element_stride; + return *this; + } + +private: + const uint8_t *data_; +}; + +// This is used as a helper type for accessing vectors. +// Vector::data() assumes the vector elements start after the length field. +template class Vector { +public: + typedef VectorIterator::mutable_return_type> + iterator; + typedef VectorIterator::return_type> + const_iterator; + + uoffset_t size() const { return EndianScalar(length_); } + + // Deprecated: use size(). Here for backwards compatibility. + uoffset_t Length() const { return size(); } + + typedef typename IndirectHelper::return_type return_type; + typedef typename IndirectHelper::mutable_return_type mutable_return_type; + + return_type Get(uoffset_t i) const { + assert(i < size()); + return IndirectHelper::Read(Data(), i); + } + + return_type operator[](uoffset_t i) const { return Get(i); } + + // If this is a Vector of enums, T will be its storage type, not the enum + // type. This function makes it convenient to retrieve value with enum + // type E. + template E GetEnum(uoffset_t i) const { + return static_cast(Get(i)); + } + + const void *GetStructFromOffset(size_t o) const { + return reinterpret_cast(Data() + o); + } + + iterator begin() { return iterator(Data(), 0); } + const_iterator begin() const { return const_iterator(Data(), 0); } + + iterator end() { return iterator(Data(), size()); } + const_iterator end() const { return const_iterator(Data(), size()); } + + // Change elements if you have a non-const pointer to this object. + // Scalars only. See reflection.h, and the documentation. + void Mutate(uoffset_t i, const T& val) { + assert(i < size()); + WriteScalar(data() + i, val); + } + + // Change an element of a vector of tables (or strings). + // "val" points to the new table/string, as you can obtain from + // e.g. reflection::AddFlatBuffer(). + void MutateOffset(uoffset_t i, const uint8_t *val) { + assert(i < size()); + assert(sizeof(T) == sizeof(uoffset_t)); + WriteScalar(data() + i, + static_cast(val - (Data() + i * sizeof(uoffset_t)))); + } + + // Get a mutable pointer to tables/strings inside this vector. + mutable_return_type GetMutableObject(uoffset_t i) const { + assert(i < size()); + return const_cast(IndirectHelper::Read(Data(), i)); + } + + // The raw data in little endian format. Use with care. + const uint8_t *Data() const { + return reinterpret_cast(&length_ + 1); + } + + uint8_t *Data() { + return reinterpret_cast(&length_ + 1); + } + + // Similarly, but typed, much like std::vector::data + const T *data() const { return reinterpret_cast(Data()); } + T *data() { return reinterpret_cast(Data()); } + + template return_type LookupByKey(K key) const { + void *search_result = std::bsearch(&key, Data(), size(), + IndirectHelper::element_stride, KeyCompare); + + if (!search_result) { + return nullptr; // Key not found. + } + + const uint8_t *element = reinterpret_cast(search_result); + + return IndirectHelper::Read(element, 0); + } + +protected: + // This class is only used to access pre-existing data. Don't ever + // try to construct these manually. + Vector(); + + uoffset_t length_; + +private: + template static int KeyCompare(const void *ap, const void *bp) { + const K *key = reinterpret_cast(ap); + const uint8_t *data = reinterpret_cast(bp); + auto table = IndirectHelper::Read(data, 0); + + // std::bsearch compares with the operands transposed, so we negate the + // result here. + return -table->KeyCompareWithValue(*key); + } +}; + +// Represent a vector much like the template above, but in this case we +// don't know what the element types are (used with reflection.h). +class VectorOfAny { +public: + uoffset_t size() const { return EndianScalar(length_); } + + const uint8_t *Data() const { + return reinterpret_cast(&length_ + 1); + } + uint8_t *Data() { + return reinterpret_cast(&length_ + 1); + } +protected: + VectorOfAny(); + + uoffset_t length_; +}; + +// Convenient helper function to get the length of any vector, regardless +// of wether it is null or not (the field is not set). +template static inline size_t VectorLength(const Vector *v) { + return v ? v->Length() : 0; +} + +struct String : public Vector { + const char *c_str() const { return reinterpret_cast(Data()); } + std::string str() const { return std::string(c_str(), Length()); } + + bool operator <(const String &o) const { + return strcmp(c_str(), o.c_str()) < 0; + } +}; + +// Simple indirection for buffer allocation, to allow this to be overridden +// with custom allocation (see the FlatBufferBuilder constructor). +class simple_allocator { + public: + virtual ~simple_allocator() {} + virtual uint8_t *allocate(size_t size) const { return new uint8_t[size]; } + virtual void deallocate(uint8_t *p) const { delete[] p; } +}; + +// This is a minimal replication of std::vector functionality, +// except growing from higher to lower addresses. i.e push_back() inserts data +// in the lowest address in the vector. +class vector_downward { + public: + explicit vector_downward(size_t initial_size, + const simple_allocator &allocator) + : reserved_((initial_size + sizeof(largest_scalar_t) - 1) & + ~(sizeof(largest_scalar_t) - 1)), + buf_(allocator.allocate(reserved_)), + cur_(buf_ + reserved_), + allocator_(allocator) {} + + ~vector_downward() { + if (buf_) + allocator_.deallocate(buf_); + } + + void clear() { + if (buf_ == nullptr) + buf_ = allocator_.allocate(reserved_); + + cur_ = buf_ + reserved_; + } + + #ifndef FLATBUFFERS_CPP98_STL + // Relinquish the pointer to the caller. + unique_ptr_t release() { + // Actually deallocate from the start of the allocated memory. + std::function deleter( + std::bind(&simple_allocator::deallocate, allocator_, buf_)); + + // Point to the desired offset. + unique_ptr_t retval(data(), deleter); + + // Don't deallocate when this instance is destroyed. + buf_ = nullptr; + cur_ = nullptr; + + return retval; + } + #endif + + size_t growth_policy(size_t bytes) { + return (bytes / 2) & ~(sizeof(largest_scalar_t) - 1); + } + + uint8_t *make_space(size_t len) { + if (len > static_cast(cur_ - buf_)) { + reallocate(len); + } + cur_ -= len; + // Beyond this, signed offsets may not have enough range: + // (FlatBuffers > 2GB not supported). + assert(size() < FLATBUFFERS_MAX_BUFFER_SIZE); + return cur_; + } + + uoffset_t size() const { + assert(cur_ != nullptr && buf_ != nullptr); + return static_cast(reserved_ - (cur_ - buf_)); + } + + uint8_t *data() const { + assert(cur_ != nullptr); + return cur_; + } + + uint8_t *data_at(size_t offset) const { return buf_ + reserved_ - offset; } + + void push(const uint8_t *bytes, size_t num) { + auto dest = make_space(num); + memcpy(dest, bytes, num); + } + + // Specialized version of push() that avoids memcpy call for small data. + template void push_small(T little_endian_t) { + auto dest = make_space(sizeof(T)); + *reinterpret_cast(dest) = little_endian_t; + } + + // fill() is most frequently called with small byte counts (<= 4), + // which is why we're using loops rather than calling memset. + void fill(size_t zero_pad_bytes) { + auto dest = make_space(zero_pad_bytes); + for (size_t i = 0; i < zero_pad_bytes; i++) dest[i] = 0; + } + + // Version for when we know the size is larger. + void fill_big(size_t zero_pad_bytes) { + auto dest = make_space(zero_pad_bytes); + memset(dest, 0, zero_pad_bytes); + } + + void pop(size_t bytes_to_remove) { cur_ += bytes_to_remove; } + + private: + // You shouldn't really be copying instances of this class. + vector_downward(const vector_downward &); + vector_downward &operator=(const vector_downward &); + + size_t reserved_; + uint8_t *buf_; + uint8_t *cur_; // Points at location between empty (below) and used (above). + const simple_allocator &allocator_; + + void reallocate(size_t len) { + auto old_size = size(); + auto largest_align = AlignOf(); + reserved_ += (std::max)(len, growth_policy(reserved_)); + // Round up to avoid undefined behavior from unaligned loads and stores. + reserved_ = (reserved_ + (largest_align - 1)) & ~(largest_align - 1); + auto new_buf = allocator_.allocate(reserved_); + auto new_cur = new_buf + reserved_ - old_size; + memcpy(new_cur, cur_, old_size); + cur_ = new_cur; + allocator_.deallocate(buf_); + buf_ = new_buf; + } +}; + +// Converts a Field ID to a virtual table offset. +inline voffset_t FieldIndexToOffset(voffset_t field_id) { + // Should correspond to what EndTable() below builds up. + const int fixed_fields = 2; // Vtable size and Object Size. + return static_cast((field_id + fixed_fields) * sizeof(voffset_t)); +} + +// Computes how many bytes you'd have to pad to be able to write an +// "scalar_size" scalar if the buffer had grown to "buf_size" (downwards in +// memory). +inline size_t PaddingBytes(size_t buf_size, size_t scalar_size) { + return ((~buf_size) + 1) & (scalar_size - 1); +} + +template const T* data(const std::vector &v) { + return v.empty() ? nullptr : &v.front(); +} +template T* data(std::vector &v) { + return v.empty() ? nullptr : &v.front(); +} + +/// @endcond + +/// @addtogroup flatbuffers_cpp_api +/// @{ +/// @class FlatBufferBuilder +/// @brief Helper class to hold data needed in creation of a FlatBuffer. +/// To serialize data, you typically call one of the `Create*()` functions in +/// the generated code, which in turn call a sequence of `StartTable`/ +/// `PushElement`/`AddElement`/`EndTable`, or the builtin `CreateString`/ +/// `CreateVector` functions. Do this is depth-first order to build up a tree to +/// the root. `Finish()` wraps up the buffer ready for transport. +class FlatBufferBuilder +/// @cond FLATBUFFERS_INTERNAL +FLATBUFFERS_FINAL_CLASS +/// @endcond +{ + public: + /// @brief Default constructor for FlatBufferBuilder. + /// @param[in] initial_size The initial size of the buffer, in bytes. Defaults + /// to`1024`. + /// @param[in] allocator A pointer to the `simple_allocator` that should be + /// used. Defaults to `nullptr`, which means the `default_allocator` will be + /// be used. + explicit FlatBufferBuilder(uoffset_t initial_size = 1024, + const simple_allocator *allocator = nullptr) + : buf_(initial_size, allocator ? *allocator : default_allocator), + nested(false), finished(false), minalign_(1), force_defaults_(false), + dedup_vtables_(true), string_pool(nullptr) { + offsetbuf_.reserve(16); // Avoid first few reallocs. + vtables_.reserve(16); + EndianCheck(); + } + + ~FlatBufferBuilder() { + if (string_pool) delete string_pool; + } + + /// @brief Reset all the state in this FlatBufferBuilder so it can be reused + /// to construct another buffer. + void Clear() { + buf_.clear(); + offsetbuf_.clear(); + nested = false; + finished = false; + vtables_.clear(); + minalign_ = 1; + if (string_pool) string_pool->clear(); + } + + /// @brief The current size of the serialized buffer, counting from the end. + /// @return Returns an `uoffset_t` with the current size of the buffer. + uoffset_t GetSize() const { return buf_.size(); } + + /// @brief Get the serialized buffer (after you call `Finish()`). + /// @return Returns an `uint8_t` pointer to the FlatBuffer data inside the + /// buffer. + uint8_t *GetBufferPointer() const { + Finished(); + return buf_.data(); + } + + /// @brief Get a pointer to an unfinished buffer. + /// @return Returns a `uint8_t` pointer to the unfinished buffer. + uint8_t *GetCurrentBufferPointer() const { return buf_.data(); } + + #ifndef FLATBUFFERS_CPP98_STL + /// @brief Get the released pointer to the serialized buffer. + /// @warning Do NOT attempt to use this FlatBufferBuilder afterwards! + /// @return The `unique_ptr` returned has a special allocator that knows how + /// to deallocate this pointer (since it points to the middle of an + /// allocation). Thus, do not mix this pointer with other `unique_ptr`'s, or + /// call `release()`/`reset()` on it. + unique_ptr_t ReleaseBufferPointer() { + Finished(); + return buf_.release(); + } + #endif + + /// @brief get the minimum alignment this buffer needs to be accessed + /// properly. This is only known once all elements have been written (after + /// you call Finish()). You can use this information if you need to embed + /// a FlatBuffer in some other buffer, such that you can later read it + /// without first having to copy it into its own buffer. + size_t GetBufferMinAlignment() { + Finished(); + return minalign_; + } + + /// @cond FLATBUFFERS_INTERNAL + void Finished() const { + // If you get this assert, you're attempting to get access a buffer + // which hasn't been finished yet. Be sure to call + // FlatBufferBuilder::Finish with your root table. + // If you really need to access an unfinished buffer, call + // GetCurrentBufferPointer instead. + assert(finished); + } + /// @endcond + + /// @brief In order to save space, fields that are set to their default value + /// don't get serialized into the buffer. + /// @param[in] bool fd When set to `true`, always serializes default values. + void ForceDefaults(bool fd) { force_defaults_ = fd; } + + /// @brief By default vtables are deduped in order to save space. + /// @param[in] bool dedup When set to `true`, dedup vtables. + void DedupVtables(bool dedup) { dedup_vtables_ = dedup; } + + /// @cond FLATBUFFERS_INTERNAL + void Pad(size_t num_bytes) { buf_.fill(num_bytes); } + + void Align(size_t elem_size) { + if (elem_size > minalign_) minalign_ = elem_size; + buf_.fill(PaddingBytes(buf_.size(), elem_size)); + } + + void PushFlatBuffer(const uint8_t *bytes, size_t size) { + PushBytes(bytes, size); + finished = true; + } + + void PushBytes(const uint8_t *bytes, size_t size) { + buf_.push(bytes, size); + } + + void PopBytes(size_t amount) { buf_.pop(amount); } + + template void AssertScalarT() { + #ifndef FLATBUFFERS_CPP98_STL + // The code assumes power of 2 sizes and endian-swap-ability. + static_assert(std::is_scalar::value + // The Offset type is essentially a scalar but fails is_scalar. + || sizeof(T) == sizeof(Offset), + "T must be a scalar type"); + #endif + } + + // Write a single aligned scalar to the buffer + template uoffset_t PushElement(T element) { + AssertScalarT(); + T litle_endian_element = EndianScalar(element); + Align(sizeof(T)); + buf_.push_small(litle_endian_element); + return GetSize(); + } + + template uoffset_t PushElement(Offset off) { + // Special case for offsets: see ReferTo below. + return PushElement(ReferTo(off.o)); + } + + // When writing fields, we track where they are, so we can create correct + // vtables later. + void TrackField(voffset_t field, uoffset_t off) { + FieldLoc fl = { off, field }; + offsetbuf_.push_back(fl); + } + + // Like PushElement, but additionally tracks the field this represents. + template void AddElement(voffset_t field, T e, T def) { + // We don't serialize values equal to the default. + if (e == def && !force_defaults_) return; + auto off = PushElement(e); + TrackField(field, off); + } + + template void AddOffset(voffset_t field, Offset off) { + if (!off.o) return; // An offset of 0 means NULL, don't store. + AddElement(field, ReferTo(off.o), static_cast(0)); + } + + template void AddStruct(voffset_t field, const T *structptr) { + if (!structptr) return; // Default, don't store. + Align(AlignOf()); + buf_.push_small(*structptr); + TrackField(field, GetSize()); + } + + void AddStructOffset(voffset_t field, uoffset_t off) { + TrackField(field, off); + } + + // Offsets initially are relative to the end of the buffer (downwards). + // This function converts them to be relative to the current location + // in the buffer (when stored here), pointing upwards. + uoffset_t ReferTo(uoffset_t off) { + // Align to ensure GetSize() below is correct. + Align(sizeof(uoffset_t)); + // Offset must refer to something already in buffer. + assert(off && off <= GetSize()); + return GetSize() - off + static_cast(sizeof(uoffset_t)); + } + + void NotNested() { + // If you hit this, you're trying to construct a Table/Vector/String + // during the construction of its parent table (between the MyTableBuilder + // and table.Finish(). + // Move the creation of these sub-objects to above the MyTableBuilder to + // not get this assert. + // Ignoring this assert may appear to work in simple cases, but the reason + // it is here is that storing objects in-line may cause vtable offsets + // to not fit anymore. It also leads to vtable duplication. + assert(!nested); + } + + // From generated code (or from the parser), we call StartTable/EndTable + // with a sequence of AddElement calls in between. + uoffset_t StartTable() { + NotNested(); + nested = true; + return GetSize(); + } + + // This finishes one serialized object by generating the vtable if it's a + // table, comparing it against existing vtables, and writing the + // resulting vtable offset. + uoffset_t EndTable(uoffset_t start, voffset_t numfields) { + // If you get this assert, a corresponding StartTable wasn't called. + assert(nested); + // Write the vtable offset, which is the start of any Table. + // We fill it's value later. + auto vtableoffsetloc = PushElement(0); + // Write a vtable, which consists entirely of voffset_t elements. + // It starts with the number of offsets, followed by a type id, followed + // by the offsets themselves. In reverse: + buf_.fill_big(numfields * sizeof(voffset_t)); + auto table_object_size = vtableoffsetloc - start; + assert(table_object_size < 0x10000); // Vtable use 16bit offsets. + PushElement(static_cast(table_object_size)); + PushElement(FieldIndexToOffset(numfields)); + // Write the offsets into the table + for (auto field_location = offsetbuf_.begin(); + field_location != offsetbuf_.end(); + ++field_location) { + auto pos = static_cast(vtableoffsetloc - field_location->off); + // If this asserts, it means you've set a field twice. + assert(!ReadScalar(buf_.data() + field_location->id)); + WriteScalar(buf_.data() + field_location->id, pos); + } + offsetbuf_.clear(); + auto vt1 = reinterpret_cast(buf_.data()); + auto vt1_size = ReadScalar(vt1); + auto vt_use = GetSize(); + // See if we already have generated a vtable with this exact same + // layout before. If so, make it point to the old one, remove this one. + if (dedup_vtables_) { + for (auto it = vtables_.begin(); it != vtables_.end(); ++it) { + auto vt2 = reinterpret_cast(buf_.data_at(*it)); + auto vt2_size = *vt2; + if (vt1_size != vt2_size || memcmp(vt2, vt1, vt1_size)) continue; + vt_use = *it; + buf_.pop(GetSize() - vtableoffsetloc); + break; + } + } + // If this is a new vtable, remember it. + if (vt_use == GetSize()) { + vtables_.push_back(vt_use); + } + // Fill the vtable offset we created above. + // The offset points from the beginning of the object to where the + // vtable is stored. + // Offsets default direction is downward in memory for future format + // flexibility (storing all vtables at the start of the file). + WriteScalar(buf_.data_at(vtableoffsetloc), + static_cast(vt_use) - + static_cast(vtableoffsetloc)); + + nested = false; + return vtableoffsetloc; + } + + // This checks a required field has been set in a given table that has + // just been constructed. + template void Required(Offset table, voffset_t field) { + auto table_ptr = buf_.data_at(table.o); + auto vtable_ptr = table_ptr - ReadScalar(table_ptr); + bool ok = ReadScalar(vtable_ptr + field) != 0; + // If this fails, the caller will show what field needs to be set. + assert(ok); + (void)ok; + } + + uoffset_t StartStruct(size_t alignment) { + Align(alignment); + return GetSize(); + } + + uoffset_t EndStruct() { return GetSize(); } + + void ClearOffsets() { offsetbuf_.clear(); } + + // Aligns such that when "len" bytes are written, an object can be written + // after it with "alignment" without padding. + void PreAlign(size_t len, size_t alignment) { + buf_.fill(PaddingBytes(GetSize() + len, alignment)); + } + template void PreAlign(size_t len) { + AssertScalarT(); + PreAlign(len, sizeof(T)); + } + /// @endcond + + /// @brief Store a string in the buffer, which can contain any binary data. + /// @param[in] str A const char pointer to the data to be stored as a string. + /// @param[in] len The number of bytes that should be stored from `str`. + /// @return Returns the offset in the buffer where the string starts. + Offset CreateString(const char *str, size_t len) { + NotNested(); + PreAlign(len + 1); // Always 0-terminated. + buf_.fill(1); + PushBytes(reinterpret_cast(str), len); + PushElement(static_cast(len)); + return Offset(GetSize()); + } + + /// @brief Store a string in the buffer, which is null-terminated. + /// @param[in] str A const char pointer to a C-string to add to the buffer. + /// @return Returns the offset in the buffer where the string starts. + Offset CreateString(const char *str) { + return CreateString(str, strlen(str)); + } + + /// @brief Store a string in the buffer, which can contain any binary data. + /// @param[in] str A const reference to a std::string to store in the buffer. + /// @return Returns the offset in the buffer where the string starts. + Offset CreateString(const std::string &str) { + return CreateString(str.c_str(), str.length()); + } + + /// @brief Store a string in the buffer, which can contain any binary data. + /// @param[in] str A const pointer to a `String` struct to add to the buffer. + /// @return Returns the offset in the buffer where the string starts + Offset CreateString(const String *str) { + return str ? CreateString(str->c_str(), str->Length()) : 0; + } + + /// @brief Store a string in the buffer, which can contain any binary data. + /// If a string with this exact contents has already been serialized before, + /// instead simply returns the offset of the existing string. + /// @param[in] str A const char pointer to the data to be stored as a string. + /// @param[in] len The number of bytes that should be stored from `str`. + /// @return Returns the offset in the buffer where the string starts. + Offset CreateSharedString(const char *str, size_t len) { + if (!string_pool) + string_pool = new StringOffsetMap(StringOffsetCompare(buf_)); + auto size_before_string = buf_.size(); + // Must first serialize the string, since the set is all offsets into + // buffer. + auto off = CreateString(str, len); + auto it = string_pool->find(off); + // If it exists we reuse existing serialized data! + if (it != string_pool->end()) { + // We can remove the string we serialized. + buf_.pop(buf_.size() - size_before_string); + return *it; + } + // Record this string for future use. + string_pool->insert(off); + return off; + } + + /// @brief Store a string in the buffer, which null-terminated. + /// If a string with this exact contents has already been serialized before, + /// instead simply returns the offset of the existing string. + /// @param[in] str A const char pointer to a C-string to add to the buffer. + /// @return Returns the offset in the buffer where the string starts. + Offset CreateSharedString(const char *str) { + return CreateSharedString(str, strlen(str)); + } + + /// @brief Store a string in the buffer, which can contain any binary data. + /// If a string with this exact contents has already been serialized before, + /// instead simply returns the offset of the existing string. + /// @param[in] str A const reference to a std::string to store in the buffer. + /// @return Returns the offset in the buffer where the string starts. + Offset CreateSharedString(const std::string &str) { + return CreateSharedString(str.c_str(), str.length()); + } + + /// @brief Store a string in the buffer, which can contain any binary data. + /// If a string with this exact contents has already been serialized before, + /// instead simply returns the offset of the existing string. + /// @param[in] str A const pointer to a `String` struct to add to the buffer. + /// @return Returns the offset in the buffer where the string starts + Offset CreateSharedString(const String *str) { + return CreateSharedString(str->c_str(), str->Length()); + } + + /// @cond FLATBUFFERS_INTERNAL + uoffset_t EndVector(size_t len) { + assert(nested); // Hit if no corresponding StartVector. + nested = false; + return PushElement(static_cast(len)); + } + + void StartVector(size_t len, size_t elemsize) { + NotNested(); + nested = true; + PreAlign(len * elemsize); + PreAlign(len * elemsize, elemsize); // Just in case elemsize > uoffset_t. + } + + // Call this right before StartVector/CreateVector if you want to force the + // alignment to be something different than what the element size would + // normally dictate. + // This is useful when storing a nested_flatbuffer in a vector of bytes, + // or when storing SIMD floats, etc. + void ForceVectorAlignment(size_t len, size_t elemsize, size_t alignment) { + PreAlign(len * elemsize, alignment); + } + + uint8_t *ReserveElements(size_t len, size_t elemsize) { + return buf_.make_space(len * elemsize); + } + /// @endcond + + /// @brief Serialize an array into a FlatBuffer `vector`. + /// @tparam T The data type of the array elements. + /// @param[in] v A pointer to the array of type `T` to serialize into the + /// buffer as a `vector`. + /// @param[in] len The number of elements to serialize. + /// @return Returns a typed `Offset` into the serialized data indicating + /// where the vector is stored. + template Offset> CreateVector(const T *v, size_t len) { + StartVector(len, sizeof(T)); + for (auto i = len; i > 0; ) { + PushElement(v[--i]); + } + return Offset>(EndVector(len)); + } + + /// @brief Serialize a `std::vector` into a FlatBuffer `vector`. + /// @tparam T The data type of the `std::vector` elements. + /// @param v A const reference to the `std::vector` to serialize into the + /// buffer as a `vector`. + /// @return Returns a typed `Offset` into the serialized data indicating + /// where the vector is stored. + template Offset> CreateVector(const std::vector &v) { + return CreateVector(data(v), v.size()); + } + + // vector may be implemented using a bit-set, so we can't access it as + // an array. Instead, read elements manually. + // Background: https://isocpp.org/blog/2012/11/on-vectorbool + Offset> CreateVector(const std::vector &v) { + StartVector(v.size(), sizeof(uint8_t)); + for (auto i = v.size(); i > 0; ) { + PushElement(static_cast(v[--i])); + } + return Offset>(EndVector(v.size())); + } + + #ifndef FLATBUFFERS_CPP98_STL + /// @brief Serialize values returned by a function into a FlatBuffer `vector`. + /// This is a convenience function that takes care of iteration for you. + /// @tparam T The data type of the `std::vector` elements. + /// @param f A function that takes the current iteration 0..vector_size-1 and + /// returns any type that you can construct a FlatBuffers vector out of. + /// @return Returns a typed `Offset` into the serialized data indicating + /// where the vector is stored. + template Offset> CreateVector(size_t vector_size, + const std::function &f) { + std::vector elems(vector_size); + for (size_t i = 0; i < vector_size; i++) elems[i] = f(i); + return CreateVector(elems); + } + #endif + + /// @brief Serialize a `std::vector` into a FlatBuffer `vector`. + /// This is a convenience function for a common case. + /// @param v A const reference to the `std::vector` to serialize into the + /// buffer as a `vector`. + /// @return Returns a typed `Offset` into the serialized data indicating + /// where the vector is stored. + Offset>> CreateVectorOfStrings( + const std::vector &v) { + std::vector> offsets(v.size()); + for (size_t i = 0; i < v.size(); i++) offsets[i] = CreateString(v[i]); + return CreateVector(offsets); + } + + /// @brief Serialize an array of structs into a FlatBuffer `vector`. + /// @tparam T The data type of the struct array elements. + /// @param[in] v A pointer to the array of type `T` to serialize into the + /// buffer as a `vector`. + /// @param[in] len The number of elements to serialize. + /// @return Returns a typed `Offset` into the serialized data indicating + /// where the vector is stored. + template Offset> CreateVectorOfStructs( + const T *v, size_t len) { + StartVector(len * sizeof(T) / AlignOf(), AlignOf()); + PushBytes(reinterpret_cast(v), sizeof(T) * len); + return Offset>(EndVector(len)); + } + + #ifndef FLATBUFFERS_CPP98_STL + /// @brief Serialize an array of structs into a FlatBuffer `vector`. + /// @tparam T The data type of the struct array elements. + /// @param[in] f A function that takes the current iteration 0..vector_size-1 + /// and a pointer to the struct that must be filled. + /// @return Returns a typed `Offset` into the serialized data indicating + /// where the vector is stored. + /// This is mostly useful when flatbuffers are generated with mutation + /// accessors. + template Offset> CreateVectorOfStructs( + size_t vector_size, const std::function &filler) { + StartVector(vector_size * sizeof(T) / AlignOf(), AlignOf()); + T *structs = reinterpret_cast(buf_.make_space(vector_size * sizeof(T))); + for (size_t i = 0; i < vector_size; i++) { + filler(i, structs); + structs++; + } + return Offset>(EndVector(vector_size)); + } + #endif + + /// @brief Serialize a `std::vector` of structs into a FlatBuffer `vector`. + /// @tparam T The data type of the `std::vector` struct elements. + /// @param[in]] v A const reference to the `std::vector` of structs to + /// serialize into the buffer as a `vector`. + /// @return Returns a typed `Offset` into the serialized data indicating + /// where the vector is stored. + template Offset> CreateVectorOfStructs( + const std::vector &v) { + return CreateVectorOfStructs(data(v), v.size()); + } + + /// @cond FLATBUFFERS_INTERNAL + template + struct TableKeyComparator { + TableKeyComparator(vector_downward& buf) : buf_(buf) {} + bool operator()(const Offset &a, const Offset &b) const { + auto table_a = reinterpret_cast(buf_.data_at(a.o)); + auto table_b = reinterpret_cast(buf_.data_at(b.o)); + return table_a->KeyCompareLessThan(table_b); + } + vector_downward& buf_; + + private: + TableKeyComparator& operator= (const TableKeyComparator&); + }; + /// @endcond + + /// @brief Serialize an array of `table` offsets as a `vector` in the buffer + /// in sorted order. + /// @tparam T The data type that the offset refers to. + /// @param[in] v An array of type `Offset` that contains the `table` + /// offsets to store in the buffer in sorted order. + /// @param[in] len The number of elements to store in the `vector`. + /// @return Returns a typed `Offset` into the serialized data indicating + /// where the vector is stored. + template Offset>> CreateVectorOfSortedTables( + Offset *v, size_t len) { + std::sort(v, v + len, TableKeyComparator(buf_)); + return CreateVector(v, len); + } + + /// @brief Serialize an array of `table` offsets as a `vector` in the buffer + /// in sorted order. + /// @tparam T The data type that the offset refers to. + /// @param[in] v An array of type `Offset` that contains the `table` + /// offsets to store in the buffer in sorted order. + /// @return Returns a typed `Offset` into the serialized data indicating + /// where the vector is stored. + template Offset>> CreateVectorOfSortedTables( + std::vector> *v) { + return CreateVectorOfSortedTables(data(*v), v->size()); + } + + /// @brief Specialized version of `CreateVector` for non-copying use cases. + /// Write the data any time later to the returned buffer pointer `buf`. + /// @param[in] len The number of elements to store in the `vector`. + /// @param[in] elemsize The size of each element in the `vector`. + /// @param[out] buf A pointer to a `uint8_t` pointer that can be + /// written to at a later time to serialize the data into a `vector` + /// in the buffer. + uoffset_t CreateUninitializedVector(size_t len, size_t elemsize, + uint8_t **buf) { + NotNested(); + StartVector(len, elemsize); + buf_.make_space(len * elemsize); + auto vec_start = GetSize(); + auto vec_end = EndVector(len); + *buf = buf_.data_at(vec_start); + return vec_end; + } + + /// @brief Specialized version of `CreateVector` for non-copying use cases. + /// Write the data any time later to the returned buffer pointer `buf`. + /// @tparam T The data type of the data that will be stored in the buffer + /// as a `vector`. + /// @param[in] len The number of elements to store in the `vector`. + /// @param[out] buf A pointer to a pointer of type `T` that can be + /// written to at a later time to serialize the data into a `vector` + /// in the buffer. + template Offset> CreateUninitializedVector( + size_t len, T **buf) { + return CreateUninitializedVector(len, sizeof(T), + reinterpret_cast(buf)); + } + + /// @brief The length of a FlatBuffer file header. + static const size_t kFileIdentifierLength = 4; + + /// @brief Finish serializing a buffer by writing the root offset. + /// @param[in] file_identifier If a `file_identifier` is given, the buffer + /// will be prefixed with a standard FlatBuffers file header. + template void Finish(Offset root, + const char *file_identifier = nullptr) { + + Finish(root.o, file_identifier, false); + } + + /// @brief Finish a buffer with a 32 bit size field pre-fixed (size of the + /// buffer following the size field). These buffers are NOT compatible + /// with standard buffers created by Finish, i.e. you can't call GetRoot + /// on them, you have to use GetSizePrefixedRoot instead. + /// All >32 bit quantities in this buffer will be aligned when the whole + /// size pre-fixed buffer is aligned. + /// These kinds of buffers are useful for creating a stream of FlatBuffers. + template void FinishSizePrefixed(Offset root, + const char *file_identifier = nullptr) { + Finish(root.o, file_identifier, true); + } + + private: + // You shouldn't really be copying instances of this class. + FlatBufferBuilder(const FlatBufferBuilder &); + FlatBufferBuilder &operator=(const FlatBufferBuilder &); + + void Finish(uoffset_t root, const char *file_identifier, bool size_prefix) { + NotNested(); + // This will cause the whole buffer to be aligned. + PreAlign((size_prefix ? sizeof(uoffset_t) : 0) + + sizeof(uoffset_t) + + (file_identifier ? kFileIdentifierLength : 0), + minalign_); + if (file_identifier) { + assert(strlen(file_identifier) == kFileIdentifierLength); + PushBytes(reinterpret_cast(file_identifier), + kFileIdentifierLength); + } + PushElement(ReferTo(root)); // Location of root. + if (size_prefix) { + PushElement(GetSize()); + } + finished = true; + } + + struct FieldLoc { + uoffset_t off; + voffset_t id; + }; + + simple_allocator default_allocator; + + vector_downward buf_; + + // Accumulating offsets of table members while it is being built. + std::vector offsetbuf_; + + // Ensure objects are not nested. + bool nested; + + // Ensure the buffer is finished before it is being accessed. + bool finished; + + std::vector vtables_; // todo: Could make this into a map? + + size_t minalign_; + + bool force_defaults_; // Serialize values equal to their defaults anyway. + + bool dedup_vtables_; + + struct StringOffsetCompare { + StringOffsetCompare(const vector_downward &buf) : buf_(&buf) {} + bool operator() (const Offset &a, const Offset &b) const { + auto stra = reinterpret_cast(buf_->data_at(a.o)); + auto strb = reinterpret_cast(buf_->data_at(b.o)); + return strncmp(stra->c_str(), strb->c_str(), + std::min(stra->size(), strb->size()) + 1) < 0; + } + const vector_downward *buf_; + }; + + // For use with CreateSharedString. Instantiated on first use only. + typedef std::set, StringOffsetCompare> StringOffsetMap; + StringOffsetMap *string_pool; +}; +/// @} + +/// @cond FLATBUFFERS_INTERNAL +// Helpers to get a typed pointer to the root object contained in the buffer. +template T *GetMutableRoot(void *buf) { + EndianCheck(); + return reinterpret_cast(reinterpret_cast(buf) + + EndianScalar(*reinterpret_cast(buf))); +} + +template const T *GetRoot(const void *buf) { + return GetMutableRoot(const_cast(buf)); +} + +template const T *GetSizePrefixedRoot(const void *buf) { + return GetRoot(reinterpret_cast(buf) + sizeof(uoffset_t)); +} + +/// Helpers to get a typed pointer to objects that are currently being built. +/// @warning Creating new objects will lead to reallocations and invalidates +/// the pointer! +template T *GetMutableTemporaryPointer(FlatBufferBuilder &fbb, + Offset offset) { + return reinterpret_cast(fbb.GetCurrentBufferPointer() + + fbb.GetSize() - offset.o); +} + +template const T *GetTemporaryPointer(FlatBufferBuilder &fbb, + Offset offset) { + return GetMutableTemporaryPointer(fbb, offset); +} + +// Helper to see if the identifier in a buffer has the expected value. +inline bool BufferHasIdentifier(const void *buf, const char *identifier) { + return strncmp(reinterpret_cast(buf) + sizeof(uoffset_t), + identifier, FlatBufferBuilder::kFileIdentifierLength) == 0; +} + +// Helper class to verify the integrity of a FlatBuffer +class Verifier FLATBUFFERS_FINAL_CLASS { + public: + Verifier(const uint8_t *buf, size_t buf_len, size_t _max_depth = 64, + size_t _max_tables = 1000000) + : buf_(buf), end_(buf + buf_len), depth_(0), max_depth_(_max_depth), + num_tables_(0), max_tables_(_max_tables) + #ifdef FLATBUFFERS_TRACK_VERIFIER_BUFFER_SIZE + , upper_bound_(buf) + #endif + {} + + // Central location where any verification failures register. + bool Check(bool ok) const { + #ifdef FLATBUFFERS_DEBUG_VERIFICATION_FAILURE + assert(ok); + #endif + #ifdef FLATBUFFERS_TRACK_VERIFIER_BUFFER_SIZE + if (!ok) + upper_bound_ = buf_; + #endif + return ok; + } + + // Verify any range within the buffer. + bool Verify(const void *elem, size_t elem_len) const { + #ifdef FLATBUFFERS_TRACK_VERIFIER_BUFFER_SIZE + auto upper_bound = reinterpret_cast(elem) + elem_len; + if (upper_bound_ < upper_bound) + upper_bound_ = upper_bound; + #endif + return Check(elem_len <= (size_t) (end_ - buf_) && + elem >= buf_ && + elem <= end_ - elem_len); + } + + // Verify a range indicated by sizeof(T). + template bool Verify(const void *elem) const { + return Verify(elem, sizeof(T)); + } + + // Verify a pointer (may be NULL) of a table type. + template bool VerifyTable(const T *table) { + return !table || table->Verify(*this); + } + + // Verify a pointer (may be NULL) of any vector type. + template bool Verify(const Vector *vec) const { + const uint8_t *end; + return !vec || + VerifyVector(reinterpret_cast(vec), sizeof(T), + &end); + } + + // Verify a pointer (may be NULL) of a vector to struct. + template bool Verify(const Vector *vec) const { + return Verify(reinterpret_cast *>(vec)); + } + + // Verify a pointer (may be NULL) to string. + bool Verify(const String *str) const { + const uint8_t *end; + return !str || + (VerifyVector(reinterpret_cast(str), 1, &end) && + Verify(end, 1) && // Must have terminator + Check(*end == '\0')); // Terminating byte must be 0. + } + + // Common code between vectors and strings. + bool VerifyVector(const uint8_t *vec, size_t elem_size, + const uint8_t **end) const { + // Check we can read the size field. + if (!Verify(vec)) return false; + // Check the whole array. If this is a string, the byte past the array + // must be 0. + auto size = ReadScalar(vec); + auto max_elems = FLATBUFFERS_MAX_BUFFER_SIZE / elem_size; + if (!Check(size < max_elems)) + return false; // Protect against byte_size overflowing. + auto byte_size = sizeof(size) + elem_size * size; + *end = vec + byte_size; + return Verify(vec, byte_size); + } + + // Special case for string contents, after the above has been called. + bool VerifyVectorOfStrings(const Vector> *vec) const { + if (vec) { + for (uoffset_t i = 0; i < vec->size(); i++) { + if (!Verify(vec->Get(i))) return false; + } + } + return true; + } + + // Special case for table contents, after the above has been called. + template bool VerifyVectorOfTables(const Vector> *vec) { + if (vec) { + for (uoffset_t i = 0; i < vec->size(); i++) { + if (!vec->Get(i)->Verify(*this)) return false; + } + } + return true; + } + + template bool VerifyBufferFromStart(const char *identifier, + const uint8_t *start) { + if (identifier && + (size_t(end_ - start) < 2 * sizeof(flatbuffers::uoffset_t) || + !BufferHasIdentifier(start, identifier))) { + return false; + } + + // Call T::Verify, which must be in the generated code for this type. + return Verify(start) && + reinterpret_cast(start + ReadScalar(start))-> + Verify(*this) + #ifdef FLATBUFFERS_TRACK_VERIFIER_BUFFER_SIZE + && GetComputedSize() + #endif + ; + } + + // Verify this whole buffer, starting with root type T. + template bool VerifyBuffer(const char *identifier) { + return VerifyBufferFromStart(identifier, buf_); + } + + template bool VerifySizePrefixedBuffer(const char *identifier) { + return Verify(buf_) && + ReadScalar(buf_) == end_ - buf_ - sizeof(uoffset_t) && + VerifyBufferFromStart(identifier, buf_ + sizeof(uoffset_t)); + } + + // Called at the start of a table to increase counters measuring data + // structure depth and amount, and possibly bails out with false if + // limits set by the constructor have been hit. Needs to be balanced + // with EndTable(). + bool VerifyComplexity() { + depth_++; + num_tables_++; + return Check(depth_ <= max_depth_ && num_tables_ <= max_tables_); + } + + // Called at the end of a table to pop the depth count. + bool EndTable() { + depth_--; + return true; + } + + #ifdef FLATBUFFERS_TRACK_VERIFIER_BUFFER_SIZE + // Returns the message size in bytes + size_t GetComputedSize() const { + uintptr_t size = upper_bound_ - buf_; + // Align the size to uoffset_t + size = (size - 1 + sizeof(uoffset_t)) & ~(sizeof(uoffset_t) - 1); + return (buf_ + size > end_) ? 0 : size; + } + #endif + + private: + const uint8_t *buf_; + const uint8_t *end_; + size_t depth_; + size_t max_depth_; + size_t num_tables_; + size_t max_tables_; +#ifdef FLATBUFFERS_TRACK_VERIFIER_BUFFER_SIZE + mutable const uint8_t *upper_bound_; +#endif +}; + +// Convenient way to bundle a buffer and its length, to pass it around +// typed by its root. +// A BufferRef does not own its buffer. +struct BufferRefBase {}; // for std::is_base_of +template struct BufferRef : BufferRefBase { + BufferRef() : buf(nullptr), len(0), must_free(false) {} + BufferRef(uint8_t *_buf, uoffset_t _len) + : buf(_buf), len(_len), must_free(false) {} + + ~BufferRef() { if (must_free) free(buf); } + + const T *GetRoot() const { return flatbuffers::GetRoot(buf); } + + bool Verify() { + Verifier verifier(buf, len); + return verifier.VerifyBuffer(nullptr); + } + + uint8_t *buf; + uoffset_t len; + bool must_free; +}; + +// "structs" are flat structures that do not have an offset table, thus +// always have all members present and do not support forwards/backwards +// compatible extensions. + +class Struct FLATBUFFERS_FINAL_CLASS { + public: + template T GetField(uoffset_t o) const { + return ReadScalar(&data_[o]); + } + + template T GetStruct(uoffset_t o) const { + return reinterpret_cast(&data_[o]); + } + + const uint8_t *GetAddressOf(uoffset_t o) const { return &data_[o]; } + uint8_t *GetAddressOf(uoffset_t o) { return &data_[o]; } + + private: + uint8_t data_[1]; +}; + +// "tables" use an offset table (possibly shared) that allows fields to be +// omitted and added at will, but uses an extra indirection to read. +class Table { + public: + const uint8_t *GetVTable() const { + return data_ - ReadScalar(data_); + } + + // This gets the field offset for any of the functions below it, or 0 + // if the field was not present. + voffset_t GetOptionalFieldOffset(voffset_t field) const { + // The vtable offset is always at the start. + auto vtable = GetVTable(); + // The first element is the size of the vtable (fields + type id + itself). + auto vtsize = ReadScalar(vtable); + // If the field we're accessing is outside the vtable, we're reading older + // data, so it's the same as if the offset was 0 (not present). + return field < vtsize ? ReadScalar(vtable + field) : 0; + } + + template T GetField(voffset_t field, T defaultval) const { + auto field_offset = GetOptionalFieldOffset(field); + return field_offset ? ReadScalar(data_ + field_offset) : defaultval; + } + + template P GetPointer(voffset_t field) { + auto field_offset = GetOptionalFieldOffset(field); + auto p = data_ + field_offset; + return field_offset + ? reinterpret_cast

(p + ReadScalar(p)) + : nullptr; + } + template P GetPointer(voffset_t field) const { + return const_cast(this)->GetPointer

(field); + } + + template P GetStruct(voffset_t field) const { + auto field_offset = GetOptionalFieldOffset(field); + auto p = const_cast(data_ + field_offset); + return field_offset ? reinterpret_cast

(p) : nullptr; + } + + template bool SetField(voffset_t field, T val) { + auto field_offset = GetOptionalFieldOffset(field); + if (!field_offset) return false; + WriteScalar(data_ + field_offset, val); + return true; + } + + bool SetPointer(voffset_t field, const uint8_t *val) { + auto field_offset = GetOptionalFieldOffset(field); + if (!field_offset) return false; + WriteScalar(data_ + field_offset, + static_cast(val - (data_ + field_offset))); + return true; + } + + uint8_t *GetAddressOf(voffset_t field) { + auto field_offset = GetOptionalFieldOffset(field); + return field_offset ? data_ + field_offset : nullptr; + } + const uint8_t *GetAddressOf(voffset_t field) const { + return const_cast

(this)->GetAddressOf(field); + } + + bool CheckField(voffset_t field) const { + return GetOptionalFieldOffset(field) != 0; + } + + // Verify the vtable of this table. + // Call this once per table, followed by VerifyField once per field. + bool VerifyTableStart(Verifier &verifier) const { + // Check the vtable offset. + if (!verifier.Verify(data_)) return false; + auto vtable = GetVTable(); + // Check the vtable size field, then check vtable fits in its entirety. + return verifier.VerifyComplexity() && + verifier.Verify(vtable) && + (ReadScalar(vtable) & (sizeof(voffset_t) - 1)) == 0 && + verifier.Verify(vtable, ReadScalar(vtable)); + } + + // Verify a particular field. + template bool VerifyField(const Verifier &verifier, + voffset_t field) const { + // Calling GetOptionalFieldOffset should be safe now thanks to + // VerifyTable(). + auto field_offset = GetOptionalFieldOffset(field); + // Check the actual field. + return !field_offset || verifier.Verify(data_ + field_offset); + } + + // VerifyField for required fields. + template bool VerifyFieldRequired(const Verifier &verifier, + voffset_t field) const { + auto field_offset = GetOptionalFieldOffset(field); + return verifier.Check(field_offset != 0) && + verifier.Verify(data_ + field_offset); + } + + private: + // private constructor & copy constructor: you obtain instances of this + // class by pointing to existing data only + Table(); + Table(const Table &other); + + uint8_t data_[1]; +}; + +/// @brief This can compute the start of a FlatBuffer from a root pointer, i.e. +/// it is the opposite transformation of GetRoot(). +/// This may be useful if you want to pass on a root and have the recipient +/// delete the buffer afterwards. +inline const uint8_t *GetBufferStartFromRootPointer(const void *root) { + auto table = reinterpret_cast(root); + auto vtable = table->GetVTable(); + // Either the vtable is before the root or after the root. + auto start = std::min(vtable, reinterpret_cast(root)); + // Align to at least sizeof(uoffset_t). + start = reinterpret_cast( + reinterpret_cast(start) & ~(sizeof(uoffset_t) - 1)); + // Additionally, there may be a file_identifier in the buffer, and the root + // offset. The buffer may have been aligned to any size between + // sizeof(uoffset_t) and FLATBUFFERS_MAX_ALIGNMENT (see "force_align"). + // Sadly, the exact alignment is only known when constructing the buffer, + // since it depends on the presence of values with said alignment properties. + // So instead, we simply look at the next uoffset_t values (root, + // file_identifier, and alignment padding) to see which points to the root. + // None of the other values can "impersonate" the root since they will either + // be 0 or four ASCII characters. + static_assert(FlatBufferBuilder::kFileIdentifierLength == sizeof(uoffset_t), + "file_identifier is assumed to be the same size as uoffset_t"); + for (auto possible_roots = FLATBUFFERS_MAX_ALIGNMENT / sizeof(uoffset_t) + 1; + possible_roots; + possible_roots--) { + start -= sizeof(uoffset_t); + if (ReadScalar(start) + start == + reinterpret_cast(root)) return start; + } + // We didn't find the root, either the "root" passed isn't really a root, + // or the buffer is corrupt. + // Assert, because calling this function with bad data may cause reads + // outside of buffer boundaries. + assert(false); + return nullptr; +} + +// Base class for native objects (FlatBuffer data de-serialized into native +// C++ data structures). +// Contains no functionality, purely documentative. +struct NativeTable { +}; + +/// @brief Function types to be used with resolving hashes into objects and +/// back again. The resolver gets a pointer to a field inside an object API +/// object that is of the type specified in the schema using the attribute +/// `cpp_type` (it is thus important whatever you write to this address +/// matches that type). The value of this field is initially null, so you +/// may choose to implement a delayed binding lookup using this function +/// if you wish. The resolver does the opposite lookup, for when the object +/// is being serialized again. +typedef uint64_t hash_value_t; +#ifdef FLATBUFFERS_CPP98_STL + typedef void (*resolver_function_t)(void **pointer_adr, hash_value_t hash); + typedef hash_value_t (*rehasher_function_t)(void *pointer); +#else + typedef std::function + resolver_function_t; + typedef std::function rehasher_function_t; +#endif + +// Helper function to test if a field is present, using any of the field +// enums in the generated code. +// `table` must be a generated table type. Since this is a template parameter, +// this is not typechecked to be a subclass of Table, so beware! +// Note: this function will return false for fields equal to the default +// value, since they're not stored in the buffer (unless force_defaults was +// used). +template bool IsFieldPresent(const T *table, voffset_t field) { + // Cast, since Table is a private baseclass of any table types. + return reinterpret_cast(table)->CheckField(field); +} + +// Utility function for reverse lookups on the EnumNames*() functions +// (in the generated C++ code) +// names must be NULL terminated. +inline int LookupEnum(const char **names, const char *name) { + for (const char **p = names; *p; p++) + if (!strcmp(*p, name)) + return static_cast(p - names); + return -1; +} + +// These macros allow us to layout a struct with a guarantee that they'll end +// up looking the same on different compilers and platforms. +// It does this by disallowing the compiler to do any padding, and then +// does padding itself by inserting extra padding fields that make every +// element aligned to its own size. +// Additionally, it manually sets the alignment of the struct as a whole, +// which is typically its largest element, or a custom size set in the schema +// by the force_align attribute. +// These are used in the generated code only. + +#if defined(_MSC_VER) + #define MANUALLY_ALIGNED_STRUCT(alignment) \ + __pragma(pack(1)); \ + struct __declspec(align(alignment)) + #define STRUCT_END(name, size) \ + __pragma(pack()); \ + static_assert(sizeof(name) == size, "compiler breaks packing rules") +#elif defined(__GNUC__) || defined(__clang__) + #define MANUALLY_ALIGNED_STRUCT(alignment) \ + _Pragma("pack(1)") \ + struct __attribute__((aligned(alignment))) + #define STRUCT_END(name, size) \ + _Pragma("pack()") \ + static_assert(sizeof(name) == size, "compiler breaks packing rules") +#else + #error Unknown compiler, please define structure alignment macros +#endif + +// String which identifies the current version of FlatBuffers. +// flatbuffer_version_string is used by Google developers to identify which +// applications uploaded to Google Play are using this library. This allows +// the development team at Google to determine the popularity of the library. +// How it works: Applications that are uploaded to the Google Play Store are +// scanned for this version string. We track which applications are using it +// to measure popularity. You are free to remove it (of course) but we would +// appreciate if you left it in. + +// Weak linkage is culled by VS & doesn't work on cygwin. +#if !defined(_WIN32) && !defined(__CYGWIN__) + +extern volatile __attribute__((weak)) const char *flatbuffer_version_string; +volatile __attribute__((weak)) const char *flatbuffer_version_string = + "FlatBuffers " + FLATBUFFERS_STRING(FLATBUFFERS_VERSION_MAJOR) "." + FLATBUFFERS_STRING(FLATBUFFERS_VERSION_MINOR) "." + FLATBUFFERS_STRING(FLATBUFFERS_VERSION_REVISION); + +#endif // !defined(_WIN32) && !defined(__CYGWIN__) + +#define DEFINE_BITMASK_OPERATORS(E, T)\ + inline E operator | (E lhs, E rhs){\ + return E(T(lhs) | T(rhs));\ + }\ + inline E operator & (E lhs, E rhs){\ + return E(T(lhs) & T(rhs));\ + }\ + inline E operator ^ (E lhs, E rhs){\ + return E(T(lhs) ^ T(rhs));\ + }\ + inline E operator ~ (E lhs){\ + return E(~T(lhs));\ + }\ + inline E operator |= (E &lhs, E rhs){\ + lhs = lhs | rhs;\ + return lhs;\ + }\ + inline E operator &= (E &lhs, E rhs){\ + lhs = lhs & rhs;\ + return lhs;\ + }\ + inline E operator ^= (E &lhs, E rhs){\ + lhs = lhs ^ rhs;\ + return lhs;\ + }\ + inline bool operator !(E rhs) \ + {\ + return !bool(T(rhs)); \ + } +/// @endcond +} // namespace flatbuffers + +#endif // FLATBUFFERS_H_ \ No newline at end of file diff --git a/horovod/tensorflow/wire/mpi_message.fbs b/horovod/tensorflow/wire/mpi_message.fbs new file mode 100644 index 0000000000..464f37e9ab --- /dev/null +++ b/horovod/tensorflow/wire/mpi_message.fbs @@ -0,0 +1,89 @@ +// Copyright 2016 The TensorFlow Authors. All Rights Reserved. +// Modifications copyright (C) 2017 Uber Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +namespace horovod.tensorflow.wire; + +// Supported data types. +enum MPIDataType:byte { + TF_MPI_UINT8 = 0, + TF_MPI_INT8 = 1, + TF_MPI_UINT16 = 2, + TF_MPI_INT16 = 3, + TF_MPI_INT32 = 4, + TF_MPI_INT64 = 5, + TF_MPI_FLOAT32 = 6, + TF_MPI_FLOAT64 = 7 +} + +// An MPIRequest is a message sent from a rank greater than zero to the +// coordinator (rank zero), informing the coordinator of an operation that +// the rank wants to do and the tensor that it wants to apply the operation to. +enum MPIRequestType:byte { + ALLREDUCE = 0, + ALLGATHER = 1, + BROADCAST = 2 +} +table MPIRequest { + // The request rank is necessary to create a consistent ordering of results, + // for example in the allgather where the order of outputs should be sorted + // by rank. + request_rank:int; + request_type:MPIRequestType; + tensor_type:MPIDataType; + tensor_name:string; + + // Root rank is necessary for broadcast operation. + root_rank:int; + + // Device this request is made on. + device:int; + + // We use a repeated integer instead of a TensorShapeProto because linking directly + // to TensorFlow protos causes issues. See the comment for MPIDataType. + tensor_shape:[long]; +} + +// An MPIResponse is a message sent from the coordinator (rank zero) to a rank +// greater than zero, informing the rank of an operation should be performed +// now. If the operation requested would result in an error (for example, due +// to a type or shape mismatch), then the MPIResponse can contain an error and +// an error message instead. Finally, an MPIResponse can be a DONE message (if +// there are no more tensors to reduce on this tick of the background loop) or +// SHUTDOWN if all MPI processes should shut down. +enum MPIResponseType:byte { + ALLREDUCE = 0, + ALLGATHER = 1, + BROADCAST = 2, + ERROR = 3, + DONE = 4, + SHUTDOWN = 5 +} +table MPIResponse { + // Empty if the type is DONE or SHUTDOWN. + response_type:MPIResponseType; + tensor_name:string; + + // Empty unless response_type is ERROR. + error_message:string; + + // List of devices participating in this operation. + devices:[int]; + + // Empty unless response_type is ALLGATHER. + // These tensor sizes are the dimension zero sizes of all the input matrices, + // indexed by the rank. + tensor_sizes:[long]; +} diff --git a/horovod/tensorflow/wire/mpi_message_generated.h b/horovod/tensorflow/wire/mpi_message_generated.h new file mode 100644 index 0000000000..0a7e80ea71 --- /dev/null +++ b/horovod/tensorflow/wire/mpi_message_generated.h @@ -0,0 +1,343 @@ +// Copyright 2017 Uber Technologies, Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +// automatically generated by the FlatBuffers compiler, do not modify + + +#ifndef FLATBUFFERS_GENERATED_MPIMESSAGE_HOROVOD_TENSORFLOW_WIRE_H_ +#define FLATBUFFERS_GENERATED_MPIMESSAGE_HOROVOD_TENSORFLOW_WIRE_H_ + +#include "flatbuffers/flatbuffers.h" + +namespace horovod { +namespace tensorflow { +namespace wire { + +struct MPIRequest; + +struct MPIResponse; + +enum MPIDataType { + MPIDataType_TF_MPI_UINT8 = 0, + MPIDataType_TF_MPI_INT8 = 1, + MPIDataType_TF_MPI_UINT16 = 2, + MPIDataType_TF_MPI_INT16 = 3, + MPIDataType_TF_MPI_INT32 = 4, + MPIDataType_TF_MPI_INT64 = 5, + MPIDataType_TF_MPI_FLOAT32 = 6, + MPIDataType_TF_MPI_FLOAT64 = 7, + MPIDataType_MIN = MPIDataType_TF_MPI_UINT8, + MPIDataType_MAX = MPIDataType_TF_MPI_FLOAT64 +}; + +inline const char **EnumNamesMPIDataType() { + static const char *names[] = { + "TF_MPI_UINT8", + "TF_MPI_INT8", + "TF_MPI_UINT16", + "TF_MPI_INT16", + "TF_MPI_INT32", + "TF_MPI_INT64", + "TF_MPI_FLOAT32", + "TF_MPI_FLOAT64", + nullptr + }; + return names; +} + +inline const char *EnumNameMPIDataType(MPIDataType e) { + const size_t index = static_cast(e); + return EnumNamesMPIDataType()[index]; +} + +enum MPIRequestType { + MPIRequestType_ALLREDUCE = 0, + MPIRequestType_ALLGATHER = 1, + MPIRequestType_BROADCAST = 2, + MPIRequestType_MIN = MPIRequestType_ALLREDUCE, + MPIRequestType_MAX = MPIRequestType_BROADCAST +}; + +inline const char **EnumNamesMPIRequestType() { + static const char *names[] = { + "ALLREDUCE", + "ALLGATHER", + "BROADCAST", + nullptr + }; + return names; +} + +inline const char *EnumNameMPIRequestType(MPIRequestType e) { + const size_t index = static_cast(e); + return EnumNamesMPIRequestType()[index]; +} + +enum MPIResponseType { + MPIResponseType_ALLREDUCE = 0, + MPIResponseType_ALLGATHER = 1, + MPIResponseType_BROADCAST = 2, + MPIResponseType_ERROR = 3, + MPIResponseType_DONE = 4, + MPIResponseType_SHUTDOWN = 5, + MPIResponseType_MIN = MPIResponseType_ALLREDUCE, + MPIResponseType_MAX = MPIResponseType_SHUTDOWN +}; + +inline const char **EnumNamesMPIResponseType() { + static const char *names[] = { + "ALLREDUCE", + "ALLGATHER", + "BROADCAST", + "ERROR", + "DONE", + "SHUTDOWN", + nullptr + }; + return names; +} + +inline const char *EnumNameMPIResponseType(MPIResponseType e) { + const size_t index = static_cast(e); + return EnumNamesMPIResponseType()[index]; +} + +struct MPIRequest FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + enum { + VT_REQUEST_RANK = 4, + VT_REQUEST_TYPE = 6, + VT_TENSOR_TYPE = 8, + VT_TENSOR_NAME = 10, + VT_ROOT_RANK = 12, + VT_DEVICE = 14, + VT_TENSOR_SHAPE = 16 + }; + int32_t request_rank() const { + return GetField(VT_REQUEST_RANK, 0); + } + MPIRequestType request_type() const { + return static_cast(GetField(VT_REQUEST_TYPE, 0)); + } + MPIDataType tensor_type() const { + return static_cast(GetField(VT_TENSOR_TYPE, 0)); + } + const flatbuffers::String *tensor_name() const { + return GetPointer(VT_TENSOR_NAME); + } + int32_t root_rank() const { + return GetField(VT_ROOT_RANK, 0); + } + int32_t device() const { + return GetField(VT_DEVICE, 0); + } + const flatbuffers::Vector *tensor_shape() const { + return GetPointer *>(VT_TENSOR_SHAPE); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_REQUEST_RANK) && + VerifyField(verifier, VT_REQUEST_TYPE) && + VerifyField(verifier, VT_TENSOR_TYPE) && + VerifyField(verifier, VT_TENSOR_NAME) && + verifier.Verify(tensor_name()) && + VerifyField(verifier, VT_ROOT_RANK) && + VerifyField(verifier, VT_DEVICE) && + VerifyField(verifier, VT_TENSOR_SHAPE) && + verifier.Verify(tensor_shape()) && + verifier.EndTable(); + } +}; + +struct MPIRequestBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_request_rank(int32_t request_rank) { + fbb_.AddElement(MPIRequest::VT_REQUEST_RANK, request_rank, 0); + } + void add_request_type(MPIRequestType request_type) { + fbb_.AddElement(MPIRequest::VT_REQUEST_TYPE, static_cast(request_type), 0); + } + void add_tensor_type(MPIDataType tensor_type) { + fbb_.AddElement(MPIRequest::VT_TENSOR_TYPE, static_cast(tensor_type), 0); + } + void add_tensor_name(flatbuffers::Offset tensor_name) { + fbb_.AddOffset(MPIRequest::VT_TENSOR_NAME, tensor_name); + } + void add_root_rank(int32_t root_rank) { + fbb_.AddElement(MPIRequest::VT_ROOT_RANK, root_rank, 0); + } + void add_device(int32_t device) { + fbb_.AddElement(MPIRequest::VT_DEVICE, device, 0); + } + void add_tensor_shape(flatbuffers::Offset> tensor_shape) { + fbb_.AddOffset(MPIRequest::VT_TENSOR_SHAPE, tensor_shape); + } + MPIRequestBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + MPIRequestBuilder &operator=(const MPIRequestBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_, 7); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateMPIRequest( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t request_rank = 0, + MPIRequestType request_type = MPIRequestType_ALLREDUCE, + MPIDataType tensor_type = MPIDataType_TF_MPI_UINT8, + flatbuffers::Offset tensor_name = 0, + int32_t root_rank = 0, + int32_t device = 0, + flatbuffers::Offset> tensor_shape = 0) { + MPIRequestBuilder builder_(_fbb); + builder_.add_tensor_shape(tensor_shape); + builder_.add_device(device); + builder_.add_root_rank(root_rank); + builder_.add_tensor_name(tensor_name); + builder_.add_request_rank(request_rank); + builder_.add_tensor_type(tensor_type); + builder_.add_request_type(request_type); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateMPIRequestDirect( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t request_rank = 0, + MPIRequestType request_type = MPIRequestType_ALLREDUCE, + MPIDataType tensor_type = MPIDataType_TF_MPI_UINT8, + const char *tensor_name = nullptr, + int32_t root_rank = 0, + int32_t device = 0, + const std::vector *tensor_shape = nullptr) { + return horovod::tensorflow::wire::CreateMPIRequest( + _fbb, + request_rank, + request_type, + tensor_type, + tensor_name ? _fbb.CreateString(tensor_name) : 0, + root_rank, + device, + tensor_shape ? _fbb.CreateVector(*tensor_shape) : 0); +} + +struct MPIResponse FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + enum { + VT_RESPONSE_TYPE = 4, + VT_TENSOR_NAME = 6, + VT_ERROR_MESSAGE = 8, + VT_DEVICES = 10, + VT_TENSOR_SIZES = 12 + }; + MPIResponseType response_type() const { + return static_cast(GetField(VT_RESPONSE_TYPE, 0)); + } + const flatbuffers::String *tensor_name() const { + return GetPointer(VT_TENSOR_NAME); + } + const flatbuffers::String *error_message() const { + return GetPointer(VT_ERROR_MESSAGE); + } + const flatbuffers::Vector *devices() const { + return GetPointer *>(VT_DEVICES); + } + const flatbuffers::Vector *tensor_sizes() const { + return GetPointer *>(VT_TENSOR_SIZES); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_RESPONSE_TYPE) && + VerifyField(verifier, VT_TENSOR_NAME) && + verifier.Verify(tensor_name()) && + VerifyField(verifier, VT_ERROR_MESSAGE) && + verifier.Verify(error_message()) && + VerifyField(verifier, VT_DEVICES) && + verifier.Verify(devices()) && + VerifyField(verifier, VT_TENSOR_SIZES) && + verifier.Verify(tensor_sizes()) && + verifier.EndTable(); + } +}; + +struct MPIResponseBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_response_type(MPIResponseType response_type) { + fbb_.AddElement(MPIResponse::VT_RESPONSE_TYPE, static_cast(response_type), 0); + } + void add_tensor_name(flatbuffers::Offset tensor_name) { + fbb_.AddOffset(MPIResponse::VT_TENSOR_NAME, tensor_name); + } + void add_error_message(flatbuffers::Offset error_message) { + fbb_.AddOffset(MPIResponse::VT_ERROR_MESSAGE, error_message); + } + void add_devices(flatbuffers::Offset> devices) { + fbb_.AddOffset(MPIResponse::VT_DEVICES, devices); + } + void add_tensor_sizes(flatbuffers::Offset> tensor_sizes) { + fbb_.AddOffset(MPIResponse::VT_TENSOR_SIZES, tensor_sizes); + } + MPIResponseBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + MPIResponseBuilder &operator=(const MPIResponseBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_, 5); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateMPIResponse( + flatbuffers::FlatBufferBuilder &_fbb, + MPIResponseType response_type = MPIResponseType_ALLREDUCE, + flatbuffers::Offset tensor_name = 0, + flatbuffers::Offset error_message = 0, + flatbuffers::Offset> devices = 0, + flatbuffers::Offset> tensor_sizes = 0) { + MPIResponseBuilder builder_(_fbb); + builder_.add_tensor_sizes(tensor_sizes); + builder_.add_devices(devices); + builder_.add_error_message(error_message); + builder_.add_tensor_name(tensor_name); + builder_.add_response_type(response_type); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateMPIResponseDirect( + flatbuffers::FlatBufferBuilder &_fbb, + MPIResponseType response_type = MPIResponseType_ALLREDUCE, + const char *tensor_name = nullptr, + const char *error_message = nullptr, + const std::vector *devices = nullptr, + const std::vector *tensor_sizes = nullptr) { + return horovod::tensorflow::wire::CreateMPIResponse( + _fbb, + response_type, + tensor_name ? _fbb.CreateString(tensor_name) : 0, + error_message ? _fbb.CreateString(error_message) : 0, + devices ? _fbb.CreateVector(*devices) : 0, + tensor_sizes ? _fbb.CreateVector(*tensor_sizes) : 0); +} + +} // namespace wire +} // namespace tensorflow +} // namespace horovod + +#endif // FLATBUFFERS_GENERATED_MPIMESSAGE_HOROVOD_TENSORFLOW_WIRE_H_ diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000..1d7dd34746 --- /dev/null +++ b/setup.py @@ -0,0 +1,281 @@ +# Copyright 2017 Uber Technologies, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import os +from distutils.errors import CompileError, DistutilsError, DistutilsPlatformError, LinkError +from setuptools import setup, Extension, find_packages +from setuptools.command.build_ext import build_ext +import shlex +import subprocess +import textwrap +import traceback + + +tensorflow_mpi_lib = Extension('horovod.tensorflow.mpi_lib', []) + + +def get_tf_include(): + try: + import tensorflow as tf + return tf.sysconfig.get_include() + except ImportError: + raise DistutilsPlatformError( + 'import tensorflow failed, is it installed?\n\n%s' % traceback.format_exc()) + + +def get_tf_abi(build_ext, tf_include): + last_err = None + cxx11_abi_macro = '_GLIBCXX_USE_CXX11_ABI' + for cxx11_abi in ['0', '1']: + try: + lib_file = test_compile(build_ext, 'test_tensorflow_abi', + macros=[(cxx11_abi_macro, cxx11_abi)], + include_dirs=[tf_include], code=textwrap.dedent('''\ + #include + #include "tensorflow/core/framework/op.h" + #include "tensorflow/core/framework/op_kernel.h" + #include "tensorflow/core/framework/shape_inference.h" + void test() { + auto ignore = tensorflow::strings::StrCat("a", "b"); + } + ''')) + + from tensorflow.python.framework import load_library + load_library.load_op_library(lib_file) + + return cxx11_abi_macro, cxx11_abi + except (CompileError, LinkError): + last_err = 'Unable to determine CXX11 ABI to use with TensorFlow (see error above).' + except Exception: + last_err = 'Unable to determine CXX11 ABI to use with TensorFlow. ' \ + 'Last error:\n\n%s' % traceback.format_exc() + + raise DistutilsPlatformError(last_err) + + +def get_mpi_flags(): + try: + mpi_show_output = subprocess.check_output( + ['mpicxx', '-show'], universal_newlines=True).strip() + # strip off compiler call portion and always escape each arg + return ' '.join(['"' + arg.replace('"', '"\'"\'"') + '"' + for arg in shlex.split(mpi_show_output)[1:]]) + except Exception: + raise DistutilsPlatformError( + 'mpicxx -show failed, is mpicxx in $PATH?\n\n%s' % traceback.format_exc()) + + +def test_compile(build_ext, name, code, libraries=None, include_dirs=None, library_dirs=None, macros=None): + test_compile_dir = os.path.join(build_ext.build_temp, 'test_compile') + if not os.path.exists(test_compile_dir): + os.makedirs(test_compile_dir) + + source_file = os.path.join(test_compile_dir, '%s.cc' % name) + with open(source_file, 'w') as f: + f.write(code) + + compiler = build_ext.compiler + [object_file] = compiler.object_filenames([source_file]) + shared_object_file = compiler.shared_object_filename( + name, output_dir=test_compile_dir) + + compiler.compile([source_file], extra_preargs=['-std=c++11'], + include_dirs=include_dirs, macros=macros) + compiler.link_shared_object( + [object_file], shared_object_file, libraries=libraries, library_dirs=library_dirs) + + return shared_object_file + + +def get_cuda_dirs(build_ext): + cuda_include_dirs = [] + cuda_lib_dirs = [] + + cuda_home = os.environ.get('HOROVOD_CUDA_HOME') + if cuda_home: + cuda_include_dirs += ['%s/include' % cuda_home] + cuda_lib_dirs += ['%s/lib' % cuda_home, '%s/lib64' % cuda_home] + + cuda_include = os.environ.get('HOROVOD_CUDA_INCLUDE') + if cuda_include: + cuda_include_dirs += [cuda_include] + + cuda_lib = os.environ.get('HOROVOD_CUDA_LIB') + if cuda_lib: + cuda_lib_dirs += [cuda_lib] + + if not cuda_include_dirs and not cuda_lib_dirs: + # default to /usr/local/cuda + cuda_include_dirs += ['/usr/local/cuda/include'] + cuda_lib_dirs += ['/usr/local/cuda/lib', '/usr/local/cuda/lib64'] + + try: + test_compile(build_ext, 'test_cuda', libraries=['cudart'], include_dirs=cuda_include_dirs, + library_dirs=cuda_lib_dirs, code=textwrap.dedent('''\ + #include + void test() { + cudaSetDevice(0); + } + ''')) + except (CompileError, LinkError): + raise DistutilsPlatformError( + 'CUDA library was not found (see error above).\n' + 'Please specify correct CUDA location via HOROVOD_CUDA_HOME ' + 'environment variable or combination of HOROVOD_CUDA_INCLUDE and ' + 'HOROVOD_CUDA_LIB environment variables.\n\n' + 'HOROVOD_CUDA_HOME - path where CUDA include and lib directories can be found\n' + 'HOROVOD_CUDA_INCLUDE - path to CUDA include directory\n' + 'HOROVOD_CUDA_LIB - path to CUDA lib directory') + + return cuda_include_dirs, cuda_lib_dirs + + +def get_nccl_dirs(build_ext, cuda_include_dirs, cuda_lib_dirs): + nccl_include_dirs = [] + nccl_lib_dirs = [] + + nccl_home = os.environ.get('HOROVOD_NCCL_HOME') + if nccl_home: + nccl_include_dirs += ['%s/include' % nccl_home] + nccl_lib_dirs += ['%s/lib' % nccl_home, '%s/lib64' % nccl_home] + + nccl_include = os.environ.get('HOROVOD_NCCL_INCLUDE') + if nccl_include: + nccl_include_dirs += [nccl_include] + + nccl_lib = os.environ.get('HOROVOD_NCCL_LIB') + if nccl_lib: + nccl_lib_dirs += [nccl_lib] + + try: + test_compile(build_ext, 'test_nccl', libraries=['nccl'], include_dirs=nccl_include_dirs + cuda_include_dirs, + library_dirs=nccl_lib_dirs + cuda_lib_dirs, code=textwrap.dedent('''\ + #include + #if NCCL_MAJOR < 2 + #error Horovod requires NCCL 2.0 or later version, please upgrade. + #endif + void test() { + ncclUniqueId nccl_id; + ncclGetUniqueId(&nccl_id); + } + ''')) + except (CompileError, LinkError): + raise DistutilsPlatformError( + 'NCCL 2.0 library or its later version was not found (see error above).\n' + 'Please specify correct NCCL location via HOROVOD_NCCL_HOME ' + 'environment variable or combination of HOROVOD_NCCL_INCLUDE and ' + 'HOROVOD_NCCL_LIB environment variables.\n\n' + 'HOROVOD_NCCL_HOME - path where NCCL include and lib directories can be found\n' + 'HOROVOD_NCCL_INCLUDE - path to NCCL include directory\n' + 'HOROVOD_NCCL_LIB - path to NCCL lib directory') + + return nccl_include_dirs, nccl_lib_dirs + + +def fully_define_extension(build_ext): + tf_include = get_tf_include() + tf_abi = get_tf_abi(build_ext, tf_include) + mpi_flags = get_mpi_flags() + + gpu_allreduce = os.environ.get('HOROVOD_GPU_ALLREDUCE') + if gpu_allreduce and gpu_allreduce != 'MPI' and gpu_allreduce != 'NCCL': + raise DistutilsError('HOROVOD_GPU_ALLREDUCE=%s is invalid, supported ' + 'values are "", "MPI", "NCCL".' % gpu_allreduce) + + gpu_allgather = os.environ.get('HOROVOD_GPU_ALLGATHER') + if gpu_allgather and gpu_allgather != 'MPI': + raise DistutilsError('HOROVOD_GPU_ALLGATHER=%s is invalid, supported ' + 'values are "", "MPI".' % gpu_allgather) + + gpu_broadcast = os.environ.get('HOROVOD_GPU_BROADCAST') + if gpu_broadcast and gpu_broadcast != 'MPI': + raise DistutilsError('HOROVOD_GPU_BROADCAST=%s is invalid, supported ' + 'values are "", "MPI".' % gpu_broadcast) + + if gpu_allreduce or gpu_allgather or gpu_broadcast: + have_cuda = True + cuda_include_dirs, cuda_lib_dirs = get_cuda_dirs(build_ext) + else: + have_cuda = False + cuda_include_dirs = cuda_lib_dirs = [] + + if gpu_allreduce == 'NCCL': + have_nccl = True + nccl_include_dirs, nccl_lib_dirs = get_nccl_dirs( + build_ext, cuda_include_dirs, cuda_lib_dirs) + else: + have_nccl = False + nccl_include_dirs = nccl_lib_dirs = [] + + MACROS = [] + INCLUDES = [tf_include] + SOURCES = ['horovod/tensorflow/mpi_message.cc', + 'horovod/tensorflow/mpi_ops.cc'] + COMPILE_FLAGS = ['-std=c++11', '-fPIC', '-O2'] + shlex.split(mpi_flags) + LINK_FLAGS = shlex.split(mpi_flags) + LIBRARY_DIRS = [] + LIBRARIES = [] + + if tf_abi: + COMPILE_FLAGS += ['-D%s=%s' % tf_abi] + + if have_cuda: + MACROS += [('HAVE_CUDA', '1')] + INCLUDES += cuda_include_dirs + LIBRARY_DIRS += cuda_lib_dirs + LIBRARIES = ['cudart'] + + if have_nccl: + MACROS += [('HAVE_NCCL', '1')] + INCLUDES += nccl_include_dirs + LIBRARY_DIRS += nccl_lib_dirs + LIBRARIES = ['nccl'] + + if gpu_allreduce: + MACROS += [('HOROVOD_GPU_ALLREDUCE', "'%s'" % gpu_allreduce[0])] + + if gpu_allgather: + MACROS += [('HOROVOD_GPU_ALLGATHER', "'%s'" % gpu_allgather[0])] + + if gpu_broadcast: + MACROS += [('HOROVOD_GPU_BROADCAST', "'%s'" % gpu_broadcast[0])] + + tensorflow_mpi_lib.define_macros = MACROS + tensorflow_mpi_lib.include_dirs = INCLUDES + tensorflow_mpi_lib.sources = SOURCES + tensorflow_mpi_lib.extra_compile_args = COMPILE_FLAGS + tensorflow_mpi_lib.extra_link_args = LINK_FLAGS + tensorflow_mpi_lib.library_dirs = LIBRARY_DIRS + tensorflow_mpi_lib.libraries = LIBRARIES + + +# run the customize_compiler +class custom_build_ext(build_ext): + def build_extensions(self): + fully_define_extension(self) + build_ext.build_extensions(self) + + +setup(name='horovod', + version='0.9.0', + packages=find_packages(), + description='Distributed training framework for TensorFlow.', + author='Uber Technologies, Inc.', + long_description=textwrap.dedent('''\ + Horovod is a distributed training framework for TensorFlow. + The goal of Horovod is to make distributed Deep Learning + fast and easy to use.'''), + ext_modules=[tensorflow_mpi_lib], + cmdclass={'build_ext': custom_build_ext}, + zip_safe=False)