Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TF Compute Server #3525

Merged
merged 47 commits into from Jun 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
3c12358
Add support for Tensorflow Data Service
EnricoMi Apr 20, 2022
6b335ea
Move compute_*.py into horovod.tensorflow.data, fix examples
EnricoMi Apr 20, 2022
6444108
Make output_filename configurable in compute_worker.py
EnricoMi Apr 20, 2022
546fa49
Add tf data service to docs
EnricoMi Apr 20, 2022
67b8618
Make worker and example work with horovodrun, move docs into tensorfl…
EnricoMi Apr 20, 2022
c75c984
Make spark worker work with spark-submit
EnricoMi Apr 20, 2022
83d3c50
Add horovodrun example to CI
EnricoMi Apr 20, 2022
d8c1476
Remove tensorflow_data_service.rst from summary.rst
EnricoMi Apr 20, 2022
38119b8
Download to CWD directly, not mnist sub-directory
EnricoMi Apr 21, 2022
d9491ea
Add spark-submit example to docs and CI
EnricoMi Apr 21, 2022
bd7c6a9
Reduce run time for examples in CI
EnricoMi Apr 21, 2022
53b2ede
Use default path to fetch mnist dataset, which is pre-fetched in test…
EnricoMi Apr 22, 2022
ff9c5b0
Use --mpi instead of --gloo for MPI tests
EnricoMi Apr 22, 2022
22f9116
Run two workers to save ram, remove -H option
EnricoMi Apr 22, 2022
c6eac52
Escape $ differently in test command, but only for Buildkite
EnricoMi Apr 25, 2022
f423a2f
Revert "Escape $ differently in test command, but only for Buildkite"
EnricoMi Apr 25, 2022
ba22df1
Reference the worker file directly
EnricoMi Apr 26, 2022
9e3998b
Initialize Horovod for Tensorflow in tf worker
EnricoMi Apr 26, 2022
8596ead
Pin horovod task to GPU
EnricoMi Apr 27, 2022
9cf4813
Move rank and size around
EnricoMi Apr 27, 2022
847c01b
Update CHANGELOG.md
EnricoMi May 5, 2022
c0cabf5
Introducing TimeoutException
EnricoMi May 5, 2022
2d23441
Add tests for compute service
EnricoMi May 5, 2022
5c1c819
Fix shutdown test
EnricoMi May 6, 2022
5664728
Add timeout parameter to TfDataServiceConfig
EnricoMi May 6, 2022
02597e9
Add tf unit tests
EnricoMi May 7, 2022
d98faf3
Syncronize tests, assert batches
EnricoMi May 10, 2022
cc4f773
Add processing_mode to send_to_data_service, improve logging
EnricoMi May 10, 2022
7fef190
Relax assertions, add logging, add timeout parameter to compute worke…
EnricoMi May 10, 2022
80171b2
Add training-side tests
EnricoMi May 10, 2022
d1ed69a
Add DEBUG level to pytest tests
EnricoMi May 10, 2022
6b1c85d
Skip round-robin test
EnricoMi May 10, 2022
7c55576
Test processing modes
EnricoMi May 10, 2022
df42dc8
Remove expected batches, skip pre tf2
EnricoMi May 10, 2022
f6d829f
Minor restructure of tests
EnricoMi May 10, 2022
e18c2f8
Remove port detection and address spec for worker
EnricoMi May 11, 2022
35c7cf4
Bind to single GPU
EnricoMi May 11, 2022
d26780d
Revert "Add DEBUG level to pytest tests"
EnricoMi May 11, 2022
5eba267
Minor comment fix
EnricoMi May 11, 2022
61f711e
Have horovod.tensorflow.data.compute_worker.py script broadcast config
EnricoMi May 30, 2022
91afa2f
Add some words about TF data service to docs
EnricoMi Jun 8, 2022
cbae912
Shutdown dispatcher in finally clause
EnricoMi Jun 8, 2022
9e6946e
Move the finished config file into place
EnricoMi Jun 8, 2022
a24c9f5
Fix config broadcast for MPI in GPU environment
EnricoMi Jun 16, 2022
5b6e765
Fixing typos in docs
EnricoMi Jun 17, 2022
8319e44
Add tensorflow issue to skipped test
EnricoMi Jun 17, 2022
5becdd0
Remove extra timeout from compute_worker_fn
EnricoMi Jun 17, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 14 additions & 0 deletions .buildkite/gen-pipeline.sh
Expand Up @@ -240,6 +240,10 @@ run_mpi_integration() {
":tensorflow: MPI TensorFlow 2.0 Keras MNIST api (${test})" \
"bash -c \"${oneccl_env} python /horovod/examples/tensorflow2/tensorflow2_keras_mnist.py 2 localhost:2 mpi\""
fi

run_test "${test}" "${queue}" \
":tensorflow: MPI TensorFlow 2.0 MNIST Data Service (${test})" \
"bash -c \"${oneccl_env} horovodrun -np 2 python -m horovod.tensorflow.data.compute_worker /tmp/compute.json & horovodrun -np 2 --mpi python /horovod/examples/tensorflow2/tensorflow2_mnist_data_service.py /tmp/compute.json\""
fi
}

Expand Down Expand Up @@ -307,6 +311,10 @@ run_gloo_integration() {
":tensorflow: Gloo TensorFlow 2.0 MNIST Elastic api (${test})" \
"python /horovod/examples/elastic/tensorflow2/tensorflow2_mnist_elastic.py 2 2 2 localhost:2,127.0.0.1:2"
fi

run_test "${test}" "${queue}" \
":tensorflow: Gloo TensorFlow 2.0 MNIST Data Service (${test})" \
"bash -c \"horovodrun -np 2 python -m horovod.tensorflow.data.compute_worker /tmp/compute.json & horovodrun -np 2 --gloo python /horovod/examples/tensorflow2/tensorflow2_mnist_data_service.py /tmp/compute.json\""
else
run_test "${test}" "${queue}" \
":tensorflow: Gloo TensorFlow MNIST (${test})" \
Expand Down Expand Up @@ -411,6 +419,12 @@ run_spark_integration() {
"bash -c \"OMP_NUM_THREADS=1 /spark_env.sh python /horovod/examples/spark/keras/keras_spark_mnist.py --num-proc 2 --work-dir /work --data-dir /data --epochs 3\""
fi

if [[ ${test} == *"tf2_"* ]] || [[ ${test} == *"tfhead"* ]]; then
run_test "${test}" "${queue}" \
":spark: Spark TensorFlow 2.0 MNIST Data Service (${test})" \
"bash -c \"cd /horovod/examples/spark/tensorflow2; spark-submit --master \\\"local[2]\\\" \\\"/horovod/horovod/spark/tensorflow/compute_worker.py\\\" /tmp/compute.json & OMP_NUM_THREADS=1 /spark_env.sh spark-submit --master \\\"local[2]\\\" --py-files tensorflow2_mnist_data_service_train_fn_compute_side_dispatcher.py,tensorflow2_mnist_data_service_train_fn_training_side_dispatcher.py tensorflow2_mnist_data_service.py /tmp/compute.json\""
fi

run_test "${test}" "${queue}" \
":spark: Spark Torch MNIST (${test})" \
"bash -c \"OMP_NUM_THREADS=1 /spark_env.sh python /horovod/examples/spark/pytorch/pytorch_spark_mnist.py --num-proc 2 --work-dir /work --data-dir /data --epochs 3\""
Expand Down
288 changes: 288 additions & 0 deletions .github/workflows/ci.yaml

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions CHANGELOG.md
Expand Up @@ -12,6 +12,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

- Spark: expose random seed as an optional parameter. ([#3517](https://github.com/horovod/horovod/pull/3517))

- Added Horovod job to spin up distributed TensorFlow Data Service. ([#3525](https://github.com/horovod/horovod/pull/3525))

### Changed

- MXNet: Updated allreduce functions to newer `op` API. ([#3299](https://github.com/horovod/horovod/pull/3299))
Expand Down
3 changes: 3 additions & 0 deletions Dockerfile.test.cpu
Expand Up @@ -250,6 +250,9 @@ RUN sed -i "s/dataset.take(20000/dataset.take(100/" /horovod/examples/tensorflow
# Hack TensorFlow 2.0 example to be smaller.
RUN sed -i "s/dataset.take(10000/dataset.take(100/" /horovod/examples/tensorflow2/tensorflow2_mnist.py

# Hack TensorFlow 2.0 Data Service example to be smaller.
RUN sed -i "s/ epochs=24/ epochs=4/" /horovod/examples/tensorflow2/tensorflow2_mnist_data_service_train_fn_*_side_dispatcher.py

# Hack Keras MNIST advanced example to be smaller.
RUN sed -i "s/'--epochs', type=int, default=24,/'--epochs', type=int, default=9,/" /horovod/examples/keras/keras_mnist_advanced.py
RUN sed -i "s/model.add(Conv2D(32, kernel_size=(3, 3),/model.add(Conv2D(1, kernel_size=(3, 3),/" /horovod/examples/keras/keras_mnist_advanced.py
Expand Down
3 changes: 3 additions & 0 deletions Dockerfile.test.gpu
Expand Up @@ -224,6 +224,9 @@ RUN sed -i "s/dataset.take(20000/dataset.take(100/" /horovod/examples/tensorflow
# Hack TensorFlow 2.0 example to be smaller.
RUN sed -i "s/dataset.take(10000/dataset.take(100/" /horovod/examples/tensorflow2/tensorflow2_mnist.py

# Hack TensorFlow 2.0 Data Service example to be smaller.
RUN sed -i "s/ epochs=24/ epochs=4/" /horovod/examples/tensorflow2/tensorflow2_mnist_data_service_train_fn_*_side_dispatcher.py

# Hack Keras MNIST advanced example to be smaller.
RUN sed -i "s/'--epochs', type=int, default=24,/'--epochs', type=int, default=9,/" /horovod/examples/keras/keras_mnist_advanced.py

Expand Down
149 changes: 149 additions & 0 deletions docs/tensorflow.rst
Expand Up @@ -184,3 +184,152 @@ TensorFlow v2 Example (from the `MNIST <https://github.com/horovod/horovod/blob/
# corrupting it.
if hvd.rank() == 0:
checkpoint.save(checkpoint_dir)

Horovod with TensorFlow Data Service
------------------------------------

A `TensorFlow Data Service <https://www.tensorflow.org/api_docs/python/tf/data/experimental/service>`_
allows to move CPU intensive processing of your dataset from your training process to a cluster of
CPU-rich processes.

With Horovod, it is easy to spin up a TensorFlow Data Service on your Horovod cluster and to connect
your Horovod training job to it.

Run the following command to run a TensorFlow Data Service via Horovod:

.. code-block:: bash

horovodrun -np 4 python -m horovod.tensorflow.data.compute_worker /tmp/compute.json

This starts a TensorFlow Data Service (here called compute job) with one dispatcher and four workers.

.. note:: The config file is written by the compute job and has to be located on a path that is accessible
to all nodes that run the compute job, e.g. a distributed file system.

Your training job can then move CPU intensive dataset operations to this data service by
calling ``.send_to_data_service(…)`` on the TensorFlow dataset:

.. code-block:: python

from horovod.tensorflow.data.compute_service import TfDataServiceConfig

hvd.init()
rank = hvd.rank()
size = hvd.size()

compute_config = TfDataServiceConfig.read('/tmp/compute.json', wait_for_file_creation=True)

dataset = dataset.repeat() \
.shuffle(10000) \
.batch(128) \
.send_to_data_service(compute_config, rank, size) \
.prefetch(tf.data.experimental.AUTOTUNE)

All transformations before calling ``send_to_data_service`` will be executed by the data service,
while all transformations after it are executed locally by the training script.

You can find the `tensorflow2_mnist_data_service.py <https://github.com/horovod/horovod/blob/master/examples/tensorflow2/tensorflow2_mnist_data_service.py>`_
example in the examples directory.

First start the data service as shown above. While the data service is running, start the example training script:

.. code-block:: bash

horovodrun -np 2 python tensorflow2_mnist_data_service.py /tmp/compute.json

The compute job normally runs on CPU nodes while the training job runs on GPU nodes. This allows to run CPU intensive
dataset transformation on CPU nodes while running GPU intensive training on GPU nodes. There can be multiple CPUs
dedicated to one GPU task.

Use the ``--hosts`` argument to run compute and train job on CPU (here ``cpu-node-1`` and ``cpu-node-2``)
and GPU nodes (here ``gpu-node-1`` and ``gpu-node-2``), respectively:

.. code-block:: bash

horovodrun -np 4 --hosts cpu-node-1:2,cpu-node-2:2 python -m horovod.tensorflow.data.compute_worker /tmp/compute.json
horovodrun -np 2 --hosts gpu-node-1:1,gpu-node-2:1 python tensorflow2_mnist_data_service.py /tmp/compute.json

.. note::

Please make sure you understand how TensorFlow Data Service distributes dataset transformations:
See the `distribute <https://www.tensorflow.org/api_docs/python/tf/data/experimental/service/distribute>`_ transformation.

Multiple Dispatchers
~~~~~~~~~~~~~~~~~~~~

The data service allows for multiple dispatchers, one per training task. Each dispatcher gets the same number of workers.
As workers are dedicated to a single dispatcher, workers get dedicated to a single training task.
The size of your compute job (``-np 4``) has to be a multiple of the number of dispatchers (``--dispatchers 2``):

.. code-block:: bash

horovodrun -np 4 python -m horovod.tensorflow.data.compute_worker --dispatchers 2 /tmp/compute.json

This requires the number of dispatchers (``--dispatchers 2``) to match the size of your training job (``-np 2``):

.. code-block:: bash

horovodrun -np 2 python tensorflow2_mnist_data_service.py /tmp/compute.json

Single Dispatchers
~~~~~~~~~~~~~~~~~~

With a single dispatcher, TensorFlow allows to reuse the dataset across all training tasks. This is done on a
first-come-first-serve basis, or round robin. The only supported processing mode is ``"distributed_epoch"``.

Training-side dispatchers
~~~~~~~~~~~~~~~~~~~~~~~~~

The dispatchers by default run inside the compute job. You can, however, also run them inside the training job.
Add ``--dispatcher-side training`` to tell the compute job that dispatchers are started by the training job.

.. code-block:: bash

horovodrun -np 4 python -m horovod.tensorflow.data.compute_worker --dispatcher-side training /tmp/compute.json

The training script then starts the dispatchers via ``with tf_data_service(…)`` and distributes the dataset itself:

.. code-block:: python

hvd.init()
rank = hvd.rank()
size = hvd.size()

compute_config = TfDataServiceConfig.read('/tmp/compute.json', wait_for_file_creation=True)

with tf_data_service(compute_config, rank) as dispatcher_address:

dataset = dataset.repeat() \
.shuffle(10000) \
.batch(128) \
.apply(tf.data.experimental.service.distribute(
processing_mode="distributed_epoch",
service=dispatcher_address,
job_name='job' if reuse_dataset else None,
consumer_index=rank if round_robin else None,
num_consumers=size if round_robin else None)) \
.prefetch(tf.data.experimental.AUTOTUNE)

To see the specific changes needed to make the training job run dispatchers,
simply diff the training-side example with the compute-side example:

.. code-block:: bash

diff -w examples/tensorflow2/tensorflow2_mnist_data_service_train_fn_*

Compute job on Spark cluster
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The compute job can be started on a Spark cluster using ``spark-submit``:

.. code-block:: bash

worker_py=$(python -c "import horovod.spark.tensorflow.compute_worker as worker; print(worker.__file__)")
spark-submit --master "local[4]" "$worker_py" /tmp/compute.json


While the compute job is running, start the training job:

cd examples/spark/tensorflow2
spark-submit --master "local[2]" --py-files tensorflow2_mnist_data_service_train_fn_compute_side_dispatcher.py,tensorflow2_mnist_data_service_train_fn_training_side_dispatcher.py tensorflow2_mnist_data_service.py /tmp/compute.json

As usual, the config file has to be located on a path that is accessible to all nodes that run the compute job.
84 changes: 84 additions & 0 deletions examples/spark/tensorflow2/tensorflow2_mnist_data_service.py
@@ -0,0 +1,84 @@
# Copyright 2022 G-Research. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import argparse
import os
import sys

from pyspark import SparkConf
from pyspark.sql import SparkSession

from horovod.spark import run
from horovod.tensorflow.data.compute_service import TfDataServiceConfig
from tensorflow2_mnist_data_service_train_fn_compute_side_dispatcher import train_fn as train_fn_compute_side
from tensorflow2_mnist_data_service_train_fn_training_side_dispatcher import train_fn as train_fn_training_side

os.environ["CUDA_VISIBLE_DEVICES"] = "-1"


# This exemplifies how to use the Tensorflow Compute Service with Horovod.
# The Tensorflow Dispatcher can reside with the training script, or the compute service.
# If you use only one of these options, you can ignore the respective code of the other option in this example.
if __name__ == '__main__':
parser = argparse.ArgumentParser()

parser.add_argument("configfile", type=str,
help=f"The path to the compute service config file.")

parser.add_argument("--reuse-dataset", required=False, action="store_true", default=False,
help=f"Reusing the dataset allows the training tasks to reads from a single dataset "
f"in a first-come-first-serve manner.",
dest="reuse_dataset")

parser.add_argument("--round-robin", required=False, action="store_true", default=False,
help=f"Reusing the dataset can be done round-robin instead first-come-first-serve.",
dest="round_robin")

parsed_args = parser.parse_args()

compute_config = TfDataServiceConfig.read(parsed_args.configfile, wait_for_file_creation=True)

conf = SparkConf()
spark = SparkSession.builder.config(conf=conf).getOrCreate()
spark_context = spark.sparkContext
training_tasks = spark_context.defaultParallelism

if compute_config.dispatchers > 1 and training_tasks != compute_config.dispatchers:
print(f'The number of training tasks ({training_tasks}) must match '
f'the number of dispatchers ({compute_config.dispatchers}) configured in the '
f'data service config file ({parsed_args.configfile}).', file=sys.stderr)
sys.exit(1)

# pick the right train_fn depending on the dispatcher side
if compute_config.dispatcher_side == 'training':
train_fn = train_fn_training_side
elif compute_config.dispatcher_side == 'compute':
train_fn = train_fn_compute_side
else:
raise ValueError(f'Unsupported dispatcher side: {compute_config.dispatcher_side}')

# run the distributed training
run(train_fn,
args=(compute_config,),
kwargs={
'reuse_dataset': parsed_args.reuse_dataset,
'round_robin': parsed_args.round_robin
},
num_proc=training_tasks,
stdout=sys.stdout,
stderr=sys.stderr)

compute = compute_config.compute_client(verbose=2)
compute.shutdown()
76 changes: 76 additions & 0 deletions examples/tensorflow2/tensorflow2_mnist_data_service.py
@@ -0,0 +1,76 @@
# Copyright 2022 G-Research. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import argparse
import os
import sys

from horovod.runner.common.util import env
from horovod.tensorflow.data.compute_service import TfDataServiceConfig
from tensorflow2_mnist_data_service_train_fn_compute_side_dispatcher import train_fn as train_fn_compute_side
from tensorflow2_mnist_data_service_train_fn_training_side_dispatcher import train_fn as train_fn_training_side

os.environ["CUDA_VISIBLE_DEVICES"] = "-1"


# This exemplifies how to use the Tensorflow Compute Service with Horovod.
# The Tensorflow Dispatcher can reside with the training script, or the compute service.
# If you use only one of these options, you can ignore the respective code of the other option in this example.
if __name__ == '__main__':
parser = argparse.ArgumentParser()

parser.add_argument("configfile", type=str,
help=f"The path to the compute service config file.")

parser.add_argument("--training-tasks", required=False, type=int,
help=f"The number of training tasks when there is only one dispatcher. "
f"Otherwise there are as many training tasks as there are dispatchers.",
dest="training_tasks")

parser.add_argument("--reuse-dataset", required=False, action="store_true", default=False,
help=f"Reusing the dataset allows the training tasks to reads from a single dispatcher "
f"in a first-come-first-serve manner.",
dest="reuse_dataset")

parser.add_argument("--round-robin", required=False, action="store_true", default=False,
help=f"Reusing the dataset can be done round-robin instead first-come-first-serve.",
dest="round_robin")

parsed_args = parser.parse_args()

compute_config = TfDataServiceConfig.read(parsed_args.configfile, wait_for_file_creation=True)

rank, size = env.get_env_rank_and_size()

if compute_config.dispatchers > 1 and compute_config.dispatchers != size:
print(f'Unless there is only one dispatcher, the number of training tasks ({size}) must match '
f'the number of dispatchers ({compute_config.dispatchers}) configured in the '
f'data service config file ({parsed_args.compute_service_config_file}).', file=sys.stderr)
sys.exit(1)

# pick the right train_fn depending on the dispatcher side
if compute_config.dispatcher_side == 'training':
train_fn = train_fn_training_side
elif compute_config.dispatcher_side == 'compute':
train_fn = train_fn_compute_side
else:
raise ValueError(f'Unsupported dispatcher side: {compute_config.dispatcher_side}')

# run the distributed training
train_fn(compute_config, reuse_dataset=parsed_args.reuse_dataset, round_robin=parsed_args.round_robin)

if rank == 0:
compute = compute_config.compute_client(verbose=2)
compute.shutdown()