Skip to content
Permalink
Browse files

Fixed Keras DistributedOptimizer to use base from_config method to su…

…pport 2.0 learning rate schedules (#1588)

Signed-off-by: Travis Addair <taddair@uber.com>
  • Loading branch information
tgaddair committed Jan 9, 2020
1 parent a3953c2 commit 438880e2a1744b309527b8b10d0bd88a77c6e565
Showing with 116 additions and 21 deletions.
  1. +6 −3 .buildkite/gen-pipeline.sh
  2. +8 −16 horovod/_keras/__init__.py
  3. +100 −0 test/test_tensorflow2_keras.py
  4. +2 −2 test/test_tensorflow_keras.py
@@ -100,7 +100,9 @@ run_all() {
local exclude_keras_if_needed=""
if [[ ${test} == *"tf2_"* ]] || [[ ${test} == *"tfhead"* ]]; then
# TODO: support for Keras + TF 2.0 and TF-Keras 2.0
exclude_keras_if_needed="| sed 's/[a-z_]*keras[a-z_.]*//g'"
exclude_keras_if_needed="| sed 's/test_keras.py//g' | sed 's/test_tensorflow_keras.py//g'"
else
exclude_keras_if_needed="| sed 's/[a-z_]*tensorflow2[a-z_.]*//g'"
fi

local exclude_interactiverun="| sed 's/test_interactiverun.py//g' | sed 's/test_spark_keras.py//g' | sed 's/test_spark_torch.py//g'"
@@ -199,11 +201,12 @@ run_gloo() {
exclude_spark_if_needed="| sed 's/[a-z_]*spark[a-z_.]*//g'"
fi

local exclude_interactiverun="| sed 's/test_interactiverun.py//g' | sed 's/test_spark_keras.py//g' | sed 's/test_spark_torch.py//g'"
# These tests are covered in MPI, and testing them in Gloo does not cover any new code paths
local excluded_tests="| sed 's/test_interactiverun.py//g' | sed 's/test_spark_keras.py//g' | sed 's/test_spark_torch.py//g' | sed 's/[a-z_]*tensorflow2[a-z_.]*//g'"

run_test "${test}" "${pytest_queue}" \
":pytest: Run PyTests (${test})" \
"bash -c \"cd /horovod/test && (echo test_*.py ${exclude_spark_if_needed} ${exclude_interactiverun} | xargs -n 1 horovodrun -np 2 -H localhost:2 --gloo pytest -v --capture=no)\""
"bash -c \"cd /horovod/test && (echo test_*.py ${exclude_spark_if_needed} ${excluded_tests} | xargs -n 1 horovodrun -np 2 -H localhost:2 --gloo pytest -v --capture=no)\""

run_test "${test}" "${queue}" \
":tensorflow: Test Keras MNIST (${test})" \
@@ -20,17 +20,14 @@
def create_distributed_optimizer(keras, optimizer, name, device_dense, device_sparse,
compression, sparse_as_dense):
class _DistributedOptimizer(keras.optimizers.Optimizer):
def __init__(self, name, device_dense, device_sparse, compression, sparse_as_dense,
config):
if name is None:
name = "Distributed%s" % self.__class__.__base__.__name__
self._name = name
def __init__(self, **kwargs):
self._name = name or "Distributed%s" % self.__class__.__base__.__name__
self._device_dense = device_dense
self._device_sparse = device_sparse
self._compression = compression
self._sparse_as_dense = sparse_as_dense
self._get_gradients_used = False
super(self.__class__, self).__init__(**config)
super(self.__class__, self).__init__(**kwargs)

def get_gradients(self, loss, params):
"""
@@ -64,24 +61,19 @@ def get_gradients(self, loss, params):

def apply_gradients(self, *args, **kwargs):
if not self._get_gradients_used:
raise Exception('`apply_gradients()` was called without a call to '
'`get_gradients()`. If you\'re using TensorFlow 2.0, '
'please specify `experimental_run_tf_function=False` in '
'`compile()`.')
raise Exception('`apply_gradients()` was called without a call to '
'`get_gradients()`. If you\'re using TensorFlow 2.0, '
'please specify `experimental_run_tf_function=False` in '
'`compile()`.')
return super(self.__class__, self).apply_gradients(*args, **kwargs)

@classmethod
def from_config(cls, cfg):
return cls(name, device_dense, device_sparse, compression, sparse_as_dense, cfg)

# We dynamically create a new class that inherits from the optimizer that was passed in.
# The goal is to override get_gradients() method with an allreduce implementation.
# This class will have the same name as the optimizer it's wrapping, so that the saved
# model could be easily restored without Horovod.
cls = type(optimizer.__class__.__name__, (optimizer.__class__,),
dict(_DistributedOptimizer.__dict__))
return cls(name, device_dense, device_sparse, compression, sparse_as_dense,
optimizer.get_config())
return cls.from_config(optimizer.get_config())


def _eval(backend, op_or_result):
@@ -0,0 +1,100 @@
# Copyright 2019 Uber Technologies, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Tests for horovod.tensorflow.keras."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import numpy as np
import warnings

from tensorflow import keras

import horovod.tensorflow.keras as hvd


class Tf2KerasTests(tf.test.TestCase):
"""
Tests for ops in horovod.tensorflow.keras.
"""

def __init__(self, *args, **kwargs):
super(Tf2KerasTests, self).__init__(*args, **kwargs)
warnings.simplefilter('module')
hvd.init()

gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
if gpus:
tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], 'GPU')

def test_train_model_lr_schedule(self):
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
0.001 * hvd.size(),
decay_steps=100000,
decay_rate=0.96,
staircase=True)
opt = tf.keras.optimizers.Adam(lr_schedule)
opt = hvd.DistributedOptimizer(opt)

model = keras.models.Sequential()
model.add(keras.layers.Dense(2, input_shape=(3,)))
model.add(keras.layers.RepeatVector(3))
model.add(keras.layers.ThresholdedReLU(0.5))
model.compile(loss=keras.losses.mean_squared_error,
optimizer=opt,
metrics=[keras.metrics.categorical_accuracy],
experimental_run_tf_function=False)

x = np.random.random((1, 3))
y = np.random.random((1, 3, 2))

# No assertions, we just need to verify that it doesn't hang or error
callbacks = [hvd.callbacks.BroadcastGlobalVariablesCallback(0)]
model.fit(x,
y,
steps_per_epoch=10,
callbacks=callbacks,
epochs=1)

def test_sparse_as_dense(self):
opt = keras.optimizers.RMSprop(lr=0.0001)
opt = hvd.DistributedOptimizer(opt, sparse_as_dense=True)

model = keras.models.Sequential()
model.add(keras.layers.Embedding(1000, 64, input_length=10))
model.compile(loss=keras.losses.mean_squared_error,
optimizer=opt,
experimental_run_tf_function=False)

x = np.random.randint(1000, size=(32, 10))
y = np.random.random((32, 10, 64))
# No assertions, we just need to verify that it doesn't hang
model.train_on_batch(x, y)

def test_from_config(self):
opt = keras.optimizers.Adam()
hopt = hvd.DistributedOptimizer(opt)
cfg = hopt.get_config()

hopt_copy1 = hopt.from_config(cfg)
self.assertEqual(cfg, hopt_copy1.get_config())

hopt_copy2 = hopt.__class__.from_config(cfg)
self.assertEqual(cfg, hopt_copy2.get_config())
@@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================

"""Tests for horovod.keras."""
"""Tests for horovod.tensorflow.keras."""

from __future__ import absolute_import
from __future__ import division
@@ -36,7 +36,7 @@

class TfKerasTests(tf.test.TestCase):
"""
Tests for ops in horovod.keras.
Tests for ops in horovod.tensorflow.keras.
"""

def __init__(self, *args, **kwargs):

0 comments on commit 438880e

Please sign in to comment.
You can’t perform that action at this time.