diff --git a/.buildkite/gen-pipeline.sh b/.buildkite/gen-pipeline.sh index fc811d318a..e313763dda 100755 --- a/.buildkite/gen-pipeline.sh +++ b/.buildkite/gen-pipeline.sh @@ -8,22 +8,19 @@ repository=823773083436.dkr.ecr.us-east-1.amazonaws.com/buildkite # list of all the tests tests=( \ - test-cpu-openmpi-py2_7-tf1_1_0-keras2_0_0-torch0_4_0-mxnet1_4_1-pyspark2_3_2 \ - test-cpu-openmpi-py3_6-tf1_1_0-keras2_0_0-torch0_4_0-mxnet1_4_1-pyspark2_3_2 \ test-cpu-openmpi-py2_7-tf1_6_0-keras2_1_2-torch0_4_1-mxnet1_4_1-pyspark2_3_2 \ test-cpu-openmpi-py3_6-tf1_6_0-keras2_1_2-torch0_4_1-mxnet1_4_1-pyspark2_3_2 \ test-cpu-gloo-py2_7-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0 \ test-cpu-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0 \ - test-cpu-gloo-py3_7-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0 \ + test-cpu-gloo-py3_7-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0 \ test-cpu-gloo-py3_8-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0 \ test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0 \ - test-cpu-openmpi-gloo-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0 \ test-cpu-openmpi-py2_7-tf2_0_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 \ test-cpu-openmpi-py3_6-tf2_0_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 \ test-cpu-openmpi-py3_6-tfhead-kerashead-torchhead-mxnethead-pyspark2_4_0 \ - test-cpu-mpich-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 \ - test-cpu-oneccl-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 \ - test-cpu-oneccl-ofi-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 \ + test-cpu-mpich-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 \ + test-cpu-oneccl-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 \ + test-cpu-oneccl-ofi-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 \ test-gpu-openmpi-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0 \ test-gpu-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0 \ test-gpu-openmpi-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0 \ @@ -98,15 +95,20 @@ run_mpi_pytest() { local oneccl_env=${3:-} oneccl_env=$(echo ${oneccl_env//:/ }) - local exclude_keras_if_needed="" + local exclude_keras="" if [[ ${test} == *"tf2_"* ]] || [[ ${test} == *"tfhead"* ]]; then # TODO: support for Keras + TF 2.0 and TF-Keras 2.0 - exclude_keras_if_needed="| sed 's/test_keras.py//g' | sed 's/test_tensorflow_keras.py//g'" + exclude_keras="| sed 's/test_keras.py//g' | sed 's/test_tensorflow_keras.py//g'" else - exclude_keras_if_needed="| sed 's/[a-z_]*tensorflow2[a-z_.]*//g'" + exclude_keras="| sed 's/[a-z_]*tensorflow2[a-z_.]*//g'" fi - local exclude_interactiverun="| sed 's/test_interactiverun.py//g' | sed 's/test_spark_keras.py//g' | sed 's/test_spark_torch.py//g'" + local exclude_elastic="" + if [[ ${test} == *"py2_"* ]]; then + exclude_elastic="| sed 's/test_elastic[a-z_.]*//g'" + fi + + local excluded_tests="| sed 's/test_interactiverun.py//g' | sed 's/test_spark_keras.py//g' | sed 's/test_spark_torch.py//g'" # Spark and Run test does not need to be executed with horovodrun, but we still run it below. local exclude_standalone_test="| sed 's/test_spark.py//g' | sed 's/test_run.py//g'" @@ -121,7 +123,7 @@ run_mpi_pytest() { # pytests have 4x GPU use cases and require a separate queue run_test "${test}" "${queue}" \ ":pytest: Run PyTests (${test})" \ - "bash -c \"${oneccl_env} cd /horovod/test && (echo test_*.py ${exclude_keras_if_needed} ${exclude_interactiverun} ${exclude_standalone_test} | xargs -n 1 \\\$(cat /mpirun_command) pytest -v --capture=no) && pytest --forked -v --capture=no ${standalone_tests}\"" + "bash -c \"${oneccl_env} cd /horovod/test && (echo test_*.py ${exclude_keras} ${exclude_elastic} ${excluded_tests} ${exclude_standalone_test} | xargs -n 1 \\\$(cat /mpirun_command) pytest -v --capture=no) && pytest --forked -v --capture=no ${standalone_tests}\"" } run_mpi_integration() { @@ -156,7 +158,7 @@ run_mpi_integration() { fi run_test "${test}" "${queue}" \ - ":python: Test PyTorch MNIST (${test})" \ + ":fire: Test PyTorch MNIST (${test})" \ "bash -c \"${oneccl_env} \\\$(cat /mpirun_command) python /horovod/examples/pytorch_mnist.py\"" run_test "${test}" "${queue}" \ @@ -165,7 +167,7 @@ run_mpi_integration() { # tests that should be executed only with the latest release since they don't test # a framework-specific functionality - if [[ ${test} == *"tf1_14_0"* ]]; then + if [[ ${test} == *"tf1_15_0"* ]]; then run_test "${test}" "${queue}" \ ":muscle: Test Stall (${test})" \ "bash -c \"${oneccl_env} \\\$(cat /mpirun_command) python /horovod/test/test_stall.py\"" @@ -206,12 +208,17 @@ run_gloo_pytest() { local test=$1 local queue=$2 - local exclude_keras_if_needed="" + local exclude_keras="" if [[ ${test} == *"tf2_"* ]] || [[ ${test} == *"tfhead"* ]]; then # TODO: support for Keras + TF 2.0 and TF-Keras 2.0 - exclude_keras_if_needed="| sed 's/test_keras.py//g' | sed 's/test_tensorflow_keras.py//g'" + exclude_keras="| sed 's/test_keras.py//g' | sed 's/test_tensorflow_keras.py//g'" else - exclude_keras_if_needed="| sed 's/[a-z_]*tensorflow2[a-z_.]*//g'" + exclude_keras="| sed 's/[a-z_]*tensorflow2[a-z_.]*//g'" + fi + + local exclude_elastic="" + if [[ ${test} == *"py2_"* ]]; then + exclude_elastic="| sed 's/test_elastic[a-z_.]*//g'" fi # These are tested as integration style tests. @@ -229,7 +236,7 @@ run_gloo_pytest() { run_test "${test}" "${queue}" \ ":pytest: Run PyTests (${test})" \ - "bash -c \"cd /horovod/test && (echo test_*.py ${exclude_keras_if_needed} ${excluded_tests} ${exclude_standalone_test} | xargs -n 1 horovodrun -np 2 -H localhost:2 --gloo pytest -v --capture=no) && pytest --forked -v --capture=no ${standalone_tests}\"" + "bash -c \"cd /horovod/test && (echo test_*.py ${exclude_keras} ${exclude_elastic} ${excluded_tests} ${exclude_standalone_test} | xargs -n 1 horovodrun -np 2 -H localhost:2 --gloo pytest -v --capture=no) && pytest --forked -v --capture=no ${standalone_tests}\"" } run_gloo_integration() { @@ -256,12 +263,24 @@ run_gloo_integration() { fi run_test "${test}" "${queue}" \ - ":python: Test PyTorch MNIST (${test})" \ + ":fire: Test PyTorch MNIST (${test})" \ "horovodrun -np 2 -H localhost:2 --gloo python /horovod/examples/pytorch_mnist.py" run_test "${test}" "${queue}" \ ":muscle: Test MXNet MNIST (${test})" \ "horovodrun -np 2 -H localhost:2 --gloo python /horovod/examples/mxnet_mnist.py" + + # Elastic + if [[ ${test} == *"py3_"* ]]; then + local elastic_tensorflow="test_elastic_tensorflow.py" + if [[ ${test} == *"tf2_"* ]] || [[ ${test} == *"tfhead"* ]]; then + elastic_tensorflow="test_elastic_tensorflow2.py" + fi + + run_test "${test}" "${queue}" \ + ":factory: Elastic Tests (${test})" \ + "bash -c \"cd /horovod/test/integration && pytest -v --log-cli-level 10 --capture=no test_elastic_torch.py ${elastic_tensorflow}\"" + fi } run_gloo() { @@ -322,7 +341,7 @@ run_single_integration() { fi run_test "${test}" "${queue}" \ - ":python: Single PyTorch MNIST (${test})" \ + ":fire: Single PyTorch MNIST (${test})" \ "bash -c \"${oneccl_env} python /horovod/examples/pytorch_mnist.py --epochs 3\"" run_test "${test}" "${queue}" \ diff --git a/docker-compose.test.yml b/docker-compose.test.yml index 7d1bb93e77..e2ab55795d 100644 --- a/docker-compose.test.yml +++ b/docker-compose.test.yml @@ -6,31 +6,6 @@ services: dockerfile: Dockerfile.test.cpu privileged: true shm_size: 8gb - test-cpu-openmpi-py2_7-tf1_1_0-keras2_0_0-torch0_4_0-mxnet1_4_1-pyspark2_3_2: - extends: test-cpu-base - build: - args: - MPI_KIND: OpenMPI - PYTHON_VERSION: 2.7 - TENSORFLOW_PACKAGE: tensorflow==1.1.0 - KERAS_PACKAGE: keras==2.0.0 - PYTORCH_PACKAGE: torch==0.4.0 - TORCHVISION_PACKAGE: torchvision==0.2.2.post3 - MXNET_PACKAGE: mxnet==1.4.1 - PYSPARK_PACKAGE: pyspark==2.3.2 - test-cpu-openmpi-py3_6-tf1_1_0-keras2_0_0-torch0_4_0-mxnet1_4_1-pyspark2_3_2: - extends: test-cpu-base - build: - args: - UBUNTU_VERSION: 18.04 - MPI_KIND: OpenMPI - PYTHON_VERSION: 3.6 - TENSORFLOW_PACKAGE: tensorflow==1.1.0 - KERAS_PACKAGE: keras==2.0.0 - PYTORCH_PACKAGE: torch==0.4.0 - TORCHVISION_PACKAGE: torchvision==0.2.2.post3 - MXNET_PACKAGE: mxnet==1.4.1 - PYSPARK_PACKAGE: pyspark==2.3.2 test-cpu-openmpi-py2_7-tf1_6_0-keras2_1_2-torch0_4_1-mxnet1_4_1-pyspark2_3_2: extends: test-cpu-base build: @@ -81,16 +56,16 @@ services: TORCHVISION_PACKAGE: torchvision==0.5.0+cpu MXNET_PACKAGE: mxnet==1.5.0 PYSPARK_PACKAGE: pyspark==2.4.0 - test-cpu-gloo-py3_7-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0: + test-cpu-gloo-py3_7-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0: extends: test-cpu-base build: args: UBUNTU_VERSION: 18.04 MPI_KIND: None PYTHON_VERSION: 3.7 - TENSORFLOW_PACKAGE: tensorflow-cpu==1.15.0 + TENSORFLOW_PACKAGE: tensorflow==2.2.0 KERAS_PACKAGE: keras==2.3.1 - PYTORCH_PACKAGE: torch==1.4.0+cpu + PYTORCH_PACKAGE: torch==1.5.0+cpu TORCHVISION_PACKAGE: torchvision==0.5.0+cpu MXNET_PACKAGE: mxnet==1.5.0 PYSPARK_PACKAGE: pyspark==2.4.0 @@ -101,7 +76,7 @@ services: UBUNTU_VERSION: 18.04 MPI_KIND: None PYTHON_VERSION: 3.8 - TENSORFLOW_PACKAGE: tensorflow==2.2.0rc3 + TENSORFLOW_PACKAGE: tensorflow==2.2.0 KERAS_PACKAGE: keras==2.3.1 PYTORCH_PACKAGE: torch==1.5.0+cpu TORCHVISION_PACKAGE: torchvision==0.5.0+cpu @@ -120,19 +95,6 @@ services: TORCHVISION_PACKAGE: torchvision==0.4.1+cpu MXNET_PACKAGE: mxnet==1.4.1 PYSPARK_PACKAGE: pyspark==2.4.0 - test-cpu-openmpi-gloo-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0: - extends: test-cpu-base - build: - args: - UBUNTU_VERSION: 18.04 - MPI_KIND: OpenMPI - PYTHON_VERSION: 3.6 - TENSORFLOW_PACKAGE: tensorflow==1.14.0 - KERAS_PACKAGE: keras==2.3.1 - PYTORCH_PACKAGE: torch==1.3.0+cpu - TORCHVISION_PACKAGE: torchvision==0.4.1+cpu - MXNET_PACKAGE: mxnet==1.4.1 - PYSPARK_PACKAGE: pyspark==2.4.0 test-cpu-openmpi-py2_7-tf2_0_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0: extends: test-cpu-base build: @@ -171,40 +133,40 @@ services: TORCHVISION_PACKAGE: torchvision==0.6.0.dev20200413 MXNET_PACKAGE: mxnet-nightly PYSPARK_PACKAGE: pyspark==2.4.0 - test-cpu-mpich-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0: + test-cpu-mpich-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0: extends: test-cpu-base build: args: UBUNTU_VERSION: 18.04 MPI_KIND: MPICH PYTHON_VERSION: 3.6 - TENSORFLOW_PACKAGE: tensorflow==1.14.0 + TENSORFLOW_PACKAGE: tensorflow-cpu==1.15.0 KERAS_PACKAGE: keras==2.3.1 PYTORCH_PACKAGE: torch==1.3.0+cpu TORCHVISION_PACKAGE: torchvision==0.4.1+cpu MXNET_PACKAGE: mxnet==1.5.0 PYSPARK_PACKAGE: pyspark==2.4.0 - test-cpu-oneccl-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0: + test-cpu-oneccl-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0: extends: test-cpu-base build: args: UBUNTU_VERSION: 18.04 MPI_KIND: ONECCL PYTHON_VERSION: 3.6 - TENSORFLOW_PACKAGE: tensorflow==1.14.0 + TENSORFLOW_PACKAGE: tensorflow-cpu==1.15.0 KERAS_PACKAGE: keras==2.3.1 PYTORCH_PACKAGE: torch==1.3.0+cpu TORCHVISION_PACKAGE: torchvision==0.4.1+cpu MXNET_PACKAGE: mxnet==1.5.0 PYSPARK_PACKAGE: pyspark==2.4.0 - test-cpu-oneccl-ofi-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0: + test-cpu-oneccl-ofi-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0: extends: test-cpu-base build: args: UBUNTU_VERSION: 18.04 MPI_KIND: ONECCL PYTHON_VERSION: 3.6 - TENSORFLOW_PACKAGE: tensorflow==1.14.0 + TENSORFLOW_PACKAGE: tensorflow-cpu==1.15.0 KERAS_PACKAGE: keras==2.3.1 PYTORCH_PACKAGE: torch==1.3.0+cpu TORCHVISION_PACKAGE: torchvision==0.4.1+cpu diff --git a/docs/elastic.rst b/docs/elastic.rst new file mode 100644 index 0000000000..cf4ad6aac4 --- /dev/null +++ b/docs/elastic.rst @@ -0,0 +1,358 @@ +.. inclusion-marker-start-do-not-remove + +Elastic Horovod +=============== + + +Elastic training enables Horovod to scale up and down the number of workers dynamically at runtime, without +requiring a restart or resuming from checkpoints saved to durable storage. With elastic training, workers can come +and go from the Horovod job without interrupting the training process. + + +When to use elastic training +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +- You are running an `autoscaling `__ job that may acquire more resources for training over time. +- Your job is running on preemptable or spot instances that may come and go with little warning. +- Your nodes are unreliable and you want your job to continue training if some of the hosts fail. + + +Requirements +~~~~~~~~~~~~ + +- Python >= 3.6 +- TensorFlow >= 1.15 or PyTorch >= 1.0 +- Horovod >= 0.20.0 with Gloo support (install Horovod using ``HOROVOD_WITH_GLOO=1`` to ensure it is installed) +- A way to discover available hosts at runtime + + +Modifying the training script with State Synchronization +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The biggest difference when moving from normal distributed training to elastic training is the need to track and synchronize +state among the workers as workers are added or removed from the job. + +To enable elastic training, make the following changes to your training script: + +1. Wrap your main training process (everything following initialization) in a function decorated with ``hvd.elastic.run``. + + The first argument to this decorated function should be an instance of ``hvd.elastic.State``. Before executing the + decorated function, this state object will be synchronized across workers. This ensures that workers that were + newly added, as well as workers that might have inconsistent state, all share the same state before training begins. + + Because the sync function uses collective ops, and upon worker add the active workers will not reset from before this + function, *no Horovod collective ops (broadcast, allreduce, allgather, etc.) can be called before this function*. + +2. Place all variables that need to be kept in sync between worker replicas (model parameters, optimizer state, epoch and batch numbers, etc.) into a ``hvd.elastic.State`` object. + + Standard state implementations are provided for TensorFlow, Keras, and PyTorch. However, it may be necessary in some cases to override + the base ``hvd.elastic.State`` object to handle broadcasting custom types. + +3. Periodically call ``state.commit()`` to backup a copy of your state in memory. + + This is useful to prevent corrupted state in the event that a worker fails unexpectedly. For example, if training fails + in the middle of a parameter update, some gradient updates may have applied while others were still being allreduced. When this + happens, a ``HorovodInternalError`` will be raised, and all parameters will be restored to the values at the time of the last commit. + + Because commits can be expensive (as the model size increases), there is a tradeoff between the per-batch processing time + and how far the training process needs to rollback in the event of a failure. For example, if you commit once every 10 + batches, you reduce the amount of copying by a factor of 10. But if a failure occurs, you may need to redo up to 10 + previously processed batches. + + Elastic Horovod can avoid these rollbacks by performing what we call a *graceful removal* of a worker. If the driver + process discovers that a host has been made available or flagged for removal, it will push a notification to the workers. + The next time ``state.commit()`` or the more lightweight ``state.check_host_updates()`` is called, a ``HostsUpdatedInterrupt`` + will be raised. This event is handled similar to the ``HorovodInternalError``, except that parameter state will not be + restored to the last commit. + + In general, if your hardware is generally reliable, and your orchestration system gives the driver ample warning + when a host is scheduled to be removed from the job, then you can safely call ``state.commit()`` on a reduced frequency, + and call ``state.check_host_updates()`` at the end of each batch instead. + +4. Register callbacks with the ``hvd.elastic.State`` object to respond to changes in the worker membership in the job. + + For example, rescaling the learning rate with the new world size or repartitioning the dataset would commonly be done + through these callbacks. + + Callbacks are called after Horovod has reinitialized, but before state is synchronized across the workers. + +The reset process following a ``HorovodInternalError`` (failure) or ``HostsUpdatedInterrupt`` (add/remove request) is as follows: + +1. Catch exception within the ``hvd.elastic.run`` decorator. +2. Restore last committed state if ``HorovodInternalError`` was raised. +3. Reinitialize Horovod context performing a new round of rendezvous. +4. Synchronize state among the workers by broadcasting from the new worker-0. +5. Resume training by executing the underlying training function. + +During rendezvous, older workers will take priority in being assigned worker-0 status to ensure that the state that +is broadcast is up to date. + + +Elastic TensorFlow +~~~~~~~~~~~~~~~~~~ + +TensorFlow v1 Example: + +.. code-block:: python + :emphasize-lines: 17,18,23,29,32,33 + + import tensorflow as tf + import horovod.tensorflow as hvd + + hvd.init() + + config = tf.ConfigProto() + config.gpu_options.allow_growth = True + config.gpu_options.visible_device_list = str(hvd.local_rank()) + + dataset = ... + model = ... + + lr = tf.Variable(base_lr * hvd.size()) + optimizer = tf.train.GradientDescentOptimizer(lr) + optimizer = hvd.DistributedOptimizer(optimizer) + + @hvd.elastic.run + def train(state, train_one_batch): + for state.epoch in range(state.epoch, epochs): + for state.batch in range(state.batch, batches_per_epoch): + train_one_batch() + if state.batch % batches_per_commit == 0: + state.commit() + state.batch = 0 + + with tf.Session(config=config) as session: + session.run(tf.global_variables_initializer()) + + def on_state_reset(): + lr.load(base_lr * hvd.size(), session) + + state = hvd.elastic.TensorFlowState(session=session, batch=0, epoch=0) + state.register_reset_callbacks([on_state_reset]) + + train_opt = optimizer.minimize(loss) + train(state, lambda: session.run(train_opt)) + +TensorFlow v2 Example: + +.. code-block:: python + :emphasize-lines: 33,34,40,43,46,47 + + import tensorflow as tf + import horovod.tensorflow as hvd + + hvd.init() + + gpus = tf.config.experimental.list_physical_devices('GPU') + for gpu in gpus: + tf.config.experimental.set_memory_growth(gpu, True) + if gpus: + tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], 'GPU') + + dataset = ... + model = ... + + optimizer = tf.optimizers.Adam(lr * hvd.size()) + + @tf.function + def train_one_batch(data, target, allreduce=True): + with tf.GradientTape() as tape: + probs = model(data, training=True) + loss = tf.losses.categorical_crossentropy(target, probs) + + if allreduce: + tape = hvd.DistributedGradientTape(tape) + + gradients = tape.gradient(loss, model.trainable_variables) + optimizer.apply_gradients(zip(gradients, model.trainable_variables)) + + # Initialize model and optimizer state so we can synchronize across workers + data, target = get_random_batch() + train_one_batch(data, target, allreduce=False) + + @hvd.elastic.run + def train(state): + for state.epoch in range(state.epoch, epochs): + for state.batch in range(state.batch, batches_per_epoch): + data, target = get_random_batch() + train_one_batch(data, target) + if state.batch % batches_per_commit == 0: + state.commit() + state.batch = 0 + + def on_state_reset(): + optimizer.lr.assign(lr * hvd.size()) + + state = hvd.elastic.TensorFlowKerasState(model, optimizer, batch=0, epoch=0) + state.register_reset_callbacks([on_state_reset]) + train(state) + + +Elastic Keras +~~~~~~~~~~~~~ + +.. code-block:: python + :emphasize-lines: 21,24,25,28,29,30,36,37 + + import tensorflow as tf + import horovod.tensorflow.keras as hvd + + hvd.init() + + config = tf.ConfigProto() + config.gpu_options.allow_growth = True + config.gpu_options.visible_device_list = str(hvd.local_rank()) + tf.keras.backend.set_session(tf.Session(config=config)) + + dataset = ... + model = ... + + opt = keras.optimizers.Adadelta(lr * hvd.size()) + opt = hvd.DistributedOptimizer(opt) + + model.compile(loss=keras.losses.sparse_categorical_crossentropy, + optimizer=opt, + metrics=['accuracy']) + + def on_state_reset(): + tf.keras.backend.set_value(model.optimizer.lr, lr * hvd.size()) + + state = hvd.elastic.KerasState(model, batch=100, epoch=0) + state.register_reset_callbacks([on_state_reset]) + + callbacks = [ + hvd.elastic.CommitStateCallback(state), + hvd.elastic.UpdateBatchStateCallback(state), + hvd.elastic.UpdateEpochStateCallback(state), + ] + + if hvd.rank() == 0: + callbacks.append(keras.callbacks.ModelCheckpoint('./checkpoint-{epoch}.h5')) + + @hvd.elastic.run + def train(state): + model.fit(dataset, + steps_per_epoch=500 // hvd.size(), + callbacks=callbacks, + epochs=epochs - state.epoch, + verbose=1 if hvd.rank() == 0 else 0) + + train(state) + + +Elastic PyTorch +~~~~~~~~~~~~~~~ + +.. code-block:: python + :emphasize-lines: 14,15,28,31,36,37 + + import torch + import horovod.torch as hvd + + hvd.init() + + torch.cuda.set_device(hvd.local_rank()) + + dataset = ... + model = ... + + optimizer = optim.SGD(model.parameters(), lr * hvd.size()) + optimizer = hvd.DistributedOptimizer(optimizer) + + @hvd.elastic.run + def train(state): + batch_offset = state.batch + for state.epoch in range(state.epoch, epochs): + for state.batch in range(state.batch, batches_per_epoch): + data, target = get_random_batch() + + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output, target) + loss.backward() + optimizer.step() + + if state.batch % batches_per_commit == 0: + state.commit() + state.batch = 0 + + def on_state_reset(): + # adjust learning rate on reset + for param_group in optimizer.param_groups: + param_group['lr'] = lr * hvd.size() + + state = hvd.elastic.TorchState(model, optimizer, batch=0, epoch=0) + state.register_reset_callbacks([on_state_reset]) + train(state) + + +Running with horovodrun +~~~~~~~~~~~~~~~~~~~~~~~ + +Elastic training jobs are started using the ``horovodrun`` command line tool. The major difference when launching +elastic jobs is that hosts are not specified explicitly, but instead **discovered** at runtime. The most general way +to allow Horovod to discover available hosts is to provide a ``--host-discovery-script`` when launching the job: + +.. code-block:: bash + + $ horovodrun -np 8 --host-discovery-script discover_hosts.sh python train.py + +The host discovery script must have user executable permissions, and return one host with its available slots per line +of the form: ``:``. For example: + +.. code-block:: bash + + $ ./discover_hosts.sh + host-1:4 + host-2:4 + host-3:4 + +If the host discovery scripts fails to execute (due to a permissions issue) or otherwise returns a non-zero exit code +the first time it is called, the training process will fail immediately. However, subsequent errors will result in +retries until the job times-out (due to failure to discover a sufficient number of slots). + +Your discovery script may omit the ``:`` if you explicitly specify the number of slots per host as an argument: + +.. code-block:: bash + + $ horovodrun -np 8 --host-discovery-script discover_hosts.sh --slots 4 python train.py + +The elastic training job will not start until at least ``-np`` slots are available for running worker processes. + +You can additionally specify the minimum and maximum number of processes to run with during the job: + +.. code-block:: bash + + $ horovodrun -np 8 --min-np 4 --max-np 12 --host-discovery-script discover_hosts.sh python train.py + +If the number of available slots falls below ``--min-np`` (due to host failure, preemption, etc.), then the job will +pause waiting for more hosts to become available or until ``HOROVOD_ELASTIC_TIMEOUT`` (default: 600 seconds) has +elapsed. If unspecified, minimum np defaults to ``-np``. + +The maximum np can be used to cap the number of processes (to prevent over-utilizing available resources) and to serve +as a reference point for learning rate scales and data partitions (in cases where these need to be held constant +regardless of the current number of workers). If unspecified, maximum np also defaults to ``-np``. + +Instances that fail will be added to a blacklist, as they may have faulty hardware. Ranks that fail repeatedly +will result in job failure, as it may be the case that the training process cannot make progress. + + +Practical Considerations: Consistent training +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +With workers frequently being added and removed from the training process, it creates the possibility for learning +rates, numbers of partitions, and other parameters that vary with the number of workers to hurt model convergence if +not properly handled. + +Learning rate will need to be rescaled via callback when using gradient averaging. Using Adasum, no adjustment will +need to be made assuming that local size remains the same. + +If using random sampling to read data, then no repartitioning need occur. For the time being, this is the recommended +strategy to simplify elastic training configuration. + +If using dataset partitioning, callbacks may be used to repartition dataset as necessary, skipping already processed +data. Care needs to be taken when partitioning the data to ensure that data is not processed more than once. As such, +the preferred approach is to keep the number of partitions constant (from ``hvd.max_size()``), but redistribute +partitions and use local gradient aggregation to keep total batch size constant. + +.. inclusion-marker-end-do-not-remove diff --git a/docs/elastic_include.rst b/docs/elastic_include.rst new file mode 100644 index 0000000000..2c8385b421 --- /dev/null +++ b/docs/elastic_include.rst @@ -0,0 +1,3 @@ +.. include:: ./elastic.rst + :start-after: inclusion-marker-start-do-not-remove + :end-before: inclusion-marker-end-do-not-remove diff --git a/docs/index.rst b/docs/index.rst index 889da3737e..fc51c9435c 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -109,6 +109,8 @@ Guides running_include + elastic_include + benchmarks_include inference_include diff --git a/examples/elastic/pytorch_mnist_elastic.py b/examples/elastic/pytorch_mnist_elastic.py new file mode 100644 index 0000000000..9115d99179 --- /dev/null +++ b/examples/elastic/pytorch_mnist_elastic.py @@ -0,0 +1,203 @@ +from __future__ import print_function +import argparse +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torchvision import datasets, transforms +import torch.utils.data.distributed +import horovod.torch as hvd +import os + +# Training settings +parser = argparse.ArgumentParser(description='PyTorch MNIST Example') +parser.add_argument('--batch-size', type=int, default=64, metavar='N', + help='input batch size for training (default: 64)') +parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', + help='input batch size for testing (default: 1000)') +parser.add_argument('--epochs', type=int, default=10, metavar='N', + help='number of epochs to train (default: 10)') +parser.add_argument('--lr', type=float, default=0.01, metavar='LR', + help='learning rate (default: 0.01)') +parser.add_argument('--momentum', type=float, default=0.5, metavar='M', + help='SGD momentum (default: 0.5)') +parser.add_argument('--no-cuda', action='store_true', default=False, + help='disables CUDA training') +parser.add_argument('--seed', type=int, default=42, metavar='S', + help='random seed (default: 42)') +parser.add_argument('--log-interval', type=int, default=10, metavar='N', + help='how many batches to wait before logging training status') +parser.add_argument('--fp16-allreduce', action='store_true', default=False, + help='use fp16 compression during allreduce') +parser.add_argument('--use-adasum', action='store_true', default=False, + help='use adasum algorithm to do reduction') + +args = parser.parse_args() +args.cuda = not args.no_cuda and torch.cuda.is_available() + +# Horovod: initialize library. +hvd.init() +torch.manual_seed(args.seed) + +if args.cuda: + # Horovod: pin GPU to local rank. + torch.cuda.set_device(hvd.local_rank()) + torch.cuda.manual_seed(args.seed) + + +# Horovod: limit # of CPU threads to be used per worker. +torch.set_num_threads(1) + +kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} +train_dataset = \ + datasets.MNIST('data-%d' % hvd.rank(), train=True, download=True, + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ])) +# Horovod: use DistributedSampler to partition the training data. +train_sampler = torch.utils.data.distributed.DistributedSampler( + train_dataset, num_replicas=hvd.size(), rank=hvd.rank()) +train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.batch_size, sampler=train_sampler, **kwargs) + +test_dataset = \ + datasets.MNIST('data-%d' % hvd.rank(), train=False, transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ])) +# Horovod: use DistributedSampler to partition the test data. +test_sampler = torch.utils.data.distributed.DistributedSampler( + test_dataset, num_replicas=hvd.size(), rank=hvd.rank()) +test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.test_batch_size, + sampler=test_sampler, **kwargs) + + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 10, kernel_size=5) + self.conv2 = nn.Conv2d(10, 20, kernel_size=5) + self.conv2_drop = nn.Dropout2d() + self.fc1 = nn.Linear(320, 50) + self.fc2 = nn.Linear(50, 10) + + def forward(self, x): + x = F.relu(F.max_pool2d(self.conv1(x), 2)) + x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) + x = x.view(-1, 320) + x = F.relu(self.fc1(x)) + x = F.dropout(x, training=self.training) + x = self.fc2(x) + return F.log_softmax(x) + + +model = Net() + +# By default, Adasum doesn't need scaling up learning rate. +lr_scaler = hvd.size() if not args.use_adasum else 1 + +if args.cuda: + # Move model to GPU. + model.cuda() + # If using GPU Adasum allreduce, scale learning rate by local_size. + if args.use_adasum and hvd.nccl_built(): + lr_scaler = hvd.local_size() + +# Horovod: scale learning rate by lr_scaler. +optimizer = optim.SGD(model.parameters(), lr=args.lr * lr_scaler, + momentum=args.momentum) + +# Horovod: (optional) compression algorithm. +compression = hvd.Compression.fp16 if args.fp16_allreduce else hvd.Compression.none + + +def metric_average(val, name): + tensor = torch.tensor(val) + avg_tensor = hvd.allreduce(tensor, name=name) + return avg_tensor.item() + + +def check_rank(epoch): + if epoch == 2 and int(os.environ.get('HOROVOD_RANK')) == 0: + print('exit rank {}'.format(hvd.rank())) + raise RuntimeError('check_rank and exit') + # exit(1) + + +@hvd.elastic.run +def train(state): + # post synchronization event (worker added, worker removed) init ... + for state.epoch in range(state.epoch, args.epochs + 1): + state.model.train() + + train_sampler.set_epoch(state.epoch) + steps_remaining = len(train_loader) - state.batch + + for state.batch, (data, target) in enumerate(train_loader): + if state.batch >= steps_remaining: + break + + check_rank(state.epoch) + if args.cuda: + data, target = data.cuda(), target.cuda() + state.optimizer.zero_grad() + output = state.model(data) + loss = F.nll_loss(output, target) + loss.backward() + state.optimizer.step() + if state.batch % args.log_interval == 0: + # Horovod: use train_sampler to determine the number of examples in + # this worker's partition. + print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( + state.epoch, state.batch * len(data), len(train_sampler), + 100.0 * state.batch / len(train_loader), loss.item())) + state.commit() + state.batch = 0 + + +def test(): + model.eval() + test_loss = 0. + test_accuracy = 0. + for data, target in test_loader: + if args.cuda: + data, target = data.cuda(), target.cuda() + output = model(data) + # sum up batch loss + test_loss += F.nll_loss(output, target, size_average=False).item() + # get the index of the max log-probability + pred = output.data.max(1, keepdim=True)[1] + test_accuracy += pred.eq(target.data.view_as(pred)).cpu().float().sum() + + # Horovod: use test_sampler to determine the number of examples in + # this worker's partition. + test_loss /= len(test_sampler) + test_accuracy /= len(test_sampler) + + # Horovod: average metric values across workers. + test_loss = metric_average(test_loss, 'avg_loss') + test_accuracy = metric_average(test_accuracy, 'avg_accuracy') + + # Horovod: print output only on first rank. + if hvd.rank() == 0: + print('\nTest set: Average loss: {:.4f}, Accuracy: {:.2f}%\n'.format( + test_loss, 100. * test_accuracy)) + + +# Horovod: wrap optimizer with DistributedOptimizer. +optimizer = hvd.DistributedOptimizer(optimizer, + named_parameters=model.named_parameters(), + compression=compression, + op=hvd.Adasum if args.use_adasum else hvd.Average) + + +# adjust learning rate on reset +def on_state_reset(): + for param_group in optimizer.param_groups: + param_group['lr'] = args.lr * hvd.size() + + +state = hvd.elastic.TorchState(model, optimizer, epoch=1, batch=0) +state.register_reset_callbacks([on_state_reset]) +train(state) +test() diff --git a/examples/elastic/pytorch_synthetic_benchmark_elastic.py b/examples/elastic/pytorch_synthetic_benchmark_elastic.py new file mode 100644 index 0000000000..e1ca8e99ee --- /dev/null +++ b/examples/elastic/pytorch_synthetic_benchmark_elastic.py @@ -0,0 +1,150 @@ +from __future__ import print_function + +import argparse +import torch.backends.cudnn as cudnn +import torch.nn.functional as F +import torch.optim as optim +import torch.utils.data.distributed +from torchvision import models +import horovod.torch as hvd +import timeit +import numpy as np + +# Benchmark settings +parser = argparse.ArgumentParser(description='PyTorch Synthetic Benchmark', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) +parser.add_argument('--fp16-allreduce', action='store_true', default=False, + help='use fp16 compression during allreduce') + +parser.add_argument('--model', type=str, default='resnet50', + help='model to benchmark') +parser.add_argument('--batch-size', type=int, default=32, + help='input batch size') + +parser.add_argument('--num-warmup-batches', type=int, default=10, + help='number of warm-up batches that don\'t count towards benchmark') +parser.add_argument('--num-batches-per-iter', type=int, default=10, + help='number of batches per benchmark iteration') +parser.add_argument('--num-iters', type=int, default=10, + help='number of benchmark iterations') +parser.add_argument('--num-batches-per-commit', type=int, default=1, + help='number of batches per commit of the elastic state object') + +parser.add_argument('--no-cuda', action='store_true', default=False, + help='disables CUDA training') + +parser.add_argument('--use-adasum', action='store_true', default=False, + help='use adasum algorithm to do reduction') + +args = parser.parse_args() +args.cuda = not args.no_cuda and torch.cuda.is_available() + +hvd.init() + +if args.cuda: + # Horovod: pin GPU to local rank. + torch.cuda.set_device(hvd.local_rank()) + +cudnn.benchmark = True + +# Set up standard model. +model = getattr(models, args.model)() + + +# By default, Adasum doesn't need scaling up learning rate. +def lr_scaler(): + return hvd.size() if not args.use_adasum else 1 + + +if args.cuda: + # Move model to GPU. + model.cuda() + # If using GPU Adasum allreduce, scale learning rate by local_size. + if args.use_adasum and hvd.nccl_built(): + lr_scaler = hvd.local_size() + +lr = 0.01 +optimizer = optim.SGD(model.parameters(), lr=lr * lr_scaler()) + +# Horovod: (optional) compression algorithm. +compression = hvd.Compression.fp16 if args.fp16_allreduce else hvd.Compression.none + +# Horovod: wrap optimizer with DistributedOptimizer. +optimizer = hvd.DistributedOptimizer(optimizer, + named_parameters=model.named_parameters(), + compression=compression, + op=hvd.Adasum if args.use_adasum else hvd.Average) + +# Horovod: broadcast parameters & optimizer state. +hvd.broadcast_parameters(model.state_dict(), root_rank=0) +hvd.broadcast_optimizer_state(optimizer, root_rank=0) + +# Set up fixed fake data +data = torch.randn(args.batch_size, 3, 224, 224) +target = torch.LongTensor(args.batch_size).random_() % 1000 +if args.cuda: + data, target = data.cuda(), target.cuda() + + +def benchmark_step(state): + optimizer.zero_grad() + output = model(data) + loss = F.cross_entropy(output, target) + loss.backward() + optimizer.step() + + state.batch += 1 + if state.batch == args.num_batches_per_commit: + state.batch = 0 + state.commit() + + +def log(s, nl=True): + if hvd.rank() != 0: + return + print(s, end='\n' if nl else '') + + +log('Model: %s' % args.model) +log('Batch size: %d' % args.batch_size) +device = 'GPU' if args.cuda else 'CPU' +log('Number of %ss: %d' % (device, hvd.size())) + + +@hvd.elastic.run +def run_benchmark(state): + # Warm-up + if not state.warm: + log('Running warmup...') + timeit.timeit(lambda: benchmark_step(state), number=args.num_warmup_batches) + state.warm = True + state.commit() + + # Benchmark + if state.iter == 0: + log('Running benchmark...') + for x in range(state.iter, args.num_iters): + time = timeit.timeit(lambda: benchmark_step(state), number=args.num_batches_per_iter) + img_sec = args.batch_size * args.num_batches_per_iter / time + log('Iter #%d: %.1f img/sec per %s' % (x, img_sec, device)) + state.img_secs.append(img_sec) + state.iter = x + state.commit() + + +# adjust learning rate on reset +def on_state_reset(): + for param_group in optimizer.param_groups: + param_group['lr'] = lr * lr_scaler() + + +state = hvd.elastic.TorchState(model, optimizer, img_secs=[], iter=0, batch=0, warm=False) +state.register_reset_callbacks([on_state_reset]) +run_benchmark(state) + +# Results +img_sec_mean = np.mean(state.img_secs) +img_sec_conf = 1.96 * np.std(state.img_secs) +log('Img/sec per %s: %.1f +-%.1f' % (device, img_sec_mean, img_sec_conf)) +log('Total img/sec on %d %s(s): %.1f +-%.1f' % + (hvd.size(), device, hvd.size() * img_sec_mean, hvd.size() * img_sec_conf)) diff --git a/examples/elastic/tensorflow2_mnist_elastic.py b/examples/elastic/tensorflow2_mnist_elastic.py new file mode 100644 index 0000000000..25d56ba9d8 --- /dev/null +++ b/examples/elastic/tensorflow2_mnist_elastic.py @@ -0,0 +1,106 @@ +# Copyright 2020 Uber Technologies, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import tensorflow as tf +import horovod.tensorflow as hvd + +# Horovod: initialize Horovod. +hvd.init() + +# Horovod: pin GPU to be used to process local rank (one GPU per process) +gpus = tf.config.experimental.list_physical_devices('GPU') +for gpu in gpus: + tf.config.experimental.set_memory_growth(gpu, True) +if gpus: + tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], 'GPU') + +(mnist_images, mnist_labels), _ = \ + tf.keras.datasets.mnist.load_data(path='mnist-%d.npz' % hvd.rank()) + +dataset = tf.data.Dataset.from_tensor_slices( + (tf.cast(mnist_images[..., tf.newaxis] / 255.0, tf.float32), + tf.cast(mnist_labels, tf.int64)) +) +dataset = dataset.repeat().shuffle(10000).batch(128) + +mnist_model = tf.keras.Sequential([ + tf.keras.layers.Conv2D(32, [3, 3], activation='relu'), + tf.keras.layers.Conv2D(64, [3, 3], activation='relu'), + tf.keras.layers.MaxPooling2D(pool_size=(2, 2)), + tf.keras.layers.Dropout(0.25), + tf.keras.layers.Flatten(), + tf.keras.layers.Dense(128, activation='relu'), + tf.keras.layers.Dropout(0.5), + tf.keras.layers.Dense(10, activation='softmax') +]) +loss = tf.losses.SparseCategoricalCrossentropy() + +# Horovod: adjust learning rate based on number of GPUs. +lr = 0.001 +opt = tf.optimizers.Adam(lr * hvd.size()) + + +@tf.function +def training_step(images, labels, allreduce=True): + with tf.GradientTape() as tape: + probs = mnist_model(images, training=True) + loss_value = loss(labels, probs) + + # Horovod: add Horovod Distributed GradientTape. + if allreduce: + tape = hvd.DistributedGradientTape(tape) + + grads = tape.gradient(loss_value, mnist_model.trainable_variables) + opt.apply_gradients(zip(grads, mnist_model.trainable_variables)) + return loss_value + + +# Horovod: initialize model and optimizer state so we can synchronize across workers +for batch_idx, (images, labels) in enumerate(dataset.take(1)): + training_step(images, labels, allreduce=False) + + +@hvd.elastic.run +def train(state): + start_batch = state.batch + + # Horovod: adjust number of steps based on number of GPUs. + for batch_idx, (images, labels) in enumerate(dataset.skip(state.batch).take(10000 // hvd.size())): + state.batch = start_batch + batch_idx + loss_value = training_step(images, labels) + + if state.batch % 10 == 0 and hvd.local_rank() == 0: + print('Step #%d\tLoss: %.6f' % (state.batch, loss_value)) + + # Horovod: commit state at the end of each batch + state.commit() + + +def on_state_reset(): + opt.lr.assign(lr * hvd.size()) + + +state = hvd.elastic.TensorFlowKerasState(mnist_model, opt, batch=0) +state.register_reset_callbacks([on_state_reset]) + +train(state) + +checkpoint_dir = './checkpoints' +checkpoint = tf.train.Checkpoint(model=mnist_model, optimizer=opt) + +# Horovod: save checkpoints only on worker 0 to prevent other workers from +# corrupting it. +if hvd.rank() == 0: + checkpoint.save(checkpoint_dir) diff --git a/examples/elastic/tensorflow2_synthetic_benchmark_elastic.py b/examples/elastic/tensorflow2_synthetic_benchmark_elastic.py new file mode 100644 index 0000000000..d4abc26429 --- /dev/null +++ b/examples/elastic/tensorflow2_synthetic_benchmark_elastic.py @@ -0,0 +1,151 @@ +# Copyright 2019 Uber Technologies, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import, division, print_function + +import argparse +import os +import numpy as np +import timeit + +import tensorflow as tf +import horovod.tensorflow as hvd +from tensorflow.keras import applications + +# Benchmark settings +parser = argparse.ArgumentParser(description='TensorFlow Synthetic Benchmark', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) +parser.add_argument('--fp16-allreduce', action='store_true', default=False, + help='use fp16 compression during allreduce') + +parser.add_argument('--model', type=str, default='ResNet50', + help='model to benchmark') +parser.add_argument('--batch-size', type=int, default=32, + help='input batch size') + +parser.add_argument('--num-warmup-batches', type=int, default=10, + help='number of warm-up batches that don\'t count towards benchmark') +parser.add_argument('--num-batches-per-iter', type=int, default=10, + help='number of batches per benchmark iteration') +parser.add_argument('--num-iters', type=int, default=10, + help='number of benchmark iterations') +parser.add_argument('--num-batches-per-commit', type=int, default=1, + help='number of batches per commit of the elastic state object') + +parser.add_argument('--no-cuda', action='store_true', default=False, + help='disables CUDA training') + +args = parser.parse_args() +args.cuda = not args.no_cuda + +# Horovod: initialize Horovod. +hvd.init() + +# Horovod: pin GPU to be used to process local rank (one GPU per process) +if args.cuda: + gpus = tf.config.experimental.list_physical_devices('GPU') + for gpu in gpus: + tf.config.experimental.set_memory_growth(gpu, True) + if gpus: + tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], 'GPU') +else: + os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + +# Set up standard model. +lr = 0.01 +model = getattr(applications, args.model)(weights=None) +opt = tf.optimizers.SGD(lr * hvd.size()) + +data = tf.random.uniform([args.batch_size, 224, 224, 3]) +target = tf.random.uniform([args.batch_size, 1], minval=0, maxval=999, dtype=tf.int64) + + +@tf.function +def train_one_batch(): + # Horovod: (optional) compression algorithm. + compression = hvd.Compression.fp16 if args.fp16_allreduce else hvd.Compression.none + + # Horovod: use DistributedGradientTape + with tf.GradientTape() as tape: + probs = model(data, training=True) + loss = tf.losses.categorical_crossentropy(target, probs) + + # Horovod: add Horovod Distributed GradientTape. + tape = hvd.DistributedGradientTape(tape, compression=compression) + + gradients = tape.gradient(loss, model.trainable_variables) + opt.apply_gradients(zip(gradients, model.trainable_variables)) + + +def benchmark_step(state): + train_one_batch() + if state is not None: + state.batch += 1 + if state.batch == args.num_batches_per_commit: + state.batch = 0 + state.commit() + + +def log(s, nl=True): + if hvd.rank() != 0: + return + print(s, end='\n' if nl else '') + + +log('Model: %s' % args.model) +log('Batch size: %d' % args.batch_size) +device = 'GPU' if args.cuda else 'CPU' +log('Number of %ss: %d' % (device, hvd.size())) + + +# Run one batch to initialize weights before synchronization +train_one_batch() + + +@hvd.elastic.run +def run_benchmark(state): + with tf.device(device): + # Warm-up + if not state.warm: + log('Running warmup...') + timeit.timeit(lambda: benchmark_step(state), number=args.num_warmup_batches) + state.warm = True + state.commit() + + # Benchmark + if state.iter == 0: + log('Running benchmark...') + for x in range(state.iter, args.num_iters): + time = timeit.timeit(lambda: benchmark_step(state), number=args.num_batches_per_iter) + img_sec = args.batch_size * args.num_batches_per_iter / time + log('Iter #%d: %.1f img/sec per %s' % (x, img_sec, device)) + state.img_secs.append(img_sec) + state.iter = x + state.commit() + + +def on_state_reset(): + tf.keras.backend.set_value(model.optimizer.lr, lr * hvd.size()) + + +state = hvd.elastic.TensorFlowKerasState(model, opt, img_secs=[], iter=0, batch=0, warm=False) +state.register_reset_callbacks([on_state_reset]) +run_benchmark(state) + +# Results +img_sec_mean = np.mean(state.img_secs) +img_sec_conf = 1.96 * np.std(state.img_secs) +log('Img/sec per %s: %.1f +-%.1f' % (device, img_sec_mean, img_sec_conf)) +log('Total img/sec on %d %s(s): %.1f +-%.1f' % + (hvd.size(), device, hvd.size() * img_sec_mean, hvd.size() * img_sec_conf)) diff --git a/examples/elastic/tensorflow_keras_mnist_elastic.py b/examples/elastic/tensorflow_keras_mnist_elastic.py new file mode 100644 index 0000000000..014547c5e1 --- /dev/null +++ b/examples/elastic/tensorflow_keras_mnist_elastic.py @@ -0,0 +1,87 @@ +from __future__ import print_function + +import tensorflow as tf +from tensorflow import keras +from tensorflow.keras.models import Sequential +from tensorflow.keras.layers import Dense, Dropout, Flatten +from tensorflow.keras.layers import Conv2D, MaxPooling2D +from tensorflow.keras import backend as K + +import horovod.tensorflow.keras as hvd + +# Horovod: initialize Horovod. +hvd.init() + +# Horovod: pin GPU to be used to process local rank (one GPU per process) +config = tf.ConfigProto() +config.gpu_options.allow_growth = True +config.gpu_options.visible_device_list = str(hvd.local_rank()) +K.set_session(tf.Session(config=config)) + +lr = 1.0 +batch_size = 128 +epochs = 24 +num_classes = 10 + +(mnist_images, mnist_labels), _ = \ + tf.keras.datasets.mnist.load_data(path='mnist-%d.npz' % hvd.rank()) + +dataset = tf.data.Dataset.from_tensor_slices( + (tf.cast(mnist_images[..., tf.newaxis] / 255.0, tf.float32), + tf.cast(mnist_labels, tf.int64)) +) +dataset = dataset.repeat().shuffle(10000).batch(batch_size) + +model = Sequential() +model.add(Conv2D(32, kernel_size=(3, 3), + activation='relu', + input_shape=(28, 28, 1))) +model.add(Conv2D(64, (3, 3), activation='relu')) +model.add(MaxPooling2D(pool_size=(2, 2))) +model.add(Dropout(0.25)) +model.add(Flatten()) +model.add(Dense(128, activation='relu')) +model.add(Dropout(0.5)) +model.add(Dense(num_classes, activation='softmax')) + +# Horovod: adjust learning rate based on number of GPUs. +opt = keras.optimizers.Adadelta(lr * hvd.size()) + +# Horovod: add Horovod Distributed Optimizer. +opt = hvd.DistributedOptimizer(opt) + +model.compile(loss=keras.losses.sparse_categorical_crossentropy, + optimizer=opt, + metrics=['accuracy']) + + +def on_state_reset(): + tf.keras.backend.set_value(model.optimizer.lr, lr * hvd.size()) + + +state = hvd.elastic.KerasState(model, batch=100, epoch=0) +state.register_reset_callbacks([on_state_reset]) + +callbacks = [ + # Horovod: elastic training callbacks to update and commit state. + hvd.elastic.CommitStateCallback(state), + hvd.elastic.UpdateBatchStateCallback(state), + hvd.elastic.UpdateEpochStateCallback(state), +] + +# Horovod: save checkpoints only on worker 0 to prevent other workers from corrupting them. +if hvd.rank() == 0: + callbacks.append(keras.callbacks.ModelCheckpoint('./checkpoint-{epoch}.h5')) + + +@hvd.elastic.run +def train(state): + # Horovod: adjust number of steps based on number of GPUs. + state.model.fit(dataset, + steps_per_epoch=500 // hvd.size(), + callbacks=callbacks, + epochs=epochs - state.epoch, + verbose=1 if hvd.rank() == 0 else 0) + + +train(state) diff --git a/horovod/_keras/elastic.py b/horovod/_keras/elastic.py new file mode 100644 index 0000000000..7d5f3f5b21 --- /dev/null +++ b/horovod/_keras/elastic.py @@ -0,0 +1,59 @@ +# Copyright 2020 Uber Technologies, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +class CommitStateCallbackImpl(object): + def __init__(self, backend, state, batches_per_commit, *args): + super(CommitStateCallbackImpl, self).__init__(*args) + self.backend = backend + self.state = state + self.batches_per_commit = batches_per_commit + self.batches_remaining = batches_per_commit + + def on_batch_end(self, batch, logs=None): + self.batches_remaining -= 1 + if self.batches_remaining == 0: + self.state.commit() + self.batches_remaining = self.batches_per_commit + + +class UpdateBatchStateCallbackImpl(object): + def __init__(self, backend, state, *args): + super(UpdateBatchStateCallbackImpl, self).__init__(*args) + self.backend = backend + self.state = state + self.steps_per_epoch = None + + def on_epoch_begin(self, epoch, logs=None): + if self.params.get('steps'): + if self.steps_per_epoch is None: + self.steps_per_epoch = self.params.get('steps') + self.params['steps'] = self.steps_per_epoch - self.state.batch + + def on_batch_end(self, batch, logs=None): + self.state.batch = batch + + def on_epoch_end(self, epoch, logs=None): + self.state.batch = 0 + + +class UpdateEpochStateCallbackImpl(object): + def __init__(self, backend, state, *args): + super(UpdateEpochStateCallbackImpl, self).__init__(*args) + self.backend = backend + self.state = state + + def on_epoch_end(self, epoch, logs=None): + self.state.epoch = epoch diff --git a/horovod/common/controller.cc b/horovod/common/controller.cc index 5a876a163e..7c44a86b15 100644 --- a/horovod/common/controller.cc +++ b/horovod/common/controller.cc @@ -163,7 +163,7 @@ ResponseList Controller::ComputeResponseList(std::atomic_bool& shut_down, } if (!message_queue_tmp.empty()) { - LOG(DEBUG, rank_) << "Sent " << message_queue_tmp.size() + LOG(TRACE, rank_) << "Sent " << message_queue_tmp.size() << " messages to coordinator."; } @@ -755,7 +755,7 @@ ResponseList Controller::FuseResponses(std::deque& responses) { } response_list.add_response(std::move(response)); - LOG(DEBUG) << "Created response of size " << tensor_size; + LOG(TRACE) << "Created response of size " << tensor_size; } return response_list; } diff --git a/horovod/common/elastic.py b/horovod/common/elastic.py new file mode 100644 index 0000000000..c253824bf5 --- /dev/null +++ b/horovod/common/elastic.py @@ -0,0 +1,171 @@ +# Copyright 2020 Uber Technologies, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import + +import functools + +from six.moves import queue + +from horovod.common.exceptions import HorovodInternalError, HostsUpdatedInterrupt +from horovod.run.elastic.worker import WorkerNotificationManager + + +notification_manager = WorkerNotificationManager() + + +class State(object): + """State representation used for tracking in memory state across workers. + + Args: + bcast_object: Function used to broadcast a variable from rank 0 to the other workers. + get_rank: Function that returns the current rank of this worker. + """ + def __init__(self, bcast_object, get_rank): + self._bcast_object = bcast_object + self._rank = get_rank + self._host_messages = queue.Queue() + self._last_updated_timestamp = 0 + self._reset_callbacks = [] + + def register_reset_callbacks(self, callbacks): + """Register callbacks that will be invoked following a reset event (worker added or removed). + + For example, a common use of a reset callback would be to update the learning rate scale with the + new number of workers. + + Args: + callbacks: list of functions to execute. + """ + self._reset_callbacks.extend(callbacks) + + def on_reset(self): + self._host_messages = queue.Queue() + self.reset() + for callback in self._reset_callbacks: + callback() + + def on_hosts_updated(self, timestamp): + self._host_messages.put(timestamp) + + def commit(self): + """Commits all modifications to state tracked by this object to host memory. + + This call will also check for any changes to known hosts, and raise a `HostsUpdatedInterrupt` + if any were detected. + + Because commits are a heavy operation involving data copy (potentially from GPU to host), it is + recommended to consider committing less frequently than once per batch. This allows users to tradeoff + between per-batch execution time and lost training steps in the event of a worker failure. + """ + self.save() + self.check_host_updates() + + def check_host_updates(self): + """Checks that a notification has been sent indicating that hosts can be added or will be removed. + + Raises a `HostsUpdatedInterrupt` if such a notification has been received. + """ + # Iterate through the update messages sent from the server. If the update timestamp + # is greater than the last update timestamp, then trigger a HostsUpdatedException. + last_updated_timestamp = prev_timestamp = self._last_updated_timestamp + while not self._host_messages.empty(): + timestamp = self._host_messages.get() + if timestamp > last_updated_timestamp: + last_updated_timestamp = timestamp + + # In order to ensure all workers raise the exception at the same time, we need to sync + # the updated state across all the workers. + # TODO(travis): this should be a max allreduce to account for changes in rank 0 + prev_timestamp, self._last_updated_timestamp = self._bcast_object((prev_timestamp, last_updated_timestamp)) + + # At this point, updated state is globally consistent across all ranks. + if self._last_updated_timestamp > prev_timestamp: + raise HostsUpdatedInterrupt() + + def save(self): + """Saves state to host memory.""" + raise NotImplementedError() + + def restore(self): + """Restores the last committed state, undoing any uncommitted modifications.""" + raise NotImplementedError() + + def sync(self): + """Synchronize state across workers.""" + raise NotImplementedError() + + def reset(self): + """Reset objects and variables following a reset event (before synchronization).""" + pass + + +class ObjectState(State): + """State for simple Python objects. + + Every object is specified as a keyword argument, and will be assigned as an attribute. + + Args: + bcast_object: Horovod broadcast object function used to sync state dictionary. + get_rank: Horovod rank function used to identify is this process is the coordinator. + kwargs: Properties to sync, will be exposed as attributes of the object. + """ + def __init__(self, bcast_object, get_rank, **kwargs): + self._bcast_object = bcast_object + self._saved_state = kwargs + self._set_attrs() + super(ObjectState, self).__init__(bcast_object=bcast_object, get_rank=get_rank) + + def save(self): + new_state = {} + for attr in self._saved_state.keys(): + new_state[attr] = getattr(self, attr) + self._saved_state = new_state + + def restore(self): + self._set_attrs() + + def sync(self): + if self._saved_state: + self._saved_state = self._bcast_object(self._saved_state) + self._set_attrs() + + def _set_attrs(self): + for attr, value in self._saved_state.items(): + setattr(self, attr, value) + + +def run_fn(func, reset): + @functools.wraps(func) + def wrapper(state, *args, **kwargs): + notification_manager.init() + notification_manager.register_listener(state) + + try: + while True: + state.sync() + + try: + return func(state, *args, **kwargs) + except HorovodInternalError: + state.restore() + except HostsUpdatedInterrupt: + pass + + reset() + state.on_reset() + finally: + notification_manager.remove_listener(state) + return wrapper diff --git a/horovod/common/exceptions.py b/horovod/common/exceptions.py new file mode 100644 index 0000000000..83500fc5f8 --- /dev/null +++ b/horovod/common/exceptions.py @@ -0,0 +1,33 @@ +# Copyright 2019 Uber Technologies, Inc. All Rights Reserved. +# Modifications copyright Microsoft +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import + + +class HorovodInternalError(RuntimeError): + """Internal error raised when a Horovod collective operation (e.g., allreduce) fails. + + This is handled in elastic mode as a recoverable error, and will result in a reset event. + """ + pass + + +class HostsUpdatedInterrupt(RuntimeError): + """Internal interrupt event indicating that the set of hosts in the job has changed. + + In elastic mode, this will result in a reset event without a restore to committed state. + """ + pass diff --git a/horovod/common/gloo/gloo_context.cc b/horovod/common/gloo/gloo_context.cc index f5e34c00bc..db09d270f7 100644 --- a/horovod/common/gloo/gloo_context.cc +++ b/horovod/common/gloo/gloo_context.cc @@ -17,6 +17,8 @@ #include #include +#include +#include #include "gloo/rendezvous/context.h" #include "gloo/rendezvous/file_store.h" @@ -41,12 +43,24 @@ namespace common { #define HOROVOD_GLOO_GLOBAL_PREFIX "global_" #define HOROVOD_GLOO_LOCAL_PREFIX "local_" #define HOROVOD_GLOO_CROSS_PREFIX "cross_" +#define HOROVOD_GLOO_GET_RANK_AND_SIZE "rank_and_size" +#define HOROVOD_HOSTNAME "HOROVOD_HOSTNAME" #define HOROVOD_RANK "HOROVOD_RANK" #define HOROVOD_SIZE "HOROVOD_SIZE" #define HOROVOD_LOCAL_RANK "HOROVOD_LOCAL_RANK" #define HOROVOD_LOCAL_SIZE "HOROVOD_LOCAL_SIZE" #define HOROVOD_CROSS_RANK "HOROVOD_CROSS_RANK" #define HOROVOD_CROSS_SIZE "HOROVOD_CROSS_SIZE" +#define HOROVOD_ELASTIC "HOROVOD_ELASTIC" + +int ParseNextInt(std::stringstream& ss) { + assert(ss.good()); + + std::string substr; + getline(ss, substr, ','); + + return (int) std::strtol(substr.c_str(), nullptr, 10); +} std::chrono::milliseconds GetTimeoutFromEnv() { auto s = std::chrono::seconds(GetIntEnvOrDefault(HOROVOD_GLOO_TIMEOUT_SECONDS, 30)); @@ -140,6 +154,55 @@ void GlooContext::Initialize(const std::string& gloo_iface) { LOG(DEBUG) << "no rendezvous server provided, assuming single process execution"; } + bool elastic = GetBoolEnvOrDefault(HOROVOD_ELASTIC, false); + if (elastic && reset_) { + LOG(DEBUG) << "elastic mode reinitialization started, reset rank=" << rank << " size=" << size; + std::string hostname = std::getenv(HOROVOD_HOSTNAME); + std::string server_addr = rendezvous_addr_env; + std::string scope = HOROVOD_GLOO_GET_RANK_AND_SIZE; + HTTPStore init_store(server_addr, rendezvous_port, scope, rank); + + auto key = hostname + ":" + std::to_string(local_rank); + std::vector result = init_store.get(key); + std::string s(result.begin(), result.end()); + std::stringstream ss(s); + + int last_rank = rank; + int last_size = size; + int last_local_rank = local_rank; + int last_local_size = local_size; + int last_cross_rank = cross_rank; + int last_cross_size = cross_size; + + rank = ParseNextInt(ss); + if (rank == -1) { + // Signals that this host is not part of the job + std::ostringstream out; + out << hostname << "[" << local_rank << "] has been removed from elastic job"; + throw std::runtime_error(out.str()); + } + + size = ParseNextInt(ss); + local_rank = ParseNextInt(ss); + local_size = ParseNextInt(ss); + cross_rank = ParseNextInt(ss); + cross_size = ParseNextInt(ss); + + SetEnv(HOROVOD_RANK, std::to_string(rank).c_str()); + SetEnv(HOROVOD_SIZE, std::to_string(size).c_str()); + SetEnv(HOROVOD_LOCAL_RANK, std::to_string(local_rank).c_str()); + SetEnv(HOROVOD_LOCAL_SIZE, std::to_string(local_size).c_str()); + SetEnv(HOROVOD_CROSS_RANK, std::to_string(cross_rank).c_str()); + SetEnv(HOROVOD_CROSS_SIZE, std::to_string(cross_size).c_str()); + LOG(DEBUG) << "elastic mode reinitialization complete, updated" << + " rank: " << last_rank << " -> " << rank << + " size: " << last_size << " -> " << size << + " local_rank: " << last_local_rank << " -> " << local_rank << + " local_size: " << last_local_size << " -> " << local_size << + " cross_rank: " << last_cross_rank << " -> " << cross_rank << + " cross_size: " << last_cross_size << " -> " << cross_size; + } + ctx = Rendezvous(HOROVOD_GLOO_GLOBAL_PREFIX, rendezvous_addr_env, rendezvous_port, rank, size, dev, timeout); @@ -164,6 +227,7 @@ void GlooContext::Finalize() { ctx.reset(); cross_ctx.reset(); local_ctx.reset(); + reset_ = true; } std::shared_ptr diff --git a/horovod/common/gloo/gloo_context.h b/horovod/common/gloo/gloo_context.h index 11708bd200..1694981b17 100644 --- a/horovod/common/gloo/gloo_context.h +++ b/horovod/common/gloo/gloo_context.h @@ -47,7 +47,6 @@ struct GlooContext { bool IsEnabled() { return enabled_; } - std::shared_ptr ctx = nullptr; // Global context std::shared_ptr cross_ctx = nullptr; std::shared_ptr local_ctx = nullptr; @@ -55,6 +54,7 @@ struct GlooContext { private: // Flag indicating whether gloo is enabled. bool enabled_ = false; + bool reset_ = false; }; } // namespace common diff --git a/horovod/common/gloo/gloo_controller.cc b/horovod/common/gloo/gloo_controller.cc index 9b431ec820..00622cacb5 100644 --- a/horovod/common/gloo/gloo_controller.cc +++ b/horovod/common/gloo/gloo_controller.cc @@ -125,7 +125,7 @@ void GlooController::RecvReadyTensors(std::vector& ready_to_reduce, // ranks at this tick. // 1. Get message lengths from every rank. - auto recvcounts = new int[size_]; + std::unique_ptr recvcounts(new int[size_]); // do allgather { @@ -133,12 +133,12 @@ void GlooController::RecvReadyTensors(std::vector& ready_to_reduce, int send_data = 0; gloo::AllgatherOptions opts(gloo_context_.ctx); opts.setInput(&send_data, 1); - opts.setOutput(recvcounts, size_); + opts.setOutput(recvcounts.get(), size_); gloo::allgather(opts); } // 2. Compute displacements. - auto displcmnts = new int[size_]; + std::unique_ptr displcmnts(new int[size_]); size_t total_size = 0; for (int i = 0; i < size_; ++i) { if (i == 0) { @@ -150,15 +150,15 @@ void GlooController::RecvReadyTensors(std::vector& ready_to_reduce, } // 3. Collect messages from every rank. - auto buffer = new uint8_t[total_size]; + std::unique_ptr buffer(new uint8_t[total_size]); // do allgatherv { auto input = new uint8_t[0]; gloo::AllgathervOptions opts(gloo_context_.ctx); opts.setInput(input, 0); - std::vector count_vec(recvcounts, recvcounts + size_); - opts.setOutput(buffer, count_vec); + std::vector count_vec(recvcounts.get(), recvcounts.get() + size_); + opts.setOutput(buffer.get(), count_vec); gloo::allgatherv(opts); } @@ -166,16 +166,11 @@ void GlooController::RecvReadyTensors(std::vector& ready_to_reduce, // create a dummy list for rank 0 ready_list.emplace_back(); for (int i = 1; i < size_; ++i) { - auto rank_buffer_ptr = buffer + displcmnts[i]; + auto rank_buffer_ptr = buffer.get() + displcmnts[i]; RequestList received_message_list; RequestList::ParseFromBytes(received_message_list, rank_buffer_ptr); ready_list.push_back(std::move(received_message_list)); } - - // 5. Free buffers. - delete[] recvcounts; - delete[] displcmnts; - delete[] buffer; } void GlooController::SendFinalTensors(ResponseList& response_list) { @@ -209,16 +204,16 @@ void GlooController::SendReadyTensors(RequestList& message_list) { // Gloo doesn't have the gatherv options, using allgatherv instead. // send message length to root - auto recvcounts = new int[size_]; + std::unique_ptr recvcounts(new int[size_]); int encoded_message_length = (int)encoded_message.length() + 1; { gloo::AllgatherOptions opts(gloo_context_.ctx); opts.setInput(&encoded_message_length, 1); - opts.setOutput(recvcounts, size_); + opts.setOutput(recvcounts.get(), size_); gloo::allgather(opts); } - auto displcmnts = new int[size_]; + std::unique_ptr displcmnts(new int[size_]); size_t total_size = 0; for (int i = 0; i < size_; ++i) { if (i == 0) { @@ -230,19 +225,15 @@ void GlooController::SendReadyTensors(RequestList& message_list) { } // 3. Collect messages from every rank. - auto buffer = new uint8_t[total_size]; + std::unique_ptr buffer(new uint8_t[total_size]); // send message body to root { gloo::AllgathervOptions opts(gloo_context_.ctx); opts.setInput((uint8_t*)encoded_message.c_str(), encoded_message_length); - std::vector count_vec(recvcounts, recvcounts + size_); - opts.setOutput((uint8_t*)buffer, count_vec); + std::vector count_vec(recvcounts.get(), recvcounts.get() + size_); + opts.setOutput((uint8_t*)buffer.get(), count_vec); gloo::allgatherv(opts); } - - delete[] recvcounts; - delete[] displcmnts; - delete[] buffer; } void GlooController::RecvFinalTensors(ResponseList& response_list) { @@ -255,17 +246,16 @@ void GlooController::RecvFinalTensors(ResponseList& response_list) { gloo::broadcast(opts); } // root broadcast final message to others - auto buffer = new uint8_t[msg_length]; - memset(buffer, 0, msg_length); + std::unique_ptr buffer(new uint8_t[msg_length]); + memset(buffer.get(), 0, msg_length); { gloo::BroadcastOptions opts(gloo_context_.ctx); - opts.setOutput((uint8_t*)buffer, msg_length); + opts.setOutput((uint8_t*)buffer.get(), msg_length); opts.setRoot(RANK_ZERO); gloo::broadcast(opts); } - ResponseList::ParseFromBytes(response_list, buffer); - delete[] buffer; + ResponseList::ParseFromBytes(response_list, buffer.get()); } void GlooController::Bcast(void* buffer, size_t size, int root_rank, diff --git a/horovod/common/operations.cc b/horovod/common/operations.cc index 3d0335b3ea..45ea8fa98c 100644 --- a/horovod/common/operations.cc +++ b/horovod/common/operations.cc @@ -373,7 +373,6 @@ void BackgroundThreadLoop(HorovodGlobalState& state) { gloo_context.Initialize(ParseGlooIface()); } #endif - // Initialize controller state.controller->Initialize(); @@ -506,8 +505,11 @@ void BackgroundThreadLoop(HorovodGlobalState& state) { LOG(INFO, horovod_global.controller->GetRank()) << "Horovod Initialized"; // Iterate until shutdown. - while (RunLoopOnce(state)) - ; + try { + while (RunLoopOnce(state)); + } catch (const std::exception& ex) { + LOG(ERROR) << "Horovod background loop uncaught exception: " << ex.what(); + } // Finalize all contexts #if HAVE_NCCL @@ -580,7 +582,7 @@ bool RunLoopOnce(HorovodGlobalState& state) { int rank = state.controller->GetRank(); for (auto& response : response_list.responses()) { LOG(TRACE, rank) << "Performing " << response.tensor_names_string(); - LOG(DEBUG, rank) << "Processing " << response.tensor_names().size() + LOG(TRACE, rank) << "Processing " << response.tensor_names().size() << " tensors"; PerformOperation(response, horovod_global); LOG(TRACE, rank) << "Finished performing " diff --git a/horovod/common/ops/adasum_gpu_operations.cc b/horovod/common/ops/adasum_gpu_operations.cc index cc403a6c28..f771b984c0 100644 --- a/horovod/common/ops/adasum_gpu_operations.cc +++ b/horovod/common/ops/adasum_gpu_operations.cc @@ -159,7 +159,7 @@ AdasumGpuAllreduceOp::NcclHierarchical(std::vector& entries, (size_t)num_elements_per_rank, GetNCCLDataType(first_entry.tensor), ncclSum, *nccl_op_context_.nccl_comm_, *gpu_op_context_.stream); - nccl_context_->ErrorCheck("ncclReduceScatter", nccl_result); + nccl_context_->ErrorCheck("ncclReduceScatter", nccl_result, *nccl_op_context_.nccl_comm_); if (global_state_->timeline.Initialized()) { gpu_context_->RecordEvent(gpu_op_context_.event_queue, NCCL_REDUCESCATTER, *gpu_op_context_.stream); @@ -174,7 +174,7 @@ AdasumGpuAllreduceOp::NcclHierarchical(std::vector& entries, (size_t)num_elements_remaining, GetNCCLDataType(first_entry.tensor), ncclSum, root_rank, *nccl_op_context_.nccl_comm_, *gpu_op_context_.stream); - nccl_context_->ErrorCheck("ncclReduce", nccl_result); + nccl_context_->ErrorCheck("ncclReduce", nccl_result, *nccl_op_context_.nccl_comm_); if (global_state_->timeline.Initialized()) { gpu_context_->RecordEvent(gpu_op_context_.event_queue, NCCL_REDUCE, *gpu_op_context_.stream); @@ -267,7 +267,8 @@ AdasumGpuAllreduceOp::NcclHierarchical(std::vector& entries, "ncclAllGather", ncclAllGather(buffer_data_at_rank_offset, buffer_data, (size_t)num_elements_per_rank, GetNCCLDataType(first_entry.tensor), - *nccl_op_context_.nccl_comm_, *gpu_op_context_.stream)); + *nccl_op_context_.nccl_comm_, *gpu_op_context_.stream), + *nccl_op_context_.nccl_comm_); if (global_state_->timeline.Initialized()) { gpu_context_->RecordEvent(gpu_op_context_.event_queue, NCCL_ALLGATHER, *gpu_op_context_.stream); @@ -278,7 +279,8 @@ AdasumGpuAllreduceOp::NcclHierarchical(std::vector& entries, "ncclBcast", ncclBcast(buffer_data_remainder, (size_t)num_elements_remaining, GetNCCLDataType(first_entry.tensor), root_rank, *nccl_op_context_.nccl_comm_, - *gpu_op_context_.stream)); + *gpu_op_context_.stream), + *nccl_op_context_.nccl_comm_); if (global_state_->timeline.Initialized()) { gpu_context_->RecordEvent(gpu_op_context_.event_queue, NCCL_BCAST, *gpu_op_context_.stream); diff --git a/horovod/common/ops/nccl_operations.cc b/horovod/common/ops/nccl_operations.cc index 29a0aa8fe0..787f28d788 100644 --- a/horovod/common/ops/nccl_operations.cc +++ b/horovod/common/ops/nccl_operations.cc @@ -41,8 +41,9 @@ ncclDataType_t GetNCCLDataType(const std::shared_ptr tensor) { } } -void NCCLContext::ErrorCheck(std::string op_name, ncclResult_t nccl_result) { +void NCCLContext::ErrorCheck(std::string op_name, ncclResult_t nccl_result, ncclComm_t& nccl_comm) { if (nccl_result != ncclSuccess) { + ncclCommAbort(nccl_comm); throw std::logic_error(std::string(op_name) + " failed: " + ncclGetErrorString(nccl_result)); } } @@ -70,7 +71,7 @@ void NCCLOpContext::InitNCCLComm(const std::vector& entries, ncclUniqueId nccl_id; if (nccl_rank == 0) { - nccl_context_->ErrorCheck("ncclGetUniqueId", ncclGetUniqueId(&nccl_id)); + nccl_context_->ErrorCheck("ncclGetUniqueId", ncclGetUniqueId(&nccl_id), nccl_comm); } global_state_->controller->Bcast((void*)&nccl_id, sizeof(nccl_id), 0, @@ -78,7 +79,7 @@ void NCCLOpContext::InitNCCLComm(const std::vector& entries, ncclComm_t new_nccl_comm; auto nccl_result = ncclCommInitRank(&new_nccl_comm, nccl_size, nccl_id, nccl_rank); - nccl_context_->ErrorCheck("ncclCommInitRank", nccl_result); + nccl_context_->ErrorCheck("ncclCommInitRank", nccl_result, nccl_comm); nccl_comm = new_nccl_comm; // Barrier helps NCCL to synchronize after initialization and avoid @@ -141,7 +142,7 @@ Status NCCLAllreduce::Execute(std::vector& entries, (size_t) num_elements, GetNCCLDataType(first_entry.tensor), ncclSum, *nccl_op_context_.nccl_comm_, *gpu_op_context_.stream); - nccl_context_->ErrorCheck("ncclAllReduce", nccl_result); + nccl_context_->ErrorCheck("ncclAllReduce", nccl_result, *nccl_op_context_.nccl_comm_); if (global_state_->timeline.Initialized()) { gpu_context_->RecordEvent(gpu_op_context_.event_queue, NCCL_ALLREDUCE, *gpu_op_context_.stream); } @@ -266,7 +267,7 @@ NCCLHierarchicalAllreduce::Execute(std::vector& entries, (size_t) num_elements_per_rank, GetNCCLDataType(first_entry.tensor), ncclSum, *nccl_op_context_.nccl_comm_, *gpu_op_context_.stream); - nccl_context_->ErrorCheck("ncclReduceScatter", nccl_result); + nccl_context_->ErrorCheck("ncclReduceScatter", nccl_result, *nccl_op_context_.nccl_comm_); if (global_state_->timeline.Initialized()) { gpu_context_->RecordEvent(gpu_op_context_.event_queue, NCCL_REDUCESCATTER, *gpu_op_context_.stream); } @@ -280,7 +281,7 @@ NCCLHierarchicalAllreduce::Execute(std::vector& entries, (size_t) num_elements_remaining, GetNCCLDataType(first_entry.tensor), ncclSum, root_rank, *nccl_op_context_.nccl_comm_, *gpu_op_context_.stream); - nccl_context_->ErrorCheck("ncclReduce", nccl_result); + nccl_context_->ErrorCheck("ncclReduce", nccl_result, *nccl_op_context_.nccl_comm_); if (global_state_->timeline.Initialized()) { gpu_context_->RecordEvent(gpu_op_context_.event_queue, NCCL_REDUCE, *gpu_op_context_.stream); } @@ -325,7 +326,8 @@ NCCLHierarchicalAllreduce::Execute(std::vector& entries, ncclAllGather(buffer_data_at_rank_offset, buffer_data, (size_t) num_elements_per_rank, GetNCCLDataType(first_entry.tensor), - *nccl_op_context_.nccl_comm_, *gpu_op_context_.stream)); + *nccl_op_context_.nccl_comm_, *gpu_op_context_.stream), + *nccl_op_context_.nccl_comm_); if (global_state_->timeline.Initialized()) { gpu_context_->RecordEvent(gpu_op_context_.event_queue, NCCL_ALLGATHER, *gpu_op_context_.stream); } @@ -335,7 +337,8 @@ NCCLHierarchicalAllreduce::Execute(std::vector& entries, ncclBcast(buffer_data_remainder, (size_t) num_elements_remaining, GetNCCLDataType(first_entry.tensor), root_rank, - *nccl_op_context_.nccl_comm_, *gpu_op_context_.stream)); + *nccl_op_context_.nccl_comm_, *gpu_op_context_.stream), + *nccl_op_context_.nccl_comm_); if (global_state_->timeline.Initialized()) { gpu_context_->RecordEvent(gpu_op_context_.event_queue, NCCL_BCAST, *gpu_op_context_.stream); } @@ -387,7 +390,8 @@ Status NCCLBroadcast::Execute(std::vector& entries, e.tensor->shape().num_elements() * DataType_Size(e.tensor->dtype()), ncclChar, e.root_rank, - *nccl_op_context_.nccl_comm_, *gpu_op_context_.stream)); + *nccl_op_context_.nccl_comm_, *gpu_op_context_.stream), + *nccl_op_context_.nccl_comm_); if (global_state_->timeline.Initialized()) { gpu_context_->RecordEvent(gpu_op_context_.event_queue, NCCL_BCAST, *gpu_op_context_.stream); } diff --git a/horovod/common/ops/nccl_operations.h b/horovod/common/ops/nccl_operations.h index 078f61be4b..6629ae1574 100644 --- a/horovod/common/ops/nccl_operations.h +++ b/horovod/common/ops/nccl_operations.h @@ -37,7 +37,7 @@ ncclDataType_t GetNCCLDataType(const std::shared_ptr tensor); struct NCCLContext { std::vector, ncclComm_t>> nccl_comms; - void ErrorCheck(std::string op_name, ncclResult_t nccl_result); + void ErrorCheck(std::string op_name, ncclResult_t nccl_result, ncclComm_t& nccl_comm); void ShutDown(); }; diff --git a/horovod/common/utils/env_parser.cc b/horovod/common/utils/env_parser.cc index e291fadc74..bd3b6471e5 100644 --- a/horovod/common/utils/env_parser.cc +++ b/horovod/common/utils/env_parser.cc @@ -17,6 +17,7 @@ #include #include +#include #include "../logging.h" #include "../operations.h" @@ -142,6 +143,11 @@ void SetBoolFromEnv(const char* env, bool& val, bool value_if_set) { } } +bool GetBoolEnvOrDefault(const char* env_variable, bool default_value) { + auto env_value = std::getenv(env_variable); + return env_value != nullptr ? (bool) std::strtol(env_value, nullptr, 10) : default_value; +} + void SetIntFromEnv(const char* env, int& val) { auto env_value = std::getenv(env); if (env_value != nullptr) { @@ -159,5 +165,9 @@ double GetDoubleEnvOrDefault(const char* env_variable, double default_value) { return env_value != nullptr ? std::strtod(env_value, nullptr) : default_value; } +void SetEnv(const char* env_variable, const char* env_value) { + setenv(env_variable, env_value, true); +} + } // namespace common } diff --git a/horovod/common/utils/env_parser.h b/horovod/common/utils/env_parser.h index 46386ef6ae..5ce0cf54ca 100644 --- a/horovod/common/utils/env_parser.h +++ b/horovod/common/utils/env_parser.h @@ -37,12 +37,16 @@ void ParseStallInspectorFromEnv(StallInspector& stall_inspector); void SetBoolFromEnv(const char* env, bool& val, bool value_if_set); +bool GetBoolEnvOrDefault(const char* env_variable, bool default_value); + void SetIntFromEnv(const char* env, int& val); int GetIntEnvOrDefault(const char* env_variable, int default_value); double GetDoubleEnvOrDefault(const char* env_variable, double default_value); +void SetEnv(const char* env_variable, const char* env_value); + } // namespace common } // namespace horovod diff --git a/horovod/keras/__init__.py b/horovod/keras/__init__.py index ff4e10a360..f31014de2d 100644 --- a/horovod/keras/__init__.py +++ b/horovod/keras/__init__.py @@ -27,7 +27,7 @@ from horovod.tensorflow import nccl_built, ddl_built, ccl_built from horovod.tensorflow import Compression -from horovod.keras import callbacks +from horovod.keras import callbacks, elastic import horovod._keras as _impl diff --git a/horovod/keras/elastic.py b/horovod/keras/elastic.py new file mode 100644 index 0000000000..ba87c66e28 --- /dev/null +++ b/horovod/keras/elastic.py @@ -0,0 +1,85 @@ +# Copyright 2020 Uber Technologies, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import + +import keras + +from horovod._keras import elastic as _impl +from horovod.tensorflow.elastic import TensorFlowKerasState + + +class KerasState(TensorFlowKerasState): + """State representation of a `keras` model and optimizer. + + Args: + model: Keras model. + optimizer: Optional optimizer, can be compiled into model instead. + kwargs: Additional properties to sync, will be exposed as attributes of the object. + """ + def __init__(self, model, optimizer=None, **kwargs): + super(KerasState, self).__init__(model, optimizer=optimizer, backend=keras.backend, **kwargs) + + +class CommitStateCallback(_impl.CommitStateCallbackImpl, keras.callbacks.Callback): + """ + Keras Callback that will commit the `state` object every `batches_per_commit` + batches at the end of each batch. + """ + + def __init__(self, state, batches_per_commit=1): + """ + Constructs a new CommitStateCallback. + + Args: + state: `horovod.common.elastic.State` object to be committed. + batches_per_commit: Number of batches to complete between each commit (default: 1). + """ + super(CommitStateCallback, self).__init__(keras.backend, state, batches_per_commit) + + +class UpdateBatchStateCallback(_impl.UpdateBatchStateCallbackImpl, keras.callbacks.Callback): + """ + Keras Callback that will update the value of `state.batch` with the current batch number at + the end of each batch. Batch will reset to 0 at the end of each epoch. + + If `steps_per_epoch` is set, then this callback will also ensure that the number of steps + in the first epoch following a reset is shortened by the number of batches already processed. + """ + + def __init__(self, state): + """ + Constructs a new UpdateBatchStateCallback. + + Args: + state: `horovod.common.elastic.State` object to be updated. + """ + super(UpdateBatchStateCallback, self).__init__(keras.backend, state) + + +class UpdateEpochStateCallback(_impl.UpdateEpochStateCallbackImpl, keras.callbacks.Callback): + """ + Keras Callback that will update the value of `state.epoch` with the current epoch number at + the end of each epoch. + """ + + def __init__(self, state): + """ + Constructs a new UpdateEpochStateCallback. + + Args: + state: `horovod.common.elastic.State` object to be updated. + """ + super(UpdateEpochStateCallback, self).__init__(keras.backend, state) diff --git a/horovod/run/common/util/codec.py b/horovod/run/common/util/codec.py index 4578255ba7..1b421b447a 100644 --- a/horovod/run/common/util/codec.py +++ b/horovod/run/common/util/codec.py @@ -22,6 +22,7 @@ def loads_base64(encoded): return cloudpickle.loads(decoded) -def dumps_base64(obj): +def dumps_base64(obj, to_ascii=True): serialized = cloudpickle.dumps(obj) - return base64.b64encode(serialized).decode('ascii') + encoded = base64.b64encode(serialized) + return encoded.decode('ascii') if to_ascii else encoded diff --git a/horovod/run/common/util/config_parser.py b/horovod/run/common/util/config_parser.py index 74f55d604a..155a98179b 100644 --- a/horovod/run/common/util/config_parser.py +++ b/horovod/run/common/util/config_parser.py @@ -1,3 +1,20 @@ +# Copyright 2020 Uber Technologies, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import logging + # Parameter knobs HOROVOD_FUSION_THRESHOLD = 'HOROVOD_FUSION_THRESHOLD' HOROVOD_CYCLE_TIME = 'HOROVOD_CYCLE_TIME' diff --git a/horovod/run/common/util/hosts.py b/horovod/run/common/util/hosts.py new file mode 100644 index 0000000000..06ad8100fb --- /dev/null +++ b/horovod/run/common/util/hosts.py @@ -0,0 +1,123 @@ +# Copyright 2019 Uber Technologies, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import + +import collections + + +class HostInfo: + def __init__(self, hostname, slots): + self.hostname = hostname + self.slots = slots + + @staticmethod + def from_string(host_string): + hostname, slots = host_string.strip().split(':') + return HostInfo(hostname, int(slots)) + + +class SlotInfo: + def __init__(self, hostname, rank, local_rank, cross_rank, size=None, local_size=None, cross_size=None): + self.hostname = hostname + self.rank = rank + self.size = size + self.local_rank = local_rank + self.local_size = local_size + self.cross_rank = cross_rank + self.cross_size = cross_size + + def to_response_string(self): + return ','.join(str(v) for v in [self.rank, self.size, + self.local_rank, self.local_size, + self.cross_rank, self.cross_size]) + + def __eq__(self, other): + if isinstance(other, SlotInfo): + return self.hostname == other.hostname and \ + self.rank == other.rank and self.size == other.size and \ + self.local_rank == other.local_rank and self.local_size == other.local_size and \ + self.cross_rank == other.cross_rank and self.cross_size == other.cross_size + return False + + +INVALID_SLOT_INFO = SlotInfo(hostname='', + rank=-1, local_rank=-1, cross_rank=-1, + size=-1, local_size=-1, cross_size=-1) + + +def parse_hosts(hosts_string): + """Parse a string of comma-separated hostname:slots mappings into a list of HostItem objects. + + :param hosts_string: list of addresses and number of processes on each host. + For example: + - 'worker-0:2,worker-1:2' + - '10.11.11.11:4,10.11.11.12:4' + :return: a list of HostInfo objects describing host to slot mappings + :rtype: list[HostInfo] + """ + return [HostInfo.from_string(host_string) for host_string in hosts_string.split(',')] + + +def get_host_assignments(hosts, min_np, max_np=None): + """Assign hosts with process capacities (slots) to ranks in the Horovod process. + + This function will try to allocate as many as possible processes on the same host to leverage + local network. + + :param hosts: list of HostInfo objects describing host and slot capacity + :type hosts: list[HostInfo] + :param np: total number of processes to be allocated + :type np: int + :return: a list of the allocation of process on hosts in a AllocInfo object. + Members in the object include: hostname, rank, local_rank, cross_rank, + total_size, local_size, cross_size + :rtype: list[SlotInfo] + """ + rank = 0 + alloc_list = [] + + # key: local_rank; value: cross_size for this local_rank + local_sizes = collections.defaultdict(int) + # key: cross_rank; value: local_size for this cross_rank + cross_sizes = collections.defaultdict(int) + + # allocate processes into slots + for host_idx, host_info in enumerate(hosts): + for local_rank in range(host_info.slots): + if rank == max_np: + break + cross_rank = host_idx + alloc_list.append( + SlotInfo( + host_info.hostname, + rank, + local_rank, + cross_rank)) + cross_sizes[local_rank] += 1 + local_sizes[cross_rank] += 1 + rank += 1 + + if rank < min_np: + raise ValueError('Requested more processes ({}) than there are available slots ({})' + .format(min_np, rank)) + + # Fill in the local_size and cross_size because we can only know these number after + # allocation is done. + for alloc_item in alloc_list: + alloc_item.local_size = local_sizes[alloc_item.cross_rank] + alloc_item.cross_size = cross_sizes[alloc_item.local_rank] + alloc_item.size = rank + return alloc_list diff --git a/horovod/run/common/util/network.py b/horovod/run/common/util/network.py index 7e30de8365..23a6a6d63b 100644 --- a/horovod/run/common/util/network.py +++ b/horovod/run/common/util/network.py @@ -13,10 +13,12 @@ # limitations under the License. # ============================================================================== +import psutil import socket import struct +import threading + import cloudpickle -import psutil from six.moves import queue, socketserver diff --git a/horovod/run/common/util/safe_shell_exec.py b/horovod/run/common/util/safe_shell_exec.py index 46748a0602..e2d530c8b0 100644 --- a/horovod/run/common/util/safe_shell_exec.py +++ b/horovod/run/common/util/safe_shell_exec.py @@ -47,7 +47,10 @@ def terminate_executor_shell_and_children(pid): gone, alive = psutil.wait_procs(p.children(), timeout=GRACEFUL_TERMINATION_TIME_S) # Freeze the process to prevent it from spawning any new children. - p.send_signal(signal.SIGSTOP) + try: + p.send_signal(signal.SIGSTOP) + except psutil.NoSuchProcess: + pass # Kill children recursively. for child in alive: @@ -62,7 +65,11 @@ def terminate_executor_shell_and_children(pid): pass # Kill shell itself. - p.terminate() + try: + p.terminate() + except psutil.NoSuchProcess: + pass + try: p.wait(timeout=GRACEFUL_TERMINATION_TIME_S) except psutil.TimeoutExpired: diff --git a/horovod/run/common/util/settings.py b/horovod/run/common/util/settings.py index 910718cf1e..d01d1c6926 100644 --- a/horovod/run/common/util/settings.py +++ b/horovod/run/common/util/settings.py @@ -13,13 +13,16 @@ # limitations under the License. # ============================================================================== +from __future__ import absolute_import -class Settings(object): - def __init__(self, verbose=0, ssh_port=None, extra_mpi_args=None, tcp_flag=None, - binding_args=None, key=None, timeout=None, num_hosts=None, num_proc=None, - hosts=None, output_filename=None, run_func_mode=None, nics=None): +class BaseSettings(object): + def __init__(self, num_proc=None, verbose=0, ssh_port=None, extra_mpi_args=None, tcp_flag=None, + binding_args=None, key=None, start_timeout=None, output_filename=None, + run_func_mode=None, nics=None, elastic=False): """ + :param num_proc: number of horovod processes (-np) + :type num_proc: int :param verbose: level of verbosity :type verbose: int :param ssh_port: SSH port on all the hosts @@ -32,38 +35,45 @@ def __init__(self, verbose=0, ssh_port=None, extra_mpi_args=None, tcp_flag=None, :type binding_args: string :param key: used for encryption of parameters passed across the hosts :type key: str - :param timeout: has to finish all the checks before this timeout runs - out. - :type timeout: horovod.run.common.util.timeout.Timeout - :param num_hosts: number of horovod hosts - :type num_hosts: int - :param num_proc: number of horovod processes (-np) - :type num_proc: int - :param hosts: string of hostname with slots number - :type hosts: string + :param start_timeout: has to finish all the checks before this timeout runs out. + :type start_timeout: horovod.run.common.util.timeout.Timeout :param output_filename: optional filename to redirect stdout / stderr by process :type output_filename: string :param run_func_mode: whether it is run function mode :type run_func_mode: boolean :param nics: specify the NICs to be used for tcp network communication. :type nics: string + :param elastic: enable elastic auto-scaling and fault tolerance mode + :type elastic: boolean """ + self.num_proc = num_proc self.verbose = verbose self.ssh_port = ssh_port self.extra_mpi_args = extra_mpi_args self.tcp_flag = tcp_flag self.binding_args = binding_args self.key = key - self.timeout = timeout - self.num_hosts = num_hosts - self.num_proc = num_proc - self.hosts = hosts + self.start_timeout = start_timeout self.output_filename = output_filename self.run_func_mode = run_func_mode self.nics = nics - + self.elastic = elastic + # we do not serialize the key, as it is too risky that it could leak unintentionally def __getstate__(self): result = self.__dict__.copy() result['key'] = None return result + + +class Settings(BaseSettings): + def __init__(self, num_hosts=None, hosts=None, **kwargs): + """ + :param num_hosts: number of horovod hosts + :type num_hosts: int + :param hosts: string, comma-delimited, of hostname[s] with slots number[s] + :type hosts: string + """ + super(Settings, self).__init__(**kwargs) + self.num_hosts = num_hosts + self.hosts = hosts diff --git a/horovod/run/driver/driver_service.py b/horovod/run/driver/driver_service.py index 254aed6a9b..33ac00621b 100644 --- a/horovod/run/driver/driver_service.py +++ b/horovod/run/driver/driver_service.py @@ -20,10 +20,11 @@ from socket import AF_INET from psutil import net_if_addrs -from horovod.run.util import cache, lsf, threads from horovod.run.common.service import driver_service from horovod.run.common.util import codec, safe_shell_exec from horovod.run.task import task_service +from horovod.run.util import cache, lsf, network, threads + class HorovodRunDriverService(driver_service.BasicDriverService): NAME = 'horovod driver service' @@ -157,7 +158,7 @@ def _driver_fn(all_host_names, local_host_names, settings): # wait for all the hosts to register with the service service. if settings.verbose >= 2: print('Waiting for the hosts to acknowledge.') - driver.wait_for_initial_registration(settings.timeout) + driver.wait_for_initial_registration(settings.start_timeout) tasks = [ task_service.HorovodRunTaskClient( index, @@ -176,7 +177,7 @@ def _driver_fn(all_host_names, local_host_names, settings): # such as lo0 with address 127.0.0.1. if settings.verbose >= 2: print('Waiting for hosts to perform host-to-host interface checking.') - driver.wait_for_task_to_task_address_updates(settings.timeout) + driver.wait_for_task_to_task_address_updates(settings.start_timeout) if settings.verbose >= 2: print('Host-to-host interface checking successful.') # Determine a set of common interfaces for task-to-task communication. @@ -194,7 +195,7 @@ def _driver_fn(all_host_names, local_host_names, settings): driver.shutdown() -def get_common_interfaces(settings, all_host_names, remote_host_names, fn_cache): +def get_common_interfaces(settings, all_host_names, remote_host_names=None, fn_cache=None): ''' Find the set of common and routed interfaces on all the hosts. :param settings: the object that contains the setting for running horovod @@ -211,6 +212,9 @@ def get_common_interfaces(settings, all_host_names, remote_host_names, fn_cache) if lsf.LSFUtils.using_lsf(): return None + if remote_host_names is None: + remote_host_names = network.filter_local_addresses(all_host_names) + if len(remote_host_names) > 0: if settings.nics: # If args.nics is provided, we will use those interfaces. All the workers @@ -251,4 +255,4 @@ def get_common_interfaces(settings, all_host_names, remote_host_names, fn_cache) if settings.verbose >= 2: print('Local interface found ' + ' '.join(nics)) - return nics \ No newline at end of file + return nics diff --git a/horovod/run/elastic/__init__.py b/horovod/run/elastic/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/horovod/run/elastic/discovery.py b/horovod/run/elastic/discovery.py new file mode 100644 index 0000000000..876c2d2fd7 --- /dev/null +++ b/horovod/run/elastic/discovery.py @@ -0,0 +1,167 @@ +# Copyright 2020 Uber Technologies, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import + +import logging +import threading + +from collections import defaultdict + +import six + +from horovod.run.common.util import safe_shell_exec + + +class HostState(object): + def __init__(self): + self._event = threading.Event() + + # TODO(travis): blacklisted hosts should have a timeout period that increases with each failure + self._blacklisted = False + + def get_event(self): + if self._event.is_set(): + event = threading.Event() + self._event = event + return self._event + + def set_event(self): + self._event.set() + + def blacklist(self): + self._blacklisted = True + self.set_event() + + def is_blacklisted(self): + return self._blacklisted + + +class DiscoveredHosts(object): + def __init__(self, host_slots, host_assignment_order): + self._host_slots = host_slots + self._host_assignment_order = host_assignment_order + + @property + def host_slots(self): + return self._host_slots + + @property + def available_hosts(self): + return set(self._host_assignment_order) + + @property + def host_assignment_order(self): + return self._host_assignment_order + + def get_slots(self, host): + return self._host_slots.get(host, 0) + + def count_available_slots(self): + # Use the host_assignment_order as it does not contain blacklisted hosts + return sum([self.get_slots(host) for host in self._host_assignment_order]) + + def update(self, hosts_state): + self._host_assignment_order = [host for host in self._host_assignment_order + if not hosts_state[host].is_blacklisted()] + return self + + +class HostManager(object): + def __init__(self, discovery): + self._current_hosts = DiscoveredHosts(host_slots={}, host_assignment_order=[]) + self._hosts_state = defaultdict(HostState) + self._discovery = discovery + + def update_available_hosts(self): + # TODO(travis): also check for hosts removed from the blacklist in the future + prev_host_slots = self._current_hosts.host_slots + prev_host_assignment_order = self._current_hosts.host_assignment_order + host_slots = self._discovery.find_available_hosts_and_slots() + if prev_host_slots != host_slots: + available_hosts = set([host for host in host_slots.keys() if not self._hosts_state[host].is_blacklisted()]) + host_assignment_order = HostManager.order_available_hosts(available_hosts, prev_host_assignment_order) + self._current_hosts = DiscoveredHosts(host_slots=host_slots, + host_assignment_order=host_assignment_order) + return True + return False + + @property + def current_hosts(self): + return self._current_hosts.update(self._hosts_state) + + def blacklist(self, host): + if not self._hosts_state[host].is_blacklisted(): + logging.warning('blacklist failing host: {}'.format(host)) + self._hosts_state[host].blacklist() + + def is_blacklisted(self, host): + return self._hosts_state[host].is_blacklisted() + + def get_host_event(self, host): + return self._hosts_state[host].get_event() + + @staticmethod + def order_available_hosts(available_hosts, prev_host_assignment_order): + # We need to ensure this list preserves relative order to ensure the oldest hosts are assigned lower ranks. + host_assignment_order = [host for host in prev_host_assignment_order if host in available_hosts] + known_hosts = set(host_assignment_order) + for host in available_hosts: + if host not in known_hosts: + host_assignment_order.append(host) + return host_assignment_order + + +class HostDiscovery(object): + def find_available_hosts_and_slots(self): + """Returns a dict mapping -> .""" + raise NotImplementedError() + + +class HostDiscoveryScript(HostDiscovery): + def __init__(self, discovery_script, slots): + self._discovery_script = discovery_script + self._default_slots = slots + super(HostDiscoveryScript, self).__init__() + + def find_available_hosts_and_slots(self): + stdout = six.StringIO() + exit_code = safe_shell_exec.execute(self._discovery_script, stdout=stdout) + if exit_code != 0: + raise RuntimeError('Failed to execute discovery script: {}. Exit code: {}' + .format(self._discovery_script, exit_code)) + + host_slots = {} + lines = set(stdout.getvalue().strip().split('\n')) + for line in lines: + host = line + if ':' in line: + host, slots = line.split(':') + host_slots[host] = int(slots) + else: + host_slots[host] = self._default_slots + return host_slots + + +class FixedHosts(HostDiscovery): + def __init__(self, host_slots): + super(FixedHosts, self).__init__() + self._host_slots = host_slots + + def find_available_hosts_and_slots(self): + return self._host_slots + + def set(self, host_slots): + self._host_slots = host_slots diff --git a/horovod/run/elastic/driver.py b/horovod/run/elastic/driver.py new file mode 100644 index 0000000000..7e81bf712e --- /dev/null +++ b/horovod/run/elastic/driver.py @@ -0,0 +1,299 @@ +# Copyright 2020 Uber Technologies, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import + +import logging +import os +import threading +import time + +from collections import defaultdict + +from six.moves import queue + +from horovod.run.common.util import hosts, timeout +from horovod.run.elastic.discovery import HostManager +from horovod.run.elastic.registration import WorkerStateRegistry +from horovod.run.elastic.worker import WorkerNotificationClient + + +DISCOVER_HOSTS_FREQUENCY_SECS = 1.0 +ELASTIC_TIMEOUT_SECS = 600 + + +def _epoch_time_s(): + return int(time.time()) + + +class Results(object): + def __init__(self): + self._results = {} + self._worker_threads = queue.Queue() + + def expect(self, worker_thread): + self._worker_threads.put(worker_thread) + + def add_result(self, key, value): + if key in self._results: + return + self._results[key] = value + + def get_results(self): + while not self._worker_threads.empty(): + worker_thread = self._worker_threads.get() + worker_thread.join() + return self._results + + +class ElasticDriver(object): + def __init__(self, rendezvous, discovery, min_np, max_np, timeout=None, verbose=0): + self._rendezvous = rendezvous + self._host_manager = HostManager(discovery) + self._min_np = min_np + self._max_np = max_np + self._verbose = verbose + + self._host_assignments = {} + self._rank_assignments = {} + self._world_size = 0 + + self._wait_hosts_cond = threading.Condition() + self._timeout = timeout or int(os.getenv('HOROVOD_ELASTIC_TIMEOUT', ELASTIC_TIMEOUT_SECS)) + + self._create_worker_fn = None + self._worker_clients = {} + + self._worker_registry = WorkerStateRegistry(self, self._host_manager) + self._results = Results() + self._shutdown = threading.Event() + + self._discovery_thread = threading.Thread(target=self._discover_hosts) + self._discovery_thread.daemon = True + self._discovery_thread.start() + + def start(self, np, create_worker_fn): + self._create_worker_fn = create_worker_fn + self._activate_workers(np) + + def resume(self): + self._activate_workers(self._min_np) + + def stop(self): + self._shutdown.set() + self._discovery_thread.join() + + def finished(self): + return self._shutdown.is_set() + + def get_results(self): + return self._results.get_results() + + def register_worker_server(self, host, slot, addresses, secret_key): + self._worker_clients[(host, slot)] = WorkerNotificationClient( + addresses, secret_key, self._verbose) + + def get_worker_client(self, slot_info): + return self._worker_clients.get((slot_info.hostname, slot_info.local_rank)) + + def record_ready(self, host, slot): + self._worker_registry.record_ready(host, slot) + + def world_size(self): + return self._world_size + + def local_size(self, host): + return len(self._host_assignments[host]) + + def get_slot_info(self, host, slot): + return self._host_assignments[host][slot] if self.has_rank_assignment(host, slot) \ + else hosts.INVALID_SLOT_INFO + + def get_coordinator_info(self): + return self._rank_assignments.get(0) + + def has_rank_assignment(self, host, slot): + if self._host_manager.is_blacklisted(host): + return False + return host in self._host_assignments and len(self._host_assignments[host]) > slot + + @property + def host_assignments(self): + return self._host_assignments + + def wait_for_available_slots(self, min_np, min_hosts=1): + extra_message = ' An elastic job also requires that at least two hosts ' \ + 'are available to resolve compatible network interfaces. If you know which interfaces ' \ + 'are compatible in your network, set `--nic` to skip this check.' if min_hosts > 1 else '' + + tmout = timeout.Timeout( + self._timeout, + message='Timed out waiting for {{activity}}. Please check that you have ' + 'enough resources to run at least {min_np} Horovod processes.{extra_message}' + .format(min_np=min_np, extra_message=extra_message)) + + self._wait_hosts_cond.acquire() + try: + while True: + current_hosts = self._host_manager.current_hosts + if current_hosts.count_available_slots() >= min_np and len(current_hosts.available_hosts) >= min_hosts: + return current_hosts + if self._shutdown.is_set(): + raise RuntimeError('Job has been shutdown, see above error messages for details.') + self._wait_hosts_cond.wait(tmout.remaining()) + tmout.check_time_out_for('minimum number of slots to become available') + finally: + self._wait_hosts_cond.release() + + def _activate_workers(self, min_np): + logging.info('wait for available slots: {}'.format(min_np)) + current_hosts = self.wait_for_available_slots(min_np) + pending_slots = self._update_host_assignments(current_hosts) + self._worker_registry.reset(self.world_size()) + self._start_worker_processes(pending_slots) + + def _discover_hosts(self): + first_update = True + while not self._shutdown.is_set(): + self._wait_hosts_cond.acquire() + try: + if self._host_manager.update_available_hosts(): + self._notify_workers_host_changes(self._host_manager.current_hosts) + self._wait_hosts_cond.notify_all() + except RuntimeError as e: + if first_update: + # Misconfiguration, fail the job immediately + self._shutdown.set() + self._wait_hosts_cond.notify_all() + raise + # Transient error, retry until timeout + logging.warning(str(e)) + finally: + self._wait_hosts_cond.release() + first_update = False + self._shutdown.wait(DISCOVER_HOSTS_FREQUENCY_SECS) + + def _notify_workers_host_changes(self, current_hosts): + next_host_assignments = {} + if current_hosts.count_available_slots() >= self._min_np: + # Assignments are required to be stable via contract + next_host_assignments, _ = self._get_host_assignments(current_hosts) + + if next_host_assignments == self.host_assignments: + # Skip notifying workers when host changes would not result in changes of host assignments + logging.debug('no host assignment changes, skipping notifications') + return + + coordinator_slot_info = self.get_coordinator_info() + if not coordinator_slot_info: + logging.debug('no coordinator info, skipping notifications') + return + + coordinator_client = self.get_worker_client(coordinator_slot_info) + if not coordinator_client: + logging.debug('no coordinator client, skipping notifications') + return + + timestamp = _epoch_time_s() + try: + coordinator_client.notify_hosts_updated(timestamp) + except: + if self._verbose >= 2: + logging.exception('failed to notify {}[{}] of host updates' + .format(coordinator_slot_info.hostname, + coordinator_slot_info.local_rank)) + + def _update_host_assignments(self, current_hosts): + # Determine the slots that are already filled so we do not respawn these processes + active_slots = set([(host, slot_info.local_rank) + for host, slots in self._host_assignments.items() + for slot_info in slots]) + + # Adjust the host assignments to account for added / removed hosts + host_assignments, host_assignments_list = self._get_host_assignments(current_hosts) + + if len(self._host_assignments) > 0: + # Ensure that at least one previously active host is still assigned, otherwise there is no + # way to sync the state to the new workers + prev_hosts = self._host_assignments.keys() + next_hosts = host_assignments.keys() + if not prev_hosts & next_hosts: + raise RuntimeError('No hosts from previous set remaining, unable to broadcast state.') + + self._host_assignments = host_assignments + self._world_size = len(host_assignments_list) + self._rendezvous.httpd.init(host_assignments_list) + + # Rank assignments map from world rank to slot info + rank_assignments = {} + for slot_info in host_assignments_list: + rank_assignments[slot_info.rank] = slot_info + self._rank_assignments = rank_assignments + + # Get the newly assigned slots that need to be started + pending_slots = [slot_info + for host, slots in self._host_assignments.items() + for slot_info in slots + if (host, slot_info.local_rank) not in active_slots] + return pending_slots + + def _get_host_assignments(self, current_hosts): + # Adjust the host assignments to account for added / removed hosts + host_list = [hosts.HostInfo(host, current_hosts.get_slots(host)) + for host in current_hosts.host_assignment_order] + host_assignments_list = hosts.get_host_assignments(host_list, self._min_np, self._max_np) + host_assignments = defaultdict(list) + for slot_info in host_assignments_list: + host_assignments[slot_info.hostname].append(slot_info) + return host_assignments, host_assignments_list + + def _start_worker_processes(self, pending_slots): + for slot_info in pending_slots: + logging.info('start worker process: {}[{}]'.format(slot_info.hostname, slot_info.local_rank)) + self._start_worker_process(slot_info) + + def _start_worker_process(self, slot_info): + create_worker_fn = self._create_worker_fn + shutdown_event = self._shutdown + host_event = self._host_manager.get_host_event(slot_info.hostname) + + def run_worker(): + res = create_worker_fn(slot_info, [shutdown_event, host_event]) + exit_code, timestamp = res + self._handle_worker_exit(slot_info, exit_code, timestamp) + + thread = threading.Thread(target=run_worker) + thread.daemon = True + thread.start() + self._results.expect(thread) + + def _handle_worker_exit(self, slot_info, exit_code, timestamp): + if not self.has_rank_assignment(slot_info.hostname, slot_info.local_rank): + # Ignore hosts that are not assigned a rank + logging.debug('host {} has been blacklisted, ignoring exit from local_rank={}' + .format(slot_info.hostname, slot_info.local_rank)) + return + + if exit_code == 0: + rendezvous_id = self._worker_registry.record_success(slot_info.hostname, slot_info.local_rank) + else: + rendezvous_id = self._worker_registry.record_failure(slot_info.hostname, slot_info.local_rank) + + if self.finished() and self._worker_registry.last_rendezvous() == rendezvous_id: + logging.debug('adding results for {}[{}]: ({}, {})' + .format(slot_info.hostname, slot_info.local_rank, exit_code, timestamp)) + name = '{}[{}]'.format(slot_info.hostname, slot_info.local_rank) + self._results.add_result(name, (exit_code, timestamp)) + diff --git a/horovod/run/elastic/registration.py b/horovod/run/elastic/registration.py new file mode 100644 index 0000000000..1d6e23c96f --- /dev/null +++ b/horovod/run/elastic/registration.py @@ -0,0 +1,151 @@ +# Copyright 2020 Uber Technologies, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import + +import logging +import threading + +from collections import defaultdict + +READY = 'READY' +SUCCESS = 'SUCCESS' +FAILURE = 'FAILURE' + + +class WorkerStateRegistry(object): + def __init__(self, driver, host_manager, verbose=False): + self._driver = driver + self._host_manager = host_manager + self._lock = threading.Lock() + self._states = {} + self._workers = defaultdict(set) + self._barrier = None + self._rendezvous_id = 0 + self._verbose = verbose + self._size = 0 + + def get_recorded_slots(self): + return self._states.keys() + + def get(self, state): + return self._workers[state] + + def count(self, state): + return len(self._workers[state]) + + def reset(self, size): + with self._lock: + logging.info('reset workers: {}'.format(size)) + self._states.clear() + self._workers.clear() + self._barrier = threading.Barrier(parties=size, action=self._action) + self._rendezvous_id += 1 + self._size = size + + def size(self): + return self._size + + def last_rendezvous(self): + return self._rendezvous_id + + def record_ready(self, host, slot): + return self._record_state(host, slot, READY) + + def record_success(self, host, slot): + return self._record_state(host, slot, SUCCESS) + + def record_failure(self, host, slot): + return self._record_state(host, slot, FAILURE) + + def _record_state(self, host, slot, state): + if self._driver.finished(): + logging.info('driver finished, ignoring registration: {}[{}] = {}'.format(host, slot, state)) + return self._rendezvous_id + + key = (host, slot) + with self._lock: + if key in self._states: + # Worker originally recorded itself as READY, but the worker failed while waiting at the barrier. As + # such, we need to update the state to FAILURE, and we don't want two threads coming from the same + # worker at the barrier. + # + # In order to ensure that the new failing thread can record results in cases of total job failure, + # we also need to block this thread by waiting on the barrier. This requires us to reset the barrier, + # as otherwise this worker will be double-counted (once for the READY thread and once for FAILURE), + # which would cause the barrier to complete too early. + logging.info('key exists, reset barrier: {}[{}] = {}'.format(host, slot, state)) + self._barrier.reset() + logging.info('record state: {}[{}] = {}'.format(host, slot, state)) + self._states[key] = state + self._workers[state].add(key) + rendezvous_id = self._rendezvous_id + + rendezvous_id = self._wait(key, state, rendezvous_id) + return rendezvous_id + + def _wait(self, key, state, rendezvous_id): + while True: + try: + self._barrier.wait() + return rendezvous_id + except threading.BrokenBarrierError: + if self._barrier.broken: + # Timeout or other non-recoverable error, so exit + raise + + # Barrier has been reset + with self._lock: + # Check to make sure the reset was not caused by a change of state for this key + rendezvous_id = self._rendezvous_id + saved_state = self._states.get(key, state) + if saved_state != state: + # This worker changed its state, so do not attempt to wait again to avoid double-counting + raise RuntimeError('State {} overridden by {}'.format(state, saved_state)) + + def _action(self): + self._on_workers_recorded() + + def _on_workers_recorded(self): + logging.info('all {} workers recorded'.format(self.size())) + + # Check for success state, if any process succeeded, shutdown all other processes + if self.count(SUCCESS) > 0: + logging.info('success count == {} -> stop running'.format(self.count(SUCCESS))) + self._driver.stop() + return + + # Check that all processes failed, indicating that processing should stop + if self.count(FAILURE) == self._size: + logging.error('failure count == {} -> stop running'.format(self._size)) + self._driver.stop() + return + + # Check for failures, and add them to the blacklisted hosts list + failures = self.get(FAILURE) + for host, slot in failures: + self._host_manager.blacklist(host) + + # If every active host is blacklisted, then treat this as job failure + if all([self._host_manager.is_blacklisted(host) for host, slot in self.get_recorded_slots()]): + logging.error('blacklisted slots count == {} -> stop running'.format(self._size)) + self._driver.stop() + return + + try: + self._driver.resume() + except Exception: + logging.exception('failed to activate new hosts -> stop running') + self._driver.stop() diff --git a/horovod/run/elastic/rendezvous.py b/horovod/run/elastic/rendezvous.py new file mode 100644 index 0000000000..a4cd75a756 --- /dev/null +++ b/horovod/run/elastic/rendezvous.py @@ -0,0 +1,57 @@ +# Copyright 2019 Uber Technologies, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +from __future__ import absolute_import + +import logging + +from horovod.run.common.util import codec +from horovod.run.http.http_server import RendezvousHandler + +# GET methods +GET_RANK_AND_SIZE = 'rank_and_size' + +# PUT methods +PUT_WORKER_ADDRESSES = 'worker_addresses' + + +def create_rendezvous_handler(driver): + class ElasticRendezvousHandler(RendezvousHandler): + def _get_value(self, scope, key): + if scope == GET_RANK_AND_SIZE: + host, local_rank = key.split(':') + return self._get_rank_and_size(host, int(local_rank)) + + return super(RendezvousHandler, self)._get_value(scope, key) + + def _get_rank_and_size(self, host, local_rank): + logging.info('_get_rank_and_size: {} {}'.format(host, local_rank)) + driver.record_ready(host, local_rank) + slot_info = driver.get_slot_info(host, local_rank) + logging.info('rank and size: {} {}'.format(slot_info.rank, slot_info.size)) + return slot_info.to_response_string().encode('ascii') + + def _put_value(self, scope, key, value): + if scope == PUT_WORKER_ADDRESSES: + host, local_rank = key.split(':') + addresses, secret_key = codec.loads_base64(value) + self._put_worker_addresses(host, int(local_rank), addresses, secret_key) + + super(RendezvousHandler, self)._put_value(scope, key, value) + + def _put_worker_addresses(self, host, local_rank, addresses, secret_key): + driver.register_worker_server(host, local_rank, addresses, secret_key) + + return ElasticRendezvousHandler diff --git a/horovod/run/elastic/settings.py b/horovod/run/elastic/settings.py new file mode 100644 index 0000000000..8665fb4234 --- /dev/null +++ b/horovod/run/elastic/settings.py @@ -0,0 +1,37 @@ +# Copyright 2019 Uber Technologies, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import + +from horovod.run.common.util.settings import BaseSettings + + +class ElasticSettings(BaseSettings): + def __init__(self, discovery, min_np, max_np, elastic_timeout, **kwargs): + """ + :param discovery: object used to detect and manage available hosts + :type discovery: horovod.run.elastic.discovery.HostDiscovery + :param min_np: minimum number of processes + :type min_np: int + :param max_np: maximum number of processes + :type max_np: int + :param elastic_timeout: timeout for elastic initialisation after re-scaling in seconds + :type elastic_timeout: int + """ + super(ElasticSettings, self).__init__(elastic=True, **kwargs) + self.discovery = discovery + self.min_np = min_np + self.max_np = max_np + self.elastic_timeout = elastic_timeout diff --git a/horovod/run/elastic/worker.py b/horovod/run/elastic/worker.py new file mode 100644 index 0000000000..81213875c7 --- /dev/null +++ b/horovod/run/elastic/worker.py @@ -0,0 +1,112 @@ +# Copyright 2020 Uber Technologies, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import + +import os +import threading + +from horovod.run.common.util import network, secret +from horovod.run.elastic.rendezvous import PUT_WORKER_ADDRESSES +from horovod.run.http.http_client import put_data_into_kvstore + + +HOROVOD_GLOO_RENDEZVOUS_ADDR = 'HOROVOD_GLOO_RENDEZVOUS_ADDR' +HOROVOD_GLOO_RENDEZVOUS_PORT = 'HOROVOD_GLOO_RENDEZVOUS_PORT' +HOROVOD_GLOO_IFACE = 'HOROVOD_GLOO_IFACE' +HOROVOD_HOSTNAME = 'HOROVOD_HOSTNAME' +HOROVOD_LOCAL_RANK = 'HOROVOD_LOCAL_RANK' + + +class HostsUpdatedRequest(object): + """Notifies worker that the set of available hosts/slots has changed.""" + def __init__(self, timestamp): + self.timestamp = timestamp + + +class WorkerNotificationManager(object): + def __init__(self): + self._lock = threading.Lock() + self._service = None + self._listeners = set() + + def init(self, rendezvous_addr=None, rendezvous_port=None, + nic=None, hostname=None, local_rank=None): + with self._lock: + if self._service: + return + + rendezvous_addr = rendezvous_addr or os.environ.get(HOROVOD_GLOO_RENDEZVOUS_ADDR) + if not rendezvous_addr: + return + + rendezvous_port = rendezvous_port if rendezvous_port is not None else \ + int(os.environ.get(HOROVOD_GLOO_RENDEZVOUS_PORT)) + nic = nic or os.environ.get(HOROVOD_GLOO_IFACE) + hostname = hostname or os.environ.get(HOROVOD_HOSTNAME) + local_rank = local_rank if local_rank is not None else \ + int(os.environ.get(HOROVOD_LOCAL_RANK)) + + secret_key = secret.make_secret_key() + self._service = WorkerNotificationService(secret_key, nic, self) + + value = (self._service.addresses(), secret_key) + put_data_into_kvstore(rendezvous_addr, + rendezvous_port, + PUT_WORKER_ADDRESSES, + self._create_id(hostname, local_rank), + value) + + def register_listener(self, listener): + self._listeners.add(listener) + + def remove_listener(self, listener): + self._listeners.remove(listener) + + def handle_hosts_updated(self, timestamp): + for listener in self._listeners: + listener.on_hosts_updated(timestamp) + + def _create_id(self, hostname, local_rank): + return '{}:{}'.format(hostname, local_rank) + + +class WorkerNotificationService(network.BasicService): + NAME = 'worker notification service' + + def __init__(self, key, nic, manager): + super(WorkerNotificationService, self).__init__(WorkerNotificationService.NAME, + key, + nic) + self._manager = manager + + def _handle(self, req, client_address): + if isinstance(req, HostsUpdatedRequest): + self._manager.handle_hosts_updated(req.timestamp) + return network.AckResponse() + + return super(WorkerNotificationService, self)._handle(req, client_address) + + +class WorkerNotificationClient(network.BasicClient): + def __init__(self, addresses, key, verbose, match_intf=False): + super(WorkerNotificationClient, self).__init__(WorkerNotificationService.NAME, + addresses, + key, + verbose, + match_intf=match_intf) + + def notify_hosts_updated(self, timestamp): + self._send(HostsUpdatedRequest(timestamp)) diff --git a/horovod/run/gloo_run.py b/horovod/run/gloo_run.py index bb2c1c7211..f8119e5baf 100644 --- a/horovod/run/gloo_run.py +++ b/horovod/run/gloo_run.py @@ -29,87 +29,12 @@ from pipes import quote from horovod.run.common.util import env as env_util, safe_shell_exec +from horovod.run.common.util.hosts import get_host_assignments, parse_hosts +from horovod.run.driver import driver_service +from horovod.run.elastic.driver import ElasticDriver +from horovod.run.elastic.rendezvous import create_rendezvous_handler from horovod.run.http.http_server import RendezvousServer -from horovod.run.util import threads - - -class HostInfo: - def __init__(self, host_item): - hostname, slots = host_item.strip().split(':') - self.hostname = hostname - self.slots = int(slots) - - -class SlotInfo: - def __init__(self, hostname, rank, local_rank, cross_rank, size): - self.hostname = hostname - self.rank = rank - self.size = size - self.local_rank = local_rank - self.local_size = None - self.cross_rank = cross_rank - self.cross_size = None - - -def _allocate(hosts, np): - """ - Find the allocation of processes on hosts, this function will try to - allocate as many as possible processes on the same host to leverage - local network. - :param hosts: list of addresses and number of processes on each host. - For example, - 'worker-0:2,worker-1:2' - '10.11.11.11:4,10.11.11.12,4' - :type hosts: string - :param np: total number of processes to be allocated - :type np: int - :return: a list of the allocation of process on hosts in a AllocInfo object. - Members in the object include: hostname, rank, local_rank, cross_rank, - total_size, local_size, cross_size - :rtype: list[dict()] - """ - - host_list = [] - # split the host string to host list - for host_item in hosts.split(','): - host_list.append(HostInfo(host_item)) - - rank = 0 - alloc_list = [] - - # key: local_rank; value: cross_size for this local_rank - local_sizes = collections.defaultdict(int) - # key: cross_rank; value: local_size for this cross_rank - cross_sizes = collections.defaultdict(int) - - # allocate processes into slots - for host_idx, host_info in enumerate(host_list): - for local_rank in range(host_info.slots): - if rank == np: - break - cross_rank = host_idx - alloc_list.append( - SlotInfo( - host_info.hostname, - rank, - local_rank, - cross_rank, - np)) - cross_sizes[local_rank] += 1 - local_sizes[cross_rank] += 1 - rank += 1 - - if rank < np: - raise ValueError("Process number should not be larger than " - "total available slots.") - - # Fill in the local_size and cross_size because we can only know these number after - # allocation is done. - for alloc_item in alloc_list: - alloc_item.local_size = local_sizes[alloc_item.cross_rank] - alloc_item.cross_size = cross_sizes[alloc_item.local_rank] - - return alloc_list +from horovod.run.util import network, threads def _pad_rank(rank, size): @@ -140,21 +65,34 @@ def flush(self): f.flush() -def _alloc_info_to_command_fn(run_command, env): - def alloc_info_to_command(alloc_info): +def _slot_info_to_command_fn(run_command, env): + # TODO: Workaround for over-buffered outputs. Investigate how mpirun avoids this problem. + env = copy.copy(env) # copy env so we do not leak env modifications + env['PYTHONUNBUFFERED'] = '1' + + def slot_info_to_command(slot_info): """ - Given an alloc_info, creates a command used by gloo to launch a single job. + Given a slot_info, creates a command used by gloo to launch a single job. - :param alloc_info: host and slot to execute the run command on + :param slot_info: host and slot to execute the run command on :return: """ - # generate env for rendezvous - horovod_rendez_env = 'HOROVOD_RANK={rank} HOROVOD_SIZE={size} ' \ - 'HOROVOD_LOCAL_RANK={local_rank} HOROVOD_LOCAL_SIZE={local_size} ' \ - 'HOROVOD_CROSS_RANK={cross_rank} HOROVOD_CROSS_SIZE={cross_size} ' \ - .format(rank=alloc_info.rank, size=alloc_info.size, - local_rank=alloc_info.local_rank, local_size=alloc_info.local_size, - cross_rank=alloc_info.cross_rank, cross_size=alloc_info.cross_size) + host_name = slot_info.hostname + horovod_rendez_env = ( + 'HOROVOD_HOSTNAME={hostname} ' + 'HOROVOD_RANK={rank} ' + 'HOROVOD_SIZE={size} ' + 'HOROVOD_LOCAL_RANK={local_rank} ' + 'HOROVOD_LOCAL_SIZE={local_size} ' + 'HOROVOD_CROSS_RANK={cross_rank} ' + 'HOROVOD_CROSS_SIZE={cross_size} ' + .format(hostname=host_name, + rank=slot_info.rank, + size=slot_info.size, + local_rank=slot_info.local_rank, + local_size=slot_info.local_size, + cross_rank=slot_info.cross_rank, + cross_size=slot_info.cross_size)) return '{horovod_env} {env} {run_command}' .format( horovod_env=horovod_rendez_env, @@ -162,10 +100,20 @@ def alloc_info_to_command(alloc_info): if env_util.is_exportable(key)]), run_command=run_command) - return alloc_info_to_command + return slot_info_to_command + +def _create_elastic_worker_fn(exec_command, run_command, env, event): + get_command_with_env = _slot_info_to_command_fn(run_command, env) -def _exec_command_fn(settings, remote_host_names): + def create_worker(slot_info, events): + command = get_command_with_env(slot_info) + events = [event] + (events or []) + return exec_command(command, slot_info, events) + return create_worker + + +def _exec_command_fn(settings): """ executes the jobs defined by run command on hosts. :param hosts_alloc: list of dict indicating the allocating info. @@ -186,11 +134,13 @@ def _exec_command_fn(settings, remote_host_names): """ ssh_port_arg = '-p {ssh_port}'.format(ssh_port=settings.ssh_port) if settings.ssh_port else '' - def _exec_command(command, alloc_info, event): - index = alloc_info.rank - host_name = alloc_info.hostname + def _exec_command(command, slot_info, events): + index = slot_info.rank + host_name = slot_info.hostname - if host_name in remote_host_names: + host_address = network.resolve_host_address(host_name) + local_addresses = network.get_local_host_addresses() + if host_address not in local_addresses: command = 'ssh -o StrictHostKeyChecking=no {host} {ssh_port_arg} ' \ '{local_command}'\ .format(host=host_name, @@ -217,7 +167,7 @@ def _exec_command(command, alloc_info, event): stderr = MultiFile([sys.stderr, stderr_file]) try: - exit_code = safe_shell_exec.execute(command, index=index, events=[event], stdout=stdout, stderr=stderr) + exit_code = safe_shell_exec.execute(command, index=index, stdout=stdout, stderr=stderr, events=events) if exit_code != 0: print('Process {idx} exit with status code {ec}.'.format(idx=index, ec=exit_code)) except Exception as e: @@ -234,26 +184,7 @@ def _exec_command(command, alloc_info, event): return _exec_command -def launch_gloo(command, exec_command, settings, nics, env, server_ip): - """ - Launches the given command multiple times using gloo. - Each command is launched via exec_command. - - :param command: command to launch - :param exec_command: means to execute a single command - :param settings: settings for the distribution - :param nics: common interfaces - :param env: environment to use - :param server_ip: ip to use for rendezvous server - """ - # allocate processes into slots - host_alloc_plan = _allocate(settings.hosts, settings.num_proc) - - # create global rendezvous server - global_rendezv = RendezvousServer(settings.verbose) - # Start rendezvous server and get port that it is listening - global_rendezv_port = global_rendezv.start_server(host_alloc_plan) - +def get_run_command(command, server_ip, nics, port, elastic=False): run_command = ( 'HOROVOD_GLOO_RENDEZVOUS_ADDR={addr} ' 'HOROVOD_GLOO_RENDEZVOUS_PORT={port} ' @@ -261,52 +192,125 @@ def launch_gloo(command, exec_command, settings, nics, env, server_ip): 'HOROVOD_CPU_OPERATIONS=gloo ' 'HOROVOD_GLOO_IFACE={iface} ' 'NCCL_SOCKET_IFNAME={nics} ' + '{elastic}' '{command}' # expect a lot of environment variables - .format(addr=server_ip, - port=global_rendezv_port, - iface=list(nics)[0], # TODO: add multiple ifaces in future - nics=','.join(nics), - command=' '.join(quote(par) for par in command))) + .format(addr=server_ip, + port=port, + iface=list(nics)[0], # TODO: add multiple ifaces in future + nics=','.join(nics), + elastic='HOROVOD_ELASTIC=1 ' if elastic else '', + command=' '.join(quote(par) for par in command))) + return run_command + +def register_shutdown_event(): # Create a event for communication between threads event = threading.Event() - def set_event_on_sigterm(signum, frame): + def set_event_on_signal(signum, frame): event.set() - signal.signal(signal.SIGINT, set_event_on_sigterm) - signal.signal(signal.SIGTERM, set_event_on_sigterm) + signal.signal(signal.SIGINT, set_event_on_signal) + signal.signal(signal.SIGTERM, set_event_on_signal) + return event - # TODO: Workaround for over-buffered outputs. Investigate how mpirun avoids this problem. - env = copy.copy(env) # copy env so we do not leak env modifications - env['PYTHONUNBUFFERED'] = '1' - # In case, the main thread receives a SIGINT, the event will be set so the spawned threads can - # kill their corresponding middleman processes so the jobs can be killed as well. - alloc_info_to_command = _alloc_info_to_command_fn(run_command, env) - args_list = [[alloc_info_to_command(alloc_info), alloc_info, event] - for alloc_info in host_alloc_plan] +def launch_gloo(command, exec_command, settings, nics, env, server_ip): + """ + Launches the given command multiple times using gloo. + Each command is launched via exec_command. + :param command: command to launch + :param exec_command: means to execute a single command + :param settings: settings for the distribution + :param nics: common interfaces + :param env: environment to use + :param server_ip: ip to use for rendezvous server + """ # Make the output directory if it does not exist if settings.output_filename: _mkdir_p(settings.output_filename) + # start global rendezvous server and get port that it is listening on + rendezvous = RendezvousServer(settings.verbose) + + # allocate processes into slots + hosts = parse_hosts(settings.hosts) + host_alloc_plan = get_host_assignments(hosts, settings.num_proc) + + # start global rendezvous server and get port that it is listening on + global_rendezv_port = rendezvous.start_server() + rendezvous.httpd.init(host_alloc_plan) + run_command = get_run_command(command, server_ip, nics, global_rendezv_port) + + slot_info_to_command = _slot_info_to_command_fn(run_command, env) + event = register_shutdown_event() + args_list = [[slot_info_to_command(slot_info), slot_info, [event]] + for slot_info in host_alloc_plan] + # If an error occurs in one thread, entire process will be terminated. # Otherwise, threads will keep running. - res = threads.execute_function_multithreaded(exec_command, args_list, block_until_all_done=True) + res = threads.execute_function_multithreaded(exec_command, + args_list, + block_until_all_done=True) for name, value in sorted(res.items(), key=lambda item: item[1][1]): exit_code, timestamp = value if exit_code != 0: - raise RuntimeError('Gloo job detected that one or more processes exited with non-zero ' + raise RuntimeError('Horovod detected that one or more processes exited with non-zero ' 'status, thus causing the job to be terminated. The first process ' 'to do so was:\nProcess name: {name}\nExit code: {code}\n' .format(name=name, code=exit_code)) -def gloo_run(settings, remote_host_names, nics, env, server_ip, command): +def _get_min_start_hosts(settings): + # This function exists for the purpose of mocking in tests + return 2 if settings.elastic and not settings.nics else 1 + + +def gloo_run(settings, nics, env, server_ip, command): # Each thread will use ssh command to launch the job on each remote host. If an # error occurs in one thread, entire process will be terminated. Otherwise, # threads will keep running and ssh session. - exec_command = _exec_command_fn(settings, remote_host_names) + exec_command = _exec_command_fn(settings) launch_gloo(command, exec_command, settings, nics, env, server_ip) + + +def gloo_run_elastic(settings, env, command): + # Make the output directory if it does not exist + if settings.output_filename: + _mkdir_p(settings.output_filename) + + rendezvous = RendezvousServer(settings.verbose) + driver = ElasticDriver(rendezvous, settings.discovery, + settings.min_np, settings.max_np, + timeout=settings.elastic_timeout, + verbose=settings.verbose) + + handler = create_rendezvous_handler(driver) + global_rendezv_port = rendezvous.start_server(handler) + + # Host-to-host common interface detection requires at least 2 hosts in an elastic job. + min_hosts = _get_min_start_hosts(settings) + current_hosts = driver.wait_for_available_slots(settings.num_proc, min_hosts=min_hosts) + + nics = driver_service.get_common_interfaces(settings, current_hosts.host_assignment_order) + server_ip = network.get_driver_ip(nics) + + exec_command = _exec_command_fn(settings) + event = register_shutdown_event() + run_command = get_run_command(command, server_ip, nics, global_rendezv_port, elastic=True) + create_worker = _create_elastic_worker_fn(exec_command, run_command, env, event) + + driver.start(settings.num_proc, create_worker) + res = driver.get_results() + driver.stop() + rendezvous.stop_server() + + for name, value in sorted(res.items(), key=lambda item: item[1][1]): + exit_code, timestamp = value + if exit_code != 0: + raise RuntimeError('Horovod detected that one or more processes exited with non-zero ' + 'status, thus causing the job to be terminated. The first process ' + 'to do so was:\nProcess name: {name}\nExit code: {code}\n' + .format(name=name, code=exit_code)) diff --git a/horovod/run/http/http_client.py b/horovod/run/http/http_client.py index a5b360e60e..33d99bf3c9 100644 --- a/horovod/run/http/http_client.py +++ b/horovod/run/http/http_client.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================= + import sys -import base64 + from distutils.version import LooseVersion + if LooseVersion(sys.version) < LooseVersion('3.0.0'): from urllib2 import urlopen from urllib2 import Request @@ -24,6 +26,8 @@ from urllib.request import Request from urllib.error import HTTPError, URLError +from horovod.run.common.util import codec + def read_data_from_kvstore(addr, port, scope, key): try: @@ -33,7 +37,7 @@ def read_data_from_kvstore(addr, port, scope, key): req = Request(url) resp = urlopen(req) # TODO: remove base64 encoding because base64 is not efficient - return base64.b64decode(resp.read()) + return codec.loads_base64(resp.read()) except (HTTPError, URLError) as e: raise RuntimeError("Read data from KVStore server failed.", e) @@ -43,7 +47,7 @@ def put_data_into_kvstore(addr, port, scope, key, value): url = "http://{addr}:{port}/{scope}/{key}".format( addr=addr, port=str(port), scope=scope, key=key ) - req = Request(url, data=base64.b64encode(value)) + req = Request(url, data=codec.dumps_base64(value, to_ascii=False)) req.get_method = lambda: "PUT" # for urllib2 compatibility urlopen(req) except (HTTPError, URLError) as e: diff --git a/horovod/run/http/http_server.py b/horovod/run/http/http_server.py index 4da95f4f72..3d946e8c24 100644 --- a/horovod/run/http/http_server.py +++ b/horovod/run/http/http_server.py @@ -14,20 +14,18 @@ # ============================================================================= import collections +import logging import socket import threading -from six.moves import BaseHTTPServer, SimpleHTTPServer +from six.moves import socketserver, BaseHTTPServer, SimpleHTTPServer -from horovod.run.util.threads import in_thread from horovod.run.util.network import find_port +from horovod.run.util.threads import in_thread # Timeout for reading from a single request SINGLE_REQUEST_TIMEOUT = 3 -# Timeout for accepting new request -TOTAL_TIMEOUT = 60 - BAD_REQUEST = 400 TIMEOUT = 408 OK = 200 @@ -41,15 +39,14 @@ class KVStoreHandler(SimpleHTTPServer.SimpleHTTPRequestHandler): def do_GET(self): paths = self.path.split('/') if len(paths) < 3: - print( + logging.error( 'KVStore ERROR: Invalid request path: {path}.'.format( path=self.path)) self.send_status_code(BAD_REQUEST) return _, scope, key = paths - with self.server.cache_lock: - value = self.server.cache.get(scope, {}).get(key) + value = self._get_value(scope, key) if value is None: self.send_status_code(404) @@ -63,7 +60,7 @@ def do_GET(self): def do_PUT(self): paths = self.path.split('/') if len(paths) < 3: - print( + logging.error( 'KVStore ERROR: Invalid request path: {path}.'.format( path=self.path)) self.send_status_code(BAD_REQUEST) @@ -77,7 +74,7 @@ def do_PUT(self): value = self.rfile.read(content_length) except socket.timeout: if self.server.verbose: - print( + logging.error( 'KVStore ERROR: Timeout when receiving {content_bytes} ' 'bytes, aborting this incomplete request.' .format( content_bytes=content_length)) @@ -86,12 +83,7 @@ def do_PUT(self): self.send_status_code(TIMEOUT) return - with self.server.cache_lock: - scope_dict = self.server.cache.setdefault(scope, {}) - scope_dict[key] = value - if self.server.verbose: - print(scope, self.server.cache[scope].keys()) - + self._put_value(scope, key, value) self.send_status_code(OK) def send_status_code(self, status_code): @@ -104,13 +96,24 @@ def send_status_code(self, status_code): def log_message(self, format, *args): pass + def _get_value(self, scope, key): + with self.server.cache_lock: + return self.server.cache.get(scope, {}).get(key) + + def _put_value(self, scope, key, value): + with self.server.cache_lock: + scope_dict = self.server.cache.setdefault(scope, {}) + scope_dict[key] = value + if self.server.verbose: + logging.info(scope, self.server.cache[scope].keys()) + class RendezvousHandler(KVStoreHandler): # Override DELETE handler def do_DELETE(self): paths = self.path.split('/') if len(paths) < 3: - print( + logging.error( 'Rendezvous ERROR: Invalid request path: {path}.'.format( path=self.path)) self.send_status_code(BAD_REQUEST) @@ -120,16 +123,25 @@ def do_DELETE(self): with self.server.finished_list_lock: self.server.finished_list[scope].append(key) + if self.server.scope_size[scope] == len(self.server.finished_list[scope]): + with self.server.cache_lock: + self.server.cache.get(scope, {}).clear() self.send_status_code(OK) -class RendezvousHTTPServer(BaseHTTPServer.HTTPServer, object): +class RendezvousHTTPServer(socketserver.ThreadingMixIn, BaseHTTPServer.HTTPServer, object): def __init__(self, addr, handler, verbose): # This class has to inherit from object since HTTPServer is an old-style # class that does not inherit from object. super(RendezvousHTTPServer, self).__init__(addr, handler) + # Cache that provides the store + self.cache_lock = threading.Lock() + self.cache = {} + + self.verbose = verbose + # Lists for finished rendezvous workers self.finished_list_lock = threading.Lock() self.finished_list = collections.defaultdict(list) @@ -137,76 +149,52 @@ def __init__(self, addr, handler, verbose): # Total size for scopes self.scope_size = {} - # Cache that provides the store - self.cache_lock = threading.Lock() - self.cache = {} + def init(self, host_alloc_plan): + with self.cache_lock: + self.cache.clear() - self.verbose = verbose + with self.finished_list_lock: + self.finished_list.clear() + + self.scope_size.clear() + self._extract_scope_size(host_alloc_plan) - def extract_scope_size(self, host_alloc_plan): + def _extract_scope_size(self, host_alloc_plan): for slot_info in host_alloc_plan: - self.scope_size['global'] = slot_info.size + self.scope_size['global_'] = slot_info.size cross_rank = slot_info.cross_rank self.scope_size['local_' + str(cross_rank)] = slot_info.local_size local_rank = slot_info.local_rank self.scope_size['cross_' + str(local_rank)] = slot_info.cross_size - # Decide whether all ranks have confirmed rendezvous completion. def should_continue(self): - should_continue = False - with self.finished_list_lock: - for scope, cnt in self.scope_size.items(): - if cnt > len(self.finished_list[scope]): - should_continue = True - return should_continue - - def handle_timeout(self): - error_msg = 'Rendezvous ERROR: Rendezvous server timeout after ' \ - '{time} seconds while waiting for all the ranks to send finalize ' \ - 'messages.\n'.format(time=TOTAL_TIMEOUT) - - for scope, finished_list in self.finished_list: - if self.scope_size[scope] > len(finished_list): - error_msg += 'Scope {scope} expects {size} workers, only received' \ - 'finalized message from [{ranks}].\n'.format( - scope=scope, - size=self.scope_size[scope], - ranks=' '.join(finished_list)) - - raise RuntimeError(error_msg) + return True class RendezvousServer: - def __init__(self, verbose): + def __init__(self, verbose=0): self.httpd = None self.listen_thread = None self.verbose = verbose # Rendezvous function finds a available port, create http socket, # and start listening loop to handle request - def start_server(self, host_alloc_plan): + # self.httpd.init needs to be called after server start + def start_server(self, handler_cls=RendezvousHandler): self.httpd, port = find_port( lambda addr: RendezvousHTTPServer( - addr, RendezvousHandler, self.verbose)) - self.httpd.extract_scope_size(host_alloc_plan) + addr, handler_cls, self.verbose)) if self.verbose: - print('Rendezvous INFO: HTTP rendezvous server started.') + logging.info('Rendezvous INFO: HTTP rendezvous server started.') # start the listening loop - self.listen_thread = in_thread(target=self.listen_loop) + self.listen_thread = in_thread(target=self.httpd.serve_forever) return port - # Listening loop for handle request - def listen_loop(self): - while self.httpd.should_continue(): - self.httpd.handle_request() - - self.httpd.server_close() - - if self.verbose: - print('Rendezvous INFO: Rendezvous finishes.') - # Because this thread is daemonized, no need to join. + def stop_server(self): + self.httpd.shutdown() + self.listen_thread.join() class KVStoreHTTPServer(BaseHTTPServer.HTTPServer, object): @@ -236,7 +224,7 @@ def start_server(self): self.listen_thread = in_thread(target=self.httpd.serve_forever) if self.verbose: - print('KVStoreServer INFO: KVStore server started. Listen on port ' + str(port)) + logging.info('KVStoreServer INFO: KVStore server started. Listen on port ' + str(port)) return port @@ -246,5 +234,5 @@ def shutdown_server(self): self.httpd.server_close() if self.verbose: - print('KVStoreServer INFO: KVStore server finishes.') + logging.info('KVStoreServer INFO: KVStore server finishes.') # Because this thread is daemonized, no need to join. diff --git a/horovod/run/run_task.py b/horovod/run/run_task.py index c8bab173cb..91c36bef35 100644 --- a/horovod/run/run_task.py +++ b/horovod/run/run_task.py @@ -12,15 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================= -import cloudpickle + import sys + from horovod.run.common.util.env import get_env_rank_and_size from horovod.run.http.http_client import read_data_from_kvstore, put_data_into_kvstore def main(addr, port): - pickled_func = read_data_from_kvstore(addr, port, 'runfunc', 'func') - func = cloudpickle.loads(pickled_func) + func = read_data_from_kvstore(addr, port, 'runfunc', 'func') try: ret_val = func() except BaseException as e: @@ -28,8 +28,7 @@ def main(addr, port): raise e rank, size = get_env_rank_and_size() - pickled_ret_val = cloudpickle.dumps(ret_val) - put_data_into_kvstore(addr, port, 'runfunc_result', str(rank), pickled_ret_val) + put_data_into_kvstore(addr, port, 'runfunc_result', str(rank), ret_val) if __name__ == '__main__': diff --git a/horovod/run/runner.py b/horovod/run/runner.py index d5f2a86657..1955704000 100644 --- a/horovod/run/runner.py +++ b/horovod/run/runner.py @@ -17,6 +17,7 @@ import argparse import hashlib +import logging import os import re import sys @@ -29,7 +30,6 @@ import six import yaml -import cloudpickle import horovod @@ -39,8 +39,10 @@ from horovod.run.common.util import config_parser, safe_shell_exec, timeout, secret from horovod.run.common.util import settings as hvd_settings from horovod.run.driver import driver_service +from horovod.run.elastic import settings as elastic_settings +from horovod.run.elastic import discovery from horovod.run.util import cache, threads, network, lsf -from horovod.run.gloo_run import gloo_run +from horovod.run.gloo_run import gloo_run, gloo_run_elastic from horovod.run.mpi_run import mpi_run from horovod.run.js_run import js_run, is_jsrun_installed from horovod.run.http.http_client import read_data_from_kvstore, put_data_into_kvstore @@ -228,7 +230,8 @@ def parse_args(): np_arg = parser.add_argument('-np', '--num-proc', action='store', dest='np', type=int, required=not lsf.LSFUtils.using_lsf(), - help='Total number of training processes.') + help='Total number of training processes. In elastic mode, ' + 'number of processes required before training can start.') parser.add_argument('-cb', '--check-build', action=make_check_build_action(np_arg), nargs=0, help='Shows which frameworks and libraries have been built into Horovod.') @@ -347,6 +350,24 @@ def parse_args(): help='Regularization value [0, 1] applied to account for noise in samples. ' '(default: %(default)s)') + group_elastic = parser.add_argument_group('elastic arguments') + group_elastic.add_argument('--min-np', action='store', dest='min_np', type=int, + help='Minimum number of processes running for training to continue. If number of ' + 'available processes dips below this threshold, then training will wait for ' + 'more instances to become available. Defaults to --num-proc.') + group_elastic.add_argument('--max-np', action='store', dest='max_np', type=int, + help='Maximum number of training processes, beyond which no additional ' + 'processes will be created. If not specified, then will be unbounded.') + group_elastic.add_argument('--slots-per-host', action='store', dest='slots', type=int, + help='Number of slots for processes per host. Normally 1 slot per GPU per host. ' + 'If slots are provided by the output of the host discovery script, then ' + 'that value will override this parameter.') + group_elastic.add_argument('--elastic-timeout', action='store', dest='elastic_timeout', type=int, + help='Timeout for elastic initialisation after re-scaling the cluster. ' + 'The default value is 600 seconds. Alternatively, ' + 'The environment variable HOROVOD_ELASTIC_TIMEOUT ' + 'can also be used to.') + group_timeline = parser.add_argument_group('timeline arguments') group_timeline.add_argument('--timeline-filename', action=make_override_action(override_args), help='JSON file containing timeline of Horovod events used for debugging ' @@ -427,6 +448,14 @@ def parse_args(): help='Path to a host file containing the list of host names and the number of ' 'available slots. Each line of the file must be of the form: ' ' slots=') + group_hosts.add_argument('--host-discovery-script', action=make_override_action(override_args), + help='Used for elastic training (autoscaling and fault tolerance). ' + 'An executable script that will print to stdout every available host (one per ' + 'newline character) that can be used to run worker processes. Optionally ' + 'specifies number of slots on the same line as the hostname as: "hostname:slots".' + 'Providing a discovery script enables elastic training (see elastic arguments).' + 'The job will fail immediately if execution of the script returns a non-zero exit ' + 'code on the first call. Subsequent calls will be retried until timeout.') group_controller_parent = parser.add_argument_group('controller arguments') group_controller = group_controller_parent.add_mutually_exclusive_group() @@ -449,6 +478,11 @@ def parse_args(): config_parser.set_args_from_config(args, config, override_args) config_parser.validate_config_args(args) + args.run_func = None + + if args.check_build: + check_build(args.verbose) + return args @@ -459,6 +493,7 @@ def __init__(self): self.ssh_port = None self.disable_cache = None self.start_timeout = None + self.nic = None self.output_filename = None self.verbose = None self.command = None @@ -471,7 +506,7 @@ def __init__(self): self.cycle_time_ms = None, self.cache_capacity = None, - # hierrachy + # hierarchy self.hierarchical_allreduce = None self.hierarchical_allgather = None @@ -483,6 +518,12 @@ def __init__(self): self.autotune_bayes_opt_max_samples = None self.autotune_gaussian_process_noise = None + # elastic arguments + self.min_np = None + self.max_np = None + self.slots = None + self.elastic_timeout = None + # timeline arguments self.timeline_filename = None self.timeline_mark_cycles = None @@ -508,6 +549,7 @@ def __init__(self): # host arguments self.hosts = None self.hostfile = None + self.host_discovery_script = None # controller arguments self.use_gloo = None @@ -532,40 +574,24 @@ def parse_host_files(filename): return ','.join(hosts) -def parse_host_names(hosts): +def parse_hosts_and_slots(hosts): + host_names = [] + host_to_slots = {} + host_list = hosts.split(',') - all_host_names = [] pattern = re.compile(r'^[\w.-]+:[0-9]+$') for host in host_list: if not pattern.match(host.strip()): raise ValueError('Invalid host input, please make sure it has ' 'format as : worker-0:2,worker-1:2.') - all_host_names.append(host.strip().split(':')[0]) - return all_host_names - - -def _run(args): - if args.check_build: - check_build(args.verbose) - - # If LSF is used, use default values from job config - if lsf.LSFUtils.using_lsf(): - if not args.np: - args.np = lsf.LSFUtils.get_num_processes() - if not args.hosts and not args.hostfile: - args.hosts = ','.join('{host}:{np}'.format(host=host, np=lsf.LSFUtils.get_num_gpus()) - for host in lsf.LSFUtils.get_compute_hosts()) + hostname, slots = host.strip().split(':') + host_names.append(hostname) + host_to_slots[hostname] = int(slots) + return host_names, host_to_slots - # if hosts are not specified, either parse from hostfile, or default as - # localhost - if not args.hosts: - if args.hostfile: - args.hosts = parse_host_files(args.hostfile) - else: - # Set hosts to localhost if not specified - args.hosts = 'localhost:{np}'.format(np=args.np) - all_host_names = parse_host_names(args.hosts) +def _run_static(args): + all_host_names, _ = parse_hosts_and_slots(args.hosts) nics_set = set(args.nics.split(',')) if args.nics else None @@ -587,10 +613,10 @@ def _run(args): tcp_flag=args.tcp_flag, binding_args=args.binding_args, key=secret.make_secret_key(), - timeout=tmout, - num_hosts=len(all_host_names), + start_timeout=tmout, num_proc=args.np, hosts=args.hosts, + num_hosts=len(all_host_names), output_filename=args.output_filename, run_func_mode=args.run_func is not None, nics=nics_set) @@ -631,32 +657,77 @@ def _run(args): if args.run_func: # get the driver IPv4 address - driver_ip = network._get_driver_ip(nics) + driver_ip = network.get_driver_ip(nics) run_func_server = KVStoreServer(verbose=settings.verbose) run_func_server_port = run_func_server.start_server() - pickled_exec_func = cloudpickle.dumps(args.run_func) put_data_into_kvstore(driver_ip, run_func_server_port, - 'runfunc', 'func', pickled_exec_func) + 'runfunc', 'func', args.run_func) command = [sys.executable, '-m', 'horovod.run.run_task', str(driver_ip), str(run_func_server_port)] try: - _launch_job(args, remote_host_names, settings, nics, command) + _launch_job(args, settings, nics, command) results = [None] * args.np # TODO: make it parallel to improve performance for i in range(args.np): - pickled_result = read_data_from_kvstore(driver_ip, run_func_server_port, - 'runfunc_result', str(i)) - results[i] = cloudpickle.loads(pickled_result) + results[i] = read_data_from_kvstore(driver_ip, run_func_server_port, + 'runfunc_result', str(i)) return results finally: run_func_server.shutdown_server() else: command = args.command - _launch_job(args, remote_host_names, settings, nics, command) + _launch_job(args, settings, nics, command) return None +def _run_elastic(args): + # construct host discovery component + if args.host_discovery_script: + discover_hosts = discovery.HostDiscoveryScript(args.host_discovery_script, args.slots) + elif args.hosts: + _, available_host_slots = parse_hosts_and_slots(args.hosts) + if len(available_host_slots) < 2: + raise ValueError('Cannot run in fault tolerance mode with fewer than 2 hosts.') + discover_hosts = discovery.FixedHosts(available_host_slots) + else: + raise ValueError('One of --host-discovery-script, --hosts, or --hostnames must be provided') + + # horovodrun has to finish all the checks before this timeout runs out. + if args.start_timeout: + start_timeout = args.start_timeout + else: + # Lookup default timeout from the environment variable. + start_timeout = int(os.getenv('HOROVOD_START_TIMEOUT', '30')) + + tmout = timeout.Timeout(start_timeout, + message='Timed out waiting for {activity}. Please ' + 'check connectivity between servers. You ' + 'may need to increase the --start-timeout ' + 'parameter if you have too many servers.') + settings = elastic_settings.ElasticSettings(discovery=discover_hosts, + num_proc=args.np, + min_np=args.min_np or args.np, + max_np=args.max_np, + verbose=2 if args.verbose else 0, + ssh_port=args.ssh_port, + extra_mpi_args=args.mpi_args, + key=secret.make_secret_key(), + start_timeout=tmout, + elastic_timeout=args.elastic_timeout, + output_filename=args.output_filename, + run_func_mode=args.run_func is not None, + nics=args.nics) + + if not gloo_built(verbose=(settings.verbose >= 2)): + raise ValueError('Gloo support is required to use elastic training, but has not been built. Ensure CMake is ' + 'installed and reinstall Horovod with HOROVOD_WITH_GLOO=1 to debug the build error.') + + env = os.environ.copy() + config_parser.set_env_from_args(env, args) + gloo_run_elastic(settings, env, args.command) + + def is_gloo_used(use_gloo=None, use_mpi=None, use_jsrun=None): # determines whether run_controller will run gloo # for the given (use_gloo, _, use_mpi, _, use_jsrun, _, _) @@ -697,13 +768,17 @@ def run_controller(use_gloo, gloo_run, use_mpi, mpi_run, use_jsrun, js_run, verb 'either MPI is installed (MPI) or CMake is installed (Gloo).') -def _launch_job(args, remote_host_names, settings, nics, command): +def _is_elastic(args): + return args.host_discovery_script is not None or args.min_np is not None + + +def _launch_job(args, settings, nics, command): env = os.environ.copy() config_parser.set_env_from_args(env, args) def gloo_run_fn(): - driver_ip = network._get_driver_ip(nics) - gloo_run(settings, remote_host_names, nics, env, driver_ip, command) + driver_ip = network.get_driver_ip(nics) + gloo_run(settings, nics, env, driver_ip, command) def mpi_run_fn(): mpi_run(settings, nics, env, command) @@ -717,9 +792,37 @@ def js_run_fn(): args.verbose) +def _run(args): + # If LSF is used, use default values from job config + if lsf.LSFUtils.using_lsf(): + if not args.np: + args.np = lsf.LSFUtils.get_num_processes() + if not args.hosts and not args.hostfile and not args.host_discovery_script: + args.hosts = ','.join('{host}:{np}'.format(host=host, np=lsf.LSFUtils.get_num_gpus()) + for host in lsf.LSFUtils.get_compute_hosts()) + + # if hosts are not specified, either parse from hostfile, or default as + # localhost + if not args.hosts and not args.host_discovery_script: + if args.hostfile: + args.hosts = parse_host_files(args.hostfile) + else: + # Set hosts to localhost if not specified + args.hosts = 'localhost:{np}'.format(np=args.np) + + if _is_elastic(args): + return _run_elastic(args) + else: + return _run_static(args) + + def run_commandline(): args = parse_args() - args.run_func = None + + if args.log_level: + logging.addLevelName(logging.NOTSET, 'TRACE') + logging.basicConfig(level=logging.getLevelName(args.log_level)) + _run(args) @@ -728,6 +831,9 @@ def run( args=(), kwargs=None, np=1, + min_np=None, + max_np=None, + slots=None, hosts=None, hostfile=None, start_timeout=None, @@ -748,6 +854,15 @@ def run( :param args: Arguments to pass to `func`. :param kwargs: Keyword arguments to pass to `func`. :param np: Number of Horovod processes. + :param min_np: Minimum number of processes running for training to continue. If number of + available processes dips below this threshold, then training will wait for + more instances to become available. Defaults to np + :param max_np: Maximum number of training processes, beyond which no additional processes + will be created. If not specified, then will be unbounded. + :param slots: Number of slots for processes per host. Normally 1 slot per GPU per host. + If slots are provided by the output of the host discovery script, then that + value will override this parameter. + :param hosts: List of host names and the number of available slots for running processes on each, of the form: : (e.g.: host1:2,host2:4,host3:1 indicating 2 processes can run on host1, @@ -799,6 +914,9 @@ def wrapped_func(): hargs = HorovodArgs() hargs.np = np + hargs.min_np = min_np + hargs.max_np = max_np + hargs.slots = slots hargs.hosts = hosts hargs.hostfile = hostfile hargs.start_timeout = start_timeout diff --git a/horovod/run/task_fn.py b/horovod/run/task_fn.py index 8f26afd092..13e67b83f4 100644 --- a/horovod/run/task_fn.py +++ b/horovod/run/task_fn.py @@ -20,7 +20,7 @@ from horovod.run.task import task_service -def _task_fn(index, driver_addresses, settings): +def _task_fn(index, num_hosts, driver_addresses, settings): task = task_service.HorovodRunTaskService(index, settings.key, settings.nics) try: driver = driver_service.HorovodRunDriverClient( @@ -28,10 +28,10 @@ def _task_fn(index, driver_addresses, settings): driver.register_task(index, task.addresses(), host_hash.host_hash()) - task.wait_for_initial_registration(settings.timeout) + task.wait_for_initial_registration(settings.start_timeout) # Tasks ping each other in a circular fashion to determine interfaces # reachable within the cluster. - next_task_index = (index + 1) % settings.num_hosts + next_task_index = (index + 1) % num_hosts next_task_addresses = driver.all_task_addresses(next_task_index) # We request interface matching to weed out all the NAT'ed interfaces. next_task = task_service.HorovodRunTaskClient( @@ -47,21 +47,20 @@ def _task_fn(index, driver_addresses, settings): next_task.task_to_task_address_check_completed() # Wait to get a notification from previous task that its address checks # are completed as well. - task.wait_for_task_to_task_address_check_finish_signal(settings.timeout) + task.wait_for_task_to_task_address_check_finish_signal(settings.start_timeout) finally: task.shutdown() if __name__ == '__main__': - if len(sys.argv) != 4: - print( - 'Usage: %s ' % - sys.argv[0]) + if len(sys.argv) != 5: + print('Usage: {} '.format(sys.argv[0])) sys.exit(1) index = codec.loads_base64(sys.argv[1]) - driver_addresses = codec.loads_base64(sys.argv[2]) - settings = codec.loads_base64(sys.argv[3]) + num_hosts = codec.loads_base64(sys.argv[2]) + driver_addresses = codec.loads_base64(sys.argv[3]) + settings = codec.loads_base64(sys.argv[4]) - _task_fn(index, driver_addresses, settings) + _task_fn(index, num_hosts, driver_addresses, settings) diff --git a/horovod/run/util/network.py b/horovod/run/util/network.py index 47a8a4e5d1..39912dbf91 100644 --- a/horovod/run/util/network.py +++ b/horovod/run/util/network.py @@ -13,6 +13,8 @@ # limitations under the License. # ============================================================================== +from __future__ import absolute_import + import psutil import random import socket @@ -20,33 +22,45 @@ from socket import AF_INET from psutil import net_if_addrs +from horovod.common.util import _cache from horovod.run.util import threads -def _get_local_host_addresses(): - local_addresses = [] + +@_cache +def get_local_host_addresses(): + local_addresses = set() for intf_info_list in psutil.net_if_addrs().values(): for intf_info in intf_info_list: if intf_info.family == socket.AF_INET: - local_addresses.append(intf_info.address) + local_addresses.add(intf_info.address) return local_addresses -def get_local_host_intfs(): - return set(psutil.net_if_addrs().keys()) +def get_local_intfs(nic=None): + common_intfs = set() + for iface, addrs in net_if_addrs().items(): + if nic and iface != nic: + continue + for addr in addrs: + if addr.family == AF_INET and addr.address == '127.0.0.1': + common_intfs.add(iface) + break + return common_intfs -def filter_local_addresses(all_host_names): - local_addresses = _get_local_host_addresses() +def resolve_host_address(host_name): + try: + return socket.gethostbyname(host_name) + except socket.gaierror: + return None - def resolve_host_name(host_name): - try: - return socket.gethostbyname(host_name) - except socket.gaierror: - return None + +def filter_local_addresses(all_host_names): + local_addresses = get_local_host_addresses() args_list = [[host] for host in all_host_names] host_addresses = threads.execute_function_multithreaded( - resolve_host_name, args_list) + resolve_host_address, args_list) # host_addresses is a map remote_host_names = [] @@ -78,7 +92,7 @@ def find_port(server_factory): raise Exception('Unable to find a port to bind to.') -def _get_driver_ip(nics): +def get_driver_ip(nics): """ :param nics: object return by `_driver_fn` :return: driver ip. We make sure all workers can connect to this ip. diff --git a/horovod/spark/driver/rsh.py b/horovod/spark/driver/rsh.py index 54479a6731..5b477a28f5 100644 --- a/horovod/spark/driver/rsh.py +++ b/horovod/spark/driver/rsh.py @@ -22,7 +22,7 @@ def rsh(driver_addresses, key, settings, host_hash, command, env, local_rank, - background=True, event=None): + background=True, events=None): """ Method to run a command remotely given a host hash, local rank and driver addresses. @@ -43,7 +43,7 @@ def rsh(driver_addresses, key, settings, host_hash, command, env, local_rank, :param env: environment to use :param local_rank: local rank on the host of task to run the command in :param background: run command in background if True, returns command result otherwise - :param event: event to abort the command, only if background is True + :param events: events to abort the command, only if background is True """ if ':' in host_hash: raise Exception('Illegal host hash provided. Are you using Open MPI 4.0.0+?') @@ -59,7 +59,8 @@ def rsh(driver_addresses, key, settings, host_hash, command, env, local_rank, if not background: stop = None - if event is not None: + events = events or [] + for event in events: stop = threading.Event() on_event(event, task_client.abort_command, stop=stop) diff --git a/horovod/spark/gloo_run.py b/horovod/spark/gloo_run.py index a77dd9ffb0..4b4b7e7c66 100644 --- a/horovod/spark/gloo_run.py +++ b/horovod/spark/gloo_run.py @@ -17,15 +17,15 @@ import time from horovod.run.gloo_run import launch_gloo -from horovod.run.common.util import codec, secret +from horovod.run.common.util import codec from horovod.spark.driver.rsh import rsh def _exec_command_fn(driver_addresses, key, settings, env): - def _exec_command(command, alloc_info, event): - host = alloc_info.hostname - local_rank = alloc_info.local_rank - result = rsh(driver_addresses, key, settings, host, command, env, local_rank, False, event) + def _exec_command(command, slot_info, events): + host = slot_info.hostname + local_rank = slot_info.local_rank + result = rsh(driver_addresses, key, settings, host, command, env, local_rank, False, events) return result, time.time() return _exec_command diff --git a/horovod/spark/runner.py b/horovod/spark/runner.py index eae28046fb..421e844b40 100644 --- a/horovod/spark/runner.py +++ b/horovod/spark/runner.py @@ -47,14 +47,14 @@ def _task_fn(index, driver_addresses, key, settings, use_gloo): try: driver_client = driver_service.SparkDriverClient(driver_addresses, settings.key, settings.verbose) driver_client.register_task(index, task.addresses(), host_hash.host_hash()) - task.wait_for_initial_registration(settings.timeout) + task.wait_for_initial_registration(settings.start_timeout) task_indices_on_this_host = driver_client.task_host_hash_indices(host_hash.host_hash()) # With Gloo all tasks wait for the command # With MPI task with first index executes orted which will run mpirun_exec_fn for all tasks. minimum_lifetime_after_start = None if use_gloo or task_indices_on_this_host[0] == index: - task.wait_for_command_start(settings.timeout) + task.wait_for_command_start(settings.start_timeout) minimum_lifetime_after_start = timeout.Timeout(MINIMUM_COMMAND_LIFETIME_S, message='Just measuring runtime') task.wait_for_command_termination() @@ -170,7 +170,7 @@ def run(fn, args=(), kwargs={}, num_proc=None, start_timeout=None, settings = hvd_settings.Settings(verbose=verbose, extra_mpi_args=extra_mpi_args, key=secret.make_secret_key(), - timeout=tmout, + start_timeout=tmout, nics=nics, run_func_mode=True) @@ -239,7 +239,7 @@ def run(fn, args=(), kwargs={}, num_proc=None, start_timeout=None, def _notify_and_register_task_addresses(driver, settings): # wait for num_proc tasks to register - driver.wait_for_initial_registration(settings.timeout) + driver.wait_for_initial_registration(settings.start_timeout) if settings.verbose >= 2: print('Initial Spark task registration is complete.') @@ -256,7 +256,7 @@ def notify_and_register(index): for index in range(settings.num_proc): in_thread(notify_and_register, (index,)) - driver.wait_for_task_to_task_address_updates(settings.timeout) + driver.wait_for_task_to_task_address_updates(settings.start_timeout) if settings.verbose >= 2: print('Spark task-to-task address registration is complete.') diff --git a/horovod/tensorflow/__init__.py b/horovod/tensorflow/__init__.py index 9b70f0a4f7..f3d096ab2f 100644 --- a/horovod/tensorflow/__init__.py +++ b/horovod/tensorflow/__init__.py @@ -20,11 +20,16 @@ from __future__ import division from __future__ import print_function +import os +import warnings + from horovod.common.util import check_extension, gpu_available check_extension('horovod.tensorflow', 'HOROVOD_WITH_TENSORFLOW', __file__, 'mpi_lib') +from horovod.tensorflow import elastic from horovod.tensorflow.compression import Compression +from horovod.tensorflow.functions import broadcast_object, broadcast_object_fn, broadcast_variables from horovod.tensorflow.mpi_ops import allgather, broadcast, _allreduce from horovod.tensorflow.mpi_ops import init, shutdown from horovod.tensorflow.mpi_ops import size, local_size, rank, local_rank, is_homogeneous @@ -33,11 +38,9 @@ from horovod.tensorflow.mpi_ops import nccl_built, ddl_built, ccl_built from horovod.tensorflow.mpi_ops import Average, Sum, Adasum from horovod.tensorflow.mpi_ops import handle_average_backwards_compatibility, check_num_rank_power_of_2 - from horovod.tensorflow.util import _executing_eagerly, _make_subgraph, _cache import tensorflow as tf -import warnings def allreduce(tensor, average=None, device_dense='', device_sparse='', @@ -82,7 +85,7 @@ def allreduce(tensor, average=None, device_dense='', device_sparse='', 'workaround please pass sparse_as_dense=True to DistributedOptimizer') with tf.device(device_sparse): # For IndexedSlices, do two allgathers instead of an allreduce. - horovod_size = tf.cast(size(), tensor.values.dtype) + horovod_size = tf.cast(size(), dtype=tensor.values.dtype) values = allgather(tensor.values) indices = allgather(tensor.indices) @@ -122,36 +125,6 @@ def allreduce(tensor, average=None, device_dense='', device_sparse='', return new_tensor -@_cache -def _make_broadcast_group_fn(): - if _executing_eagerly(): - # Eager mode will parallelize independent control flow - def broadcast_group(variables, root_rank): - for var in variables: - var.assign(broadcast(var, root_rank)) - - return _make_subgraph(broadcast_group) - else: - # Graph mode requires an Op - def broadcast_group(variables, root_rank): - return tf.group(*[var.assign(broadcast(var, root_rank)) - for var in variables]) - - return broadcast_group - - -def broadcast_variables(variables, root_rank): - """Broadcasts variables from root rank to all other processes. - - Arguments: - variables: variables for broadcast - root_rank: rank of the process from which global variables will be broadcasted - to all other processes. - """ - broadcast_group = _make_broadcast_group_fn() - return broadcast_group(variables, root_rank) - - try: _global_variables = tf.global_variables except AttributeError: @@ -291,7 +264,7 @@ def compute_gradients(self, *args, **kwargs): allreduce the gradients before returning them. """ gradients = self._optimizer.compute_gradients(*args, **kwargs) - if size() > 1: + if size() > 1 or os.environ.get('HOROVOD_ELASTIC') == '1': grads, vars = zip(*gradients) avg_grads = self._allreduce_grads(grads) return list(zip(avg_grads, vars)) @@ -491,7 +464,7 @@ def __init__(self, tape, device_dense, device_sparse, compression, sparse_as_den def gradient(self, target, sources, output_gradients=None): gradients = super(self.__class__, self).gradient(target, sources, output_gradients) - if size() > 1: + if size() > 1 or os.environ.get('HOROVOD_ELASTIC') == '1': return self._allreduce_grads(gradients) else: return gradients diff --git a/horovod/tensorflow/elastic.py b/horovod/tensorflow/elastic.py new file mode 100644 index 0000000000..3f1acfb16c --- /dev/null +++ b/horovod/tensorflow/elastic.py @@ -0,0 +1,212 @@ +# Copyright 2020 Uber Technologies, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import + +from distutils.version import LooseVersion + +import tensorflow as tf + +from tensorflow.python.framework import ops + +from horovod.common.elastic import run_fn, ObjectState +from horovod.common.exceptions import HorovodInternalError +from horovod.tensorflow.functions import broadcast_object, broadcast_object_fn, broadcast_variables +from horovod.tensorflow.mpi_ops import _executing_eagerly, init, rank, shutdown + + +_IS_TF2 = LooseVersion(tf.__version__) >= LooseVersion('2.0.0') + + +def run(func): + """Decorator used to run the elastic training process. + + The purpose of this decorator is to allow for uninterrupted execution of the wrapped function + across multiple workers in parallel, as workers come and go from the system. When a new worker is added, + its state needs to be brought to the same point as the other workers, which is done by synchronizing + the state object before executing `func`. + + When a worker is added or removed, other workers will raise an exception to bring them back to such a sync + point before executing `func` again. This ensures that workers do not diverge when such reset events occur. + + It's important to note that collective operations (e.g., broadcast, allreduce) cannot be the call to + the wrapped function. Otherwise, new workers could execute these operations during their initialization + while other workers are attempting to sync state, resulting in deadlock. + + Args: + func: a wrapped function taking any number of args or kwargs. The first argument + must be a `horovod.common.elastic.State` object used to synchronize state across + workers. + """ + from tensorflow.python.framework.errors_impl import UnknownError + + def wrapper(state, *args, **kwargs): + try: + return func(state, *args, **kwargs) + except UnknownError as e: + if 'HorovodAllreduce' in e.message or \ + 'HorovodAllgather' in e.message or \ + 'HorovodBroadcast' in e.message: + raise HorovodInternalError(e) + return run_fn(wrapper, _reset) + + +def _reset(): + shutdown() + init() + + +def _broadcast_model(model, optimizer, backend): + if _executing_eagerly(): + # TensorFlow 2.0 or TensorFlow eager + broadcast_variables(model.variables, root_rank=0) + broadcast_variables(optimizer.variables(), root_rank=0) + else: + bcast_op = broadcast_variables(_global_variables(), root_rank=0) + backend.get_session().run(bcast_op) + + +def _model_built(model): + return model.built if hasattr(model, 'build') else True + + +def _global_variables(): + return tf.global_variables() if not _IS_TF2 else tf.compat.v1.global_variables() + + +def _default_session(): + return ops.get_default_session() if not _IS_TF2 else None + + +class TensorFlowKerasState(ObjectState): + """State representation of a TensorFlow Keras model and optimizer. + + Supports TensorFlow 2 models and optimizers, as well as `keras` and `tf.keras`. + + Args: + model: TensorFlow Keras model. + optimizer: Optional optimizer, can be compiled into model instead. + backend: For TensorFlow v1, backend used by Keras for obtaining the session. + kwargs: Additional properties to sync, will be exposed as attributes of the object. + """ + def __init__(self, model, optimizer=None, backend=None, **kwargs): + self.model = model + if not _model_built(model): + raise ValueError('Model must be built first. Run `model.build(input_shape)`.') + + self.optimizer = optimizer or model.optimizer + self.backend = backend + self._save_model() + + def broadcast_object_with_session(obj): + return broadcast_object(obj, session=backend.get_session()) + + broadcast_object_fn = broadcast_object if not backend or _executing_eagerly() else broadcast_object_with_session + + super(TensorFlowKerasState, self).__init__(bcast_object=broadcast_object_fn, + get_rank=rank, + **kwargs) + + def save(self): + self._save_model() + super(TensorFlowKerasState, self).save() + + def restore(self): + self._load_model() + super(TensorFlowKerasState, self).restore() + + def sync(self): + _broadcast_model(self.model, self.optimizer, backend=self.backend) + self._save_model() + super(TensorFlowKerasState, self).sync() + + def _save_model(self): + if _executing_eagerly(): + self._saved_model_state = [tf.identity(var) for var in self.model.variables] + self._saved_optimizer_state = [tf.identity(var) for var in self.optimizer.variables()] + else: + self._saved_model_state = self.model.get_weights() + self._saved_optimizer_state = self.optimizer.get_weights() + + def _load_model(self): + if _executing_eagerly(): + for var, saved_var in zip(self.model.variables, self._saved_model_state): + var.assign(saved_var) + for var, saved_var in zip(self.optimizer.variables(), self._saved_optimizer_state): + var.assign(saved_var) + else: + self.model.set_weights(self._saved_model_state) + self.optimizer.set_weights(self._saved_optimizer_state) + + +class TensorFlowState(ObjectState): + """State representation of a list of TensorFlow variables. + + Supports both TensorFlow v1 and v2. For TensorFlow v2, can only be used when eager execution is enabled. + + Args: + variables: List of `tf.Variable` objects to track (default: `tf.global_variables()`). + session: For TensorFlow v1, session used to materialize variables (default: `ops.get_default_session()`). + kwargs: Additional properties to sync, will be exposed as attributes of the object. + """ + def __init__(self, variables=None, session=None, **kwargs): + self.variables = variables or _global_variables() + self.session = session or _default_session() + self._bcast_op = broadcast_variables(self.variables, root_rank=0) + self._eval_fn = self._to_numpy if _executing_eagerly() else self._eval_var + self._assign_fn = self._assign_var if _IS_TF2 else self._load_var + self._save_model() + + bcast_obj = broadcast_object_fn(session=session) if not _executing_eagerly() else broadcast_object + + def broadcast_object_with_session(obj): + return bcast_obj(obj) + + super(TensorFlowState, self).__init__(bcast_object=broadcast_object_with_session, + get_rank=rank, + **kwargs) + + def save(self): + self._save_model() + super(TensorFlowState, self).save() + + def restore(self): + self._load_model() + super(TensorFlowState, self).restore() + + def sync(self): + if self.session is not None: + self.session.run(self._bcast_op) + self._save_model() + super(TensorFlowState, self).sync() + + def _save_model(self): + self._values = [self._eval_fn(var) for var in self.variables] + + def _eval_var(self, var): + return var.eval(self.session) + + def _to_numpy(self, var): + return var.numpy() + + def _load_model(self): + for var, value in zip(self.variables, self._values): + self._assign_fn(var, value) + + def _load_var(self, var, value): + var.load(value, self.session) + + def _assign_var(self, var, value): + var.assign(value) diff --git a/horovod/tensorflow/functions.py b/horovod/tensorflow/functions.py new file mode 100644 index 0000000000..0ffcea5765 --- /dev/null +++ b/horovod/tensorflow/functions.py @@ -0,0 +1,137 @@ +# Copyright 2020 Uber Technologies, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import + +import io + +from distutils.version import LooseVersion + +import cloudpickle +import numpy as np +import tensorflow as tf + +from tensorflow.python.framework import ops + +from horovod.tensorflow.mpi_ops import broadcast +from horovod.tensorflow.mpi_ops import local_size, rank, size +from horovod.tensorflow.util import _cache, _executing_eagerly, _make_subgraph + + +@_cache +def _make_broadcast_group_fn(): + if _executing_eagerly(): + # Eager mode will parallelize independent control flow + def broadcast_group(variables, root_rank): + for var in variables: + var.assign(broadcast(var, root_rank)) + + return _make_subgraph(broadcast_group) + else: + # Graph mode requires an Op + def broadcast_group(variables, root_rank): + return tf.group(*[var.assign(broadcast(var, root_rank)) + for var in variables]) + + return broadcast_group + + +def broadcast_variables(variables, root_rank): + """Broadcasts variables from root rank to all other processes. + + Arguments: + variables: variables for broadcast + root_rank: rank of the process from which global variables will be broadcasted + to all other processes. + """ + broadcast_group = _make_broadcast_group_fn() + return broadcast_group(variables, root_rank) + + +def broadcast_object(obj, root_rank=0, session=None, name=None): + """ + Serializes and broadcasts an object from root rank to all other processes. + + Arguments: + obj: An object capable of being serialized without losing any context. + root_rank: The rank of the process from which parameters will be + broadcasted to all other processes. + session: Session for TensorFlow v1 compatibility. + name: Optional name to use during broadcast, will default to the class + type. + Returns: + The object that was broadcast from the `root_rank`. + """ + if name is None: + name = type(obj).__name__ + + def to_numpy(v): + if not _executing_eagerly(): + sess = session or ops.get_default_session() + return sess.run(v) + else: + return v.numpy() + + if rank() == root_rank: + b = io.BytesIO() + cloudpickle.dump(obj, b) + t = tf.convert_to_tensor(bytearray(b.getvalue()), dtype=tf.uint8) + sz = tf.convert_to_tensor([t.shape[0]], dtype=tf.int32) + to_numpy(broadcast(sz, root_rank, name + '.sz')) + else: + sz = tf.convert_to_tensor([0], dtype=tf.int32) + sz = to_numpy(broadcast(sz, root_rank, name + '.sz')) + t = tf.zeros(sz.tolist()[0], dtype=tf.uint8) + + t = to_numpy(broadcast(t, root_rank, name + '.t')) + + if rank() != root_rank: + buf = io.BytesIO(t.tobytes()) + obj = cloudpickle.load(buf) + + return obj + + +def broadcast_object_fn(root_rank=0, session=None, name=None): + name = name or 'broadcast_object_fn' + + sz = tf.placeholder(tf.int32, [1], name='bcast_object_size') + bcast_size = broadcast(sz, root_rank, name + '.sz') + + t = tf.placeholder(tf.uint8, [None], name='bcast_object_data') + bcast_data = broadcast(t, root_rank, name + '.t') + + session = session or ops.get_default_session() + + def _bcast(obj): + if rank() == root_rank: + b = io.BytesIO() + cloudpickle.dump(obj, b) + t_ = bytearray(b.getvalue()) + sz_ = [len(t_)] + session.run(bcast_size, feed_dict={sz: sz_}) + else: + sz_ = [0] + sz_ = session.run(bcast_size, feed_dict={sz: sz_}) + t_ = np.zeros(sz_, dtype=np.uint8) + + t_ = session.run(bcast_data, feed_dict={t: t_}) + + if rank() != root_rank: + buf = io.BytesIO(t_.tobytes()) + obj = cloudpickle.load(buf) + + return obj + return _bcast diff --git a/horovod/tensorflow/keras/__init__.py b/horovod/tensorflow/keras/__init__.py index aa177edd3e..40ea497a3a 100644 --- a/horovod/tensorflow/keras/__init__.py +++ b/horovod/tensorflow/keras/__init__.py @@ -37,7 +37,7 @@ from horovod.tensorflow import Compression import horovod._keras as _impl -from horovod.tensorflow.keras import callbacks +from horovod.tensorflow.keras import callbacks, elastic try: diff --git a/horovod/tensorflow/keras/elastic.py b/horovod/tensorflow/keras/elastic.py new file mode 100644 index 0000000000..f0ff89f466 --- /dev/null +++ b/horovod/tensorflow/keras/elastic.py @@ -0,0 +1,85 @@ +# Copyright 2020 Uber Technologies, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import + +import tensorflow as tf + +from horovod._keras import elastic as _impl +from horovod.tensorflow.elastic import TensorFlowKerasState + + +class KerasState(TensorFlowKerasState): + """State representation of a `tf.keras` model and optimizer. + + Args: + model: Keras model. + optimizer: Optional optimizer, can be compiled into model instead. + kwargs: Additional properties to sync, will be exposed as attributes of the object. + """ + def __init__(self, model, optimizer=None, **kwargs): + super(KerasState, self).__init__(model, optimizer=optimizer, backend=tf.keras.backend, **kwargs) + + +class CommitStateCallback(_impl.CommitStateCallbackImpl, tf.keras.callbacks.Callback): + """ + Keras Callback that will commit the `state` object every `batches_per_commit` + batches at the end of each batch. + """ + + def __init__(self, state, batches_per_commit=1): + """ + Constructs a new CommitStateCallback. + + Args: + state: `horovod.common.elastic.State` object to be committed. + batches_per_commit: Number of batches to complete between each commit (default: 1). + """ + super(CommitStateCallback, self).__init__(tf.keras.backend, state, batches_per_commit) + + +class UpdateBatchStateCallback(_impl.UpdateBatchStateCallbackImpl, tf.keras.callbacks.Callback): + """ + Keras Callback that will update the value of `state.batch` with the current batch number at + the end of each batch. Batch will reset to 0 at the end of each epoch. + + If `steps_per_epoch` is set, then this callback will also ensure that the number of steps + in the first epoch following a reset is shortened by the number of batches already processed. + """ + + def __init__(self, state): + """ + Constructs a new UpdateBatchStateCallback. + + Args: + state: `horovod.common.elastic.State` object to be updated. + """ + super(UpdateBatchStateCallback, self).__init__(tf.keras.backend, state) + + +class UpdateEpochStateCallback(_impl.UpdateEpochStateCallbackImpl, tf.keras.callbacks.Callback): + """ + Keras Callback that will update the value of `state.epoch` with the current epoch number at + the end of each epoch. + """ + + def __init__(self, state): + """ + Constructs a new UpdateEpochStateCallback. + + Args: + state: `horovod.common.elastic.State` object to be updated. + """ + super(UpdateEpochStateCallback, self).__init__(tf.keras.backend, state) diff --git a/horovod/tensorflow/util.py b/horovod/tensorflow/util.py index 19fc92734c..64f6ff9967 100644 --- a/horovod/tensorflow/util.py +++ b/horovod/tensorflow/util.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== + +from __future__ import absolute_import + from distutils.version import LooseVersion import tensorflow as tf diff --git a/horovod/torch/__init__.py b/horovod/torch/__init__.py index 935999d558..5faae7949d 100644 --- a/horovod/torch/__init__.py +++ b/horovod/torch/__init__.py @@ -18,13 +18,6 @@ from __future__ import division from __future__ import print_function -from contextlib import contextmanager - -import io -import warnings - -import cloudpickle - from horovod.common.util import check_extension try: @@ -34,13 +27,9 @@ check_extension('horovod.torch', 'HOROVOD_WITH_PYTORCH', __file__, 'mpi_lib', '_mpi_lib') -try: - from collections.abc import Iterable -except ImportError: - from collections import Iterable - - +from horovod.torch import elastic from horovod.torch.compression import Compression +from horovod.torch.functions import broadcast_object, broadcast_optimizer_state, broadcast_parameters from horovod.torch.mpi_ops import allreduce, allreduce_async, allreduce_, allreduce_async_ from horovod.torch.mpi_ops import allgather, allgather_async from horovod.torch.mpi_ops import broadcast, broadcast_async, broadcast_, broadcast_async_ @@ -52,597 +41,11 @@ from horovod.torch.mpi_ops import gloo_enabled, gloo_built from horovod.torch.mpi_ops import nccl_built, ddl_built, ccl_built from horovod.torch.mpi_ops import Average, Sum, Adasum +from horovod.torch.optimizer import DistributedOptimizer from horovod.torch.sync_batch_norm import SyncBatchNorm -import torch -import collections - # Please run this function in a subprocess def _check_has_gpu(): import torch return torch.cuda.is_available() - - -class _DistributedOptimizer(torch.optim.Optimizer): - def __init__(self, params, named_parameters, compression, - backward_passes_per_step=1, op=Average): - super(self.__class__, self).__init__(params) - self._compression = compression - - if named_parameters is not None: - named_parameters = list(named_parameters) - else: - named_parameters = [('allreduce.noname.%s' % i, v) - for param_group in self.param_groups - for i, v in enumerate(param_group['params'])] - # make sure that named_parameters are tuples - if any([not isinstance(p, tuple) for p in named_parameters]): - raise ValueError('named_parameters should be a sequence of ' - 'tuples (name, parameter), usually produced by ' - 'model.named_parameters().') - - dups = _DistributedOptimizer.find_duplicates([k for k, _ in named_parameters]) - if len(dups) > 0: - raise ValueError('Parameter names in named_parameters must be unique. ' - 'Found duplicates: %s' % ', '.join(dups)) - - all_param_ids = {id(v) - for param_group in self.param_groups - for v in param_group['params']} - named_param_ids = {id(v) for k, v in named_parameters} - unnamed_param_ids = all_param_ids - named_param_ids - if len(unnamed_param_ids): - raise ValueError('named_parameters was specified, but one or more model ' - 'parameters were not named. Python object ids: ' - '%s' % ', '.join(str(id) for id in unnamed_param_ids)) - - self._parameter_names = {v: k for k, v in sorted(named_parameters)} - self.backward_passes_per_step = backward_passes_per_step - self._allreduce_delay = {v: self.backward_passes_per_step - for _, v in sorted(named_parameters)} - self.op = op - self._handles = {} - self._grad_accs = [] - self._requires_update = set() - self._synchronized = False - self._should_synchronize = True - if size() > 1: - self._register_hooks() - - @staticmethod - def find_duplicates(lst): - seen = set() - dups = set() - for el in lst: - if el in seen: - dups.add(el) - seen.add(el) - return dups - - def set_backward_passes_per_step(self, passes): - self.backward_passes_per_step = passes - for p in self._allreduce_delay: - self._allreduce_delay[p] = self.backward_passes_per_step - - def _register_hooks(self): - for param_group in self.param_groups: - for p in param_group['params']: - if p.requires_grad: - p.grad = p.data.new(p.size()).zero_() - self._requires_update.add(p) - p_tmp = p.expand_as(p) - grad_acc = p_tmp.grad_fn.next_functions[0][0] - grad_acc.register_hook(self._make_hook(p)) - self._grad_accs.append(grad_acc) - - def _allreduce_grad_async(self, p): - name = self._parameter_names.get(p) - tensor = p.grad - tensor_compressed, ctx = self._compression.compress(tensor) - - handle = allreduce_async_(tensor_compressed, name=name, op=self.op) - return handle, ctx - - def _make_hook(self, p): - def hook(*ignore): - if p in self._handles and self._handles[p][0] is not None: - if self._allreduce_delay[p] <= 0: - raise AssertionError( - "Gradients were computed more than " - "backward_passes_per_step times before call " - "to step(). Increase backward_passes_per_step to " - "accumulate gradients locally.") - assert not p.grad.requires_grad - assert self._allreduce_delay[p] > 0 - handle, ctx = None, None - self._allreduce_delay[p] -= 1 - if self._allreduce_delay[p] == 0: - handle, ctx = self._allreduce_grad_async(p) - self._handles[p] = (handle, ctx) - return hook - - def synchronize(self): - missing_p = self._requires_update - set(self._handles.keys()) - for p in missing_p: - handle, ctx = self._allreduce_grad_async(p) - self._handles[p] = (handle, ctx) - - for p, value in self._handles.items(): - handle, ctx = value - if handle is None: - handle, ctx = self._allreduce_grad_async(p) - self._handles[p] = (handle, ctx) - for p, (handle, _) in self._handles.items(): - output = synchronize(handle) - self._allreduce_delay[p] = self.backward_passes_per_step - p.grad.set_(self._compression.decompress(output, ctx)) - self._handles.clear() - - self._synchronized = True - - @contextmanager - def skip_synchronize(self): - """ - A context manager used to specify that optimizer.step() should - not perform synchronization. - - It's typically used in a following pattern: - - .. code-block:: python - - optimizer.synchronize() - with optimizer.skip_synchronize(): - optimizer.step() - """ - self._should_synchronize = False - try: - yield - finally: - self._should_synchronize = True - - def step(self, closure=None): - if self._should_synchronize: - if self._synchronized: - warnings.warn("optimizer.step() called without " - "optimizer.skip_synchronize() context after " - "optimizer.synchronize(). This can cause training " - "slowdown. You may want to consider using " - "optimizer.skip_synchronize() context if you use " - "optimizer.synchronize() in your code.") - self.synchronize() - self._synchronized = False - return super(self.__class__, self).step(closure) - - def zero_grad(self): - if self._handles: - raise AssertionError("optimizer.zero_grad() was called after loss.backward() " - "but before optimizer.step() or optimizer.synchronize(). " - "This is prohibited as it can cause a race condition.") - return super(self.__class__, self).zero_grad() - - -class _DistributedAdasumOptimizer(torch.optim.Optimizer): - def __init__(self, params, named_parameters, compression, - backward_passes_per_step=1): - super(self.__class__, self).__init__(params) - - self._compression = compression - - if named_parameters is not None: - named_parameters = list(named_parameters) - else: - named_parameters = [('allreduce.noname.%s' % i, v) - for param_group in self.param_groups - for i, v in enumerate(param_group['params'])] - - # make sure that named_parameters are tuples - if any([not isinstance(p, tuple) for p in named_parameters]): - raise ValueError('named_parameters should be a sequence of ' - 'tuples (name, parameter), usually produced by ' - 'model.named_parameters().') - - dups = _DistributedOptimizer.find_duplicates([k for k, _ in named_parameters]) - if len(dups) > 0: - raise ValueError('Parameter names in named_parameters must be unique. ' - 'Found duplicates: %s' % ', '.join(dups)) - - all_param_ids = {id(v) - for param_group in self.param_groups - for v in param_group['params']} - named_param_ids = {id(v) for k, v in named_parameters} - unnamed_param_ids = all_param_ids - named_param_ids - if len(unnamed_param_ids): - raise ValueError('named_parameters was specified, but one or more model ' - 'parameters were not named. Python object ids: ' - '%s' % ', '.join(str(id) for id in unnamed_param_ids)) - - self._parameter_names = {v: k for k, v in sorted(named_parameters)} - self.backward_passes_per_step = backward_passes_per_step - self._allreduce_delay = {v: self.backward_passes_per_step - for _, v in sorted(named_parameters)} - self._handles = {} - self._grad_accs = [] - self._requires_update = set() - self._synchronized = False - self._should_synchronize = True - - self._starting_models = { - p : torch.zeros_like(p, requires_grad=False) - for _, p in named_parameters - } - - self._register_hooks() - - def set_backward_passes_per_step(self, passes): - self.backward_passes_per_step = passes - for p in self._allreduce_delay: - self._allreduce_delay[p] = self.backward_passes_per_step - - def _register_hooks(self): - for param_group in self.param_groups: - for p in param_group['params']: - if p.requires_grad: - p.grad = p.data.new(p.size()).zero_() - self._requires_update.add(p) - p_tmp = p.expand_as(p) - grad_acc = p_tmp.grad_fn.next_functions[0][0] - grad_acc.register_hook(self._make_hook(p)) - self._grad_accs.append(grad_acc) - - def _allreduce_grad_async(self, p): - # Delta optimizer implements this logic: - # start = current.copy() - # step() -> computes 'current - \alpha.f(g)' where f is - # optimizer logic and g is the gradient - # delta = current-start - # allreduce_(delta) - # start += delta - # current = start - # In order to suppport this logic using function hook to improve performance, - # we do: - # delta = (start - \alpha.f(g)) - start - # = -\alpha.f(g) - # set start to zero and step computes -\alpha.f(g) - # where f is the underlying optimizer logic - - name = self._parameter_names.get(p) - start = self._starting_models[p] - - stashed_params = [] - for group in self.param_groups: - stashed_params.append(group['params']) - # only want to step on p - if any([p is v for v in group['params']]): - group['params'] = [p] - else: - group['params'] = [] - - start.data.copy_(p) - - super(self.__class__, self).step() - - # compute delta = curr - start - p.data.sub_(start) - - # allreduce as before - tensor_compressed, ctx = self._compression.compress(p) - handle = allreduce_async_(tensor_compressed.data, name=name, op=Adasum) - - # reset stashed parameters - for stashed, group in zip(stashed_params, self.param_groups): - group['params'] = stashed - - return handle, ctx - - def _make_hook(self, p): - def hook(*ignore): - if p in self._handles and self._handles[p][0] is not None: - if self._allreduce_delay[p] <= 0: - raise AssertionError( - "Gradients were computed more than " - "backward_passes_per_step times before call " - "to step(). Increase backward_passes_per_step to " - "accumulate gradients locally.") - assert not p.grad.requires_grad - assert self._allreduce_delay[p] > 0 - handle, ctx = None, None - self._allreduce_delay[p] -= 1 - if self._allreduce_delay[p] == 0: - handle, ctx = self._allreduce_grad_async(p) - self._handles[p] = (handle, ctx) - return hook - - def synchronize(self): - pass - - @contextmanager - def skip_synchronize(self): - raise AssertionError("Skipping synchronization is not supported when using Adasum optimizer.") - - def step(self, closure=None): - loss = None - if closure is not None: - loss = closure() - - missing_p = self._requires_update - set(self._handles.keys()) - for p in missing_p: - handle, ctx = self._allreduce_grad_async(p) - self._handles[p] = (handle, ctx) - - for p, (handle, ctx) in self._handles.items(): - # This means step() is called before backward_passes_per_steps finished. - # We do a synchoronous allreduce here. - if not handle: - handle, ctx = self._allreduce_grad_async(p) - self._handles[p] = (handle, ctx) - delta = synchronize(handle) - delta = self._compression.decompress(delta, ctx) - start = self._starting_models[p] - start.data.add_(delta.data) - p.data.copy_(start) - self._allreduce_delay[p] = self.backward_passes_per_step - self._handles.clear() - return loss - - def zero_grad(self): - if self._handles: - raise AssertionError("optimizer.zero_grad() was called after loss.backward() " - "but before optimizer.step() or optimizer.synchronize(). " - "This is prohibited as it can cause a race condition.") - return super(self.__class__, self).zero_grad() - - -def DistributedOptimizer(optimizer, named_parameters=None, - compression=Compression.none, - backward_passes_per_step=1, - op=Average): - """ - An optimizer that wraps another torch.optim.Optimizer, using an allreduce to - combine gradient values before applying gradients to model weights. - - Allreduce operations are executed after each gradient is computed by ``loss.backward()`` - in parallel with each other. The ``step()`` method ensures that all allreduce operations are - finished before applying gradients to the model. - - DistributedOptimizer exposes the ``synchronize()`` method, which forces allreduce operations - to finish before continuing the execution. It's useful in conjunction with gradient - clipping, or other operations that modify gradients in place before ``step()`` is executed. - Make sure to use ``optimizer.skip_synchronize()`` if you're calling ``synchronize()`` - in your code. - - Example of gradient clipping: - - .. code-block:: python - - output = model(data) - loss = F.nll_loss(output, target) - loss.backward() - optimizer.synchronize() - torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) - with optimizer.skip_synchronize(): - optimizer.step() - - Arguments: - optimizer: Optimizer to use for computing gradients and applying updates. - named_parameters: A mapping between parameter names and values. Used for naming of - allreduce operations. Typically just ``model.named_parameters()``. - compression: Compression algorithm used during allreduce to reduce the amount - of data sent during the each parameter update step. Defaults to - not using compression. - backward_passes_per_step: Number of expected backward passes to perform - before calling step()/synchronize(). This - allows accumulating gradients over multiple - mini-batches before reducing and applying them. - op: The reduction operation to use when combining gradients across different ranks. - """ - # We dynamically create a new class that inherits from the optimizer that was passed in. - # The goal is to override the `step()` method with an allreduce implementation. - - if op != Adasum or size() == 1: - cls = type(optimizer.__class__.__name__, (optimizer.__class__,), - dict(_DistributedOptimizer.__dict__)) - return cls(optimizer.param_groups, named_parameters, compression, backward_passes_per_step, op) - else: - cls = type(optimizer.__class__.__name__, (optimizer.__class__,), - dict(_DistributedAdasumOptimizer.__dict__)) - return cls(optimizer.param_groups, named_parameters, compression, backward_passes_per_step) - - -def broadcast_parameters(params, root_rank): - """ - Broadcasts the parameters from root rank to all other processes. - Typical usage is to broadcast the ``model.state_dict()``, - ``model.named_parameters()``, or ``model.parameters()``. - - Arguments: - params: One of the following: - - list of parameters to broadcast - - dict of parameters to broadcast - root_rank: The rank of the process from which parameters will be - broadcasted to all other processes. - """ - if isinstance(params, dict): - params = sorted(params.items()) - elif isinstance(params, list): - # support both named_parameters() and regular parameters() - params = [p if isinstance(p, tuple) else (None, p) for p in params] - else: - raise ValueError('invalid params of type: %s' % type(params)) - - # Run asynchronous broadcasts. - handles = [] - for name, p in params: - handle = broadcast_async_(p, root_rank, name) - handles.append(handle) - - # Wait for completion. - for handle in handles: - synchronize(handle) - - -def broadcast_optimizer_state(optimizer, root_rank): - """ - Broadcasts an optimizer state from root rank to all other processes. - - Arguments: - optimizer: An optimizer. - root_rank: The rank of the process from which the optimizer will be - broadcasted to all other processes. - """ - if isinstance(optimizer, torch.optim.LBFGS): - # TODO(travis): L-BFGS cannot be easily supported without serializing - # the entire state_dict, as its structure is deeply nested and contains - # None type parameter values - raise ValueError('cannot broadcast torch.optim.LBFGS state') - - state_dict = optimizer.state_dict() - - # Newly created optimizers will not have their state initialized, so - # do that initialization here - if len(state_dict['state']) == 0: - for group in optimizer.param_groups: - for p in group['params']: - if p.requires_grad and id(p) not in state_dict['state']: - p.grad = p.data.new(p.size()).zero_() - # This function accepts a torch.optim.Optimizer or a DistributedOptimizer - # wrapped around a torch optimizer. Calling step() with a DistributedOptimizer - # forces allreduce on all model parameters, which will result in deadlock - # unless every rank calls step(). Therefore, to finish state initialization - # only call optimizer.step() with a torch.optim.Optimizer. - if optimizer.__module__ == DistributedOptimizer.__module__: - super(optimizer.__class__, optimizer).step() - else: - optimizer.step() - state_dict = optimizer.state_dict() - - # If the state_dict is still empty after initialization, then - # the optimizer is stateless, and there is nothing to broadcast. - # Furthermore, attempting to access the state dict would result in - # an error. - if len(state_dict['state']) == 0: - return - - params = [] - callbacks = {} - occurrences = collections.defaultdict(int) - - # Returns the full type structure of the possibly nested objects for recursive casting back - def _get_types(x): - if isinstance(x, Iterable): - return type(x), [_get_types(xi) for xi in x] - else: - return type(x) - - # Casts an object encoded in a tensor back into its original type and subtypes - def _recursive_cast(x, dtype): - if isinstance(dtype, tuple): - t, dtypes = dtype - x = t(x) - return t([_recursive_cast(x[i], dtypes[i]) for i in range(len(x))]) - else: - return dtype(x) - - # Some optimizer parameters may be represented as scalars instead of - # tensors. In such cases, we need to wrap the scalar in a tensor, then - # broadcast, then update the appropriate value in the state_dict with the - # new unwrapped scalar value via a callback. - def _create_callback(pid, name, t, p): - def _from_tensor(): - state_dict['state'][pid][name] = t(p.cpu().numpy()[0]) - return _from_tensor - - def _create_option_callback(index, option_key, option_tensor, dtypes): - def _from_tensor(): - optimizer.param_groups[index][option_key] = _recursive_cast(option_tensor.cpu().numpy()[0], dtypes) - return _from_tensor - - # Param groups are an ordered list, normally there is only one per model, - # but users can add additional param groups for example to train - # previously frozen layers - for index, group in enumerate(state_dict['param_groups']): - # Broadcast options like learning rate - for option_key, option_value in group.items(): - if option_key == 'params': - continue - - # Options like the learning rate are scalar, and need to be wrapped in tensors - key = '%s.%d' % (option_key, index) - dtypes = _get_types(option_value) - option_tensor = torch.Tensor([option_value]) - callbacks[key] = _create_option_callback(index, option_key, option_tensor, dtypes) - params.append((key, option_tensor)) - - # The params list here is ordered by the layers in the model - for pid in group['params']: - if pid not in state_dict['state']: - # The param has not set requires_grad, so skip broadcast - continue - - param_state = state_dict['state'][pid] - for name, p in param_state.items(): - # Some parameter names may appear more than once, in which - # case we ensure they have a unique identifier defined by - # their order - occurrences[name] += 1 - key = '%s.%d' % (str(name), occurrences[name]) - - if not torch.is_tensor(p): - # Wrap the scalar in a FloatTensor, and remember its type - # so we can cast it back after unwrapping - t = type(p) - p = torch.Tensor([p]) - callbacks[key] = _create_callback(pid, name, t, p) - - params.append((key, p)) - - # Synchronized broadcast of all parameters - broadcast_parameters(params, root_rank) - - # Post-broadcast cleanup for non-tensor parameters - for key, p in params: - if key in callbacks: - callbacks[key]() - - -def broadcast_object(obj, root_rank, name=None): - """ - Serializes and broadcasts an object from root rank to all other processes. - Typical usage is to broadcast the `optimizer.state_dict()`, for example: - - .. code-block:: python - - state_dict = broadcast_object(optimizer.state_dict(), 0) - if hvd.rank() > 0: - optimizer.load_state_dict(state_dict) - - Arguments: - obj: An object capable of being serialized without losing any context. - root_rank: The rank of the process from which parameters will be - broadcasted to all other processes. - name: Optional name to use during broadcast, will default to the class - type. - Returns: - The object that was broadcast from the `root_rank`. - """ - if name is None: - name = str(type(obj)) - - if rank() == root_rank: - b = io.BytesIO() - cloudpickle.dump(obj, b) - t = torch.ByteTensor(bytearray(b.getvalue())) - sz = torch.IntTensor([t.shape[0]]) - broadcast_(sz, root_rank, name + '.sz') - else: - sz = torch.IntTensor([0]) - broadcast_(sz, root_rank, name + '.sz') - t = torch.ByteTensor(sz.tolist()[0]) - - broadcast_(t, root_rank, name + '.t') - - if rank() != root_rank: - buf = io.BytesIO(t.numpy().tobytes()) - obj = cloudpickle.load(buf) - - return obj diff --git a/horovod/torch/elastic.py b/horovod/torch/elastic.py new file mode 100644 index 0000000000..6bc819e3fb --- /dev/null +++ b/horovod/torch/elastic.py @@ -0,0 +1,85 @@ +# Copyright 2020 Uber Technologies, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import + +import copy + +from horovod.common.elastic import run_fn, ObjectState +from horovod.torch.mpi_ops import init, rank, shutdown +from horovod.torch.functions import broadcast_object, broadcast_optimizer_state, broadcast_parameters + + +def run(func): + """Decorator used to run the elastic training process. + + The purpose of this decorator is to allow for uninterrupted execution of the wrapped function + across multiple workers in parallel, as workers come and go from the system. When a new worker is added, + its state needs to be brought to the same point as the other workers, which is done by synchronizing + the state object before executing `func`. + + When a worker is added or removed, other workers will raise an exception to bring them back to such a sync + point before executing `func` again. This ensures that workers do not diverge when such reset events occur. + + It's important to note that collective operations (e.g., broadcast, allreduce) cannot be the call to + the wrapped function. Otherwise, new workers could execute these operations during their initialization + while other workers are attempting to sync state, resulting in deadlock. + + Args: + func: a wrapped function taking any number of args or kwargs. The first argument + must be a `horovod.common.elastic.State` object used to synchronize state across + workers. + """ + return run_fn(func, _reset) + + +def _reset(): + shutdown() + init() + + +class TorchState(ObjectState): + """State representation of a PyTorch model and optimizer. + + Args: + model: PyTorch model. + optimizer: PyTorch optimizer. + kwargs: Additional properties to sync, will be exposed as attributes of the object. + """ + def __init__(self, model, optimizer, **kwargs): + self.model = model + self._saved_model_state = copy.deepcopy(model.state_dict()) + + self.optimizer = optimizer + self._saved_optimizer_state = copy.deepcopy(optimizer.state_dict()) + + super(TorchState, self).__init__(bcast_object=broadcast_object, + get_rank=rank, + **kwargs) + + def save(self): + self._saved_model_state = copy.deepcopy(self.model.state_dict()) + self._saved_optimizer_state = copy.deepcopy(self.optimizer.state_dict()) + super(TorchState, self).save() + + def restore(self): + self.model.load_state_dict(self._saved_model_state) + self.optimizer.load_state_dict(self._saved_optimizer_state) + super(TorchState, self).restore() + + def sync(self): + broadcast_parameters(self.model.state_dict(), root_rank=0) + broadcast_optimizer_state(self.optimizer, root_rank=0) + super(TorchState, self).sync() diff --git a/horovod/torch/functions.py b/horovod/torch/functions.py new file mode 100644 index 0000000000..12c5845c81 --- /dev/null +++ b/horovod/torch/functions.py @@ -0,0 +1,231 @@ +# Copyright 2020 Uber Technologies, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import + +import collections +import io + +import cloudpickle +import torch + +from horovod.torch.mpi_ops import broadcast_, broadcast_async_ +from horovod.torch.mpi_ops import synchronize +from horovod.torch.mpi_ops import rank +from horovod.torch.optimizer import DistributedOptimizer + +try: + from collections.abc import Iterable +except ImportError: + from collections import Iterable + + +def broadcast_parameters(params, root_rank): + """ + Broadcasts the parameters from root rank to all other processes. + Typical usage is to broadcast the ``model.state_dict()``, + ``model.named_parameters()``, or ``model.parameters()``. + + Arguments: + params: One of the following: + - list of parameters to broadcast + - dict of parameters to broadcast + root_rank: The rank of the process from which parameters will be + broadcasted to all other processes. + """ + if isinstance(params, dict): + params = sorted(params.items()) + elif isinstance(params, list): + # support both named_parameters() and regular parameters() + params = [p if isinstance(p, tuple) else (None, p) for p in params] + else: + raise ValueError('invalid params of type: %s' % type(params)) + + # Run asynchronous broadcasts. + handles = [] + for name, p in params: + handle = broadcast_async_(p, root_rank, name) + handles.append(handle) + + # Wait for completion. + for handle in handles: + synchronize(handle) + + +def broadcast_optimizer_state(optimizer, root_rank): + """ + Broadcasts an optimizer state from root rank to all other processes. + + Arguments: + optimizer: An optimizer. + root_rank: The rank of the process from which the optimizer will be + broadcasted to all other processes. + """ + if isinstance(optimizer, torch.optim.LBFGS): + # TODO(travis): L-BFGS cannot be easily supported without serializing + # the entire state_dict, as its structure is deeply nested and contains + # None type parameter values + raise ValueError('cannot broadcast torch.optim.LBFGS state') + + state_dict = optimizer.state_dict() + + # Newly created optimizers will not have their state initialized, so + # do that initialization here + if len(state_dict['state']) == 0: + for group in optimizer.param_groups: + for p in group['params']: + if p.requires_grad and id(p) not in state_dict['state']: + p.grad = p.data.new(p.size()).zero_() + # This function accepts a torch.optim.Optimizer or a DistributedOptimizer + # wrapped around a torch optimizer. Calling step() with a DistributedOptimizer + # forces allreduce on all model parameters, which will result in deadlock + # unless every rank calls step(). Therefore, to finish state initialization + # only call optimizer.step() with a torch.optim.Optimizer. + if optimizer.__module__ == DistributedOptimizer.__module__: + super(optimizer.__class__, optimizer).step() + else: + optimizer.step() + state_dict = optimizer.state_dict() + + # If the state_dict is still empty after initialization, then + # the optimizer is stateless, and there is nothing to broadcast. + # Furthermore, attempting to access the state dict would result in + # an error. + if len(state_dict['state']) == 0: + return + + params = [] + callbacks = {} + occurrences = collections.defaultdict(int) + + # Returns the full type structure of the possibly nested objects for recursive casting back + def _get_types(x): + if isinstance(x, Iterable): + return type(x), [_get_types(xi) for xi in x] + else: + return type(x) + + # Casts an object encoded in a tensor back into its original type and subtypes + def _recursive_cast(x, dtype): + if isinstance(dtype, tuple): + t, dtypes = dtype + x = t(x) + return t([_recursive_cast(x[i], dtypes[i]) for i in range(len(x))]) + else: + return dtype(x) + + # Some optimizer parameters may be represented as scalars instead of + # tensors. In such cases, we need to wrap the scalar in a tensor, then + # broadcast, then update the appropriate value in the state_dict with the + # new unwrapped scalar value via a callback. + def _create_callback(pid, name, t, p): + def _from_tensor(): + state_dict['state'][pid][name] = t(p.cpu().numpy()[0]) + return _from_tensor + + def _create_option_callback(index, option_key, option_tensor, dtypes): + def _from_tensor(): + optimizer.param_groups[index][option_key] = _recursive_cast(option_tensor.cpu().numpy()[0], dtypes) + return _from_tensor + + # Param groups are an ordered list, normally there is only one per model, + # but users can add additional param groups for example to train + # previously frozen layers + for index, group in enumerate(state_dict['param_groups']): + # Broadcast options like learning rate + for option_key, option_value in group.items(): + if option_key == 'params': + continue + + # Options like the learning rate are scalar, and need to be wrapped in tensors + key = '%s.%d' % (option_key, index) + dtypes = _get_types(option_value) + option_tensor = torch.Tensor([option_value]) + callbacks[key] = _create_option_callback(index, option_key, option_tensor, dtypes) + params.append((key, option_tensor)) + + # The params list here is ordered by the layers in the model + for pid in group['params']: + if pid not in state_dict['state']: + # The param has not set requires_grad, so skip broadcast + continue + + param_state = state_dict['state'][pid] + for name, p in param_state.items(): + # Some parameter names may appear more than once, in which + # case we ensure they have a unique identifier defined by + # their order + occurrences[name] += 1 + key = '%s.%d' % (str(name), occurrences[name]) + + if not torch.is_tensor(p): + # Wrap the scalar in a FloatTensor, and remember its type + # so we can cast it back after unwrapping + t = type(p) + p = torch.Tensor([p]) + callbacks[key] = _create_callback(pid, name, t, p) + + params.append((key, p)) + + # Synchronized broadcast of all parameters + broadcast_parameters(params, root_rank) + + # Post-broadcast cleanup for non-tensor parameters + for key, p in params: + if key in callbacks: + callbacks[key]() + + +def broadcast_object(obj, root_rank=0, name=None): + """ + Serializes and broadcasts an object from root rank to all other processes. + Typical usage is to broadcast the `optimizer.state_dict()`, for example: + + .. code-block:: python + + state_dict = broadcast_object(optimizer.state_dict(), 0) + if hvd.rank() > 0: + optimizer.load_state_dict(state_dict) + + Arguments: + obj: An object capable of being serialized without losing any context. + root_rank: The rank of the process from which parameters will be + broadcasted to all other processes. + name: Optional name to use during broadcast, will default to the class + type. + Returns: + The object that was broadcast from the `root_rank`. + """ + if name is None: + name = type(obj).__name__ + + if rank() == root_rank: + b = io.BytesIO() + cloudpickle.dump(obj, b) + t = torch.ByteTensor(bytearray(b.getvalue())) + sz = torch.IntTensor([t.shape[0]]) + broadcast_(sz, root_rank, name + '.sz') + else: + sz = torch.IntTensor([0]) + broadcast_(sz, root_rank, name + '.sz') + t = torch.ByteTensor(sz.tolist()[0]) + + broadcast_(t, root_rank, name + '.t') + + if rank() != root_rank: + buf = io.BytesIO(t.numpy().tobytes()) + obj = cloudpickle.load(buf) + + return obj diff --git a/horovod/torch/mpi_ops.py b/horovod/torch/mpi_ops.py index a5ad933412..567fa82111 100644 --- a/horovod/torch/mpi_ops.py +++ b/horovod/torch/mpi_ops.py @@ -39,6 +39,7 @@ _NULL = mpi_lib._ffi.NULL _basics = _HorovodBasics(__file__, 'mpi_lib_impl', '_mpi_lib_impl') +from horovod.common.exceptions import HorovodInternalError from horovod.common.util import get_average_backwards_compatibility_fun, gpu_available, num_rank_is_power_2 from horovod.torch.compression import Compression @@ -123,8 +124,11 @@ def _allreduce_async(tensor, output, name, op): true_op = Sum if op == Average else op function = _check_function(_allreduce_function_factory, tensor) - handle = getattr(mpi_lib, function)(tensor, output, divisor, - name.encode() if name is not None else _NULL, true_op) + try: + handle = getattr(mpi_lib, function)(tensor, output, divisor, + name.encode() if name is not None else _NULL, true_op) + except RuntimeError as e: + raise HorovodInternalError(e) _handle_map[handle] = (tensor, output) return handle @@ -275,8 +279,11 @@ def _allgather_function_factory(tensor): def _allgather_async(tensor, output, name): function = _check_function(_allgather_function_factory, tensor) - handle = getattr(mpi_lib, function)( - tensor, output, name.encode() if name is not None else _NULL) + try: + handle = getattr(mpi_lib, function)( + tensor, output, name.encode() if name is not None else _NULL) + except RuntimeError as e: + raise HorovodInternalError(e) _handle_map[handle] = (tensor, output) return handle @@ -355,8 +362,11 @@ def _broadcast_function_factory(tensor): def _broadcast_async(tensor, output, root_rank, name): function = _check_function(_broadcast_function_factory, tensor) - handle = getattr(mpi_lib, function)( - tensor, output, root_rank, name.encode() if name is not None else _NULL) + try: + handle = getattr(mpi_lib, function)( + tensor, output, root_rank, name.encode() if name is not None else _NULL) + except RuntimeError as e: + raise HorovodInternalError(e) _handle_map[handle] = (tensor, output) return handle @@ -502,9 +512,13 @@ def synchronize(handle): """ if handle not in _handle_map: return - mpi_lib.horovod_torch_wait_and_clear(handle) - _, output = _handle_map.pop(handle) - return output + + try: + mpi_lib.horovod_torch_wait_and_clear(handle) + _, output = _handle_map.pop(handle) + return output + except RuntimeError as e: + raise HorovodInternalError(e) def join(device=-1): @@ -521,4 +535,8 @@ def join(device=-1): """ if not _v2_api: raise NotImplementedError("Join Op is not supported for PyTorch < 1.0") - return mpi_lib.horovod_torch_join(device) + + try: + return mpi_lib.horovod_torch_join(device) + except RuntimeError as e: + raise HorovodInternalError(e) diff --git a/horovod/torch/optimizer.py b/horovod/torch/optimizer.py new file mode 100644 index 0000000000..55160da7a2 --- /dev/null +++ b/horovod/torch/optimizer.py @@ -0,0 +1,425 @@ +# Copyright 2019 Uber Technologies, Inc. All Rights Reserved. +# Modifications copyright Microsoft +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import warnings + +from contextlib import contextmanager + +import torch + +from horovod.torch.compression import Compression +from horovod.torch.mpi_ops import allreduce_async_ +from horovod.torch.mpi_ops import synchronize +from horovod.torch.mpi_ops import size +from horovod.torch.mpi_ops import Average, Adasum + + +class _DistributedOptimizer(torch.optim.Optimizer): + def __init__(self, params, named_parameters, compression, + backward_passes_per_step=1, op=Average): + super(self.__class__, self).__init__(params) + self._compression = compression + + if named_parameters is not None: + named_parameters = list(named_parameters) + else: + named_parameters = [('allreduce.noname.%s' % i, v) + for param_group in self.param_groups + for i, v in enumerate(param_group['params'])] + # make sure that named_parameters are tuples + if any([not isinstance(p, tuple) for p in named_parameters]): + raise ValueError('named_parameters should be a sequence of ' + 'tuples (name, parameter), usually produced by ' + 'model.named_parameters().') + + dups = _DistributedOptimizer.find_duplicates([k for k, _ in named_parameters]) + if len(dups) > 0: + raise ValueError('Parameter names in named_parameters must be unique. ' + 'Found duplicates: %s' % ', '.join(dups)) + + all_param_ids = {id(v) + for param_group in self.param_groups + for v in param_group['params']} + named_param_ids = {id(v) for k, v in named_parameters} + unnamed_param_ids = all_param_ids - named_param_ids + if len(unnamed_param_ids): + raise ValueError('named_parameters was specified, but one or more model ' + 'parameters were not named. Python object ids: ' + '%s' % ', '.join(str(id) for id in unnamed_param_ids)) + + self._parameter_names = {v: k for k, v in sorted(named_parameters)} + self.backward_passes_per_step = backward_passes_per_step + self._allreduce_delay = {v: self.backward_passes_per_step + for _, v in sorted(named_parameters)} + self.op = op + self._handles = {} + self._grad_accs = [] + self._requires_update = set() + self._synchronized = False + self._should_synchronize = True + if size() > 1 or os.environ.get('HOROVOD_ELASTIC') == '1': + self._register_hooks() + + def load_state_dict(self, *args, **kwargs): + self._handles = {} + self._synchronized = False + self._should_synchronize = True + for p in self._allreduce_delay: + self._allreduce_delay[p] = self.backward_passes_per_step + super(self.__class__, self).load_state_dict(*args, **kwargs) + + @staticmethod + def find_duplicates(lst): + seen = set() + dups = set() + for el in lst: + if el in seen: + dups.add(el) + seen.add(el) + return dups + + def set_backward_passes_per_step(self, passes): + self.backward_passes_per_step = passes + for p in self._allreduce_delay: + self._allreduce_delay[p] = self.backward_passes_per_step + + def _register_hooks(self): + for param_group in self.param_groups: + for p in param_group['params']: + if p.requires_grad: + p.grad = p.data.new(p.size()).zero_() + self._requires_update.add(p) + p_tmp = p.expand_as(p) + grad_acc = p_tmp.grad_fn.next_functions[0][0] + grad_acc.register_hook(self._make_hook(p)) + self._grad_accs.append(grad_acc) + + def _allreduce_grad_async(self, p): + name = self._parameter_names.get(p) + tensor = p.grad + tensor_compressed, ctx = self._compression.compress(tensor) + + handle = allreduce_async_(tensor_compressed, name=name, op=self.op) + return handle, ctx + + def _make_hook(self, p): + def hook(*ignore): + if p in self._handles and self._handles[p][0] is not None: + if self._allreduce_delay[p] <= 0: + raise AssertionError( + "Gradients were computed more than " + "backward_passes_per_step times before call " + "to step(). Increase backward_passes_per_step to " + "accumulate gradients locally.") + assert not p.grad.requires_grad + assert self._allreduce_delay[p] > 0 + handle, ctx = None, None + self._allreduce_delay[p] -= 1 + if self._allreduce_delay[p] == 0: + handle, ctx = self._allreduce_grad_async(p) + self._handles[p] = (handle, ctx) + return hook + + def synchronize(self): + missing_p = self._requires_update - set(self._handles.keys()) + for p in missing_p: + handle, ctx = self._allreduce_grad_async(p) + self._handles[p] = (handle, ctx) + + for p, value in self._handles.items(): + handle, ctx = value + if handle is None: + handle, ctx = self._allreduce_grad_async(p) + self._handles[p] = (handle, ctx) + for p, (handle, _) in self._handles.items(): + output = synchronize(handle) + self._allreduce_delay[p] = self.backward_passes_per_step + p.grad.set_(self._compression.decompress(output, ctx)) + self._handles.clear() + + self._synchronized = True + + @contextmanager + def skip_synchronize(self): + """ + A context manager used to specify that optimizer.step() should + not perform synchronization. + + It's typically used in a following pattern: + + .. code-block:: python + + optimizer.synchronize() + with optimizer.skip_synchronize(): + optimizer.step() + """ + self._should_synchronize = False + try: + yield + finally: + self._should_synchronize = True + + def step(self, closure=None): + if self._should_synchronize: + if self._synchronized: + warnings.warn("optimizer.step() called without " + "optimizer.skip_synchronize() context after " + "optimizer.synchronize(). This can cause training " + "slowdown. You may want to consider using " + "optimizer.skip_synchronize() context if you use " + "optimizer.synchronize() in your code.") + self.synchronize() + self._synchronized = False + return super(self.__class__, self).step(closure) + + def zero_grad(self): + if self._handles: + raise AssertionError("optimizer.zero_grad() was called after loss.backward() " + "but before optimizer.step() or optimizer.synchronize(). " + "This is prohibited as it can cause a race condition.") + return super(self.__class__, self).zero_grad() + + +class _DistributedAdasumOptimizer(torch.optim.Optimizer): + def __init__(self, params, named_parameters, compression, + backward_passes_per_step=1): + super(self.__class__, self).__init__(params) + + self._compression = compression + + if named_parameters is not None: + named_parameters = list(named_parameters) + else: + named_parameters = [('allreduce.noname.%s' % i, v) + for param_group in self.param_groups + for i, v in enumerate(param_group['params'])] + + # make sure that named_parameters are tuples + if any([not isinstance(p, tuple) for p in named_parameters]): + raise ValueError('named_parameters should be a sequence of ' + 'tuples (name, parameter), usually produced by ' + 'model.named_parameters().') + + dups = _DistributedOptimizer.find_duplicates([k for k, _ in named_parameters]) + if len(dups) > 0: + raise ValueError('Parameter names in named_parameters must be unique. ' + 'Found duplicates: %s' % ', '.join(dups)) + + all_param_ids = {id(v) + for param_group in self.param_groups + for v in param_group['params']} + named_param_ids = {id(v) for k, v in named_parameters} + unnamed_param_ids = all_param_ids - named_param_ids + if len(unnamed_param_ids): + raise ValueError('named_parameters was specified, but one or more model ' + 'parameters were not named. Python object ids: ' + '%s' % ', '.join(str(id) for id in unnamed_param_ids)) + + self._parameter_names = {v: k for k, v in sorted(named_parameters)} + self.backward_passes_per_step = backward_passes_per_step + self._allreduce_delay = {v: self.backward_passes_per_step + for _, v in sorted(named_parameters)} + self._handles = {} + self._grad_accs = [] + self._requires_update = set() + self._synchronized = False + self._should_synchronize = True + + self._starting_models = { + p : torch.zeros_like(p, requires_grad=False) + for _, p in named_parameters + } + + self._register_hooks() + + def set_backward_passes_per_step(self, passes): + self.backward_passes_per_step = passes + for p in self._allreduce_delay: + self._allreduce_delay[p] = self.backward_passes_per_step + + def _register_hooks(self): + for param_group in self.param_groups: + for p in param_group['params']: + if p.requires_grad: + p.grad = p.data.new(p.size()).zero_() + self._requires_update.add(p) + p_tmp = p.expand_as(p) + grad_acc = p_tmp.grad_fn.next_functions[0][0] + grad_acc.register_hook(self._make_hook(p)) + self._grad_accs.append(grad_acc) + + def _allreduce_grad_async(self, p): + # Delta optimizer implements this logic: + # start = current.copy() + # step() -> computes 'current - \alpha.f(g)' where f is + # optimizer logic and g is the gradient + # delta = current-start + # allreduce_(delta) + # start += delta + # current = start + # In order to suppport this logic using function hook to improve performance, + # we do: + # delta = (start - \alpha.f(g)) - start + # = -\alpha.f(g) + # set start to zero and step computes -\alpha.f(g) + # where f is the underlying optimizer logic + + name = self._parameter_names.get(p) + start = self._starting_models[p] + + stashed_params = [] + for group in self.param_groups: + stashed_params.append(group['params']) + # only want to step on p + if any([p is v for v in group['params']]): + group['params'] = [p] + else: + group['params'] = [] + + start.data.copy_(p) + + super(self.__class__, self).step() + + # compute delta = curr - start + p.data.sub_(start) + + # allreduce as before + tensor_compressed, ctx = self._compression.compress(p) + handle = allreduce_async_(tensor_compressed.data, name=name, op=Adasum) + + # reset stashed parameters + for stashed, group in zip(stashed_params, self.param_groups): + group['params'] = stashed + + return handle, ctx + + def _make_hook(self, p): + def hook(*ignore): + if p in self._handles and self._handles[p][0] is not None: + if self._allreduce_delay[p] <= 0: + raise AssertionError( + "Gradients were computed more than " + "backward_passes_per_step times before call " + "to step(). Increase backward_passes_per_step to " + "accumulate gradients locally.") + assert not p.grad.requires_grad + assert self._allreduce_delay[p] > 0 + handle, ctx = None, None + self._allreduce_delay[p] -= 1 + if self._allreduce_delay[p] == 0: + handle, ctx = self._allreduce_grad_async(p) + self._handles[p] = (handle, ctx) + return hook + + def synchronize(self): + pass + + @contextmanager + def skip_synchronize(self): + raise AssertionError("Skipping synchronization is not supported when using Adasum optimizer.") + + def step(self, closure=None): + loss = None + if closure is not None: + loss = closure() + + missing_p = self._requires_update - set(self._handles.keys()) + for p in missing_p: + handle, ctx = self._allreduce_grad_async(p) + self._handles[p] = (handle, ctx) + + for p, (handle, ctx) in self._handles.items(): + # This means step() is called before backward_passes_per_steps finished. + # We do a synchoronous allreduce here. + if not handle: + handle, ctx = self._allreduce_grad_async(p) + self._handles[p] = (handle, ctx) + delta = synchronize(handle) + delta = self._compression.decompress(delta, ctx) + start = self._starting_models[p] + start.data.add_(delta.data) + p.data.copy_(start) + self._allreduce_delay[p] = self.backward_passes_per_step + self._handles.clear() + return loss + + def zero_grad(self): + if self._handles: + raise AssertionError("optimizer.zero_grad() was called after loss.backward() " + "but before optimizer.step() or optimizer.synchronize(). " + "This is prohibited as it can cause a race condition.") + return super(self.__class__, self).zero_grad() + + +def DistributedOptimizer(optimizer, named_parameters=None, + compression=Compression.none, + backward_passes_per_step=1, + op=Average): + """ + An optimizer that wraps another torch.optim.Optimizer, using an allreduce to + combine gradient values before applying gradients to model weights. + + Allreduce operations are executed after each gradient is computed by ``loss.backward()`` + in parallel with each other. The ``step()`` method ensures that all allreduce operations are + finished before applying gradients to the model. + + DistributedOptimizer exposes the ``synchronize()`` method, which forces allreduce operations + to finish before continuing the execution. It's useful in conjunction with gradient + clipping, or other operations that modify gradients in place before ``step()`` is executed. + Make sure to use ``optimizer.skip_synchronize()`` if you're calling ``synchronize()`` + in your code. + + Example of gradient clipping: + + .. code-block:: python + + output = model(data) + loss = F.nll_loss(output, target) + loss.backward() + optimizer.synchronize() + torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) + with optimizer.skip_synchronize(): + optimizer.step() + + Arguments: + optimizer: Optimizer to use for computing gradients and applying updates. + named_parameters: A mapping between parameter names and values. Used for naming of + allreduce operations. Typically just ``model.named_parameters()``. + compression: Compression algorithm used during allreduce to reduce the amount + of data sent during the each parameter update step. Defaults to + not using compression. + backward_passes_per_step: Number of expected backward passes to perform + before calling step()/synchronize(). This + allows accumulating gradients over multiple + mini-batches before reducing and applying them. + op: The reduction operation to use when combining gradients across different ranks. + """ + # We dynamically create a new class that inherits from the optimizer that was passed in. + # The goal is to override the `step()` method with an allreduce implementation. + + if op != Adasum or size() == 1: + cls = type(optimizer.__class__.__name__, (optimizer.__class__,), + dict(_DistributedOptimizer.__dict__)) + return cls(optimizer.param_groups, named_parameters, compression, backward_passes_per_step, op) + else: + cls = type(optimizer.__class__.__name__, (optimizer.__class__,), + dict(_DistributedAdasumOptimizer.__dict__)) + return cls(optimizer.param_groups, named_parameters, compression, backward_passes_per_step) diff --git a/test/data/expected_buildkite_pipeline.yaml b/test/data/expected_buildkite_pipeline.yaml index 5a3981b019..bffae9c2ea 100644 --- a/test/data/expected_buildkite_pipeline.yaml +++ b/test/data/expected_buildkite_pipeline.yaml @@ -1,34 +1,4 @@ steps: -- label: ':docker: Build test-cpu-openmpi-py2_7-tf1_1_0-keras2_0_0-torch0_4_0-mxnet1_4_1-pyspark2_3_2' - plugins: - - docker-compose#6b0df8a98ff97f42f4944dbb745b5b8cbf04b78c: - build: test-cpu-openmpi-py2_7-tf1_1_0-keras2_0_0-torch0_4_0-mxnet1_4_1-pyspark2_3_2 - image-repository: 823773083436.dkr.ecr.us-east-1.amazonaws.com/buildkite - cache-from: test-cpu-openmpi-py2_7-tf1_1_0-keras2_0_0-torch0_4_0-mxnet1_4_1-pyspark2_3_2:823773083436.dkr.ecr.us-east-1.amazonaws.com/buildkite:SLUG-test-cpu-openmpi-py2_7-tf1_1_0-keras2_0_0-torch0_4_0-mxnet1_4_1-pyspark2_3_2-latest - config: docker-compose.test.yml - push-retries: 5 - - ecr#v1.2.0: - login: true - timeout_in_minutes: 30 - retry: - automatic: true - agents: - queue: cpu -- label: ':docker: Build test-cpu-openmpi-py3_6-tf1_1_0-keras2_0_0-torch0_4_0-mxnet1_4_1-pyspark2_3_2' - plugins: - - docker-compose#6b0df8a98ff97f42f4944dbb745b5b8cbf04b78c: - build: test-cpu-openmpi-py3_6-tf1_1_0-keras2_0_0-torch0_4_0-mxnet1_4_1-pyspark2_3_2 - image-repository: 823773083436.dkr.ecr.us-east-1.amazonaws.com/buildkite - cache-from: test-cpu-openmpi-py3_6-tf1_1_0-keras2_0_0-torch0_4_0-mxnet1_4_1-pyspark2_3_2:823773083436.dkr.ecr.us-east-1.amazonaws.com/buildkite:SLUG-test-cpu-openmpi-py3_6-tf1_1_0-keras2_0_0-torch0_4_0-mxnet1_4_1-pyspark2_3_2-latest - config: docker-compose.test.yml - push-retries: 5 - - ecr#v1.2.0: - login: true - timeout_in_minutes: 30 - retry: - automatic: true - agents: - queue: cpu - label: ':docker: Build test-cpu-openmpi-py2_7-tf1_6_0-keras2_1_2-torch0_4_1-mxnet1_4_1-pyspark2_3_2' plugins: - docker-compose#6b0df8a98ff97f42f4944dbb745b5b8cbf04b78c: @@ -89,12 +59,12 @@ steps: automatic: true agents: queue: cpu -- label: ':docker: Build test-cpu-gloo-py3_7-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0' +- label: ':docker: Build test-cpu-gloo-py3_7-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0' plugins: - docker-compose#6b0df8a98ff97f42f4944dbb745b5b8cbf04b78c: - build: test-cpu-gloo-py3_7-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0 + build: test-cpu-gloo-py3_7-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0 image-repository: 823773083436.dkr.ecr.us-east-1.amazonaws.com/buildkite - cache-from: test-cpu-gloo-py3_7-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0:823773083436.dkr.ecr.us-east-1.amazonaws.com/buildkite:SLUG-test-cpu-gloo-py3_7-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0-latest + cache-from: test-cpu-gloo-py3_7-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0:823773083436.dkr.ecr.us-east-1.amazonaws.com/buildkite:SLUG-test-cpu-gloo-py3_7-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0-latest config: docker-compose.test.yml push-retries: 5 - ecr#v1.2.0: @@ -134,21 +104,6 @@ steps: automatic: true agents: queue: cpu -- label: ':docker: Build test-cpu-openmpi-gloo-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0' - plugins: - - docker-compose#6b0df8a98ff97f42f4944dbb745b5b8cbf04b78c: - build: test-cpu-openmpi-gloo-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0 - image-repository: 823773083436.dkr.ecr.us-east-1.amazonaws.com/buildkite - cache-from: test-cpu-openmpi-gloo-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0:823773083436.dkr.ecr.us-east-1.amazonaws.com/buildkite:SLUG-test-cpu-openmpi-gloo-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0-latest - config: docker-compose.test.yml - push-retries: 5 - - ecr#v1.2.0: - login: true - timeout_in_minutes: 30 - retry: - automatic: true - agents: - queue: cpu - label: ':docker: Build test-cpu-openmpi-py2_7-tf2_0_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0' plugins: - docker-compose#6b0df8a98ff97f42f4944dbb745b5b8cbf04b78c: @@ -194,12 +149,12 @@ steps: automatic: true agents: queue: cpu -- label: ':docker: Build test-cpu-mpich-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0' +- label: ':docker: Build test-cpu-mpich-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0' plugins: - docker-compose#6b0df8a98ff97f42f4944dbb745b5b8cbf04b78c: - build: test-cpu-mpich-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 + build: test-cpu-mpich-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 image-repository: 823773083436.dkr.ecr.us-east-1.amazonaws.com/buildkite - cache-from: test-cpu-mpich-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0:823773083436.dkr.ecr.us-east-1.amazonaws.com/buildkite:SLUG-test-cpu-mpich-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0-latest + cache-from: test-cpu-mpich-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0:823773083436.dkr.ecr.us-east-1.amazonaws.com/buildkite:SLUG-test-cpu-mpich-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0-latest config: docker-compose.test.yml push-retries: 5 - ecr#v1.2.0: @@ -209,12 +164,12 @@ steps: automatic: true agents: queue: cpu -- label: ':docker: Build test-cpu-oneccl-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0' +- label: ':docker: Build test-cpu-oneccl-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0' plugins: - docker-compose#6b0df8a98ff97f42f4944dbb745b5b8cbf04b78c: - build: test-cpu-oneccl-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 + build: test-cpu-oneccl-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 image-repository: 823773083436.dkr.ecr.us-east-1.amazonaws.com/buildkite - cache-from: test-cpu-oneccl-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0:823773083436.dkr.ecr.us-east-1.amazonaws.com/buildkite:SLUG-test-cpu-oneccl-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0-latest + cache-from: test-cpu-oneccl-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0:823773083436.dkr.ecr.us-east-1.amazonaws.com/buildkite:SLUG-test-cpu-oneccl-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0-latest config: docker-compose.test.yml push-retries: 5 - ecr#v1.2.0: @@ -224,12 +179,12 @@ steps: automatic: true agents: queue: cpu -- label: ':docker: Build test-cpu-oneccl-ofi-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0' +- label: ':docker: Build test-cpu-oneccl-ofi-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0' plugins: - docker-compose#6b0df8a98ff97f42f4944dbb745b5b8cbf04b78c: - build: test-cpu-oneccl-ofi-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 + build: test-cpu-oneccl-ofi-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 image-repository: 823773083436.dkr.ecr.us-east-1.amazonaws.com/buildkite - cache-from: test-cpu-oneccl-ofi-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0:823773083436.dkr.ecr.us-east-1.amazonaws.com/buildkite:SLUG-test-cpu-oneccl-ofi-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0-latest + cache-from: test-cpu-oneccl-ofi-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0:823773083436.dkr.ecr.us-east-1.amazonaws.com/buildkite:SLUG-test-cpu-oneccl-ofi-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0-latest config: docker-compose.test.yml push-retries: 5 - ecr#v1.2.0: @@ -340,260 +295,8 @@ steps: agents: queue: cpu - wait -- label: ':pytest: Run PyTests (test-cpu-openmpi-py2_7-tf1_1_0-keras2_0_0-torch0_4_0-mxnet1_4_1-pyspark2_3_2)' - command: bash -c " cd /horovod/test && (echo test_*.py | sed 's/[a-z_]*tensorflow2[a-z_.]*//g' | sed 's/test_interactiverun.py//g' | sed 's/test_spark_keras.py//g' | sed 's/test_spark_torch.py//g' | sed 's/test_spark.py//g' | sed 's/test_run.py//g' | xargs -n 1 \$(cat /mpirun_command) pytest -v --capture=no) && pytest --forked -v --capture=no test_spark.py test_run.py" - plugins: - - docker-compose#v2.6.0: - run: test-cpu-openmpi-py2_7-tf1_1_0-keras2_0_0-torch0_4_0-mxnet1_4_1-pyspark2_3_2 - config: docker-compose.test.yml - pull-retries: 3 - - ecr#v1.2.0: - login: true - timeout_in_minutes: 10 - retry: - automatic: true - agents: - queue: cpu -- label: ':jupyter: Run PyTests test_interactiverun (test-cpu-openmpi-py2_7-tf1_1_0-keras2_0_0-torch0_4_0-mxnet1_4_1-pyspark2_3_2)' - command: bash -c "cd /horovod/test && pytest -v --capture=no test_interactiverun.py" - plugins: - - docker-compose#v2.6.0: - run: test-cpu-openmpi-py2_7-tf1_1_0-keras2_0_0-torch0_4_0-mxnet1_4_1-pyspark2_3_2 - config: docker-compose.test.yml - pull-retries: 3 - - ecr#v1.2.0: - login: true - timeout_in_minutes: 10 - retry: - automatic: true - agents: - queue: cpu -- label: ':tensorflow: Test TensorFlow MNIST (test-cpu-openmpi-py2_7-tf1_1_0-keras2_0_0-torch0_4_0-mxnet1_4_1-pyspark2_3_2)' - command: bash -c " \$(cat /mpirun_command) python /horovod/examples/tensorflow_mnist.py" - plugins: - - docker-compose#v2.6.0: - run: test-cpu-openmpi-py2_7-tf1_1_0-keras2_0_0-torch0_4_0-mxnet1_4_1-pyspark2_3_2 - config: docker-compose.test.yml - pull-retries: 3 - - ecr#v1.2.0: - login: true - timeout_in_minutes: 10 - retry: - automatic: true - agents: - queue: cpu -- label: ':tensorflow: Test Keras MNIST (test-cpu-openmpi-py2_7-tf1_1_0-keras2_0_0-torch0_4_0-mxnet1_4_1-pyspark2_3_2)' - command: bash -c " \$(cat /mpirun_command) python /horovod/examples/keras_mnist_advanced.py" - plugins: - - docker-compose#v2.6.0: - run: test-cpu-openmpi-py2_7-tf1_1_0-keras2_0_0-torch0_4_0-mxnet1_4_1-pyspark2_3_2 - config: docker-compose.test.yml - pull-retries: 3 - - ecr#v1.2.0: - login: true - timeout_in_minutes: 10 - retry: - automatic: true - agents: - queue: cpu -- label: ':python: Test PyTorch MNIST (test-cpu-openmpi-py2_7-tf1_1_0-keras2_0_0-torch0_4_0-mxnet1_4_1-pyspark2_3_2)' - command: bash -c " \$(cat /mpirun_command) python /horovod/examples/pytorch_mnist.py" - plugins: - - docker-compose#v2.6.0: - run: test-cpu-openmpi-py2_7-tf1_1_0-keras2_0_0-torch0_4_0-mxnet1_4_1-pyspark2_3_2 - config: docker-compose.test.yml - pull-retries: 3 - - ecr#v1.2.0: - login: true - timeout_in_minutes: 10 - retry: - automatic: true - agents: - queue: cpu -- label: ':muscle: Test MXNet MNIST (test-cpu-openmpi-py2_7-tf1_1_0-keras2_0_0-torch0_4_0-mxnet1_4_1-pyspark2_3_2)' - command: bash -c " OMP_NUM_THREADS=1 \$(cat /mpirun_command) python /horovod/examples/mxnet_mnist.py" - plugins: - - docker-compose#v2.6.0: - run: test-cpu-openmpi-py2_7-tf1_1_0-keras2_0_0-torch0_4_0-mxnet1_4_1-pyspark2_3_2 - config: docker-compose.test.yml - pull-retries: 3 - - ecr#v1.2.0: - login: true - timeout_in_minutes: 10 - retry: - automatic: true - agents: - queue: cpu -- label: ':tensorflow: Single Keras MNIST (test-cpu-openmpi-py2_7-tf1_1_0-keras2_0_0-torch0_4_0-mxnet1_4_1-pyspark2_3_2)' - command: bash -c " python /horovod/examples/keras_mnist_advanced.py --epochs 3 --batch-size 64" - plugins: - - docker-compose#v2.6.0: - run: test-cpu-openmpi-py2_7-tf1_1_0-keras2_0_0-torch0_4_0-mxnet1_4_1-pyspark2_3_2 - config: docker-compose.test.yml - pull-retries: 3 - - ecr#v1.2.0: - login: true - timeout_in_minutes: 10 - retry: - automatic: true - agents: - queue: cpu -- label: ':python: Single PyTorch MNIST (test-cpu-openmpi-py2_7-tf1_1_0-keras2_0_0-torch0_4_0-mxnet1_4_1-pyspark2_3_2)' - command: bash -c " python /horovod/examples/pytorch_mnist.py --epochs 3" - plugins: - - docker-compose#v2.6.0: - run: test-cpu-openmpi-py2_7-tf1_1_0-keras2_0_0-torch0_4_0-mxnet1_4_1-pyspark2_3_2 - config: docker-compose.test.yml - pull-retries: 3 - - ecr#v1.2.0: - login: true - timeout_in_minutes: 10 - retry: - automatic: true - agents: - queue: cpu -- label: ':muscle: Single MXNet MNIST (test-cpu-openmpi-py2_7-tf1_1_0-keras2_0_0-torch0_4_0-mxnet1_4_1-pyspark2_3_2)' - command: bash -c " python /horovod/examples/mxnet_mnist.py --epochs 3" - plugins: - - docker-compose#v2.6.0: - run: test-cpu-openmpi-py2_7-tf1_1_0-keras2_0_0-torch0_4_0-mxnet1_4_1-pyspark2_3_2 - config: docker-compose.test.yml - pull-retries: 3 - - ecr#v1.2.0: - login: true - timeout_in_minutes: 10 - retry: - automatic: true - agents: - queue: cpu -- label: ':pytest: Run PyTests (test-cpu-openmpi-py3_6-tf1_1_0-keras2_0_0-torch0_4_0-mxnet1_4_1-pyspark2_3_2)' - command: bash -c " cd /horovod/test && (echo test_*.py | sed 's/[a-z_]*tensorflow2[a-z_.]*//g' | sed 's/test_interactiverun.py//g' | sed 's/test_spark_keras.py//g' | sed 's/test_spark_torch.py//g' | sed 's/test_spark.py//g' | sed 's/test_run.py//g' | xargs -n 1 \$(cat /mpirun_command) pytest -v --capture=no) && pytest --forked -v --capture=no test_spark.py test_run.py" - plugins: - - docker-compose#v2.6.0: - run: test-cpu-openmpi-py3_6-tf1_1_0-keras2_0_0-torch0_4_0-mxnet1_4_1-pyspark2_3_2 - config: docker-compose.test.yml - pull-retries: 3 - - ecr#v1.2.0: - login: true - timeout_in_minutes: 10 - retry: - automatic: true - agents: - queue: cpu -- label: ':jupyter: Run PyTests test_interactiverun (test-cpu-openmpi-py3_6-tf1_1_0-keras2_0_0-torch0_4_0-mxnet1_4_1-pyspark2_3_2)' - command: bash -c "cd /horovod/test && pytest -v --capture=no test_interactiverun.py" - plugins: - - docker-compose#v2.6.0: - run: test-cpu-openmpi-py3_6-tf1_1_0-keras2_0_0-torch0_4_0-mxnet1_4_1-pyspark2_3_2 - config: docker-compose.test.yml - pull-retries: 3 - - ecr#v1.2.0: - login: true - timeout_in_minutes: 10 - retry: - automatic: true - agents: - queue: cpu -- label: ':tensorflow: Test TensorFlow MNIST (test-cpu-openmpi-py3_6-tf1_1_0-keras2_0_0-torch0_4_0-mxnet1_4_1-pyspark2_3_2)' - command: bash -c " \$(cat /mpirun_command) python /horovod/examples/tensorflow_mnist.py" - plugins: - - docker-compose#v2.6.0: - run: test-cpu-openmpi-py3_6-tf1_1_0-keras2_0_0-torch0_4_0-mxnet1_4_1-pyspark2_3_2 - config: docker-compose.test.yml - pull-retries: 3 - - ecr#v1.2.0: - login: true - timeout_in_minutes: 10 - retry: - automatic: true - agents: - queue: cpu -- label: ':tensorflow: Test Keras MNIST (test-cpu-openmpi-py3_6-tf1_1_0-keras2_0_0-torch0_4_0-mxnet1_4_1-pyspark2_3_2)' - command: bash -c " \$(cat /mpirun_command) python /horovod/examples/keras_mnist_advanced.py" - plugins: - - docker-compose#v2.6.0: - run: test-cpu-openmpi-py3_6-tf1_1_0-keras2_0_0-torch0_4_0-mxnet1_4_1-pyspark2_3_2 - config: docker-compose.test.yml - pull-retries: 3 - - ecr#v1.2.0: - login: true - timeout_in_minutes: 10 - retry: - automatic: true - agents: - queue: cpu -- label: ':python: Test PyTorch MNIST (test-cpu-openmpi-py3_6-tf1_1_0-keras2_0_0-torch0_4_0-mxnet1_4_1-pyspark2_3_2)' - command: bash -c " \$(cat /mpirun_command) python /horovod/examples/pytorch_mnist.py" - plugins: - - docker-compose#v2.6.0: - run: test-cpu-openmpi-py3_6-tf1_1_0-keras2_0_0-torch0_4_0-mxnet1_4_1-pyspark2_3_2 - config: docker-compose.test.yml - pull-retries: 3 - - ecr#v1.2.0: - login: true - timeout_in_minutes: 10 - retry: - automatic: true - agents: - queue: cpu -- label: ':muscle: Test MXNet MNIST (test-cpu-openmpi-py3_6-tf1_1_0-keras2_0_0-torch0_4_0-mxnet1_4_1-pyspark2_3_2)' - command: bash -c " OMP_NUM_THREADS=1 \$(cat /mpirun_command) python /horovod/examples/mxnet_mnist.py" - plugins: - - docker-compose#v2.6.0: - run: test-cpu-openmpi-py3_6-tf1_1_0-keras2_0_0-torch0_4_0-mxnet1_4_1-pyspark2_3_2 - config: docker-compose.test.yml - pull-retries: 3 - - ecr#v1.2.0: - login: true - timeout_in_minutes: 10 - retry: - automatic: true - agents: - queue: cpu -- label: ':tensorflow: Single Keras MNIST (test-cpu-openmpi-py3_6-tf1_1_0-keras2_0_0-torch0_4_0-mxnet1_4_1-pyspark2_3_2)' - command: bash -c " python /horovod/examples/keras_mnist_advanced.py --epochs 3 --batch-size 64" - plugins: - - docker-compose#v2.6.0: - run: test-cpu-openmpi-py3_6-tf1_1_0-keras2_0_0-torch0_4_0-mxnet1_4_1-pyspark2_3_2 - config: docker-compose.test.yml - pull-retries: 3 - - ecr#v1.2.0: - login: true - timeout_in_minutes: 10 - retry: - automatic: true - agents: - queue: cpu -- label: ':python: Single PyTorch MNIST (test-cpu-openmpi-py3_6-tf1_1_0-keras2_0_0-torch0_4_0-mxnet1_4_1-pyspark2_3_2)' - command: bash -c " python /horovod/examples/pytorch_mnist.py --epochs 3" - plugins: - - docker-compose#v2.6.0: - run: test-cpu-openmpi-py3_6-tf1_1_0-keras2_0_0-torch0_4_0-mxnet1_4_1-pyspark2_3_2 - config: docker-compose.test.yml - pull-retries: 3 - - ecr#v1.2.0: - login: true - timeout_in_minutes: 10 - retry: - automatic: true - agents: - queue: cpu -- label: ':muscle: Single MXNet MNIST (test-cpu-openmpi-py3_6-tf1_1_0-keras2_0_0-torch0_4_0-mxnet1_4_1-pyspark2_3_2)' - command: bash -c " python /horovod/examples/mxnet_mnist.py --epochs 3" - plugins: - - docker-compose#v2.6.0: - run: test-cpu-openmpi-py3_6-tf1_1_0-keras2_0_0-torch0_4_0-mxnet1_4_1-pyspark2_3_2 - config: docker-compose.test.yml - pull-retries: 3 - - ecr#v1.2.0: - login: true - timeout_in_minutes: 10 - retry: - automatic: true - agents: - queue: cpu - label: ':pytest: Run PyTests (test-cpu-openmpi-py2_7-tf1_6_0-keras2_1_2-torch0_4_1-mxnet1_4_1-pyspark2_3_2)' - command: bash -c " cd /horovod/test && (echo test_*.py | sed 's/[a-z_]*tensorflow2[a-z_.]*//g' | sed 's/test_interactiverun.py//g' | sed 's/test_spark_keras.py//g' | sed 's/test_spark_torch.py//g' | sed 's/test_spark.py//g' | sed 's/test_run.py//g' | xargs -n 1 \$(cat /mpirun_command) pytest -v --capture=no) && pytest --forked -v --capture=no test_spark.py test_run.py" + command: bash -c " cd /horovod/test && (echo test_*.py | sed 's/[a-z_]*tensorflow2[a-z_.]*//g' | sed 's/test_elastic[a-z_.]*//g' | sed 's/test_interactiverun.py//g' | sed 's/test_spark_keras.py//g' | sed 's/test_spark_torch.py//g' | sed 's/test_spark.py//g' | sed 's/test_run.py//g' | xargs -n 1 \$(cat /mpirun_command) pytest -v --capture=no) && pytest --forked -v --capture=no test_spark.py test_run.py" plugins: - docker-compose#v2.6.0: run: test-cpu-openmpi-py2_7-tf1_6_0-keras2_1_2-torch0_4_1-mxnet1_4_1-pyspark2_3_2 @@ -648,7 +351,7 @@ steps: automatic: true agents: queue: cpu -- label: ':python: Test PyTorch MNIST (test-cpu-openmpi-py2_7-tf1_6_0-keras2_1_2-torch0_4_1-mxnet1_4_1-pyspark2_3_2)' +- label: ':fire: Test PyTorch MNIST (test-cpu-openmpi-py2_7-tf1_6_0-keras2_1_2-torch0_4_1-mxnet1_4_1-pyspark2_3_2)' command: bash -c " \$(cat /mpirun_command) python /horovod/examples/pytorch_mnist.py" plugins: - docker-compose#v2.6.0: @@ -690,7 +393,7 @@ steps: automatic: true agents: queue: cpu -- label: ':python: Single PyTorch MNIST (test-cpu-openmpi-py2_7-tf1_6_0-keras2_1_2-torch0_4_1-mxnet1_4_1-pyspark2_3_2)' +- label: ':fire: Single PyTorch MNIST (test-cpu-openmpi-py2_7-tf1_6_0-keras2_1_2-torch0_4_1-mxnet1_4_1-pyspark2_3_2)' command: bash -c " python /horovod/examples/pytorch_mnist.py --epochs 3" plugins: - docker-compose#v2.6.0: @@ -719,7 +422,7 @@ steps: agents: queue: cpu - label: ':pytest: Run PyTests (test-cpu-openmpi-py3_6-tf1_6_0-keras2_1_2-torch0_4_1-mxnet1_4_1-pyspark2_3_2)' - command: bash -c " cd /horovod/test && (echo test_*.py | sed 's/[a-z_]*tensorflow2[a-z_.]*//g' | sed 's/test_interactiverun.py//g' | sed 's/test_spark_keras.py//g' | sed 's/test_spark_torch.py//g' | sed 's/test_spark.py//g' | sed 's/test_run.py//g' | xargs -n 1 \$(cat /mpirun_command) pytest -v --capture=no) && pytest --forked -v --capture=no test_spark.py test_run.py" + command: bash -c " cd /horovod/test && (echo test_*.py | sed 's/[a-z_]*tensorflow2[a-z_.]*//g' | sed 's/test_interactiverun.py//g' | sed 's/test_spark_keras.py//g' | sed 's/test_spark_torch.py//g' | sed 's/test_spark.py//g' | sed 's/test_run.py//g' | xargs -n 1 \$(cat /mpirun_command) pytest -v --capture=no) && pytest --forked -v --capture=no test_spark.py test_run.py" plugins: - docker-compose#v2.6.0: run: test-cpu-openmpi-py3_6-tf1_6_0-keras2_1_2-torch0_4_1-mxnet1_4_1-pyspark2_3_2 @@ -774,7 +477,7 @@ steps: automatic: true agents: queue: cpu -- label: ':python: Test PyTorch MNIST (test-cpu-openmpi-py3_6-tf1_6_0-keras2_1_2-torch0_4_1-mxnet1_4_1-pyspark2_3_2)' +- label: ':fire: Test PyTorch MNIST (test-cpu-openmpi-py3_6-tf1_6_0-keras2_1_2-torch0_4_1-mxnet1_4_1-pyspark2_3_2)' command: bash -c " \$(cat /mpirun_command) python /horovod/examples/pytorch_mnist.py" plugins: - docker-compose#v2.6.0: @@ -816,7 +519,7 @@ steps: automatic: true agents: queue: cpu -- label: ':python: Single PyTorch MNIST (test-cpu-openmpi-py3_6-tf1_6_0-keras2_1_2-torch0_4_1-mxnet1_4_1-pyspark2_3_2)' +- label: ':fire: Single PyTorch MNIST (test-cpu-openmpi-py3_6-tf1_6_0-keras2_1_2-torch0_4_1-mxnet1_4_1-pyspark2_3_2)' command: bash -c " python /horovod/examples/pytorch_mnist.py --epochs 3" plugins: - docker-compose#v2.6.0: @@ -845,7 +548,7 @@ steps: agents: queue: cpu - label: ':pytest: Run PyTests (test-cpu-gloo-py2_7-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0)' - command: bash -c "cd /horovod/test && (echo test_*.py | sed 's/[a-z_]*tensorflow2[a-z_.]*//g' | sed 's/test_interactiverun.py//g' | sed 's/test_spark_keras.py//g' | sed 's/test_spark_torch.py//g' | sed 's/test_spark.py//g' | sed 's/test_run.py//g' | xargs -n 1 horovodrun -np 2 -H localhost:2 --gloo pytest -v --capture=no) && pytest --forked -v --capture=no test_spark.py test_run.py" + command: bash -c "cd /horovod/test && (echo test_*.py | sed 's/[a-z_]*tensorflow2[a-z_.]*//g' | sed 's/test_elastic[a-z_.]*//g' | sed 's/test_interactiverun.py//g' | sed 's/test_spark_keras.py//g' | sed 's/test_spark_torch.py//g' | sed 's/test_spark.py//g' | sed 's/test_run.py//g' | xargs -n 1 horovodrun -np 2 -H localhost:2 --gloo pytest -v --capture=no) && pytest --forked -v --capture=no test_spark.py test_run.py" plugins: - docker-compose#v2.6.0: run: test-cpu-gloo-py2_7-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0 @@ -886,7 +589,7 @@ steps: automatic: true agents: queue: cpu -- label: ':python: Test PyTorch MNIST (test-cpu-gloo-py2_7-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0)' +- label: ':fire: Test PyTorch MNIST (test-cpu-gloo-py2_7-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0)' command: horovodrun -np 2 -H localhost:2 --gloo python /horovod/examples/pytorch_mnist.py plugins: - docker-compose#v2.6.0: @@ -942,7 +645,7 @@ steps: automatic: true agents: queue: cpu -- label: ':python: Single PyTorch MNIST (test-cpu-gloo-py2_7-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0)' +- label: ':fire: Single PyTorch MNIST (test-cpu-gloo-py2_7-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0)' command: bash -c " python /horovod/examples/pytorch_mnist.py --epochs 3" plugins: - docker-compose#v2.6.0: @@ -971,7 +674,7 @@ steps: agents: queue: cpu - label: ':pytest: Run PyTests (test-cpu-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0)' - command: bash -c "cd /horovod/test && (echo test_*.py | sed 's/[a-z_]*tensorflow2[a-z_.]*//g' | sed 's/test_interactiverun.py//g' | sed 's/test_spark_keras.py//g' | sed 's/test_spark_torch.py//g' | sed 's/test_spark.py//g' | sed 's/test_run.py//g' | xargs -n 1 horovodrun -np 2 -H localhost:2 --gloo pytest -v --capture=no) && pytest --forked -v --capture=no test_spark.py test_run.py" + command: bash -c "cd /horovod/test && (echo test_*.py | sed 's/[a-z_]*tensorflow2[a-z_.]*//g' | sed 's/test_interactiverun.py//g' | sed 's/test_spark_keras.py//g' | sed 's/test_spark_torch.py//g' | sed 's/test_spark.py//g' | sed 's/test_run.py//g' | xargs -n 1 horovodrun -np 2 -H localhost:2 --gloo pytest -v --capture=no) && pytest --forked -v --capture=no test_spark.py test_run.py" plugins: - docker-compose#v2.6.0: run: test-cpu-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0 @@ -1012,7 +715,7 @@ steps: automatic: true agents: queue: cpu -- label: ':python: Test PyTorch MNIST (test-cpu-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0)' +- label: ':fire: Test PyTorch MNIST (test-cpu-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0)' command: horovodrun -np 2 -H localhost:2 --gloo python /horovod/examples/pytorch_mnist.py plugins: - docker-compose#v2.6.0: @@ -1040,8 +743,8 @@ steps: automatic: true agents: queue: cpu -- label: ':spark: Spark Keras Rossmann Run (test-cpu-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0)' - command: bash -c "OMP_NUM_THREADS=1 python /horovod/examples/keras_spark_rossmann_run.py --num-proc 2 --data-dir file:///data --epochs 3 --sample-rate 0.01" +- label: ':factory: Elastic Tests (test-cpu-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0)' + command: bash -c "cd /horovod/test/integration && pytest -v --log-cli-level 10 --capture=no test_elastic_torch.py test_elastic_tensorflow.py" plugins: - docker-compose#v2.6.0: run: test-cpu-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0 @@ -1054,8 +757,8 @@ steps: automatic: true agents: queue: cpu -- label: ':spark: Spark Keras Rossmann Estimator (test-cpu-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0)' - command: bash -c "OMP_NUM_THREADS=1 python /horovod/examples/keras_spark_rossmann_estimator.py --num-proc 2 --work-dir /work --data-dir file:///data --epochs 3 --sample-rate 0.01" +- label: ':spark: Spark Keras Rossmann Run (test-cpu-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0)' + command: bash -c "OMP_NUM_THREADS=1 python /horovod/examples/keras_spark_rossmann_run.py --num-proc 2 --data-dir file:///data --epochs 3 --sample-rate 0.01" plugins: - docker-compose#v2.6.0: run: test-cpu-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0 @@ -1068,179 +771,11 @@ steps: automatic: true agents: queue: cpu -- label: ':spark: Spark Keras MNIST (test-cpu-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0)' - command: bash -c "OMP_NUM_THREADS=1 python /horovod/examples/keras_spark_mnist.py --num-proc 2 --work-dir /work --data-dir /data --epochs 3" - plugins: - - docker-compose#v2.6.0: - run: test-cpu-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0 - config: docker-compose.test.yml - pull-retries: 3 - - ecr#v1.2.0: - login: true - timeout_in_minutes: 10 - retry: - automatic: true - agents: - queue: cpu -- label: ':spark: PyTests Spark Estimators (test-cpu-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0)' - command: bash -c "cd /horovod/test && pytest --forked -v --capture=no test_spark_keras.py test_spark_torch.py" - plugins: - - docker-compose#v2.6.0: - run: test-cpu-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0 - config: docker-compose.test.yml - pull-retries: 3 - - ecr#v1.2.0: - login: true - timeout_in_minutes: 10 - retry: - automatic: true - agents: - queue: cpu -- label: ':spark: Spark Torch MNIST (test-cpu-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0)' - command: bash -c "OMP_NUM_THREADS=1 python /horovod/examples/pytorch_spark_mnist.py --num-proc 2 --work-dir /work --data-dir /data --epochs 3" - plugins: - - docker-compose#v2.6.0: - run: test-cpu-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0 - config: docker-compose.test.yml - pull-retries: 3 - - ecr#v1.2.0: - login: true - timeout_in_minutes: 10 - retry: - automatic: true - agents: - queue: cpu -- label: ':tensorflow: Single Keras MNIST (test-cpu-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0)' - command: bash -c " python /horovod/examples/keras_mnist_advanced.py --epochs 3 --batch-size 64" - plugins: - - docker-compose#v2.6.0: - run: test-cpu-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0 - config: docker-compose.test.yml - pull-retries: 3 - - ecr#v1.2.0: - login: true - timeout_in_minutes: 10 - retry: - automatic: true - agents: - queue: cpu -- label: ':python: Single PyTorch MNIST (test-cpu-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0)' - command: bash -c " python /horovod/examples/pytorch_mnist.py --epochs 3" - plugins: - - docker-compose#v2.6.0: - run: test-cpu-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0 - config: docker-compose.test.yml - pull-retries: 3 - - ecr#v1.2.0: - login: true - timeout_in_minutes: 10 - retry: - automatic: true - agents: - queue: cpu -- label: ':muscle: Single MXNet MNIST (test-cpu-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0)' - command: bash -c " python /horovod/examples/mxnet_mnist.py --epochs 3" - plugins: - - docker-compose#v2.6.0: - run: test-cpu-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0 - config: docker-compose.test.yml - pull-retries: 3 - - ecr#v1.2.0: - login: true - timeout_in_minutes: 10 - retry: - automatic: true - agents: - queue: cpu -- label: ':pytest: Run PyTests (test-cpu-gloo-py3_7-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0)' - command: bash -c "cd /horovod/test && (echo test_*.py | sed 's/[a-z_]*tensorflow2[a-z_.]*//g' | sed 's/test_interactiverun.py//g' | sed 's/test_spark_keras.py//g' | sed 's/test_spark_torch.py//g' | sed 's/test_spark.py//g' | sed 's/test_run.py//g' | xargs -n 1 horovodrun -np 2 -H localhost:2 --gloo pytest -v --capture=no) && pytest --forked -v --capture=no test_spark.py test_run.py" - plugins: - - docker-compose#v2.6.0: - run: test-cpu-gloo-py3_7-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0 - config: docker-compose.test.yml - pull-retries: 3 - - ecr#v1.2.0: - login: true - timeout_in_minutes: 10 - retry: - automatic: true - agents: - queue: cpu -- label: ':tensorflow: Test TensorFlow MNIST (test-cpu-gloo-py3_7-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0)' - command: horovodrun -np 2 -H localhost:2 --gloo python /horovod/examples/tensorflow_mnist.py - plugins: - - docker-compose#v2.6.0: - run: test-cpu-gloo-py3_7-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0 - config: docker-compose.test.yml - pull-retries: 3 - - ecr#v1.2.0: - login: true - timeout_in_minutes: 10 - retry: - automatic: true - agents: - queue: cpu -- label: ':tensorflow: Test Keras MNIST (test-cpu-gloo-py3_7-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0)' - command: horovodrun -np 2 -H localhost:2 --gloo python /horovod/examples/keras_mnist_advanced.py - plugins: - - docker-compose#v2.6.0: - run: test-cpu-gloo-py3_7-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0 - config: docker-compose.test.yml - pull-retries: 3 - - ecr#v1.2.0: - login: true - timeout_in_minutes: 10 - retry: - automatic: true - agents: - queue: cpu -- label: ':python: Test PyTorch MNIST (test-cpu-gloo-py3_7-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0)' - command: horovodrun -np 2 -H localhost:2 --gloo python /horovod/examples/pytorch_mnist.py - plugins: - - docker-compose#v2.6.0: - run: test-cpu-gloo-py3_7-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0 - config: docker-compose.test.yml - pull-retries: 3 - - ecr#v1.2.0: - login: true - timeout_in_minutes: 10 - retry: - automatic: true - agents: - queue: cpu -- label: ':muscle: Test MXNet MNIST (test-cpu-gloo-py3_7-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0)' - command: horovodrun -np 2 -H localhost:2 --gloo python /horovod/examples/mxnet_mnist.py - plugins: - - docker-compose#v2.6.0: - run: test-cpu-gloo-py3_7-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0 - config: docker-compose.test.yml - pull-retries: 3 - - ecr#v1.2.0: - login: true - timeout_in_minutes: 10 - retry: - automatic: true - agents: - queue: cpu -- label: ':spark: Spark Keras Rossmann Run (test-cpu-gloo-py3_7-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0)' - command: bash -c "OMP_NUM_THREADS=1 python /horovod/examples/keras_spark_rossmann_run.py --num-proc 2 --data-dir file:///data --epochs 3 --sample-rate 0.01" - plugins: - - docker-compose#v2.6.0: - run: test-cpu-gloo-py3_7-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0 - config: docker-compose.test.yml - pull-retries: 3 - - ecr#v1.2.0: - login: true - timeout_in_minutes: 10 - retry: - automatic: true - agents: - queue: cpu -- label: ':spark: Spark Keras Rossmann Estimator (test-cpu-gloo-py3_7-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0)' +- label: ':spark: Spark Keras Rossmann Estimator (test-cpu-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0)' command: bash -c "OMP_NUM_THREADS=1 python /horovod/examples/keras_spark_rossmann_estimator.py --num-proc 2 --work-dir /work --data-dir file:///data --epochs 3 --sample-rate 0.01" plugins: - docker-compose#v2.6.0: - run: test-cpu-gloo-py3_7-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0 + run: test-cpu-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -1250,235 +785,11 @@ steps: automatic: true agents: queue: cpu -- label: ':spark: Spark Keras MNIST (test-cpu-gloo-py3_7-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0)' +- label: ':spark: Spark Keras MNIST (test-cpu-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0)' command: bash -c "OMP_NUM_THREADS=1 python /horovod/examples/keras_spark_mnist.py --num-proc 2 --work-dir /work --data-dir /data --epochs 3" plugins: - docker-compose#v2.6.0: - run: test-cpu-gloo-py3_7-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0 - config: docker-compose.test.yml - pull-retries: 3 - - ecr#v1.2.0: - login: true - timeout_in_minutes: 10 - retry: - automatic: true - agents: - queue: cpu -- label: ':spark: PyTests Spark Estimators (test-cpu-gloo-py3_7-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0)' - command: bash -c "cd /horovod/test && pytest --forked -v --capture=no test_spark_keras.py test_spark_torch.py" - plugins: - - docker-compose#v2.6.0: - run: test-cpu-gloo-py3_7-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0 - config: docker-compose.test.yml - pull-retries: 3 - - ecr#v1.2.0: - login: true - timeout_in_minutes: 10 - retry: - automatic: true - agents: - queue: cpu -- label: ':spark: Spark Torch MNIST (test-cpu-gloo-py3_7-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0)' - command: bash -c "OMP_NUM_THREADS=1 python /horovod/examples/pytorch_spark_mnist.py --num-proc 2 --work-dir /work --data-dir /data --epochs 3" - plugins: - - docker-compose#v2.6.0: - run: test-cpu-gloo-py3_7-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0 - config: docker-compose.test.yml - pull-retries: 3 - - ecr#v1.2.0: - login: true - timeout_in_minutes: 10 - retry: - automatic: true - agents: - queue: cpu -- label: ':tensorflow: Single Keras MNIST (test-cpu-gloo-py3_7-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0)' - command: bash -c " python /horovod/examples/keras_mnist_advanced.py --epochs 3 --batch-size 64" - plugins: - - docker-compose#v2.6.0: - run: test-cpu-gloo-py3_7-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0 - config: docker-compose.test.yml - pull-retries: 3 - - ecr#v1.2.0: - login: true - timeout_in_minutes: 10 - retry: - automatic: true - agents: - queue: cpu -- label: ':python: Single PyTorch MNIST (test-cpu-gloo-py3_7-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0)' - command: bash -c " python /horovod/examples/pytorch_mnist.py --epochs 3" - plugins: - - docker-compose#v2.6.0: - run: test-cpu-gloo-py3_7-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0 - config: docker-compose.test.yml - pull-retries: 3 - - ecr#v1.2.0: - login: true - timeout_in_minutes: 10 - retry: - automatic: true - agents: - queue: cpu -- label: ':muscle: Single MXNet MNIST (test-cpu-gloo-py3_7-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0)' - command: bash -c " python /horovod/examples/mxnet_mnist.py --epochs 3" - plugins: - - docker-compose#v2.6.0: - run: test-cpu-gloo-py3_7-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0 - config: docker-compose.test.yml - pull-retries: 3 - - ecr#v1.2.0: - login: true - timeout_in_minutes: 10 - retry: - automatic: true - agents: - queue: cpu -- label: ':pytest: Run PyTests (test-cpu-gloo-py3_8-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0)' - command: bash -c "cd /horovod/test && (echo test_*.py | sed 's/test_keras.py//g' | sed 's/test_tensorflow_keras.py//g' | sed 's/test_interactiverun.py//g' | sed 's/test_spark_keras.py//g' | sed 's/test_spark_torch.py//g' | sed 's/test_spark.py//g' | sed 's/test_run.py//g' | xargs -n 1 horovodrun -np 2 -H localhost:2 --gloo pytest -v --capture=no) && pytest --forked -v --capture=no test_run.py" - plugins: - - docker-compose#v2.6.0: - run: test-cpu-gloo-py3_8-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0 - config: docker-compose.test.yml - pull-retries: 3 - - ecr#v1.2.0: - login: true - timeout_in_minutes: 10 - retry: - automatic: true - agents: - queue: cpu -- label: ':tensorflow: Test TensorFlow 2.0 MNIST (test-cpu-gloo-py3_8-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0)' - command: horovodrun -np 2 -H localhost:2 --gloo python /horovod/examples/tensorflow2_mnist.py - plugins: - - docker-compose#v2.6.0: - run: test-cpu-gloo-py3_8-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0 - config: docker-compose.test.yml - pull-retries: 3 - - ecr#v1.2.0: - login: true - timeout_in_minutes: 10 - retry: - automatic: true - agents: - queue: cpu -- label: ':tensorflow: Test TensorFlow 2.0 Keras MNIST (test-cpu-gloo-py3_8-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0)' - command: horovodrun -np 2 -H localhost:2 --gloo python /horovod/examples/tensorflow2_keras_mnist.py - plugins: - - docker-compose#v2.6.0: - run: test-cpu-gloo-py3_8-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0 - config: docker-compose.test.yml - pull-retries: 3 - - ecr#v1.2.0: - login: true - timeout_in_minutes: 10 - retry: - automatic: true - agents: - queue: cpu -- label: ':python: Test PyTorch MNIST (test-cpu-gloo-py3_8-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0)' - command: horovodrun -np 2 -H localhost:2 --gloo python /horovod/examples/pytorch_mnist.py - plugins: - - docker-compose#v2.6.0: - run: test-cpu-gloo-py3_8-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0 - config: docker-compose.test.yml - pull-retries: 3 - - ecr#v1.2.0: - login: true - timeout_in_minutes: 10 - retry: - automatic: true - agents: - queue: cpu -- label: ':muscle: Test MXNet MNIST (test-cpu-gloo-py3_8-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0)' - command: horovodrun -np 2 -H localhost:2 --gloo python /horovod/examples/mxnet_mnist.py - plugins: - - docker-compose#v2.6.0: - run: test-cpu-gloo-py3_8-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0 - config: docker-compose.test.yml - pull-retries: 3 - - ecr#v1.2.0: - login: true - timeout_in_minutes: 10 - retry: - automatic: true - agents: - queue: cpu -- label: ':python: Single PyTorch MNIST (test-cpu-gloo-py3_8-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0)' - command: bash -c " python /horovod/examples/pytorch_mnist.py --epochs 3" - plugins: - - docker-compose#v2.6.0: - run: test-cpu-gloo-py3_8-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0 - config: docker-compose.test.yml - pull-retries: 3 - - ecr#v1.2.0: - login: true - timeout_in_minutes: 10 - retry: - automatic: true - agents: - queue: cpu -- label: ':muscle: Single MXNet MNIST (test-cpu-gloo-py3_8-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0)' - command: bash -c " python /horovod/examples/mxnet_mnist.py --epochs 3" - plugins: - - docker-compose#v2.6.0: - run: test-cpu-gloo-py3_8-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0 - config: docker-compose.test.yml - pull-retries: 3 - - ecr#v1.2.0: - login: true - timeout_in_minutes: 10 - retry: - automatic: true - agents: - queue: cpu -- label: ':pytest: Run PyTests (test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0)' - command: bash -c " cd /horovod/test && (echo test_*.py | sed 's/[a-z_]*tensorflow2[a-z_.]*//g' | sed 's/test_interactiverun.py//g' | sed 's/test_spark_keras.py//g' | sed 's/test_spark_torch.py//g' | sed 's/test_spark.py//g' | sed 's/test_run.py//g' | xargs -n 1 \$(cat /mpirun_command) pytest -v --capture=no) && pytest --forked -v --capture=no test_spark.py test_run.py" - plugins: - - docker-compose#v2.6.0: - run: test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0 - config: docker-compose.test.yml - pull-retries: 3 - - ecr#v1.2.0: - login: true - timeout_in_minutes: 10 - retry: - automatic: true - agents: - queue: cpu -- label: ':jupyter: Run PyTests test_interactiverun (test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0)' - command: bash -c "cd /horovod/test && pytest -v --capture=no test_interactiverun.py" - plugins: - - docker-compose#v2.6.0: - run: test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0 - config: docker-compose.test.yml - pull-retries: 3 - - ecr#v1.2.0: - login: true - timeout_in_minutes: 10 - retry: - automatic: true - agents: - queue: cpu -- label: ':tensorflow: Test TensorFlow MNIST (test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0)' - command: bash -c " \$(cat /mpirun_command) python /horovod/examples/tensorflow_mnist.py" - plugins: - - docker-compose#v2.6.0: - run: test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0 - config: docker-compose.test.yml - pull-retries: 3 - - ecr#v1.2.0: - login: true - timeout_in_minutes: 10 - retry: - automatic: true - agents: - queue: cpu -- label: ':tensorflow: Test TensorFlow Eager MNIST (test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0)' - command: bash -c " \$(cat /mpirun_command) python /horovod/examples/tensorflow_mnist_eager.py" - plugins: - - docker-compose#v2.6.0: - run: test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0 + run: test-cpu-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -1487,12 +798,12 @@ steps: retry: automatic: true agents: - queue: cpu -- label: ':tensorflow: Test Keras MNIST (test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0)' - command: bash -c " \$(cat /mpirun_command) python /horovod/examples/keras_mnist_advanced.py" + queue: cpu +- label: ':spark: PyTests Spark Estimators (test-cpu-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0)' + command: bash -c "cd /horovod/test && pytest --forked -v --capture=no test_spark_keras.py test_spark_torch.py" plugins: - docker-compose#v2.6.0: - run: test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0 + run: test-cpu-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -1502,11 +813,11 @@ steps: automatic: true agents: queue: cpu -- label: ':python: Test PyTorch MNIST (test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0)' - command: bash -c " \$(cat /mpirun_command) python /horovod/examples/pytorch_mnist.py" +- label: ':spark: Spark Torch MNIST (test-cpu-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0)' + command: bash -c "OMP_NUM_THREADS=1 python /horovod/examples/pytorch_spark_mnist.py --num-proc 2 --work-dir /work --data-dir /data --epochs 3" plugins: - docker-compose#v2.6.0: - run: test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0 + run: test-cpu-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -1516,11 +827,11 @@ steps: automatic: true agents: queue: cpu -- label: ':muscle: Test MXNet MNIST (test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0)' - command: bash -c " OMP_NUM_THREADS=1 \$(cat /mpirun_command) python /horovod/examples/mxnet_mnist.py" +- label: ':tensorflow: Single Keras MNIST (test-cpu-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0)' + command: bash -c " python /horovod/examples/keras_mnist_advanced.py --epochs 3 --batch-size 64" plugins: - docker-compose#v2.6.0: - run: test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0 + run: test-cpu-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -1530,11 +841,11 @@ steps: automatic: true agents: queue: cpu -- label: ':muscle: Test Stall (test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0)' - command: bash -c " \$(cat /mpirun_command) python /horovod/test/test_stall.py" +- label: ':fire: Single PyTorch MNIST (test-cpu-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0)' + command: bash -c " python /horovod/examples/pytorch_mnist.py --epochs 3" plugins: - docker-compose#v2.6.0: - run: test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0 + run: test-cpu-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -1544,11 +855,11 @@ steps: automatic: true agents: queue: cpu -- label: ':terminal: Test Horovodrun (test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0)' - command: horovodrun -np 2 -H localhost:2 python /horovod/examples/tensorflow_mnist.py +- label: ':muscle: Single MXNet MNIST (test-cpu-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0)' + command: bash -c " python /horovod/examples/mxnet_mnist.py --epochs 3" plugins: - docker-compose#v2.6.0: - run: test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0 + run: test-cpu-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_4_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -1558,11 +869,11 @@ steps: automatic: true agents: queue: cpu -- label: ':terminal: Test Horovodrun (test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0)' - command: echo 'localhost slots=2' > hostfile +- label: ':pytest: Run PyTests (test-cpu-gloo-py3_7-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0)' + command: bash -c "cd /horovod/test && (echo test_*.py | sed 's/test_keras.py//g' | sed 's/test_tensorflow_keras.py//g' | sed 's/test_interactiverun.py//g' | sed 's/test_spark_keras.py//g' | sed 's/test_spark_torch.py//g' | sed 's/test_spark.py//g' | sed 's/test_run.py//g' | xargs -n 1 horovodrun -np 2 -H localhost:2 --gloo pytest -v --capture=no) && pytest --forked -v --capture=no test_spark.py test_run.py" plugins: - docker-compose#v2.6.0: - run: test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0 + run: test-cpu-gloo-py3_7-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -1572,11 +883,11 @@ steps: automatic: true agents: queue: cpu -- label: ':spark: Spark Keras Rossmann Run (test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0)' - command: bash -c "OMP_NUM_THREADS=1 python /horovod/examples/keras_spark_rossmann_run.py --num-proc 2 --data-dir file:///data --epochs 3 --sample-rate 0.01" +- label: ':tensorflow: Test TensorFlow 2.0 MNIST (test-cpu-gloo-py3_7-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0)' + command: horovodrun -np 2 -H localhost:2 --gloo python /horovod/examples/tensorflow2_mnist.py plugins: - docker-compose#v2.6.0: - run: test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0 + run: test-cpu-gloo-py3_7-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -1586,11 +897,11 @@ steps: automatic: true agents: queue: cpu -- label: ':spark: Spark Keras Rossmann Estimator (test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0)' - command: bash -c "OMP_NUM_THREADS=1 python /horovod/examples/keras_spark_rossmann_estimator.py --num-proc 2 --work-dir /work --data-dir file:///data --epochs 3 --sample-rate 0.01" +- label: ':tensorflow: Test TensorFlow 2.0 Keras MNIST (test-cpu-gloo-py3_7-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0)' + command: horovodrun -np 2 -H localhost:2 --gloo python /horovod/examples/tensorflow2_keras_mnist.py plugins: - docker-compose#v2.6.0: - run: test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0 + run: test-cpu-gloo-py3_7-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -1600,11 +911,11 @@ steps: automatic: true agents: queue: cpu -- label: ':spark: Spark Keras MNIST (test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0)' - command: bash -c "OMP_NUM_THREADS=1 python /horovod/examples/keras_spark_mnist.py --num-proc 2 --work-dir /work --data-dir /data --epochs 3" +- label: ':fire: Test PyTorch MNIST (test-cpu-gloo-py3_7-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0)' + command: horovodrun -np 2 -H localhost:2 --gloo python /horovod/examples/pytorch_mnist.py plugins: - docker-compose#v2.6.0: - run: test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0 + run: test-cpu-gloo-py3_7-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -1614,11 +925,11 @@ steps: automatic: true agents: queue: cpu -- label: ':spark: PyTests Spark Estimators (test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0)' - command: bash -c "cd /horovod/test && pytest --forked -v --capture=no test_spark_keras.py test_spark_torch.py" +- label: ':muscle: Test MXNet MNIST (test-cpu-gloo-py3_7-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0)' + command: horovodrun -np 2 -H localhost:2 --gloo python /horovod/examples/mxnet_mnist.py plugins: - docker-compose#v2.6.0: - run: test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0 + run: test-cpu-gloo-py3_7-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -1628,11 +939,11 @@ steps: automatic: true agents: queue: cpu -- label: ':spark: Spark Torch MNIST (test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0)' - command: bash -c "OMP_NUM_THREADS=1 python /horovod/examples/pytorch_spark_mnist.py --num-proc 2 --work-dir /work --data-dir /data --epochs 3" +- label: ':factory: Elastic Tests (test-cpu-gloo-py3_7-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0)' + command: bash -c "cd /horovod/test/integration && pytest -v --log-cli-level 10 --capture=no test_elastic_torch.py test_elastic_tensorflow2.py" plugins: - docker-compose#v2.6.0: - run: test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0 + run: test-cpu-gloo-py3_7-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -1642,11 +953,11 @@ steps: automatic: true agents: queue: cpu -- label: ':tensorflow: Single Keras MNIST (test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0)' - command: bash -c " python /horovod/examples/keras_mnist_advanced.py --epochs 3 --batch-size 64" +- label: ':spark: Spark Torch MNIST (test-cpu-gloo-py3_7-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0)' + command: bash -c "OMP_NUM_THREADS=1 python /horovod/examples/pytorch_spark_mnist.py --num-proc 2 --work-dir /work --data-dir /data --epochs 3" plugins: - docker-compose#v2.6.0: - run: test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0 + run: test-cpu-gloo-py3_7-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -1656,11 +967,11 @@ steps: automatic: true agents: queue: cpu -- label: ':python: Single PyTorch MNIST (test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0)' +- label: ':fire: Single PyTorch MNIST (test-cpu-gloo-py3_7-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0)' command: bash -c " python /horovod/examples/pytorch_mnist.py --epochs 3" plugins: - docker-compose#v2.6.0: - run: test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0 + run: test-cpu-gloo-py3_7-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -1670,11 +981,11 @@ steps: automatic: true agents: queue: cpu -- label: ':muscle: Single MXNet MNIST (test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0)' +- label: ':muscle: Single MXNet MNIST (test-cpu-gloo-py3_7-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0)' command: bash -c " python /horovod/examples/mxnet_mnist.py --epochs 3" plugins: - docker-compose#v2.6.0: - run: test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0 + run: test-cpu-gloo-py3_7-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -1684,11 +995,11 @@ steps: automatic: true agents: queue: cpu -- label: ':pytest: Run PyTests (test-cpu-openmpi-gloo-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0)' - command: bash -c "cd /horovod/test && (echo test_*.py | sed 's/[a-z_]*tensorflow2[a-z_.]*//g' | sed 's/test_interactiverun.py//g' | sed 's/test_spark_keras.py//g' | sed 's/test_spark_torch.py//g' | sed 's/test_spark.py//g' | sed 's/test_run.py//g' | xargs -n 1 horovodrun -np 2 -H localhost:2 --gloo pytest -v --capture=no) && pytest --forked -v --capture=no test_spark.py test_run.py" +- label: ':pytest: Run PyTests (test-cpu-gloo-py3_8-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0)' + command: bash -c "cd /horovod/test && (echo test_*.py | sed 's/test_keras.py//g' | sed 's/test_tensorflow_keras.py//g' | sed 's/test_interactiverun.py//g' | sed 's/test_spark_keras.py//g' | sed 's/test_spark_torch.py//g' | sed 's/test_spark.py//g' | sed 's/test_run.py//g' | xargs -n 1 horovodrun -np 2 -H localhost:2 --gloo pytest -v --capture=no) && pytest --forked -v --capture=no test_run.py" plugins: - docker-compose#v2.6.0: - run: test-cpu-openmpi-gloo-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0 + run: test-cpu-gloo-py3_8-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -1698,11 +1009,11 @@ steps: automatic: true agents: queue: cpu -- label: ':tensorflow: Test TensorFlow MNIST (test-cpu-openmpi-gloo-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0)' - command: horovodrun -np 2 -H localhost:2 --gloo python /horovod/examples/tensorflow_mnist.py +- label: ':tensorflow: Test TensorFlow 2.0 MNIST (test-cpu-gloo-py3_8-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0)' + command: horovodrun -np 2 -H localhost:2 --gloo python /horovod/examples/tensorflow2_mnist.py plugins: - docker-compose#v2.6.0: - run: test-cpu-openmpi-gloo-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0 + run: test-cpu-gloo-py3_8-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -1712,11 +1023,11 @@ steps: automatic: true agents: queue: cpu -- label: ':tensorflow: Test Keras MNIST (test-cpu-openmpi-gloo-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0)' - command: horovodrun -np 2 -H localhost:2 --gloo python /horovod/examples/keras_mnist_advanced.py +- label: ':tensorflow: Test TensorFlow 2.0 Keras MNIST (test-cpu-gloo-py3_8-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0)' + command: horovodrun -np 2 -H localhost:2 --gloo python /horovod/examples/tensorflow2_keras_mnist.py plugins: - docker-compose#v2.6.0: - run: test-cpu-openmpi-gloo-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0 + run: test-cpu-gloo-py3_8-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -1726,11 +1037,11 @@ steps: automatic: true agents: queue: cpu -- label: ':python: Test PyTorch MNIST (test-cpu-openmpi-gloo-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0)' +- label: ':fire: Test PyTorch MNIST (test-cpu-gloo-py3_8-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0)' command: horovodrun -np 2 -H localhost:2 --gloo python /horovod/examples/pytorch_mnist.py plugins: - docker-compose#v2.6.0: - run: test-cpu-openmpi-gloo-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0 + run: test-cpu-gloo-py3_8-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -1740,11 +1051,11 @@ steps: automatic: true agents: queue: cpu -- label: ':muscle: Test MXNet MNIST (test-cpu-openmpi-gloo-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0)' +- label: ':muscle: Test MXNet MNIST (test-cpu-gloo-py3_8-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0)' command: horovodrun -np 2 -H localhost:2 --gloo python /horovod/examples/mxnet_mnist.py plugins: - docker-compose#v2.6.0: - run: test-cpu-openmpi-gloo-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0 + run: test-cpu-gloo-py3_8-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -1754,11 +1065,11 @@ steps: automatic: true agents: queue: cpu -- label: ':pytest: Run PyTests (test-cpu-openmpi-gloo-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0)' - command: bash -c " cd /horovod/test && (echo test_*.py | sed 's/[a-z_]*tensorflow2[a-z_.]*//g' | sed 's/test_interactiverun.py//g' | sed 's/test_spark_keras.py//g' | sed 's/test_spark_torch.py//g' | sed 's/test_spark.py//g' | sed 's/test_run.py//g' | xargs -n 1 \$(cat /mpirun_command) pytest -v --capture=no) && pytest --forked -v --capture=no test_spark.py test_run.py" +- label: ':factory: Elastic Tests (test-cpu-gloo-py3_8-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0)' + command: bash -c "cd /horovod/test/integration && pytest -v --log-cli-level 10 --capture=no test_elastic_torch.py test_elastic_tensorflow2.py" plugins: - docker-compose#v2.6.0: - run: test-cpu-openmpi-gloo-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0 + run: test-cpu-gloo-py3_8-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -1768,11 +1079,11 @@ steps: automatic: true agents: queue: cpu -- label: ':jupyter: Run PyTests test_interactiverun (test-cpu-openmpi-gloo-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0)' - command: bash -c "cd /horovod/test && pytest -v --capture=no test_interactiverun.py" +- label: ':fire: Single PyTorch MNIST (test-cpu-gloo-py3_8-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0)' + command: bash -c " python /horovod/examples/pytorch_mnist.py --epochs 3" plugins: - docker-compose#v2.6.0: - run: test-cpu-openmpi-gloo-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0 + run: test-cpu-gloo-py3_8-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -1782,11 +1093,11 @@ steps: automatic: true agents: queue: cpu -- label: ':tensorflow: Test TensorFlow MNIST (test-cpu-openmpi-gloo-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0)' - command: bash -c " \$(cat /mpirun_command) python /horovod/examples/tensorflow_mnist.py" +- label: ':muscle: Single MXNet MNIST (test-cpu-gloo-py3_8-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0)' + command: bash -c " python /horovod/examples/mxnet_mnist.py --epochs 3" plugins: - docker-compose#v2.6.0: - run: test-cpu-openmpi-gloo-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0 + run: test-cpu-gloo-py3_8-tf2_2_0-keras2_3_1-torch1_5_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -1796,11 +1107,11 @@ steps: automatic: true agents: queue: cpu -- label: ':tensorflow: Test TensorFlow Eager MNIST (test-cpu-openmpi-gloo-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0)' - command: bash -c " \$(cat /mpirun_command) python /horovod/examples/tensorflow_mnist_eager.py" +- label: ':pytest: Run PyTests (test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0)' + command: bash -c " cd /horovod/test && (echo test_*.py | sed 's/[a-z_]*tensorflow2[a-z_.]*//g' | sed 's/test_interactiverun.py//g' | sed 's/test_spark_keras.py//g' | sed 's/test_spark_torch.py//g' | sed 's/test_spark.py//g' | sed 's/test_run.py//g' | xargs -n 1 \$(cat /mpirun_command) pytest -v --capture=no) && pytest --forked -v --capture=no test_spark.py test_run.py" plugins: - docker-compose#v2.6.0: - run: test-cpu-openmpi-gloo-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0 + run: test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -1810,11 +1121,11 @@ steps: automatic: true agents: queue: cpu -- label: ':tensorflow: Test Keras MNIST (test-cpu-openmpi-gloo-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0)' - command: bash -c " \$(cat /mpirun_command) python /horovod/examples/keras_mnist_advanced.py" +- label: ':jupyter: Run PyTests test_interactiverun (test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0)' + command: bash -c "cd /horovod/test && pytest -v --capture=no test_interactiverun.py" plugins: - docker-compose#v2.6.0: - run: test-cpu-openmpi-gloo-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0 + run: test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -1824,11 +1135,11 @@ steps: automatic: true agents: queue: cpu -- label: ':python: Test PyTorch MNIST (test-cpu-openmpi-gloo-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0)' - command: bash -c " \$(cat /mpirun_command) python /horovod/examples/pytorch_mnist.py" +- label: ':tensorflow: Test TensorFlow MNIST (test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0)' + command: bash -c " \$(cat /mpirun_command) python /horovod/examples/tensorflow_mnist.py" plugins: - docker-compose#v2.6.0: - run: test-cpu-openmpi-gloo-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0 + run: test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -1838,11 +1149,11 @@ steps: automatic: true agents: queue: cpu -- label: ':muscle: Test MXNet MNIST (test-cpu-openmpi-gloo-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0)' - command: bash -c " OMP_NUM_THREADS=1 \$(cat /mpirun_command) python /horovod/examples/mxnet_mnist.py" +- label: ':tensorflow: Test TensorFlow Eager MNIST (test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0)' + command: bash -c " \$(cat /mpirun_command) python /horovod/examples/tensorflow_mnist_eager.py" plugins: - docker-compose#v2.6.0: - run: test-cpu-openmpi-gloo-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0 + run: test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -1852,11 +1163,11 @@ steps: automatic: true agents: queue: cpu -- label: ':muscle: Test Stall (test-cpu-openmpi-gloo-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0)' - command: bash -c " \$(cat /mpirun_command) python /horovod/test/test_stall.py" +- label: ':tensorflow: Test Keras MNIST (test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0)' + command: bash -c " \$(cat /mpirun_command) python /horovod/examples/keras_mnist_advanced.py" plugins: - docker-compose#v2.6.0: - run: test-cpu-openmpi-gloo-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0 + run: test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -1866,11 +1177,11 @@ steps: automatic: true agents: queue: cpu -- label: ':terminal: Test Horovodrun (test-cpu-openmpi-gloo-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0)' - command: horovodrun -np 2 -H localhost:2 python /horovod/examples/tensorflow_mnist.py +- label: ':fire: Test PyTorch MNIST (test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0)' + command: bash -c " \$(cat /mpirun_command) python /horovod/examples/pytorch_mnist.py" plugins: - docker-compose#v2.6.0: - run: test-cpu-openmpi-gloo-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0 + run: test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -1880,11 +1191,11 @@ steps: automatic: true agents: queue: cpu -- label: ':terminal: Test Horovodrun (test-cpu-openmpi-gloo-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0)' - command: echo 'localhost slots=2' > hostfile +- label: ':muscle: Test MXNet MNIST (test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0)' + command: bash -c " OMP_NUM_THREADS=1 \$(cat /mpirun_command) python /horovod/examples/mxnet_mnist.py" plugins: - docker-compose#v2.6.0: - run: test-cpu-openmpi-gloo-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0 + run: test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -1894,11 +1205,11 @@ steps: automatic: true agents: queue: cpu -- label: ':spark: Spark Keras Rossmann Run (test-cpu-openmpi-gloo-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0)' +- label: ':spark: Spark Keras Rossmann Run (test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0)' command: bash -c "OMP_NUM_THREADS=1 python /horovod/examples/keras_spark_rossmann_run.py --num-proc 2 --data-dir file:///data --epochs 3 --sample-rate 0.01" plugins: - docker-compose#v2.6.0: - run: test-cpu-openmpi-gloo-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0 + run: test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -1908,11 +1219,11 @@ steps: automatic: true agents: queue: cpu -- label: ':spark: Spark Keras Rossmann Estimator (test-cpu-openmpi-gloo-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0)' +- label: ':spark: Spark Keras Rossmann Estimator (test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0)' command: bash -c "OMP_NUM_THREADS=1 python /horovod/examples/keras_spark_rossmann_estimator.py --num-proc 2 --work-dir /work --data-dir file:///data --epochs 3 --sample-rate 0.01" plugins: - docker-compose#v2.6.0: - run: test-cpu-openmpi-gloo-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0 + run: test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -1922,11 +1233,11 @@ steps: automatic: true agents: queue: cpu -- label: ':spark: Spark Keras MNIST (test-cpu-openmpi-gloo-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0)' +- label: ':spark: Spark Keras MNIST (test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0)' command: bash -c "OMP_NUM_THREADS=1 python /horovod/examples/keras_spark_mnist.py --num-proc 2 --work-dir /work --data-dir /data --epochs 3" plugins: - docker-compose#v2.6.0: - run: test-cpu-openmpi-gloo-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0 + run: test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -1936,11 +1247,11 @@ steps: automatic: true agents: queue: cpu -- label: ':spark: PyTests Spark Estimators (test-cpu-openmpi-gloo-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0)' +- label: ':spark: PyTests Spark Estimators (test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0)' command: bash -c "cd /horovod/test && pytest --forked -v --capture=no test_spark_keras.py test_spark_torch.py" plugins: - docker-compose#v2.6.0: - run: test-cpu-openmpi-gloo-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0 + run: test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -1950,11 +1261,11 @@ steps: automatic: true agents: queue: cpu -- label: ':spark: Spark Torch MNIST (test-cpu-openmpi-gloo-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0)' +- label: ':spark: Spark Torch MNIST (test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0)' command: bash -c "OMP_NUM_THREADS=1 python /horovod/examples/pytorch_spark_mnist.py --num-proc 2 --work-dir /work --data-dir /data --epochs 3" plugins: - docker-compose#v2.6.0: - run: test-cpu-openmpi-gloo-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0 + run: test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -1964,11 +1275,11 @@ steps: automatic: true agents: queue: cpu -- label: ':tensorflow: Single Keras MNIST (test-cpu-openmpi-gloo-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0)' +- label: ':tensorflow: Single Keras MNIST (test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0)' command: bash -c " python /horovod/examples/keras_mnist_advanced.py --epochs 3 --batch-size 64" plugins: - docker-compose#v2.6.0: - run: test-cpu-openmpi-gloo-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0 + run: test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -1978,11 +1289,11 @@ steps: automatic: true agents: queue: cpu -- label: ':python: Single PyTorch MNIST (test-cpu-openmpi-gloo-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0)' +- label: ':fire: Single PyTorch MNIST (test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0)' command: bash -c " python /horovod/examples/pytorch_mnist.py --epochs 3" plugins: - docker-compose#v2.6.0: - run: test-cpu-openmpi-gloo-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0 + run: test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -1992,11 +1303,11 @@ steps: automatic: true agents: queue: cpu -- label: ':muscle: Single MXNet MNIST (test-cpu-openmpi-gloo-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0)' +- label: ':muscle: Single MXNet MNIST (test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0)' command: bash -c " python /horovod/examples/mxnet_mnist.py --epochs 3" plugins: - docker-compose#v2.6.0: - run: test-cpu-openmpi-gloo-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0 + run: test-cpu-openmpi-py3_6-tf1_14_0-keras2_2_4-torch1_2_0-mxnet1_4_1-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -2007,7 +1318,7 @@ steps: agents: queue: cpu - label: ':pytest: Run PyTests (test-cpu-openmpi-py2_7-tf2_0_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' - command: bash -c " cd /horovod/test && (echo test_*.py | sed 's/test_keras.py//g' | sed 's/test_tensorflow_keras.py//g' | sed 's/test_interactiverun.py//g' | sed 's/test_spark_keras.py//g' | sed 's/test_spark_torch.py//g' | sed 's/test_spark.py//g' | sed 's/test_run.py//g' | xargs -n 1 \$(cat /mpirun_command) pytest -v --capture=no) && pytest --forked -v --capture=no test_spark.py test_run.py" + command: bash -c " cd /horovod/test && (echo test_*.py | sed 's/test_keras.py//g' | sed 's/test_tensorflow_keras.py//g' | sed 's/test_elastic[a-z_.]*//g' | sed 's/test_interactiverun.py//g' | sed 's/test_spark_keras.py//g' | sed 's/test_spark_torch.py//g' | sed 's/test_spark.py//g' | sed 's/test_run.py//g' | xargs -n 1 \$(cat /mpirun_command) pytest -v --capture=no) && pytest --forked -v --capture=no test_spark.py test_run.py" plugins: - docker-compose#v2.6.0: run: test-cpu-openmpi-py2_7-tf2_0_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 @@ -2034,7 +1345,7 @@ steps: automatic: true agents: queue: cpu -- label: ':python: Test PyTorch MNIST (test-cpu-openmpi-py2_7-tf2_0_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' +- label: ':fire: Test PyTorch MNIST (test-cpu-openmpi-py2_7-tf2_0_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' command: bash -c " \$(cat /mpirun_command) python /horovod/examples/pytorch_mnist.py" plugins: - docker-compose#v2.6.0: @@ -2104,7 +1415,7 @@ steps: automatic: true agents: queue: cpu -- label: ':python: Single PyTorch MNIST (test-cpu-openmpi-py2_7-tf2_0_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' +- label: ':fire: Single PyTorch MNIST (test-cpu-openmpi-py2_7-tf2_0_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' command: bash -c " python /horovod/examples/pytorch_mnist.py --epochs 3" plugins: - docker-compose#v2.6.0: @@ -2133,7 +1444,7 @@ steps: agents: queue: cpu - label: ':pytest: Run PyTests (test-cpu-openmpi-py3_6-tf2_0_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' - command: bash -c " cd /horovod/test && (echo test_*.py | sed 's/test_keras.py//g' | sed 's/test_tensorflow_keras.py//g' | sed 's/test_interactiverun.py//g' | sed 's/test_spark_keras.py//g' | sed 's/test_spark_torch.py//g' | sed 's/test_spark.py//g' | sed 's/test_run.py//g' | xargs -n 1 \$(cat /mpirun_command) pytest -v --capture=no) && pytest --forked -v --capture=no test_spark.py test_run.py" + command: bash -c " cd /horovod/test && (echo test_*.py | sed 's/test_keras.py//g' | sed 's/test_tensorflow_keras.py//g' | sed 's/test_interactiverun.py//g' | sed 's/test_spark_keras.py//g' | sed 's/test_spark_torch.py//g' | sed 's/test_spark.py//g' | sed 's/test_run.py//g' | xargs -n 1 \$(cat /mpirun_command) pytest -v --capture=no) && pytest --forked -v --capture=no test_spark.py test_run.py" plugins: - docker-compose#v2.6.0: run: test-cpu-openmpi-py3_6-tf2_0_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 @@ -2160,7 +1471,7 @@ steps: automatic: true agents: queue: cpu -- label: ':python: Test PyTorch MNIST (test-cpu-openmpi-py3_6-tf2_0_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' +- label: ':fire: Test PyTorch MNIST (test-cpu-openmpi-py3_6-tf2_0_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' command: bash -c " \$(cat /mpirun_command) python /horovod/examples/pytorch_mnist.py" plugins: - docker-compose#v2.6.0: @@ -2230,7 +1541,7 @@ steps: automatic: true agents: queue: cpu -- label: ':python: Single PyTorch MNIST (test-cpu-openmpi-py3_6-tf2_0_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' +- label: ':fire: Single PyTorch MNIST (test-cpu-openmpi-py3_6-tf2_0_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' command: bash -c " python /horovod/examples/pytorch_mnist.py --epochs 3" plugins: - docker-compose#v2.6.0: @@ -2259,7 +1570,7 @@ steps: agents: queue: cpu - label: ':pytest: Run PyTests (test-cpu-openmpi-py3_6-tfhead-kerashead-torchhead-mxnethead-pyspark2_4_0)' - command: bash -c " cd /horovod/test && (echo test_*.py | sed 's/test_keras.py//g' | sed 's/test_tensorflow_keras.py//g' | sed 's/test_interactiverun.py//g' | sed 's/test_spark_keras.py//g' | sed 's/test_spark_torch.py//g' | sed 's/test_spark.py//g' | sed 's/test_run.py//g' | xargs -n 1 \$(cat /mpirun_command) pytest -v --capture=no) && pytest --forked -v --capture=no test_spark.py test_run.py" + command: bash -c " cd /horovod/test && (echo test_*.py | sed 's/test_keras.py//g' | sed 's/test_tensorflow_keras.py//g' | sed 's/test_interactiverun.py//g' | sed 's/test_spark_keras.py//g' | sed 's/test_spark_torch.py//g' | sed 's/test_spark.py//g' | sed 's/test_run.py//g' | xargs -n 1 \$(cat /mpirun_command) pytest -v --capture=no) && pytest --forked -v --capture=no test_spark.py test_run.py" plugins: - docker-compose#v2.6.0: run: test-cpu-openmpi-py3_6-tfhead-kerashead-torchhead-mxnethead-pyspark2_4_0 @@ -2286,7 +1597,7 @@ steps: automatic: true agents: queue: cpu -- label: ':python: Test PyTorch MNIST (test-cpu-openmpi-py3_6-tfhead-kerashead-torchhead-mxnethead-pyspark2_4_0)' +- label: ':fire: Test PyTorch MNIST (test-cpu-openmpi-py3_6-tfhead-kerashead-torchhead-mxnethead-pyspark2_4_0)' command: bash -c " \$(cat /mpirun_command) python /horovod/examples/pytorch_mnist.py" plugins: - docker-compose#v2.6.0: @@ -2356,7 +1667,7 @@ steps: automatic: true agents: queue: cpu -- label: ':python: Single PyTorch MNIST (test-cpu-openmpi-py3_6-tfhead-kerashead-torchhead-mxnethead-pyspark2_4_0)' +- label: ':fire: Single PyTorch MNIST (test-cpu-openmpi-py3_6-tfhead-kerashead-torchhead-mxnethead-pyspark2_4_0)' command: bash -c " python /horovod/examples/pytorch_mnist.py --epochs 3" plugins: - docker-compose#v2.6.0: @@ -2384,11 +1695,11 @@ steps: automatic: true agents: queue: cpu -- label: ':pytest: Run PyTests (test-cpu-mpich-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' - command: bash -c " cd /horovod/test && (echo test_*.py | sed 's/[a-z_]*tensorflow2[a-z_.]*//g' | sed 's/test_interactiverun.py//g' | sed 's/test_spark_keras.py//g' | sed 's/test_spark_torch.py//g' | sed 's/test_spark.py//g' | sed 's/test_run.py//g' | xargs -n 1 \$(cat /mpirun_command) pytest -v --capture=no) && pytest --forked -v --capture=no test_spark.py test_run.py" +- label: ':pytest: Run PyTests (test-cpu-mpich-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' + command: bash -c " cd /horovod/test && (echo test_*.py | sed 's/[a-z_]*tensorflow2[a-z_.]*//g' | sed 's/test_interactiverun.py//g' | sed 's/test_spark_keras.py//g' | sed 's/test_spark_torch.py//g' | sed 's/test_spark.py//g' | sed 's/test_run.py//g' | xargs -n 1 \$(cat /mpirun_command) pytest -v --capture=no) && pytest --forked -v --capture=no test_spark.py test_run.py" plugins: - docker-compose#v2.6.0: - run: test-cpu-mpich-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 + run: test-cpu-mpich-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -2398,11 +1709,11 @@ steps: automatic: true agents: queue: cpu -- label: ':tensorflow: Test TensorFlow MNIST (test-cpu-mpich-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' +- label: ':tensorflow: Test TensorFlow MNIST (test-cpu-mpich-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' command: bash -c " \$(cat /mpirun_command) python /horovod/examples/tensorflow_mnist.py" plugins: - docker-compose#v2.6.0: - run: test-cpu-mpich-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 + run: test-cpu-mpich-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -2412,11 +1723,11 @@ steps: automatic: true agents: queue: cpu -- label: ':tensorflow: Test TensorFlow Eager MNIST (test-cpu-mpich-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' +- label: ':tensorflow: Test TensorFlow Eager MNIST (test-cpu-mpich-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' command: bash -c " \$(cat /mpirun_command) python /horovod/examples/tensorflow_mnist_eager.py" plugins: - docker-compose#v2.6.0: - run: test-cpu-mpich-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 + run: test-cpu-mpich-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -2426,11 +1737,11 @@ steps: automatic: true agents: queue: cpu -- label: ':tensorflow: Test Keras MNIST (test-cpu-mpich-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' +- label: ':tensorflow: Test Keras MNIST (test-cpu-mpich-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' command: bash -c " \$(cat /mpirun_command) python /horovod/examples/keras_mnist_advanced.py" plugins: - docker-compose#v2.6.0: - run: test-cpu-mpich-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 + run: test-cpu-mpich-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -2440,11 +1751,11 @@ steps: automatic: true agents: queue: cpu -- label: ':python: Test PyTorch MNIST (test-cpu-mpich-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' +- label: ':fire: Test PyTorch MNIST (test-cpu-mpich-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' command: bash -c " \$(cat /mpirun_command) python /horovod/examples/pytorch_mnist.py" plugins: - docker-compose#v2.6.0: - run: test-cpu-mpich-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 + run: test-cpu-mpich-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -2454,11 +1765,11 @@ steps: automatic: true agents: queue: cpu -- label: ':muscle: Test MXNet MNIST (test-cpu-mpich-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' +- label: ':muscle: Test MXNet MNIST (test-cpu-mpich-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' command: bash -c " OMP_NUM_THREADS=1 \$(cat /mpirun_command) python /horovod/examples/mxnet_mnist.py" plugins: - docker-compose#v2.6.0: - run: test-cpu-mpich-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 + run: test-cpu-mpich-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -2468,11 +1779,11 @@ steps: automatic: true agents: queue: cpu -- label: ':muscle: Test Stall (test-cpu-mpich-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' +- label: ':muscle: Test Stall (test-cpu-mpich-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' command: bash -c " \$(cat /mpirun_command) python /horovod/test/test_stall.py" plugins: - docker-compose#v2.6.0: - run: test-cpu-mpich-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 + run: test-cpu-mpich-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -2482,11 +1793,11 @@ steps: automatic: true agents: queue: cpu -- label: ':tensorflow: Single Keras MNIST (test-cpu-mpich-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' +- label: ':tensorflow: Single Keras MNIST (test-cpu-mpich-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' command: bash -c " python /horovod/examples/keras_mnist_advanced.py --epochs 3 --batch-size 64" plugins: - docker-compose#v2.6.0: - run: test-cpu-mpich-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 + run: test-cpu-mpich-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -2496,11 +1807,11 @@ steps: automatic: true agents: queue: cpu -- label: ':python: Single PyTorch MNIST (test-cpu-mpich-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' +- label: ':fire: Single PyTorch MNIST (test-cpu-mpich-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' command: bash -c " python /horovod/examples/pytorch_mnist.py --epochs 3" plugins: - docker-compose#v2.6.0: - run: test-cpu-mpich-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 + run: test-cpu-mpich-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -2510,11 +1821,11 @@ steps: automatic: true agents: queue: cpu -- label: ':muscle: Single MXNet MNIST (test-cpu-mpich-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' +- label: ':muscle: Single MXNet MNIST (test-cpu-mpich-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' command: bash -c " python /horovod/examples/mxnet_mnist.py --epochs 3" plugins: - docker-compose#v2.6.0: - run: test-cpu-mpich-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 + run: test-cpu-mpich-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -2524,11 +1835,11 @@ steps: automatic: true agents: queue: cpu -- label: ':pytest: Run PyTests (test-cpu-oneccl-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' - command: bash -c "\$(cat /oneccl_env) && echo '/mpirun_command_mpi' > /mpirun_command && cd /horovod/test && (echo test_*.py | sed 's/[a-z_]*tensorflow2[a-z_.]*//g' | sed 's/test_interactiverun.py//g' | sed 's/test_spark_keras.py//g' | sed 's/test_spark_torch.py//g' | sed 's/test_spark.py//g' | sed 's/test_run.py//g' | xargs -n 1 \$(cat /mpirun_command) pytest -v --capture=no) && pytest --forked -v --capture=no test_spark.py test_run.py" +- label: ':pytest: Run PyTests (test-cpu-oneccl-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' + command: bash -c "\$(cat /oneccl_env) && echo '/mpirun_command_mpi' > /mpirun_command && cd /horovod/test && (echo test_*.py | sed 's/[a-z_]*tensorflow2[a-z_.]*//g' | sed 's/test_interactiverun.py//g' | sed 's/test_spark_keras.py//g' | sed 's/test_spark_torch.py//g' | sed 's/test_spark.py//g' | sed 's/test_run.py//g' | xargs -n 1 \$(cat /mpirun_command) pytest -v --capture=no) && pytest --forked -v --capture=no test_spark.py test_run.py" plugins: - docker-compose#v2.6.0: - run: test-cpu-oneccl-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 + run: test-cpu-oneccl-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -2538,11 +1849,11 @@ steps: automatic: true agents: queue: cpu -- label: ':tensorflow: Test TensorFlow MNIST (test-cpu-oneccl-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' +- label: ':tensorflow: Test TensorFlow MNIST (test-cpu-oneccl-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' command: bash -c "\$(cat /oneccl_env) && echo '/mpirun_command_mpi' > /mpirun_command && \$(cat /mpirun_command) python /horovod/examples/tensorflow_mnist.py" plugins: - docker-compose#v2.6.0: - run: test-cpu-oneccl-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 + run: test-cpu-oneccl-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -2552,11 +1863,11 @@ steps: automatic: true agents: queue: cpu -- label: ':tensorflow: Test TensorFlow Eager MNIST (test-cpu-oneccl-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' +- label: ':tensorflow: Test TensorFlow Eager MNIST (test-cpu-oneccl-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' command: bash -c "\$(cat /oneccl_env) && echo '/mpirun_command_mpi' > /mpirun_command && \$(cat /mpirun_command) python /horovod/examples/tensorflow_mnist_eager.py" plugins: - docker-compose#v2.6.0: - run: test-cpu-oneccl-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 + run: test-cpu-oneccl-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -2566,11 +1877,11 @@ steps: automatic: true agents: queue: cpu -- label: ':tensorflow: Test Keras MNIST (test-cpu-oneccl-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' +- label: ':tensorflow: Test Keras MNIST (test-cpu-oneccl-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' command: bash -c "\$(cat /oneccl_env) && echo '/mpirun_command_mpi' > /mpirun_command && \$(cat /mpirun_command) python /horovod/examples/keras_mnist_advanced.py" plugins: - docker-compose#v2.6.0: - run: test-cpu-oneccl-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 + run: test-cpu-oneccl-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -2580,11 +1891,11 @@ steps: automatic: true agents: queue: cpu -- label: ':python: Test PyTorch MNIST (test-cpu-oneccl-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' +- label: ':fire: Test PyTorch MNIST (test-cpu-oneccl-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' command: bash -c "\$(cat /oneccl_env) && echo '/mpirun_command_mpi' > /mpirun_command && \$(cat /mpirun_command) python /horovod/examples/pytorch_mnist.py" plugins: - docker-compose#v2.6.0: - run: test-cpu-oneccl-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 + run: test-cpu-oneccl-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -2594,11 +1905,11 @@ steps: automatic: true agents: queue: cpu -- label: ':muscle: Test MXNet MNIST (test-cpu-oneccl-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' +- label: ':muscle: Test MXNet MNIST (test-cpu-oneccl-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' command: bash -c "\$(cat /oneccl_env) && echo '/mpirun_command_mpi' > /mpirun_command && OMP_NUM_THREADS=1 \$(cat /mpirun_command) python /horovod/examples/mxnet_mnist.py" plugins: - docker-compose#v2.6.0: - run: test-cpu-oneccl-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 + run: test-cpu-oneccl-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -2608,11 +1919,11 @@ steps: automatic: true agents: queue: cpu -- label: ':muscle: Test Stall (test-cpu-oneccl-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' +- label: ':muscle: Test Stall (test-cpu-oneccl-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' command: bash -c "\$(cat /oneccl_env) && echo '/mpirun_command_mpi' > /mpirun_command && \$(cat /mpirun_command) python /horovod/test/test_stall.py" plugins: - docker-compose#v2.6.0: - run: test-cpu-oneccl-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 + run: test-cpu-oneccl-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -2622,11 +1933,11 @@ steps: automatic: true agents: queue: cpu -- label: ':tensorflow: Single Keras MNIST (test-cpu-oneccl-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' +- label: ':tensorflow: Single Keras MNIST (test-cpu-oneccl-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' command: bash -c "\$(cat /oneccl_env) && echo '/mpirun_command_mpi' > /mpirun_command && python /horovod/examples/keras_mnist_advanced.py --epochs 3 --batch-size 64" plugins: - docker-compose#v2.6.0: - run: test-cpu-oneccl-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 + run: test-cpu-oneccl-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -2636,11 +1947,11 @@ steps: automatic: true agents: queue: cpu -- label: ':python: Single PyTorch MNIST (test-cpu-oneccl-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' +- label: ':fire: Single PyTorch MNIST (test-cpu-oneccl-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' command: bash -c "\$(cat /oneccl_env) && echo '/mpirun_command_mpi' > /mpirun_command && python /horovod/examples/pytorch_mnist.py --epochs 3" plugins: - docker-compose#v2.6.0: - run: test-cpu-oneccl-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 + run: test-cpu-oneccl-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -2650,11 +1961,11 @@ steps: automatic: true agents: queue: cpu -- label: ':muscle: Single MXNet MNIST (test-cpu-oneccl-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' +- label: ':muscle: Single MXNet MNIST (test-cpu-oneccl-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' command: bash -c "\$(cat /oneccl_env) && echo '/mpirun_command_mpi' > /mpirun_command && python /horovod/examples/mxnet_mnist.py --epochs 3" plugins: - docker-compose#v2.6.0: - run: test-cpu-oneccl-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 + run: test-cpu-oneccl-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -2664,11 +1975,11 @@ steps: automatic: true agents: queue: cpu -- label: ':pytest: Run PyTests (test-cpu-oneccl-ofi-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' - command: bash -c "\$(cat /oneccl_env) && echo '/mpirun_command_ofi' > /mpirun_command && cd /horovod/test && (echo test_*.py | sed 's/[a-z_]*tensorflow2[a-z_.]*//g' | sed 's/test_interactiverun.py//g' | sed 's/test_spark_keras.py//g' | sed 's/test_spark_torch.py//g' | sed 's/test_spark.py//g' | sed 's/test_run.py//g' | xargs -n 1 \$(cat /mpirun_command) pytest -v --capture=no) && pytest --forked -v --capture=no test_spark.py test_run.py" +- label: ':pytest: Run PyTests (test-cpu-oneccl-ofi-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' + command: bash -c "\$(cat /oneccl_env) && echo '/mpirun_command_ofi' > /mpirun_command && cd /horovod/test && (echo test_*.py | sed 's/[a-z_]*tensorflow2[a-z_.]*//g' | sed 's/test_interactiverun.py//g' | sed 's/test_spark_keras.py//g' | sed 's/test_spark_torch.py//g' | sed 's/test_spark.py//g' | sed 's/test_run.py//g' | xargs -n 1 \$(cat /mpirun_command) pytest -v --capture=no) && pytest --forked -v --capture=no test_spark.py test_run.py" plugins: - docker-compose#v2.6.0: - run: test-cpu-oneccl-ofi-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 + run: test-cpu-oneccl-ofi-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -2678,11 +1989,11 @@ steps: automatic: true agents: queue: cpu -- label: ':tensorflow: Test TensorFlow MNIST (test-cpu-oneccl-ofi-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' +- label: ':tensorflow: Test TensorFlow MNIST (test-cpu-oneccl-ofi-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' command: bash -c "\$(cat /oneccl_env) && echo '/mpirun_command_ofi' > /mpirun_command && \$(cat /mpirun_command) python /horovod/examples/tensorflow_mnist.py" plugins: - docker-compose#v2.6.0: - run: test-cpu-oneccl-ofi-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 + run: test-cpu-oneccl-ofi-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -2692,11 +2003,11 @@ steps: automatic: true agents: queue: cpu -- label: ':tensorflow: Test TensorFlow Eager MNIST (test-cpu-oneccl-ofi-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' +- label: ':tensorflow: Test TensorFlow Eager MNIST (test-cpu-oneccl-ofi-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' command: bash -c "\$(cat /oneccl_env) && echo '/mpirun_command_ofi' > /mpirun_command && \$(cat /mpirun_command) python /horovod/examples/tensorflow_mnist_eager.py" plugins: - docker-compose#v2.6.0: - run: test-cpu-oneccl-ofi-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 + run: test-cpu-oneccl-ofi-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -2706,11 +2017,11 @@ steps: automatic: true agents: queue: cpu -- label: ':tensorflow: Test Keras MNIST (test-cpu-oneccl-ofi-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' +- label: ':tensorflow: Test Keras MNIST (test-cpu-oneccl-ofi-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' command: bash -c "\$(cat /oneccl_env) && echo '/mpirun_command_ofi' > /mpirun_command && \$(cat /mpirun_command) python /horovod/examples/keras_mnist_advanced.py" plugins: - docker-compose#v2.6.0: - run: test-cpu-oneccl-ofi-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 + run: test-cpu-oneccl-ofi-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -2720,11 +2031,11 @@ steps: automatic: true agents: queue: cpu -- label: ':python: Test PyTorch MNIST (test-cpu-oneccl-ofi-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' +- label: ':fire: Test PyTorch MNIST (test-cpu-oneccl-ofi-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' command: bash -c "\$(cat /oneccl_env) && echo '/mpirun_command_ofi' > /mpirun_command && \$(cat /mpirun_command) python /horovod/examples/pytorch_mnist.py" plugins: - docker-compose#v2.6.0: - run: test-cpu-oneccl-ofi-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 + run: test-cpu-oneccl-ofi-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -2734,11 +2045,11 @@ steps: automatic: true agents: queue: cpu -- label: ':muscle: Test MXNet MNIST (test-cpu-oneccl-ofi-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' +- label: ':muscle: Test MXNet MNIST (test-cpu-oneccl-ofi-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' command: bash -c "\$(cat /oneccl_env) && echo '/mpirun_command_ofi' > /mpirun_command && OMP_NUM_THREADS=1 \$(cat /mpirun_command) python /horovod/examples/mxnet_mnist.py" plugins: - docker-compose#v2.6.0: - run: test-cpu-oneccl-ofi-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 + run: test-cpu-oneccl-ofi-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -2748,11 +2059,11 @@ steps: automatic: true agents: queue: cpu -- label: ':muscle: Test Stall (test-cpu-oneccl-ofi-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' +- label: ':muscle: Test Stall (test-cpu-oneccl-ofi-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' command: bash -c "\$(cat /oneccl_env) && echo '/mpirun_command_ofi' > /mpirun_command && \$(cat /mpirun_command) python /horovod/test/test_stall.py" plugins: - docker-compose#v2.6.0: - run: test-cpu-oneccl-ofi-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 + run: test-cpu-oneccl-ofi-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -2762,11 +2073,11 @@ steps: automatic: true agents: queue: cpu -- label: ':tensorflow: Single Keras MNIST (test-cpu-oneccl-ofi-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' +- label: ':tensorflow: Single Keras MNIST (test-cpu-oneccl-ofi-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' command: bash -c "\$(cat /oneccl_env) && echo '/mpirun_command_ofi' > /mpirun_command && python /horovod/examples/keras_mnist_advanced.py --epochs 3 --batch-size 64" plugins: - docker-compose#v2.6.0: - run: test-cpu-oneccl-ofi-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 + run: test-cpu-oneccl-ofi-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -2776,11 +2087,11 @@ steps: automatic: true agents: queue: cpu -- label: ':python: Single PyTorch MNIST (test-cpu-oneccl-ofi-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' +- label: ':fire: Single PyTorch MNIST (test-cpu-oneccl-ofi-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' command: bash -c "\$(cat /oneccl_env) && echo '/mpirun_command_ofi' > /mpirun_command && python /horovod/examples/pytorch_mnist.py --epochs 3" plugins: - docker-compose#v2.6.0: - run: test-cpu-oneccl-ofi-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 + run: test-cpu-oneccl-ofi-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -2790,11 +2101,11 @@ steps: automatic: true agents: queue: cpu -- label: ':muscle: Single MXNet MNIST (test-cpu-oneccl-ofi-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' +- label: ':muscle: Single MXNet MNIST (test-cpu-oneccl-ofi-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' command: bash -c "\$(cat /oneccl_env) && echo '/mpirun_command_ofi' > /mpirun_command && python /horovod/examples/mxnet_mnist.py --epochs 3" plugins: - docker-compose#v2.6.0: - run: test-cpu-oneccl-ofi-py3_6-tf1_14_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 + run: test-cpu-oneccl-ofi-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 config: docker-compose.test.yml pull-retries: 3 - ecr#v1.2.0: @@ -2806,7 +2117,7 @@ steps: queue: cpu - wait - label: ':pytest: Run PyTests (test-gpu-openmpi-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0)' - command: bash -c " cd /horovod/test && (echo test_*.py | sed 's/[a-z_]*tensorflow2[a-z_.]*//g' | sed 's/test_interactiverun.py//g' | sed 's/test_spark_keras.py//g' | sed 's/test_spark_torch.py//g' | sed 's/test_spark.py//g' | sed 's/test_run.py//g' | xargs -n 1 \$(cat /mpirun_command) pytest -v --capture=no) && pytest --forked -v --capture=no test_spark.py test_run.py" + command: bash -c " cd /horovod/test && (echo test_*.py | sed 's/[a-z_]*tensorflow2[a-z_.]*//g' | sed 's/test_interactiverun.py//g' | sed 's/test_spark_keras.py//g' | sed 's/test_spark_torch.py//g' | sed 's/test_spark.py//g' | sed 's/test_run.py//g' | xargs -n 1 \$(cat /mpirun_command) pytest -v --capture=no) && pytest --forked -v --capture=no test_spark.py test_run.py" plugins: - docker-compose#v2.6.0: run: test-gpu-openmpi-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0 @@ -2820,7 +2131,7 @@ steps: agents: queue: 4x-gpu-g4 - label: ':pytest: Run PyTests (test-gpu-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0)' - command: bash -c "cd /horovod/test && (echo test_*.py | sed 's/[a-z_]*tensorflow2[a-z_.]*//g' | sed 's/test_interactiverun.py//g' | sed 's/test_spark_keras.py//g' | sed 's/test_spark_torch.py//g' | sed 's/test_spark.py//g' | sed 's/test_run.py//g' | xargs -n 1 horovodrun -np 2 -H localhost:2 --gloo pytest -v --capture=no) && pytest --forked -v --capture=no test_spark.py test_run.py" + command: bash -c "cd /horovod/test && (echo test_*.py | sed 's/[a-z_]*tensorflow2[a-z_.]*//g' | sed 's/test_interactiverun.py//g' | sed 's/test_spark_keras.py//g' | sed 's/test_spark_torch.py//g' | sed 's/test_spark.py//g' | sed 's/test_run.py//g' | xargs -n 1 horovodrun -np 2 -H localhost:2 --gloo pytest -v --capture=no) && pytest --forked -v --capture=no test_spark.py test_run.py" plugins: - docker-compose#v2.6.0: run: test-gpu-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0 @@ -2834,7 +2145,7 @@ steps: agents: queue: 4x-gpu-g4 - label: ':pytest: Run PyTests (test-gpu-openmpi-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0)' - command: bash -c "cd /horovod/test && (echo test_*.py | sed 's/[a-z_]*tensorflow2[a-z_.]*//g' | sed 's/test_interactiverun.py//g' | sed 's/test_spark_keras.py//g' | sed 's/test_spark_torch.py//g' | sed 's/test_spark.py//g' | sed 's/test_run.py//g' | xargs -n 1 horovodrun -np 2 -H localhost:2 --gloo pytest -v --capture=no) && pytest --forked -v --capture=no test_spark.py test_run.py" + command: bash -c "cd /horovod/test && (echo test_*.py | sed 's/[a-z_]*tensorflow2[a-z_.]*//g' | sed 's/test_interactiverun.py//g' | sed 's/test_spark_keras.py//g' | sed 's/test_spark_torch.py//g' | sed 's/test_spark.py//g' | sed 's/test_run.py//g' | xargs -n 1 horovodrun -np 2 -H localhost:2 --gloo pytest -v --capture=no) && pytest --forked -v --capture=no test_spark.py test_run.py" plugins: - docker-compose#v2.6.0: run: test-gpu-openmpi-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0 @@ -2848,7 +2159,7 @@ steps: agents: queue: 4x-gpu-g4 - label: ':pytest: Run PyTests (test-gpu-openmpi-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0)' - command: bash -c " cd /horovod/test && (echo test_*.py | sed 's/[a-z_]*tensorflow2[a-z_.]*//g' | sed 's/test_interactiverun.py//g' | sed 's/test_spark_keras.py//g' | sed 's/test_spark_torch.py//g' | sed 's/test_spark.py//g' | sed 's/test_run.py//g' | xargs -n 1 \$(cat /mpirun_command) pytest -v --capture=no) && pytest --forked -v --capture=no test_spark.py test_run.py" + command: bash -c " cd /horovod/test && (echo test_*.py | sed 's/[a-z_]*tensorflow2[a-z_.]*//g' | sed 's/test_interactiverun.py//g' | sed 's/test_spark_keras.py//g' | sed 's/test_spark_torch.py//g' | sed 's/test_spark.py//g' | sed 's/test_run.py//g' | xargs -n 1 \$(cat /mpirun_command) pytest -v --capture=no) && pytest --forked -v --capture=no test_spark.py test_run.py" plugins: - docker-compose#v2.6.0: run: test-gpu-openmpi-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0 @@ -2862,7 +2173,7 @@ steps: agents: queue: 4x-gpu-g4 - label: ':pytest: Run PyTests (test-gpu-openmpi-py3_6-tf2_0_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' - command: bash -c " cd /horovod/test && (echo test_*.py | sed 's/test_keras.py//g' | sed 's/test_tensorflow_keras.py//g' | sed 's/test_interactiverun.py//g' | sed 's/test_spark_keras.py//g' | sed 's/test_spark_torch.py//g' | sed 's/test_spark.py//g' | sed 's/test_run.py//g' | xargs -n 1 \$(cat /mpirun_command) pytest -v --capture=no) && pytest --forked -v --capture=no test_spark.py test_run.py" + command: bash -c " cd /horovod/test && (echo test_*.py | sed 's/test_keras.py//g' | sed 's/test_tensorflow_keras.py//g' | sed 's/test_interactiverun.py//g' | sed 's/test_spark_keras.py//g' | sed 's/test_spark_torch.py//g' | sed 's/test_spark.py//g' | sed 's/test_run.py//g' | xargs -n 1 \$(cat /mpirun_command) pytest -v --capture=no) && pytest --forked -v --capture=no test_spark.py test_run.py" plugins: - docker-compose#v2.6.0: run: test-gpu-openmpi-py3_6-tf2_0_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 @@ -2876,7 +2187,7 @@ steps: agents: queue: 4x-gpu-g4 - label: ':pytest: Run PyTests (test-gpu-openmpi-py3_6-tfhead-kerashead-torchhead-mxnethead-pyspark2_4_0)' - command: bash -c " cd /horovod/test && (echo test_*.py | sed 's/test_keras.py//g' | sed 's/test_tensorflow_keras.py//g' | sed 's/test_interactiverun.py//g' | sed 's/test_spark_keras.py//g' | sed 's/test_spark_torch.py//g' | sed 's/test_spark.py//g' | sed 's/test_run.py//g' | xargs -n 1 \$(cat /mpirun_command) pytest -v --capture=no) && pytest --forked -v --capture=no test_spark.py test_run.py" + command: bash -c " cd /horovod/test && (echo test_*.py | sed 's/test_keras.py//g' | sed 's/test_tensorflow_keras.py//g' | sed 's/test_interactiverun.py//g' | sed 's/test_spark_keras.py//g' | sed 's/test_spark_torch.py//g' | sed 's/test_spark.py//g' | sed 's/test_run.py//g' | xargs -n 1 \$(cat /mpirun_command) pytest -v --capture=no) && pytest --forked -v --capture=no test_spark.py test_run.py" plugins: - docker-compose#v2.6.0: run: test-gpu-openmpi-py3_6-tfhead-kerashead-torchhead-mxnethead-pyspark2_4_0 @@ -2890,7 +2201,7 @@ steps: agents: queue: 4x-gpu-g4 - label: ':pytest: Run PyTests (test-mixed-openmpi-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' - command: bash -c " cd /horovod/test && (echo test_*.py | sed 's/[a-z_]*tensorflow2[a-z_.]*//g' | sed 's/test_interactiverun.py//g' | sed 's/test_spark_keras.py//g' | sed 's/test_spark_torch.py//g' | sed 's/test_spark.py//g' | sed 's/test_run.py//g' | xargs -n 1 \$(cat /mpirun_command) pytest -v --capture=no) && pytest --forked -v --capture=no test_spark.py test_run.py" + command: bash -c " cd /horovod/test && (echo test_*.py | sed 's/[a-z_]*tensorflow2[a-z_.]*//g' | sed 's/test_interactiverun.py//g' | sed 's/test_spark_keras.py//g' | sed 's/test_spark_torch.py//g' | sed 's/test_spark.py//g' | sed 's/test_run.py//g' | xargs -n 1 \$(cat /mpirun_command) pytest -v --capture=no) && pytest --forked -v --capture=no test_spark.py test_run.py" plugins: - docker-compose#v2.6.0: run: test-mixed-openmpi-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 @@ -2960,7 +2271,7 @@ steps: automatic: true agents: queue: 2x-gpu-g4 -- label: ':python: Test PyTorch MNIST (test-gpu-openmpi-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0)' +- label: ':fire: Test PyTorch MNIST (test-gpu-openmpi-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0)' command: bash -c " \$(cat /mpirun_command) python /horovod/examples/pytorch_mnist.py" plugins: - docker-compose#v2.6.0: @@ -2988,6 +2299,48 @@ steps: automatic: true agents: queue: 2x-gpu-g4 +- label: ':muscle: Test Stall (test-gpu-openmpi-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0)' + command: bash -c " \$(cat /mpirun_command) python /horovod/test/test_stall.py" + plugins: + - docker-compose#v2.6.0: + run: test-gpu-openmpi-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0 + config: docker-compose.test.yml + pull-retries: 3 + - ecr#v1.2.0: + login: true + timeout_in_minutes: 10 + retry: + automatic: true + agents: + queue: 2x-gpu-g4 +- label: ':terminal: Test Horovodrun (test-gpu-openmpi-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0)' + command: horovodrun -np 2 -H localhost:2 python /horovod/examples/tensorflow_mnist.py + plugins: + - docker-compose#v2.6.0: + run: test-gpu-openmpi-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0 + config: docker-compose.test.yml + pull-retries: 3 + - ecr#v1.2.0: + login: true + timeout_in_minutes: 10 + retry: + automatic: true + agents: + queue: 2x-gpu-g4 +- label: ':terminal: Test Horovodrun (test-gpu-openmpi-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0)' + command: echo 'localhost slots=2' > hostfile + plugins: + - docker-compose#v2.6.0: + run: test-gpu-openmpi-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0 + config: docker-compose.test.yml + pull-retries: 3 + - ecr#v1.2.0: + login: true + timeout_in_minutes: 10 + retry: + automatic: true + agents: + queue: 2x-gpu-g4 - label: ':spark: Spark Keras Rossmann Run (test-gpu-openmpi-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0)' command: bash -c "OMP_NUM_THREADS=1 python /horovod/examples/keras_spark_rossmann_run.py --num-proc 2 --data-dir file:///data --epochs 3 --sample-rate 0.01" plugins: @@ -3072,7 +2425,7 @@ steps: automatic: true agents: queue: 2x-gpu-g4 -- label: ':python: Test PyTorch MNIST (test-gpu-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0)' +- label: ':fire: Test PyTorch MNIST (test-gpu-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0)' command: horovodrun -np 2 -H localhost:2 --gloo python /horovod/examples/pytorch_mnist.py plugins: - docker-compose#v2.6.0: @@ -3100,6 +2453,20 @@ steps: automatic: true agents: queue: 2x-gpu-g4 +- label: ':factory: Elastic Tests (test-gpu-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0)' + command: bash -c "cd /horovod/test/integration && pytest -v --log-cli-level 10 --capture=no test_elastic_torch.py test_elastic_tensorflow.py" + plugins: + - docker-compose#v2.6.0: + run: test-gpu-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0 + config: docker-compose.test.yml + pull-retries: 3 + - ecr#v1.2.0: + login: true + timeout_in_minutes: 10 + retry: + automatic: true + agents: + queue: 2x-gpu-g4 - label: ':spark: Spark Keras Rossmann Run (test-gpu-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0)' command: bash -c "OMP_NUM_THREADS=1 python /horovod/examples/keras_spark_rossmann_run.py --num-proc 2 --data-dir file:///data --epochs 3 --sample-rate 0.01" plugins: @@ -3184,7 +2551,7 @@ steps: automatic: true agents: queue: 2x-gpu-g4 -- label: ':python: Test PyTorch MNIST (test-gpu-openmpi-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0)' +- label: ':fire: Test PyTorch MNIST (test-gpu-openmpi-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0)' command: horovodrun -np 2 -H localhost:2 --gloo python /horovod/examples/pytorch_mnist.py plugins: - docker-compose#v2.6.0: @@ -3212,6 +2579,20 @@ steps: automatic: true agents: queue: 2x-gpu-g4 +- label: ':factory: Elastic Tests (test-gpu-openmpi-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0)' + command: bash -c "cd /horovod/test/integration && pytest -v --log-cli-level 10 --capture=no test_elastic_torch.py test_elastic_tensorflow.py" + plugins: + - docker-compose#v2.6.0: + run: test-gpu-openmpi-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0 + config: docker-compose.test.yml + pull-retries: 3 + - ecr#v1.2.0: + login: true + timeout_in_minutes: 10 + retry: + automatic: true + agents: + queue: 2x-gpu-g4 - label: ':jupyter: Run PyTests test_interactiverun (test-gpu-openmpi-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0)' command: bash -c "cd /horovod/test && pytest -v --capture=no test_interactiverun.py" plugins: @@ -3268,7 +2649,7 @@ steps: automatic: true agents: queue: 2x-gpu-g4 -- label: ':python: Test PyTorch MNIST (test-gpu-openmpi-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0)' +- label: ':fire: Test PyTorch MNIST (test-gpu-openmpi-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0)' command: bash -c " \$(cat /mpirun_command) python /horovod/examples/pytorch_mnist.py" plugins: - docker-compose#v2.6.0: @@ -3296,6 +2677,48 @@ steps: automatic: true agents: queue: 2x-gpu-g4 +- label: ':muscle: Test Stall (test-gpu-openmpi-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0)' + command: bash -c " \$(cat /mpirun_command) python /horovod/test/test_stall.py" + plugins: + - docker-compose#v2.6.0: + run: test-gpu-openmpi-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0 + config: docker-compose.test.yml + pull-retries: 3 + - ecr#v1.2.0: + login: true + timeout_in_minutes: 10 + retry: + automatic: true + agents: + queue: 2x-gpu-g4 +- label: ':terminal: Test Horovodrun (test-gpu-openmpi-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0)' + command: horovodrun -np 2 -H localhost:2 python /horovod/examples/tensorflow_mnist.py + plugins: + - docker-compose#v2.6.0: + run: test-gpu-openmpi-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0 + config: docker-compose.test.yml + pull-retries: 3 + - ecr#v1.2.0: + login: true + timeout_in_minutes: 10 + retry: + automatic: true + agents: + queue: 2x-gpu-g4 +- label: ':terminal: Test Horovodrun (test-gpu-openmpi-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0)' + command: echo 'localhost slots=2' > hostfile + plugins: + - docker-compose#v2.6.0: + run: test-gpu-openmpi-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0 + config: docker-compose.test.yml + pull-retries: 3 + - ecr#v1.2.0: + login: true + timeout_in_minutes: 10 + retry: + automatic: true + agents: + queue: 2x-gpu-g4 - label: ':spark: Spark Keras Rossmann Run (test-gpu-openmpi-gloo-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_4_1-pyspark2_4_0)' command: bash -c "OMP_NUM_THREADS=1 python /horovod/examples/keras_spark_rossmann_run.py --num-proc 2 --data-dir file:///data --epochs 3 --sample-rate 0.01" plugins: @@ -3366,7 +2789,7 @@ steps: automatic: true agents: queue: 2x-gpu-g4 -- label: ':python: Test PyTorch MNIST (test-gpu-openmpi-py3_6-tf2_0_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' +- label: ':fire: Test PyTorch MNIST (test-gpu-openmpi-py3_6-tf2_0_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' command: bash -c " \$(cat /mpirun_command) python /horovod/examples/pytorch_mnist.py" plugins: - docker-compose#v2.6.0: @@ -3450,7 +2873,7 @@ steps: automatic: true agents: queue: 2x-gpu-g4 -- label: ':python: Test PyTorch MNIST (test-gpu-openmpi-py3_6-tfhead-kerashead-torchhead-mxnethead-pyspark2_4_0)' +- label: ':fire: Test PyTorch MNIST (test-gpu-openmpi-py3_6-tfhead-kerashead-torchhead-mxnethead-pyspark2_4_0)' command: bash -c " \$(cat /mpirun_command) python /horovod/examples/pytorch_mnist.py" plugins: - docker-compose#v2.6.0: @@ -3576,7 +2999,7 @@ steps: automatic: true agents: queue: 2x-gpu-g4 -- label: ':python: Test PyTorch MNIST (test-mixed-openmpi-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' +- label: ':fire: Test PyTorch MNIST (test-mixed-openmpi-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' command: bash -c " \$(cat /mpirun_command) python /horovod/examples/pytorch_mnist.py" plugins: - docker-compose#v2.6.0: @@ -3604,6 +3027,48 @@ steps: automatic: true agents: queue: 2x-gpu-g4 +- label: ':muscle: Test Stall (test-mixed-openmpi-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' + command: bash -c " \$(cat /mpirun_command) python /horovod/test/test_stall.py" + plugins: + - docker-compose#v2.6.0: + run: test-mixed-openmpi-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 + config: docker-compose.test.yml + pull-retries: 3 + - ecr#v1.2.0: + login: true + timeout_in_minutes: 10 + retry: + automatic: true + agents: + queue: 2x-gpu-g4 +- label: ':terminal: Test Horovodrun (test-mixed-openmpi-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' + command: horovodrun -np 2 -H localhost:2 python /horovod/examples/tensorflow_mnist.py + plugins: + - docker-compose#v2.6.0: + run: test-mixed-openmpi-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 + config: docker-compose.test.yml + pull-retries: 3 + - ecr#v1.2.0: + login: true + timeout_in_minutes: 10 + retry: + automatic: true + agents: + queue: 2x-gpu-g4 +- label: ':terminal: Test Horovodrun (test-mixed-openmpi-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' + command: echo 'localhost slots=2' > hostfile + plugins: + - docker-compose#v2.6.0: + run: test-mixed-openmpi-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0 + config: docker-compose.test.yml + pull-retries: 3 + - ecr#v1.2.0: + login: true + timeout_in_minutes: 10 + retry: + automatic: true + agents: + queue: 2x-gpu-g4 - label: ':spark: Spark Keras Rossmann Run (test-mixed-openmpi-py3_6-tf1_15_0-keras2_3_1-torch1_3_0-mxnet1_5_0-pyspark2_4_0)' command: bash -c "OMP_NUM_THREADS=1 python /horovod/examples/keras_spark_rossmann_run.py --num-proc 2 --data-dir file:///data --epochs 3 --sample-rate 0.01" plugins: diff --git a/test/integration/data/elastic_tensorflow2_main.py b/test/integration/data/elastic_tensorflow2_main.py new file mode 100644 index 0000000000..6d6d4a9009 --- /dev/null +++ b/test/integration/data/elastic_tensorflow2_main.py @@ -0,0 +1,155 @@ +# Copyright 2020 Uber Technologies, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import + +import argparse +import json +import os +import psutil +import time + +import tensorflow as tf + +import horovod.tensorflow as hvd + +parser = argparse.ArgumentParser(description='TensorFlow 2 Elastic Test', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + +parser.add_argument('--batches-per-epoch', type=int, default=10, + help='number of batches per epoch') +parser.add_argument('--batches-per-commit', type=int, default=1, + help='number of batches per commit of the elastic state object') +parser.add_argument('--epochs', type=int, default=3, + help='number of epochs') +parser.add_argument('--logfile', default='/tmp/logfile.txt', + help='log file to record results (one line per epoch)') +parser.add_argument('--discovery-schedule', default='[]', + help='JSON string specifying schedule of host updates each epoch') +parser.add_argument('--exit-schedule', + help='JSON string mapping from (epoch, batch) to list of ranks to exit at that time') +parser.add_argument('--exit-mode', default='exception', + help='means used to cause a worker to exit [exception | kill]') + +args = parser.parse_args() + +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + +hvd.init() + +batch_size = 32 +data = tf.random.uniform([batch_size, 2]) +indices = tf.random.uniform([batch_size], minval=0, maxval=2, dtype=tf.int64) +target = tf.one_hot(indices, 2) + +lr = 0.001 +model = tf.keras.Sequential([tf.keras.layers.Dense(2, activation='softmax')]) +optimizer = tf.optimizers.SGD(lr * hvd.size()) + +hostname = os.environ.get('HOROVOD_HOSTNAME') +start_rank = int(os.environ.get('HOROVOD_RANK', 0)) + +discovery_schedule = json.loads(args.discovery_schedule) +epoch_to_hosts = {epoch: hosts for epoch, hosts in discovery_schedule if epoch is not None} +default_hosts = discovery_schedule[-1][1] if len(discovery_schedule) > 0 else [] + +exit_schedule = json.loads(args.exit_schedule) if args.exit_schedule else {} + + +def check_exit(epoch, batch): + key = str((epoch, batch)) + if key in exit_schedule: + ranks_to_exit = exit_schedule[key] + if start_rank in ranks_to_exit: + if args.exit_mode == 'exception': + raise RuntimeError('check_rank and exit epoch={} batch={} start_rank={} rank={}' + .format(epoch, batch, start_rank, hvd.rank())) + else: + psutil.Process(os.getpid()).kill() + + +def log_state(state): + state_dict = { + 'epoch': state.epoch, + 'batch': state.batch, + 'commits': state.commits, + 'hostname': hostname, + 'start_rank': start_rank, + 'rank': hvd.rank(), + 'size': hvd.size(), + 'rendezvous': state.rendezvous} + with open(args.logfile, 'a') as f: + f.write(json.dumps(state_dict) + os.linesep) + + +@tf.function +def step(allreduce=True): + # Horovod: use DistributedGradientTape + with tf.GradientTape() as tape: + probs = model(data, training=True) + loss = tf.losses.categorical_crossentropy(target, probs) + + # Horovod: add Horovod Distributed GradientTape. + if allreduce: + tape = hvd.DistributedGradientTape(tape) + + gradients = tape.gradient(loss, model.trainable_variables) + optimizer.apply_gradients(zip(gradients, model.trainable_variables)) + + +step(allreduce=False) + + +@hvd.elastic.run +def train(state): + state.rendezvous += 1 + while state.epoch < args.epochs: + print('epoch {} batch {}'.format(state.epoch, state.batch)) + + while state.batch < args.batches_per_epoch: + check_exit(state.epoch, state.batch) + step() + + state.batch += 1 + if state.batch % args.batches_per_commit == 0: + state.commits += 1 + state.commit() + + if hvd.rank() == 0: + log_state(state) + + current_hosts = epoch_to_hosts.get(state.epoch, default_hosts) + next_hosts = epoch_to_hosts.get(state.epoch + 1, default_hosts) + if current_hosts != next_hosts: + print('host changes: {} -> {}'.format(current_hosts, next_hosts)) + start = int(time.time()) + while state._host_messages.empty(): + if int(time.time()) - start > 3: + raise TimeoutError('Timed out waiting for notifications from driver.') + time.sleep(0.1) + + state.epoch += 1 + state.batch = 0 + state.commits += 1 + state.commit() + + +def on_state_reset(): + optimizer.lr.assign(lr * hvd.size()) + + +state = hvd.elastic.TensorFlowKerasState(model, optimizer, batch=0, epoch=0, commits=0, rendezvous=0) +state.register_reset_callbacks([on_state_reset]) +train(state) diff --git a/test/integration/data/elastic_tensorflow_main.py b/test/integration/data/elastic_tensorflow_main.py new file mode 100644 index 0000000000..e1c6deedc5 --- /dev/null +++ b/test/integration/data/elastic_tensorflow_main.py @@ -0,0 +1,148 @@ +# Copyright 2020 Uber Technologies, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import + +import argparse +import json +import os +import psutil +import time + +import tensorflow as tf + +import horovod.tensorflow as hvd + +parser = argparse.ArgumentParser(description='TensorFlow Elastic Test', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + +parser.add_argument('--batches-per-epoch', type=int, default=10, + help='number of batches per epoch') +parser.add_argument('--batches-per-commit', type=int, default=1, + help='number of batches per commit of the elastic state object') +parser.add_argument('--epochs', type=int, default=3, + help='number of epochs') +parser.add_argument('--logfile', default='/tmp/logfile.txt', + help='log file to record results (one line per epoch)') +parser.add_argument('--discovery-schedule', default='[]', + help='JSON string specifying schedule of host updates each epoch') +parser.add_argument('--exit-schedule', + help='JSON string mapping from (epoch, batch) to list of ranks to exit at that time') +parser.add_argument('--exit-mode', default='exception', + help='means used to cause a worker to exit [exception | kill]') + +args = parser.parse_args() + +config = tf.ConfigProto() +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" +config.gpu_options.allow_growth = False +config.gpu_options.visible_device_list = '' + +hvd.init() + +base_lr = 0.01 +lr = tf.Variable(base_lr * hvd.size()) +model = tf.keras.Sequential([tf.keras.layers.Dense(2, activation='softmax')]) +optimizer = tf.train.GradientDescentOptimizer(lr) +optimizer = hvd.DistributedOptimizer(optimizer) + +batch_size = 32 +data = tf.random_uniform([batch_size, 2]) +target = tf.random_uniform([batch_size, 1], minval=0, maxval=2, dtype=tf.int64) + +probs = tf.layers.dense(data, 2, activation=None) +loss = tf.losses.sparse_softmax_cross_entropy(target, probs) + +hostname = os.environ.get('HOROVOD_HOSTNAME') +start_rank = int(os.environ.get('HOROVOD_RANK', 0)) + +discovery_schedule = json.loads(args.discovery_schedule) +epoch_to_hosts = {epoch: hosts for epoch, hosts in discovery_schedule if epoch is not None} +default_hosts = discovery_schedule[-1][1] if len(discovery_schedule) > 0 else [] + +exit_schedule = json.loads(args.exit_schedule) if args.exit_schedule else {} + + +def check_exit(epoch, batch): + key = str((epoch, batch)) + if key in exit_schedule: + ranks_to_exit = exit_schedule[key] + if start_rank in ranks_to_exit: + if args.exit_mode == 'exception': + raise RuntimeError('check_rank and exit epoch={} batch={} start_rank={} rank={}' + .format(epoch, batch, start_rank, hvd.rank())) + else: + psutil.Process(os.getpid()).kill() + + +def log_state(state): + state_dict = { + 'epoch': state.epoch, + 'batch': state.batch, + 'commits': state.commits, + 'hostname': hostname, + 'start_rank': start_rank, + 'rank': hvd.rank(), + 'size': hvd.size(), + 'rendezvous': state.rendezvous} + with open(args.logfile, 'a') as f: + f.write(json.dumps(state_dict) + os.linesep) + + +@hvd.elastic.run +def train(state, step): + state.rendezvous += 1 + while state.epoch < args.epochs: + print('epoch {} batch {}'.format(state.epoch, state.batch)) + + while state.batch < args.batches_per_epoch: + check_exit(state.epoch, state.batch) + step() + + state.batch += 1 + if state.batch % args.batches_per_commit == 0: + state.commits += 1 + state.commit() + + if hvd.rank() == 0: + log_state(state) + + current_hosts = epoch_to_hosts.get(state.epoch, default_hosts) + next_hosts = epoch_to_hosts.get(state.epoch + 1, default_hosts) + if current_hosts != next_hosts: + print('host changes: {} -> {}'.format(current_hosts, next_hosts)) + start = int(time.time()) + while state._host_messages.empty(): + if int(time.time()) - start > 3: + raise TimeoutError('Timed out waiting for notifications from driver.') + time.sleep(0.1) + + state.epoch += 1 + state.batch = 0 + state.commits += 1 + state.commit() + + +with tf.Session(config=config) as session: + session.run(tf.global_variables_initializer()) + + def on_state_reset(): + lr.load(base_lr * hvd.size(), session) + + state = hvd.elastic.TensorFlowState(session=session, batch=0, epoch=0, commits=0, rendezvous=0) + state.register_reset_callbacks([on_state_reset]) + + train_opt = optimizer.minimize(loss) + train(state, lambda: session.run(train_opt)) diff --git a/test/integration/data/elastic_torch_main.py b/test/integration/data/elastic_torch_main.py new file mode 100644 index 0000000000..51a0ffea83 --- /dev/null +++ b/test/integration/data/elastic_torch_main.py @@ -0,0 +1,142 @@ +# Copyright 2020 Uber Technologies, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import + +import argparse +import json +import os +import psutil +import time + +import torch +import torch.nn.functional as F + +import horovod.torch as hvd + +parser = argparse.ArgumentParser(description='PyTorch Elastic Test', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + +parser.add_argument('--batches-per-epoch', type=int, default=10, + help='number of batches per epoch') +parser.add_argument('--batches-per-commit', type=int, default=1, + help='number of batches per commit of the elastic state object') +parser.add_argument('--epochs', type=int, default=3, + help='number of epochs') +parser.add_argument('--logfile', default='/tmp/logfile.txt', + help='log file to record results (one line per epoch)') +parser.add_argument('--discovery-schedule', default='[]', + help='JSON string specifying schedule of host updates each epoch') +parser.add_argument('--exit-schedule', + help='JSON string mapping from (epoch, batch) to list of ranks to exit at that time') +parser.add_argument('--exit-mode', default='exception', + help='means used to cause a worker to exit [exception | kill]') + +args = parser.parse_args() + +hvd.init() + +batch_size = 32 +data = torch.randn(batch_size, 2) +target = torch.LongTensor(batch_size).random_() % 2 + +lr = 0.001 +model = torch.nn.Sequential(torch.nn.Linear(2, 2)) +optimizer = torch.optim.SGD(model.parameters(), lr=lr * hvd.size()) +optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters()) + +hostname = os.environ.get('HOROVOD_HOSTNAME') +start_rank = int(os.environ.get('HOROVOD_RANK', 0)) + +discovery_schedule = json.loads(args.discovery_schedule) +epoch_to_hosts = {epoch: hosts for epoch, hosts in discovery_schedule if epoch is not None} +default_hosts = discovery_schedule[-1][1] if len(discovery_schedule) > 0 else [] + +exit_schedule = json.loads(args.exit_schedule) if args.exit_schedule else {} + + +def check_exit(epoch, batch): + key = str((epoch, batch)) + if key in exit_schedule: + ranks_to_exit = exit_schedule[key] + if start_rank in ranks_to_exit: + if args.exit_mode == 'exception': + raise RuntimeError('check_rank and exit epoch={} batch={} start_rank={} rank={}' + .format(epoch, batch, start_rank, hvd.rank())) + else: + psutil.Process(os.getpid()).kill() + + +def log_state(state): + state_dict = { + 'epoch': state.epoch, + 'batch': state.batch, + 'commits': state.commits, + 'hostname': hostname, + 'start_rank': start_rank, + 'rank': hvd.rank(), + 'size': hvd.size(), + 'rendezvous': state.rendezvous} + with open(args.logfile, 'a') as f: + f.write(json.dumps(state_dict) + os.linesep) + + +@hvd.elastic.run +def train(state): + state.rendezvous += 1 + while state.epoch < args.epochs: + print('epoch {} batch {}'.format(state.epoch, state.batch)) + + while state.batch < args.batches_per_epoch: + check_exit(state.epoch, state.batch) + + optimizer.zero_grad() + output = model(data) + loss = F.cross_entropy(output, target) + loss.backward() + optimizer.step() + + state.batch += 1 + if state.batch % args.batches_per_commit == 0: + state.commits += 1 + state.commit() + + if hvd.rank() == 0: + log_state(state) + + current_hosts = epoch_to_hosts.get(state.epoch, default_hosts) + next_hosts = epoch_to_hosts.get(state.epoch + 1, default_hosts) + if current_hosts != next_hosts: + print('host changes: {} -> {}'.format(current_hosts, next_hosts)) + start = int(time.time()) + while state._host_messages.empty(): + if int(time.time()) - start > 3: + raise TimeoutError('Timed out waiting for notifications from driver.') + time.sleep(0.1) + + state.epoch += 1 + state.batch = 0 + state.commits += 1 + state.commit() + + +def on_state_reset(): + for param_group in optimizer.param_groups: + param_group['lr'] = lr * hvd.size() + + +state = hvd.elastic.TorchState(model, optimizer, batch=0, epoch=0, commits=0, rendezvous=0) +state.register_reset_callbacks([on_state_reset]) +train(state) diff --git a/test/integration/elastic_common.py b/test/integration/elastic_common.py new file mode 100644 index 0000000000..9b4daaceb8 --- /dev/null +++ b/test/integration/elastic_common.py @@ -0,0 +1,238 @@ +# Copyright 2020 Uber Technologies, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import + +import contextlib +import json +import os +import sys + +import mock +import pytest + +from horovod.run.common.util import config_parser +from horovod.run.runner import parse_args, _run_elastic + +sys.path.append(os.path.join(os.path.dirname(__file__), os.pardir)) + +from common import override_args, temppath + + +DISCOVERY_SCRIPT_TEMPLATE = """#!/bin/bash +epoch=0 +if [ -f {logfile} ]; then + epoch=$(< {logfile} wc -l | tr -d '[:space:]') +fi +""" + + +def _get_discovery_lines(schedule_step, start, end): + epoch, hosts = schedule_step + hosts_str = os.linesep.join(['echo "{}"'.format(host) for host in hosts]) + if start and end: + return hosts_str + os.linesep + if start: + return 'if [ "$epoch" == "{}" ]; then'.format(epoch) + os.linesep + hosts_str + os.linesep + elif not start and not end: + return 'elif [ "$epoch" == "{}" ]; then'.format(epoch) + os.linesep + hosts_str + os.linesep + else: + return 'else' + os.linesep + hosts_str + os.linesep + 'fi' + os.linesep + + +@contextlib.contextmanager +def _temp_discovery_script(logfile, discovery_schedule): + with temppath() as discovery_script: + with open(discovery_script, 'w') as f: + f.write(DISCOVERY_SCRIPT_TEMPLATE.format(logfile=logfile) + os.linesep) + for i, schedule_step in enumerate(discovery_schedule): + f.write(_get_discovery_lines(schedule_step, + start=i == 0, + end=i == len(discovery_schedule) - 1)) + os.chmod(discovery_script, 0o755) + yield discovery_script + + +class BaseElasticTests(object): + def __init__(self, training_script, *args, **kwargs): + self._training_script = training_script + super(BaseElasticTests, self).__init__(*args, **kwargs) + + def _run(self, discovery_schedule, exit_schedule=None, np=2, min_np=2, max_np=4, hosts=None, exit_mode='exception'): + with temppath() as logfile: + with _temp_discovery_script(logfile, discovery_schedule) as discovery_script: + command_args = ['horovodrun', + '-np', str(np), + '--min-np', str(min_np), + '--log-level', 'DEBUG'] + if hosts is not None: + command_args += ['-H', hosts] + else: + command_args += ['--host-discovery-script', discovery_script, + '--max-np', str(max_np)] + command_args += ['python', self._training_script, + '--logfile', logfile, + '--discovery-schedule', json.dumps(discovery_schedule), + '--exit-schedule', json.dumps(exit_schedule or {}), + '--exit-mode', exit_mode] + print(' '.join(command_args)) + + with override_args(*command_args): + args = parse_args() + env = {} + config_parser.set_env_from_args(env, args) + _run_elastic(args) + + with open(logfile, 'r') as f: + lines = f.readlines() + + print('logfile:') + for line in lines: + print(line) + + return [json.loads(line) for line in lines] + + @mock.patch('horovod.run.elastic.driver.DISCOVER_HOSTS_FREQUENCY_SECS', 0.01) + @mock.patch('horovod.run.gloo_run._get_min_start_hosts', return_value=1) + def test_hosts_added_and_removed(self, mock_get_min_start_hosts): + for slots, np, min_np, max_np in [(2, 2, 2, 4), (1, 1, 1, 2)]: + discovery_schedule = [ + (0, ['localhost:{}'.format(slots)]), + (1, ['localhost:{}'.format(slots), '127.0.0.1:{}'.format(slots)]), + (None, ['127.0.0.1:{}'.format(slots)]), + ] + + results = self._run(discovery_schedule, np=np, min_np=min_np, max_np=max_np) + for result in results: + print(result) + + assert len(results) == 3 + + assert results[0]['start_rank'] == 0 + assert results[0]['size'] == slots + assert results[0]['hostname'] == 'localhost' + + assert results[1]['start_rank'] == 0 + assert results[1]['size'] == slots * 2 + assert results[1]['hostname'] == 'localhost' + + assert results[2]['start_rank'] == slots + assert results[2]['size'] == slots + assert results[2]['hostname'] == '127.0.0.1' + assert results[2]['rendezvous'] == 3 + + @mock.patch('horovod.run.elastic.driver.DISCOVER_HOSTS_FREQUENCY_SECS', 0.01) + @mock.patch('horovod.run.gloo_run._get_min_start_hosts', return_value=1) + def test_single_rank_failure(self, mock_get_min_start_hosts): + for exit_mode in ['exception', 'kill']: + discovery_schedule = [ + (None, ['localhost:2', '127.0.0.1:2']), + ] + + exit_schedule = { + str((1, 0)): [0], + } + + results = self._run(discovery_schedule, exit_schedule=exit_schedule, exit_mode=exit_mode) + + assert len(results) == 3 + + assert results[0]['start_rank'] == 0 + assert results[0]['size'] == 4 + assert results[0]['rendezvous'] == 1 + + assert results[1]['start_rank'] == 2 + assert results[1]['size'] == 2 + assert results[1]['rendezvous'] == 2 + + assert results[2]['start_rank'] == 2 + assert results[2]['size'] == 2 + assert results[2]['rendezvous'] == 2 + + @mock.patch('horovod.run.elastic.driver.DISCOVER_HOSTS_FREQUENCY_SECS', 0.01) + @mock.patch('horovod.run.gloo_run._get_min_start_hosts', return_value=1) + def test_fault_tolerance_without_scaling(self, mock_get_min_start_hosts): + for exit_mode in ['exception', 'kill']: + discovery_schedule = [ + (None, ['localhost:2', '127.0.0.1:2']), + ] + + hosts = 'localhost:2,127.0.0.1:2' + + exit_schedule = { + str((1, 0)): [0], + } + + results = self._run(discovery_schedule, hosts=hosts, exit_schedule=exit_schedule, exit_mode=exit_mode) + + assert len(results) == 3 + + assert results[0]['start_rank'] == 0 + assert results[0]['size'] == 4 + assert results[0]['rendezvous'] == 1 + + assert results[1]['start_rank'] == 2 + assert results[1]['size'] == 2 + assert results[1]['rendezvous'] == 2 + + assert results[2]['start_rank'] == 2 + assert results[2]['size'] == 2 + assert results[2]['rendezvous'] == 2 + + @mock.patch('horovod.run.elastic.driver.DISCOVER_HOSTS_FREQUENCY_SECS', 0.01) + @mock.patch('horovod.run.gloo_run._get_min_start_hosts', return_value=1) + def test_all_ranks_failure(self, mock_get_min_start_hosts): + discovery_schedule = [ + (None, ['localhost:2', '127.0.0.1:2']), + ] + + exit_schedule = { + str((1, 0)): [0, 1, 2, 3], + } + + message = 'Horovod detected that one or more processes exited with non-zero status' + with pytest.raises(RuntimeError, match=message): + self._run(discovery_schedule, exit_schedule=exit_schedule) + + @mock.patch('horovod.run.elastic.driver.DISCOVER_HOSTS_FREQUENCY_SECS', 0.01) + @mock.patch('horovod.run.gloo_run._get_min_start_hosts', return_value=1) + def test_all_hosts_blacklisted(self, mock_get_min_start_hosts): + discovery_schedule = [ + (None, ['localhost:2', '127.0.0.1:2']), + ] + + exit_schedule = { + str((1, 0)): [0, 2], + } + + message = 'Horovod detected that one or more processes exited with non-zero status' + with pytest.raises(RuntimeError, match=message): + self._run(discovery_schedule, exit_schedule=exit_schedule) + + @mock.patch('horovod.run.elastic.driver.ELASTIC_TIMEOUT_SECS', 1) + @mock.patch('horovod.run.elastic.driver.DISCOVER_HOSTS_FREQUENCY_SECS', 0.01) + @mock.patch('horovod.run.gloo_run._get_min_start_hosts', return_value=1) + def test_min_hosts_timeout(self, mock_get_min_start_hosts): + discovery_schedule = [ + (None, ['localhost:2', '127.0.0.1:2']), + ] + + exit_schedule = { + str((1, 0)): [0], + } + + message = 'Horovod detected that one or more processes exited with non-zero status' + with pytest.raises(RuntimeError, match=message): + self._run(discovery_schedule, exit_schedule=exit_schedule, np=4, min_np=4) diff --git a/test/integration/test_elastic_tensorflow.py b/test/integration/test_elastic_tensorflow.py new file mode 100644 index 0000000000..483bd9a9fd --- /dev/null +++ b/test/integration/test_elastic_tensorflow.py @@ -0,0 +1,31 @@ +# Copyright 2020 Uber Technologies, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import unittest +import warnings + +from elastic_common import BaseElasticTests + + +class ElasticTensorFlowTests(BaseElasticTests, unittest.TestCase): + def __init__(self, *args, **kwargs): + training_script = os.path.join(os.path.dirname(__file__), 'data/elastic_tensorflow_main.py') + super(ElasticTensorFlowTests, self).__init__(training_script, *args, **kwargs) + warnings.simplefilter('module') diff --git a/test/integration/test_elastic_tensorflow2.py b/test/integration/test_elastic_tensorflow2.py new file mode 100644 index 0000000000..2988c01e36 --- /dev/null +++ b/test/integration/test_elastic_tensorflow2.py @@ -0,0 +1,31 @@ +# Copyright 2020 Uber Technologies, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import unittest +import warnings + +from elastic_common import BaseElasticTests + + +class ElasticTensorFlow2Tests(BaseElasticTests, unittest.TestCase): + def __init__(self, *args, **kwargs): + training_script = os.path.join(os.path.dirname(__file__), 'data/elastic_tensorflow2_main.py') + super(ElasticTensorFlow2Tests, self).__init__(training_script, *args, **kwargs) + warnings.simplefilter('module') diff --git a/test/integration/test_elastic_torch.py b/test/integration/test_elastic_torch.py new file mode 100644 index 0000000000..c56273064c --- /dev/null +++ b/test/integration/test_elastic_torch.py @@ -0,0 +1,31 @@ +# Copyright 2020 Uber Technologies, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import unittest +import warnings + +from elastic_common import BaseElasticTests + + +class ElasticTorchTests(BaseElasticTests, unittest.TestCase): + def __init__(self, *args, **kwargs): + training_script = os.path.join(os.path.dirname(__file__), 'data/elastic_torch_main.py') + super(ElasticTorchTests, self).__init__(training_script, *args, **kwargs) + warnings.simplefilter('module') diff --git a/test/test_elastic_driver.py b/test/test_elastic_driver.py new file mode 100644 index 0000000000..06054762b5 --- /dev/null +++ b/test/test_elastic_driver.py @@ -0,0 +1,492 @@ +# Copyright 2020 Uber Technologies, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time +import unittest +import warnings + +import mock +import pytest + +from horovod.run.util import network +from horovod.run.elastic.discovery import FixedHosts, HostManager +from horovod.run.elastic.driver import ElasticDriver +from horovod.run.elastic.rendezvous import create_rendezvous_handler +from horovod.run.elastic.worker import WorkerNotificationManager +from horovod.run.http.http_server import RendezvousServer + + +def wait_for_one(events): + while True: + for event in events: + if event.is_set(): + return + time.sleep(0.01) + + +def sequence(lst): + for v in lst: + yield v + while True: + yield lst[-1] + + +class ElasticDriverTests(unittest.TestCase): + """ + Tests for async processing logic in horovod.elastic. + """ + + def __init__(self, *args, **kwargs): + super(ElasticDriverTests, self).__init__(*args, **kwargs) + warnings.simplefilter('module') + + def test_rank_and_size(self): + """Tests two hosts, two slots each in standard happy path.""" + slots = {'host-1': 2, 'host-2': 2} + discovery = FixedHosts(slots) + + driver = ElasticDriver(mock.Mock(), discovery, min_np=2, max_np=4) + driver.wait_for_available_slots(min_np=2) + + rank_results = {} + + def exec_command(slot_info, events): + driver.record_ready(slot_info.hostname, slot_info.local_rank) + updated_slot_info = driver.get_slot_info(slot_info.hostname, slot_info.local_rank) + rank_results[slot_info.rank] = (slot_info, updated_slot_info) + return 0, time.time() + + driver.start(np=2, create_worker_fn=exec_command) + res = driver.get_results() + driver.stop() + + assert len(res) == 4 + for name, (exit_code, timestamp) in res.items(): + assert exit_code == 0, name + + assert len(rank_results) == 4 + for rank, (slot_info, updated_slot_info) in rank_results.items(): + assert slot_info.to_response_string() == updated_slot_info.to_response_string(), rank + + def test_rank_and_size_with_host_failure(self): + """Tests two hosts, two slots each with second host failing before rendezvous completes.""" + slots = {'host-1': 2, 'host-2': 2} + discovery = FixedHosts(slots) + + driver = ElasticDriver(mock.Mock(), discovery, min_np=2, max_np=4) + driver.wait_for_available_slots(min_np=2) + + rank_results = {} + + def exec_command(slot_info, events): + if slot_info.hostname == 'host-2': + return 1, time.time() + + driver.record_ready(slot_info.hostname, slot_info.local_rank) + updated_slot_info = driver.get_slot_info(slot_info.hostname, slot_info.local_rank) + rank_results[slot_info.rank] = (slot_info, updated_slot_info) + return 0, time.time() + + driver.start(np=2, create_worker_fn=exec_command) + res = driver.get_results() + driver.stop() + + assert len(res) == 2 + for name, (exit_code, timestamp) in res.items(): + assert exit_code == 0, name + + assert len(rank_results) == 2 + for rank, (slot_info, updated_slot_info) in rank_results.items(): + assert updated_slot_info.size == 2, rank + assert updated_slot_info.rank == slot_info.rank % 2, rank + assert updated_slot_info.local_size == slot_info.local_size, rank + assert updated_slot_info.local_rank == slot_info.local_rank, rank + assert updated_slot_info.cross_size == 1, rank + assert updated_slot_info.cross_rank == 0, rank + + @mock.patch('horovod.run.elastic.driver.DISCOVER_HOSTS_FREQUENCY_SECS', 0.01) + def test_rank_and_size_with_host_added(self): + """Tests training starts with one host two slots, then a second host is added.""" + slots = {'host-1': 2} + discovery = FixedHosts(slots) + + def add_host(): + slots = {'host-1': 2, 'host-2': 2} + discovery.set(slots) + + driver = ElasticDriver(mock.Mock(), discovery, min_np=2, max_np=4) + driver.wait_for_available_slots(min_np=2) + + rank_results = {} + + def exec_command(slot_info, events): + driver.record_ready(slot_info.hostname, slot_info.local_rank) + + if slot_info.hostname == 'host-1': + if slot_info.rank == 0: + add_host() + driver.wait_for_available_slots(4) + driver.record_ready(slot_info.hostname, slot_info.local_rank) + + driver.record_ready(slot_info.hostname, slot_info.local_rank) + updated_slot_info = driver.get_slot_info(slot_info.hostname, slot_info.local_rank) + rank_results[slot_info.rank] = (slot_info, updated_slot_info) + return 0, time.time() + + driver.start(np=2, create_worker_fn=exec_command) + res = driver.get_results() + driver.stop() + + assert len(res) == 4 + for name, (exit_code, timestamp) in res.items(): + assert exit_code == 0, name + + assert len(rank_results) == 4 + for rank, (slot_info, updated_slot_info) in rank_results.items(): + assert updated_slot_info.size == 4, rank + assert updated_slot_info.rank == slot_info.rank, rank + assert updated_slot_info.local_size == slot_info.local_size, rank + assert updated_slot_info.local_rank == slot_info.local_rank, rank + assert updated_slot_info.cross_size == 2, rank + assert updated_slot_info.cross_rank == slot_info.cross_rank, rank + + @mock.patch('horovod.run.elastic.driver.DISCOVER_HOSTS_FREQUENCY_SECS', 0.01) + @mock.patch('horovod.run.elastic.driver.ElasticDriver.get_coordinator_info') + @mock.patch('horovod.run.elastic.driver.ElasticDriver.get_worker_client') + def test_wait_for_available_slots(self, mock_get_worker_client, mock_get_coordinator_info): + """Tests that driver blocks until the min number of slots are available.""" + slots = [{'host-1': 4}, + {'host-1': 4, 'host-2': 8}, + {'host-1': 4, 'host-2': 8, 'host-3': 4}] + mock_discovery = mock.Mock() + mock_discovery.find_available_hosts_and_slots.side_effect = sequence(slots) + + driver = ElasticDriver(mock.Mock(), mock_discovery, min_np=8, max_np=20) + driver.wait_for_available_slots(min_np=16) + assert driver._host_manager.current_hosts.count_available_slots() >= 16 + driver.stop() + + # Notify coordinator 2 times, as the first time we are below min_np and the existing host assignments + # are empty + assert mock_get_worker_client.call_count == 2 + assert mock_get_coordinator_info.call_count == 2 + + @mock.patch('horovod.run.elastic.driver.DISCOVER_HOSTS_FREQUENCY_SECS', 0.01) + def test_wait_for_min_hosts(self): + """Tests that driver blocks until the min number of hosts and slots are available.""" + slots = [{'host-1': 4}, + {'host-1': 4, 'host-2': 8}, + {'host-1': 4, 'host-2': 8, 'host-3': 4}] + mock_discovery = mock.Mock() + mock_discovery.find_available_hosts_and_slots.side_effect = sequence(slots) + + driver = ElasticDriver(mock.Mock(), mock_discovery, min_np=2, max_np=12) + driver.wait_for_available_slots(min_np=2, min_hosts=2) + + # Even though we only needed 2 slots, because we also needed 2 hosts, we will at least 12 slots total + assert driver._host_manager.current_hosts.count_available_slots() >= 12 + driver.stop() + + def test_all_workers_fail(self): + """Tests that training fails when all workers fail.""" + slots = {'host-1': 2, 'host-2': 2} + discovery = FixedHosts(slots) + + driver = ElasticDriver(mock.Mock(), discovery, min_np=2, max_np=4) + driver.wait_for_available_slots(min_np=2) + + def exec_command(slot_info, events): + driver.record_ready(slot_info.hostname, slot_info.local_rank) + return 1, time.time() + + driver.start(np=2, create_worker_fn=exec_command) + res = driver.get_results() + driver.stop() + + assert len(res) == 4 + for name, (exit_code, timestamp) in res.items(): + assert exit_code == 1, name + + def test_shutdown_on_success(self): + """Tests that shutdown event is triggered when one worker succeeds but the others are still working.""" + slots = {'host-1': 2, 'host-2': 2} + discovery = FixedHosts(slots) + + driver = ElasticDriver(mock.Mock(), discovery, min_np=2, max_np=4) + driver.wait_for_available_slots(min_np=2) + + def exec_command(slot_info, events): + if slot_info.rank == 0: + return 0, time.time() + + driver.record_ready(slot_info.hostname, slot_info.local_rank) + wait_for_one(events) + return 1, time.time() + + driver.start(np=2, create_worker_fn=exec_command) + res = driver.get_results() + driver.stop() + + assert len(res) == 4 + + exit_code_sum = 0 + for name, (exit_code, timestamp) in res.items(): + exit_code_sum += exit_code + assert exit_code_sum == 3 + + def test_host_shutdown_on_worker_failure(self): + """Tests two hosts, two slots each with one process on second host failing, causing host shutdown.""" + slots = {'host-1': 2, 'host-2': 2} + discovery = FixedHosts(slots) + + driver = ElasticDriver(mock.Mock(), discovery, min_np=2, max_np=4) + driver.wait_for_available_slots(min_np=2) + + rank_results = {} + + def exec_command(slot_info, events): + if slot_info.hostname == 'host-1': + if slot_info.local_rank == 0: + return 1, time.time() + + driver.record_ready(slot_info.hostname, slot_info.local_rank) + wait_for_one(events) + return 1, time.time() + + driver.record_ready(slot_info.hostname, slot_info.local_rank) + updated_slot_info = driver.get_slot_info(slot_info.hostname, slot_info.local_rank) + rank_results[slot_info.rank] = (slot_info, updated_slot_info) + return 0, time.time() + + driver.start(np=2, create_worker_fn=exec_command) + res = driver.get_results() + driver.stop() + + assert len(res) == 2 + for name, (exit_code, timestamp) in res.items(): + assert exit_code == 0, name + + assert len(rank_results) == 2 + for rank, (slot_info, updated_slot_info) in rank_results.items(): + assert updated_slot_info.size == 2, rank + assert updated_slot_info.rank == slot_info.rank % 2, rank + assert updated_slot_info.local_size == slot_info.local_size, rank + assert updated_slot_info.local_rank == slot_info.local_rank, rank + assert updated_slot_info.cross_size == 1, rank + assert updated_slot_info.cross_rank == 0, rank + + @mock.patch('horovod.run.elastic.driver.DISCOVER_HOSTS_FREQUENCY_SECS', 0.01) + def test_worker_notification_manager(self): + """Tests that host add events are sent to the worker notification service and consumed.""" + slots = {'host-1': 2} + discovery = FixedHosts(slots) + + rendezvous = RendezvousServer() + driver = ElasticDriver(rendezvous, discovery, min_np=2, max_np=4) + driver.wait_for_available_slots(min_np=2) + handler = create_rendezvous_handler(driver) + + common_intfs = network.get_local_intfs() + addr = network.get_driver_ip(common_intfs) + port = rendezvous.start_server(handler) + nic = list(common_intfs)[0] + + rank_results = {} + + class NotificationReceiver: + def __init__(self): + self.events = [] + + def on_hosts_updated(self, timestamp): + self.events.append(timestamp) + + def add_host(): + slots = {'host-1': 2, 'host-2': 2} + discovery.set(slots) + + def remove_host(): + slots = {'host-2': 2} + discovery.set(slots) + + def exec_command(slot_info, events): + manager = WorkerNotificationManager() + manager.init(rendezvous_addr=addr, + rendezvous_port=port, + nic=nic, + hostname=slot_info.hostname, + local_rank=slot_info.local_rank) + + notification_receiver = NotificationReceiver() + manager.register_listener(notification_receiver) + + driver.record_ready(slot_info.hostname, slot_info.local_rank) + + if slot_info.rank == 0: + add_host() + driver.wait_for_available_slots(4) + + if slot_info.rank == 0: + remove_host() + + # Busy wait for the number of available slots to decrease + while driver._host_manager.current_hosts.count_available_slots() > 2: + time.sleep(0.01) + + rank_results[slot_info.rank] = notification_receiver.events + return 0, time.time() + + driver.start(np=2, create_worker_fn=exec_command) + res = driver.get_results() + driver.stop() + + assert len(res) == 2 + for name, (exit_code, timestamp) in res.items(): + assert exit_code == 0, name + + assert len(rank_results) == 2 + for rank, timestamps in rank_results.items(): + expected = 2 if rank == 0 else 0 + assert len(timestamps) == expected, rank + + rendezvous.stop_server() + + @mock.patch('horovod.run.elastic.driver.DISCOVER_HOSTS_FREQUENCY_SECS', 0.01) + @mock.patch('horovod.run.elastic.driver.ElasticDriver.host_assignments') + @mock.patch('horovod.run.elastic.driver.ElasticDriver.get_coordinator_info') + @mock.patch('horovod.run.elastic.driver.ElasticDriver.get_worker_client') + def test_send_notifications_without_assignments(self, mock_get_worker_client, mock_get_coordinator_info, + mock_host_assignments): + """Tests that notifications are still sent correctly even if host assignments cannot be generated.""" + slots = [{'host-1': 8, 'host-2': 4}, + {'host-1': 8, 'host-2': 4}, + {'host-2': 4}, + {'host-2': 4}, + {'host-2': 4, 'host-3': 12}] + discovery = mock.Mock() + discovery.find_available_hosts_and_slots.side_effect = sequence(slots) + + driver = ElasticDriver(mock.Mock(), discovery, min_np=8, max_np=12) + driver.wait_for_available_slots(min_np=16) + driver.stop() + + # On the second call, we should see the number of slots dip below the minimum, but we still want to ensure + # we notify workers every time there is a change, so in total we should observe 3 calls. + assert mock_get_worker_client.call_count == 3 + assert mock_get_coordinator_info.call_count == 3 + + def test_order_available_hosts(self): + """Tests the order is preserved for host assignment as available hosts are updated.""" + # This will be a set in practice, but use a list here to guarantee order. + available_hosts = ['a', 'b', 'c'] + ordered_hosts = [] + ordered_hosts = HostManager.order_available_hosts(available_hosts, ordered_hosts) + assert ordered_hosts == available_hosts + + # We remove a host, add a host, and chance the order, but relative order should be preserved + available_hosts = ['d', 'c', 'b'] + ordered_hosts = HostManager.order_available_hosts(available_hosts, ordered_hosts) + assert ordered_hosts == ['b', 'c', 'd'] + + def test_update_available_hosts(self): + """Tests that the current hosts object is immutable, while fetching the latest is correctly updated.""" + mock_discovery = mock.Mock() + mock_discovery.find_available_hosts_and_slots.side_effect = [ + {'a': 2}, + {'a': 2, 'b': 2}, + {'b': 2} + ] + host_manager = HostManager(mock_discovery) + + # Should be empty initially + current_hosts = host_manager.current_hosts + assert current_hosts.available_hosts == set() + assert current_hosts.count_available_slots() == 0 + + host_manager.update_available_hosts() + + # First, check that nothing changed with our existing object, which is immutable + assert current_hosts.available_hosts == set() + assert current_hosts.count_available_slots() == 0 + + # Now verify that the new object has the correct sets + current_hosts = host_manager.current_hosts + assert current_hosts.available_hosts == {'a'} + assert current_hosts.count_available_slots() == 2 + + # Now check again + host_manager.update_available_hosts() + current_hosts = host_manager.current_hosts + assert current_hosts.available_hosts == {'a', 'b'} + assert current_hosts.count_available_slots() == 4 + + # And again + host_manager.update_available_hosts() + current_hosts = host_manager.current_hosts + assert current_hosts.available_hosts == {'b'} + assert current_hosts.count_available_slots() == 2 + + def test_blacklist_host(self): + """Tests the hosts are blacklisted, resulting in changes to the available hosts.""" + mock_discovery = mock.Mock() + mock_discovery.find_available_hosts_and_slots.return_value = {'a': 2, 'b': 2} + host_manager = HostManager(mock_discovery) + + host_manager.update_available_hosts() + + # Sanity check before we blacklist + current_hosts = host_manager.current_hosts + assert current_hosts.available_hosts == {'a', 'b'} + assert current_hosts.count_available_slots() == 4 + + # Now blacklist, our existing object should not change (immutable) + host_manager.blacklist('a') + assert current_hosts.available_hosts == {'a', 'b'} + assert current_hosts.count_available_slots() == 4 + + # Check the new object, make sure we've blacklisted the host + current_hosts = host_manager.current_hosts + assert current_hosts.available_hosts == {'b'} + assert current_hosts.count_available_slots() == 2 + + def test_shutdown_on_initial_discovery_failure(self): + """Tests that the driver will shutdown immediately if initial host discovery fails.""" + discovery = mock.Mock() + discovery.find_available_hosts_and_slots.side_effect = RuntimeError() + + discover_hosts = ElasticDriver._discover_hosts + + def wrapped_discover_hosts(obj): + try: + discover_hosts(obj) + except RuntimeError: + # Suppress the error message from the background discovery thread to clean up unit tests + pass + + try: + ElasticDriver._discover_hosts = wrapped_discover_hosts + driver = ElasticDriver(mock.Mock(), discovery, min_np=2, max_np=4) + with pytest.raises(RuntimeError): + driver.wait_for_available_slots(min_np=2) + assert driver.finished() + finally: + ElasticDriver._discover_hosts = discover_hosts + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_interactiverun.py b/test/test_interactiverun.py index 282ccdf84d..5e96b5ad9f 100644 --- a/test/test_interactiverun.py +++ b/test/test_interactiverun.py @@ -17,17 +17,15 @@ from __future__ import division from __future__ import print_function -import os -import subprocess -import time -import torch -import traceback import unittest import warnings + import pytest +import torch import horovod.torch as hvd +from horovod.common.util import gloo_built, mpi_built from horovod.run import run @@ -38,7 +36,6 @@ def __init__(self, *args, **kwargs): warnings.simplefilter('module') def test_happy_run(self): - def fn(a, b, c, d): hvd.init() rank = hvd.rank() @@ -51,7 +48,14 @@ def fn(a, b, c, d): else: return None + assert gloo_built() or mpi_built() for use_gloo, use_mpi in [(True, False), (False, True)]: + if use_mpi and not mpi_built(): + continue + + if use_gloo and not gloo_built(): + continue + res1 = run(fn, (1, 20), {"c": 300, "d": 4000}, np=1, use_gloo=use_gloo, use_mpi=use_mpi) self.assertListEqual([[0, 4321]], res1) res2 = run(fn, (1, 20), {"c": 300, "d": 4000}, np=3, use_gloo=use_gloo, use_mpi=use_mpi) @@ -60,18 +64,18 @@ def fn(a, b, c, d): None], res2) def test_failed_run(self): - def fn(): hvd.init() rank = hvd.rank() if rank == 1: raise RuntimeError() - with pytest.raises(RuntimeError, match='Gloo job detected that one or more processes exited'): - run(fn, np=2, use_gloo=True) - - with pytest.raises(RuntimeError, match='mpirun failed'): - run(fn, np=2, use_mpi=True) - + assert gloo_built() or mpi_built() + if gloo_built(): + with pytest.raises(RuntimeError, match='Horovod detected that one or more processes exited'): + run(fn, np=2, use_gloo=True) + if mpi_built(): + with pytest.raises(RuntimeError, match='mpirun failed'): + run(fn, np=2, use_mpi=True) diff --git a/test/test_keras.py b/test/test_keras.py index 84c8e85708..02c102ce50 100644 --- a/test/test_keras.py +++ b/test/test_keras.py @@ -19,10 +19,13 @@ from __future__ import division from __future__ import print_function +from distutils.version import LooseVersion + import keras from keras import backend as K import numpy as np +import pytest import tensorflow as tf import warnings @@ -248,3 +251,65 @@ def test_from_config(self): hopt_copy2 = hopt.__class__.from_config(cfg) self.assertEqual(cfg, hopt_copy2.get_config()) + + @pytest.mark.skipif(LooseVersion(tf.__version__) < LooseVersion('1.15.0'), + reason='Synchronizing state requires TensorFlow 1.15 or above') + def test_elastic_state(self): + with self.test_session(config=self.config) as sess: + K.set_session(sess) + + v = 1.0 if hvd.rank() == 0 else 2.0 + model1 = keras.models.Sequential([ + keras.layers.Dense(2, activation='softmax') + ]) + model1.build((2, 2)) + model1.set_weights( + [np.array([[v, v], [v, v]], dtype=np.float32), + np.array([v, v], dtype=np.float32)]) + + model2 = keras.models.Sequential([ + keras.layers.Dense(2, activation='softmax') + ]) + model2.build((2, 2)) + model2.set_weights( + [np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32), + np.array([0.0, 0.0], dtype=np.float32)]) + + optimizer = keras.optimizers.Adam(0.001 * hvd.size()) + + state = hvd.elastic.KerasState(model1, optimizer, batch=20 + hvd.rank(), epoch=10 + hvd.rank()) + state.sync() + + model1_weights = model1.get_weights() + model2_weights = model2.get_weights() + + # After sync, all values should match the root rank + for w in state.model.get_weights(): + self.assertAllClose(w, np.ones_like(w)) + assert state.batch == 20 + assert state.epoch == 10 + + # Partially modify then restore + model1.set_weights(model2_weights) + state.batch = 21 + state.epoch = 11 + + state.restore() + + for w1, w2 in zip(model1.get_weights(), model1_weights): + self.assertAllClose(w1, w2) + assert state.batch == 20 + assert state.epoch == 10 + + # Partially modify then commit + model1.set_weights(model2_weights) + state.batch = 21 + state.epoch = 11 + + state.commit() + state.restore() + + for w1, w2 in zip(model1.get_weights(), model2_weights): + self.assertAllClose(w1, w2) + assert state.batch == 21 + assert state.epoch == 11 diff --git a/test/test_run.py b/test/test_run.py index a3f0527a5c..f47b5b46d0 100644 --- a/test/test_run.py +++ b/test/test_run.py @@ -615,7 +615,7 @@ def test_mpi_run_full(self): extra_mpi_args='>mpi-extra args go here<', binding_args='>binding args go here<', key=secret.make_secret_key(), - timeout=tmout, + start_timeout=tmout, num_hosts=1, num_proc=1, hosts='>host names go here<', diff --git a/test/test_spark.py b/test/test_spark.py index 754bd8661f..85780fba5e 100644 --- a/test/test_spark.py +++ b/test/test_spark.py @@ -297,7 +297,7 @@ def test_spark_run_with_non_zero_exit_with_mpi(self): @pytest.mark.skipif(sys.version_info < (3, 0), reason='Horovod on Spark over Gloo only supported on Python3') def test_spark_run_with_non_zero_exit_with_gloo(self): - expected = '^Gloo job detected that one or more processes exited with non-zero ' \ + expected = '^Horovod detected that one or more processes exited with non-zero ' \ 'status, thus causing the job to be terminated. The first process ' \ 'to do so was:\nProcess name: 0\nExit code: 1$' self.do_test_spark_run_with_non_zero_exit(use_mpi=False, use_gloo=True, @@ -459,10 +459,10 @@ def _exec_command(command, alloc_info, event): self.assertFalse(str(e.value).startswith('Timed out waiting for Spark tasks to start.'), 'Spark timed out before mpi_run was called, test setup is broken.') - self.assertRegexpMatches(str(e.value), - '^Gloo job detected that one or more processes exited with non-zero status, ' - 'thus causing the job to be terminated. The first process to do so was:\n' - 'Process name: [0-9]\nExit code: 1+\n$') + self.assertEqual('Horovod detected that one or more processes exited with non-zero status, ' + 'thus causing the job to be terminated. The first process to do so was:\n' + 'Process name: 0\n' + 'Exit code: 1\n', str(e.value)) num_proc = cores if num_proc is None else num_proc self.assertEqual(expected_np, num_proc) @@ -494,7 +494,8 @@ def _exec_command(command, alloc_info, event): self.assertEqual(alloc_info.local_rank, alloc_info.rank) # command fully derived from alloc_info - expected_command = ('HOROVOD_RANK={rank} ' + expected_command = ('HOROVOD_HOSTNAME=[^ ]+ ' + 'HOROVOD_RANK={rank} ' 'HOROVOD_SIZE={size} ' 'HOROVOD_LOCAL_RANK={local_rank} ' 'HOROVOD_LOCAL_SIZE={local_size} ' @@ -517,6 +518,7 @@ def _exec_command(command, alloc_info, event): # for better comparison replace sections in actual_command that change across runs / hosts actual_command = call_args[0][0] for replacement in ['_HOROVOD_SECRET_KEY=[^ ]+', + 'HOROVOD_HOSTNAME=[^ ]+', 'HOROVOD_GLOO_RENDEZVOUS_ADDR=[^ ]+', 'HOROVOD_GLOO_RENDEZVOUS_PORT=[0-9]+', 'HOROVOD_GLOO_IFACE=[^ ]+', @@ -534,22 +536,31 @@ def test_rsh_with_non_zero_exit_code(self): self.do_test_rsh('false', 1) def test_rsh_event(self): + self.do_test_rsh_events(1) + + def test_rsh_events(self): + self.do_test_rsh_events(3) + + def do_test_rsh_events(self, test_events): + self.assertGreater(test_events, 0, 'test should not be trivial') + sleep = 10 command = 'sleep {}'.format(sleep) - event = threading.Event() - delay(lambda: event.set(), 1.0) + for triggered_event in range(test_events): + events = [threading.Event() for _ in range(test_events)] + delay(lambda: events[triggered_event].set(), 1.0) - start = time.time() - self.do_test_rsh(command, 143, event=event) - duration = time.time() - start + start = time.time() + self.do_test_rsh(command, 143, events=events) + duration = time.time() - start - self.assertGreaterEqual(duration, 1.0) - self.assertLess(duration, 2.00 + safe_shell_exec.GRACEFUL_TERMINATION_TIME_S, - 'sleep should not finish') - self.assertGreater(sleep, 2.00 + safe_shell_exec.GRACEFUL_TERMINATION_TIME_S, - 'sleep should be large enough') + self.assertGreaterEqual(duration, 1.0) + self.assertLess(duration, 2.00 + safe_shell_exec.GRACEFUL_TERMINATION_TIME_S, + 'sleep should not finish') + self.assertGreater(sleep, 2.00 + safe_shell_exec.GRACEFUL_TERMINATION_TIME_S, + 'sleep should be large enough') - def do_test_rsh(self, command, expected_result, event=None): + def do_test_rsh(self, command, expected_result, events=None): def fn(): return 0 @@ -563,7 +574,7 @@ def fn(): settings = hvd_settings.Settings(verbose=2, key=key) env = {} - res = rsh(driver.addresses(), key, settings, host_hash, command, env, 0, False, event=event) + res = rsh(driver.addresses(), key, settings, host_hash, command, env, 0, False, events=events) self.assertEqual(expected_result, res) def test_mpirun_exec_fn(self): diff --git a/test/test_tensorflow.py b/test/test_tensorflow.py index d60bcfa147..b3636a5629 100644 --- a/test/test_tensorflow.py +++ b/test/test_tensorflow.py @@ -21,9 +21,12 @@ from __future__ import division from __future__ import print_function +from distutils.version import LooseVersion + import itertools import numpy as np import os +import pytest import tensorflow as tf from horovod.tensorflow.util import _executing_eagerly, _has_eager from tensorflow.python.framework import ops @@ -50,13 +53,16 @@ ccl_supported_types = set([tf.uint8, tf.int32, tf.int64, tf.float32, tf.float64]) -class MPITests(tf.test.TestCase): +_IS_TF2 = LooseVersion(tf.__version__) >= LooseVersion('2.0.0') + + +class TensorFlowTests(tf.test.TestCase): """ Tests for ops in horovod.tensorflow. """ def __init__(self, *args, **kwargs): - super(MPITests, self).__init__(*args, **kwargs) + super(TensorFlowTests, self).__init__(*args, **kwargs) warnings.simplefilter('module') if _has_eager: if hasattr(tf, 'contrib') and hasattr(tf.contrib, 'eager'): @@ -74,6 +80,20 @@ def evaluate(self, tensors): else: return sess.run(tensors) + def assign(self, variables, values): + if _executing_eagerly(): + for var, val in zip(variables, values): + var.assign(val) + else: + sess = ops.get_default_session() + if sess is None: + with self.test_session(config=config) as sess: + for var, val in zip(variables, values): + var.load(val, sess) + else: + for var, val in zip(variables, values): + var.load(val, sess) + def random_uniform(self, *args, **kwargs): if hasattr(tf, 'random') and hasattr(tf.random, 'set_seed'): tf.random.set_seed(1234) @@ -1062,10 +1082,110 @@ def test_compression_fp16(self): err = np.linalg.norm(expected - actual) self.assertLess(err, 0.00000001) + def test_broadcast_object(self): + if LooseVersion(tf.__version__) < LooseVersion('1.15.0'): + self.skipTest("Broadcasting object requires TensorFlow 1.15 or above") + + hvd.init() + + with tf.device("/cpu:0"): + expected_obj = { + 'hello': 123, + 0: [1, 2] + } + obj = expected_obj if hvd.rank() == 0 else {} + + obj = hvd.broadcast_object(obj, root_rank=0) + self.assertDictEqual(obj, expected_obj) + + def test_broadcast_object_fn(self): + if LooseVersion(tf.__version__) < LooseVersion('1.15.0'): + self.skipTest("Broadcasting object requires TensorFlow 1.15 or above") + + if hvd._executing_eagerly() or _IS_TF2: + # Only for TF 1.0 in graph mode + return + + hvd.init() + + with tf.device("/cpu:0"): + expected_obj = { + 'hello': 123, + 0: [1, 2] + } + obj = expected_obj if hvd.rank() == 0 else {} + + bcast = hvd.broadcast_object_fn(root_rank=0) + obj = bcast(obj) + self.assertDictEqual(obj, expected_obj) + + def test_elastic_state(self): + if LooseVersion(tf.__version__) < LooseVersion('1.15.0'): + self.skipTest("Broadcasting object requires TensorFlow 1.15 or above") + + if not hvd._executing_eagerly() and _IS_TF2: + # Only support TF 2.0 in eager mode + return + + hvd.init() + + with tf.device("/cpu:0"): + v = 1.0 if hvd.rank() == 0 else 2.0 + weights1 = [ + np.array([[v, v], [v, v]]), + np.array([v, v]) + ] + vars1 = [tf.Variable(arr) for arr in weights1] + + weights2 = [ + np.array([[1.0, 2.0], [3.0, 4.0]]), + np.array([0.0, 0.0]) + ] + + if not hvd._executing_eagerly(): + init = tf.global_variables_initializer() + self.evaluate(init) + + state = hvd.elastic.TensorFlowState(vars1, batch=20 + hvd.rank(), epoch=10 + hvd.rank()) + state.sync() + + weights1 = [np.ones_like(w) for w in weights1] + + # After sync, all values should match the root rank + for w in self.evaluate(vars1): + self.assertAllClose(w, np.ones_like(w)) + assert state.batch == 20 + assert state.epoch == 10 + + # Partially modify then restore + self.assign(vars1, weights2) + state.batch = 21 + state.epoch = 11 + + state.restore() + + for w1, w2 in zip(self.evaluate(vars1), weights1): + self.assertAllClose(w1, w2) + assert state.batch == 20 + assert state.epoch == 10 + + # Partially modify then commit + self.assign(vars1, weights2) + state.batch = 21 + state.epoch = 11 + + state.commit() + state.restore() + + for w1, w2 in zip(self.evaluate(vars1), weights2): + self.assertAllClose(w1, w2) + assert state.batch == 21 + assert state.epoch == 11 + if _has_eager: from tensorflow.python.framework.test_util import run_all_in_graph_and_eager_modes - run_all_in_graph_and_eager_modes(MPITests) + run_all_in_graph_and_eager_modes(TensorFlowTests) if __name__ == '__main__': tf.test.main() diff --git a/test/test_tensorflow2_keras.py b/test/test_tensorflow2_keras.py index 84ae7cd34b..3061499cac 100644 --- a/test/test_tensorflow2_keras.py +++ b/test/test_tensorflow2_keras.py @@ -98,3 +98,60 @@ def test_from_config(self): hopt_copy2 = hopt.__class__.from_config(cfg) self.assertEqual(cfg, hopt_copy2.get_config()) + + def test_elastic_state(self): + v = 1.0 if hvd.rank() == 0 else 2.0 + model1 = tf.keras.Sequential([ + tf.keras.layers.Dense(2, activation='softmax') + ]) + model1.build((2, 2)) + model1.set_weights( + [np.array([[v, v], [v, v]], dtype=np.float32), + np.array([v, v], dtype=np.float32)]) + + model2 = tf.keras.Sequential([ + tf.keras.layers.Dense(2, activation='softmax') + ]) + model2.build((2, 2)) + model2.set_weights( + [np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32), + np.array([0.0, 0.0], dtype=np.float32)]) + + optimizer = tf.optimizers.Adam(0.001 * hvd.size()) + + state = hvd.elastic.KerasState(model1, optimizer, batch=20 + hvd.rank(), epoch=10 + hvd.rank()) + state.sync() + + model1_weights = model1.get_weights() + model2_weights = model2.get_weights() + + # After sync, all values should match the root rank + for w in state.model.get_weights(): + self.assertAllClose(w, np.ones_like(w)) + assert state.batch == 20 + assert state.epoch == 10 + + # Partially modify then restore + model1.set_weights(model2_weights) + state.batch = 21 + state.epoch = 11 + + state.restore() + + for w1, w2 in zip(model1.get_weights(), model1_weights): + self.assertAllClose(w1, w2) + assert state.batch == 20 + assert state.epoch == 10 + + # Partially modify then commit + model1.set_weights(model2_weights) + state.batch = 21 + state.epoch = 11 + + state.commit() + state.restore() + + for w1, w2 in zip(model1.get_weights(), model2_weights): + self.assertAllClose(w1, w2) + assert state.batch == 21 + assert state.epoch == 11 diff --git a/test/test_tensorflow_keras.py b/test/test_tensorflow_keras.py index 1132b1e00f..c6a093c873 100644 --- a/test/test_tensorflow_keras.py +++ b/test/test_tensorflow_keras.py @@ -19,9 +19,9 @@ from __future__ import division from __future__ import print_function -import tensorflow as tf import numpy as np import pytest +import tensorflow as tf import warnings from distutils.version import LooseVersion @@ -293,3 +293,65 @@ def test_from_config(self): hopt_copy2 = hopt.__class__.from_config(cfg) self.assertEqual(cfg, hopt_copy2.get_config()) + + @pytest.mark.skipif(LooseVersion(tf.__version__) < LooseVersion('1.15.0'), + reason='Synchronizing state requires TensorFlow 1.15 or above') + def test_elastic_state(self): + with self.test_session(config=self.config) as sess: + K.set_session(sess) + + v = 1.0 if hvd.rank() == 0 else 2.0 + model1 = tf.keras.Sequential([ + tf.keras.layers.Dense(2, activation='softmax') + ]) + model1.build((2, 2)) + model1.set_weights( + [np.array([[v, v], [v, v]], dtype=np.float32), + np.array([v, v], dtype=np.float32)]) + + model2 = tf.keras.Sequential([ + tf.keras.layers.Dense(2, activation='softmax') + ]) + model2.build((2, 2)) + model2.set_weights( + [np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32), + np.array([0.0, 0.0], dtype=np.float32)]) + + optimizer = tf.keras.optimizers.Adam(0.001 * hvd.size()) + + state = hvd.elastic.KerasState(model1, optimizer, batch=20 + hvd.rank(), epoch=10 + hvd.rank()) + state.sync() + + model1_weights = model1.get_weights() + model2_weights = model2.get_weights() + + # After sync, all values should match the root rank + for w in state.model.get_weights(): + self.assertAllClose(w, np.ones_like(w)) + assert state.batch == 20 + assert state.epoch == 10 + + # Partially modify then restore + model1.set_weights(model2_weights) + state.batch = 21 + state.epoch = 11 + + state.restore() + + for w1, w2 in zip(model1.get_weights(), model1_weights): + self.assertAllClose(w1, w2) + assert state.batch == 20 + assert state.epoch == 10 + + # Partially modify then commit + model1.set_weights(model2_weights) + state.batch = 21 + state.epoch = 11 + + state.commit() + state.restore() + + for w1, w2 in zip(model1.get_weights(), model2_weights): + self.assertAllClose(w1, w2) + assert state.batch == 21 + assert state.epoch == 11 diff --git a/test/test_torch.py b/test/test_torch.py index 40d0c643d3..2e846e1f09 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -1726,5 +1726,63 @@ def test_horovod_sync_batch_norm(self): assert (hvd.allreduce(sync_bn.bias.grad, name='sync_bn.bias.grad') - bn.bias.grad).abs().sum() < 1e-6 assert (hvd.allreduce(ts1.grad, name='ts1.grad') - ts2.grad).abs().sum() < 1e-6 + @pytest.mark.skipif(LooseVersion(torch.__version__) < LooseVersion('1.0.0'), + reason='Synchronizing state requires PyTorch 1.0 or above') + def test_elastic_state(self): + hvd.init() + + v = 1.0 if hvd.rank() == 0 else 2.0 + model1 = torch.nn.Sequential(torch.nn.Linear(2, 2)) + model1.load_state_dict({ + '0.weight': torch.tensor([[v, v], [v, v]]), + '0.bias': torch.tensor([v, v]) + }) + + model2 = torch.nn.Sequential(torch.nn.Linear(2, 2)) + model2.load_state_dict({ + '0.weight': torch.tensor([[1.0, 2.0], [3.0, 4.0]]), + '0.bias': torch.tensor([0.0, 0.0]) + }) + + optimizer = torch.optim.SGD(model1.parameters(), lr=0.001 * hvd.size()) + + state = hvd.elastic.TorchState(model1, optimizer, batch=20 + hvd.rank(), epoch=10 + hvd.rank()) + state.sync() + + model1_weights = model1.state_dict().values() + model2_weights = model2.state_dict().values() + + # After sync, all values should match the root rank + for w in state.model.state_dict().values(): + np.testing.assert_allclose(w, np.ones_like(w)) + assert state.batch == 20 + assert state.epoch == 10 + + # Partially modify then restore + model1.load_state_dict(model2.state_dict()) + state.batch = 21 + state.epoch = 11 + + state.restore() + + for w1, w2 in zip(model1.state_dict().values(), model1_weights): + np.testing.assert_allclose(w1, w2) + assert state.batch == 20 + assert state.epoch == 10 + + # Partially modify then commit + model1.load_state_dict(model2.state_dict()) + state.batch = 21 + state.epoch = 11 + + state.commit() + state.restore() + + for w1, w2 in zip(model1.state_dict().values(), model2_weights): + np.testing.assert_allclose(w1, w2) + assert state.batch == 21 + assert state.epoch == 11 + + if __name__ == "__main__": unittest.main()