Skip to content
Permalink
Browse files

TF-Keras support for TensorFlow 2.0 (#1238)

Signed-off-by: Alex Sergeev <alsrgv@users.noreply.github.com>
  • Loading branch information...
alsrgv committed Jul 21, 2019
1 parent d7ef9e1 commit 5e6905309a6935e2356845c0873fb0f3ebec9aa5
@@ -148,11 +148,19 @@ RUN sed -i "s/last_step=20000/last_step=100/" /horovod/examples/tensorflow_mnist
# Hack TensorFlow Eager MNIST example to be smaller.
RUN sed -i "s/dataset.take(20000/dataset.take(100/" /horovod/examples/tensorflow_mnist_eager.py

# Hack TensorFlow 2.0 example to be smaller.
RUN sed -i "s/dataset.take(10000/dataset.take(100/" /horovod/examples/tensorflow2_mnist.py

# Hack Keras MNIST advanced example to be smaller.
RUN sed -i "s/epochs = .*/epochs = 9/" /horovod/examples/keras_mnist_advanced.py
RUN sed -i "s/model.add(Conv2D(32, kernel_size=(3, 3),/model.add(Conv2D(1, kernel_size=(3, 3),/" /horovod/examples/keras_mnist_advanced.py
RUN sed -i "s/model.add(Conv2D(64, (3, 3), activation='relu'))//" /horovod/examples/keras_mnist_advanced.py

# Hack TensorFlow 2.0 Keras MNIST advanced example to be smaller.
RUN sed -i "s/epochs = .*/epochs = 9/" /horovod/examples/tensorflow2_keras_mnist.py
RUN sed -i "s/tf.keras.layers.Conv2D(32, \\[3, 3\\],/tf.keras.layers.Conv2D(1, [3, 3],/" /horovod/examples/tensorflow2_keras_mnist.py
RUN sed -i "s/tf.keras.layers.Conv2D(64, \\[3, 3\\], activation='relu')),//" /horovod/examples/tensorflow2_keras_mnist.py

# Hack PyTorch MNIST example to be smaller.
RUN sed -i "s/'--epochs', type=int, default=10,/'--epochs', type=int, default=2,/" /horovod/examples/pytorch_mnist.py
RUN sed -i "s/self.fc1 = nn.Linear(320, 50)/self.fc1 = nn.Linear(784, 50)/" /horovod/examples/pytorch_mnist.py
@@ -105,9 +105,15 @@ RUN sed -i "s/last_step=20000/last_step=100/" /horovod/examples/tensorflow_mnist
# Hack TensorFlow Eager MNIST example to be smaller.
RUN sed -i "s/dataset.take(20000/dataset.take(100/" /horovod/examples/tensorflow_mnist_eager.py

# Hack TensorFlow 2.0 example to be smaller.
RUN sed -i "s/dataset.take(10000/dataset.take(100/" /horovod/examples/tensorflow2_mnist.py

# Hack Keras MNIST advanced example to be smaller.
RUN sed -i "s/epochs = .*/epochs = 9/" /horovod/examples/keras_mnist_advanced.py

# Hack TensorFlow 2.0 Keras MNIST advanced example to be smaller.
RUN sed -i "s/epochs = .*/epochs = 9/" /horovod/examples/tensorflow2_keras_mnist.py

# Hack PyTorch MNIST example to be smaller.
RUN sed -i "s/'--epochs', type=int, default=10,/'--epochs', type=int, default=2,/" /horovod/examples/pytorch_mnist.py

@@ -0,0 +1,86 @@
# Copyright 2019 Uber Technologies, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import tensorflow as tf
import horovod.tensorflow.keras as hvd

# Horovod: initialize Horovod.
hvd.init()

# Horovod: pin GPU to be used to process local rank (one GPU per process)
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
if gpus:
tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], 'GPU')

(mnist_images, mnist_labels), _ = \
tf.keras.datasets.mnist.load_data(path='mnist-%d.npz' % hvd.rank())

dataset = tf.data.Dataset.from_tensor_slices(
(tf.cast(mnist_images[..., tf.newaxis] / 255.0, tf.float32),
tf.cast(mnist_labels, tf.int64))
)
dataset = dataset.repeat().shuffle(10000).batch(128)

mnist_model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, [3, 3], activation='relu'),
tf.keras.layers.Conv2D(64, [3, 3], activation='relu'),
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
tf.keras.layers.Dropout(0.25),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(10, activation='softmax')
])

# Horovod: adjust learning rate based on number of GPUs.
opt = tf.optimizers.Adam(0.001 * hvd.size())

# Horovod: add Horovod DistributedOptimizer.
opt = hvd.DistributedOptimizer(opt)

mnist_model.compile(loss=tf.losses.SparseCategoricalCrossentropy(),
optimizer=opt,
metrics=['accuracy'])

callbacks = [
# Horovod: broadcast initial variable states from rank 0 to all other processes.
# This is necessary to ensure consistent initialization of all workers when
# training is started with random weights or restored from a checkpoint.
hvd.callbacks.BroadcastGlobalVariablesCallback(0),

# Horovod: average metrics among workers at the end of every epoch.
#
# Note: This callback must be in the list before the ReduceLROnPlateau,
# TensorBoard or other metrics-based callbacks.
hvd.callbacks.MetricAverageCallback(),

# Horovod: using `lr = 1.0 * hvd.size()` from the very beginning leads to worse final
# accuracy. Scale the learning rate `lr = 1.0` ---> `lr = 1.0 * hvd.size()` during
# the first three epochs. See https://arxiv.org/abs/1706.02677 for details.
hvd.callbacks.LearningRateWarmupCallback(warmup_epochs=3, verbose=1),
]

# Horovod: save checkpoints only on worker 0 to prevent other workers from corrupting them.
if hvd.rank() == 0:
callbacks.append(tf.keras.callbacks.ModelCheckpoint('./checkpoint-{epoch}.h5'))

# Horovod: write logs on worker 0.
verbose = 1 if hvd.rank() == 0 else 0

# Train the model.
# Horovod: adjust number of steps based on number of GPUs.
mnist_model.fit(dataset, steps_per_epoch=500 // hvd.size(), callbacks=callbacks, epochs=24, verbose=verbose)
@@ -26,25 +26,29 @@
if gpus:
tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], 'GPU')

mnist_model = tf.keras.Sequential([
tf.keras.layers.Conv2D(16, [3, 3], activation='relu'),
tf.keras.layers.Conv2D(16, [3, 3], activation='relu'),
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(10)
])
loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True)

# Horovod: adjust learning rate based on number of GPUs.
opt = tf.optimizers.Adam(0.001 * hvd.size())

(mnist_images, mnist_labels), _ = \
tf.keras.datasets.mnist.load_data(path='mnist-%d.npz' % hvd.rank())

dataset = tf.data.Dataset.from_tensor_slices(
(tf.cast(mnist_images[..., tf.newaxis] / 255.0, tf.float32),
tf.cast(mnist_labels, tf.int64))
)
dataset = dataset.repeat().shuffle(1000).batch(32)
dataset = dataset.repeat().shuffle(10000).batch(128)

mnist_model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, [3, 3], activation='relu'),
tf.keras.layers.Conv2D(64, [3, 3], activation='relu'),
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
tf.keras.layers.Dropout(0.25),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(10, activation='softmax')
])
loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True)

# Horovod: adjust learning rate based on number of GPUs.
opt = tf.optimizers.Adam(0.001 * hvd.size())

checkpoint_dir = './checkpoints'
checkpoint = tf.train.Checkpoint(model=mnist_model, optimizer=opt)
@@ -76,7 +80,7 @@ def training_step(images, labels, first_batch):


# Horovod: adjust number of steps based on number of GPUs.
for batch, (images, labels) in enumerate(dataset.take(20000 // hvd.size())):
for batch, (images, labels) in enumerate(dataset.take(10000 // hvd.size())):
loss_value = training_step(images, labels, batch == 0)

if batch % 10 == 0 and hvd.local_rank() == 0:
@@ -70,24 +70,28 @@ def get_gradients(self, loss, params):
optimizer.get_config())


def broadcast_global_variables(backend, root_rank):
bcast_op = hvd.broadcast_global_variables(root_rank)
return backend.get_session().run(bcast_op)
def _eval(backend, op_or_result):
if hvd._executing_eagerly():
return op_or_result
else:
return backend.get_session().run(op_or_result)


if hasattr(hvd, 'broadcast_global_variables'):
def broadcast_global_variables(backend, root_rank):
return _eval(backend, hvd.broadcast_global_variables(root_rank))


def allreduce(backend, value, name, average):
allreduce_op = hvd.allreduce(tf.constant(value, name=name), average=average)
return backend.get_session().run(allreduce_op)
return _eval(backend, hvd.allreduce(tf.constant(value, name=name), average=average))


def allgather(backend, value, name):
allgather_op = hvd.allgather(tf.constant(value, name=name))
return backend.get_session().run(allgather_op)
return _eval(backend, hvd.allgather(tf.constant(value, name=name)))


def broadcast(backend, value, root_rank, name):
bcast_op = hvd.broadcast(tf.constant(value, name=name), root_rank)
return backend.get_session().run(bcast_op)
return _eval(backend, hvd.broadcast(tf.constant(value, name=name), root_rank))


def load_model(keras, wrap_optimizer, filepath, custom_optimizers, custom_objects):
@@ -23,11 +23,24 @@ def __init__(self, backend, root_rank, device='', *args):
self.backend = backend
self.root_rank = root_rank
self.device = device
self.broadcast_done = False

def on_batch_end(self, batch, logs=None):
if self.broadcast_done:
return

def on_train_begin(self, logs=None):
with tf.device(self.device):
bcast_op = hvd.broadcast_global_variables(self.root_rank)
self.backend.get_session().run(bcast_op)
if hvd._executing_eagerly() and hasattr(self.model, 'variables'):
# TensorFlow 2.0 or TensorFlow eager
hvd.broadcast_variables(self.model.variables,
root_rank=self.root_rank)
hvd.broadcast_variables(self.model.optimizer.variables(),
root_rank=self.root_rank)
else:
bcast_op = hvd.broadcast_global_variables(self.root_rank)
self.backend.get_session().run(bcast_op)

self.broadcast_done = True


class MetricAverageCallbackImpl(object):
@@ -51,13 +64,17 @@ def _average_metrics_in_place(self, logs):
# Reduce every metric among workers. Sort metrics by name
# to ensure consistent order.
for metric, value in sorted(logs.items()):
if metric not in self.variables:
self.variables[metric], self.allreduce_ops[metric] = \
self._make_variable(metric, value)
if hvd._executing_eagerly():
reduced_logs[metric] = \
hvd.allreduce(tf.constant(value, name=metric)).numpy()
else:
self.backend.set_value(self.variables[metric], value)
reduced_logs[metric] = \
self.backend.get_session().run(self.allreduce_ops[metric])
if metric not in self.variables:
self.variables[metric], self.allreduce_ops[metric] = \
self._make_variable(metric, value)
else:
self.backend.set_value(self.variables[metric], value)
reduced_logs[metric] = \
self.backend.get_session().run(self.allreduce_ops[metric])
# Override the reduced values back into logs dictionary
# for other callbacks to use.
for metric, value in reduced_logs.items():
@@ -109,7 +126,7 @@ def _adjust_learning_rate(self, epoch):
# See the paper cited above for more information about momentum correction.
self.restore_momentum = self.backend.get_value(self.model.optimizer.momentum)
self.backend.set_value(self.model.optimizer.momentum,
self.restore_momentum * new_lr / old_lr)
self.restore_momentum * new_lr / old_lr)

def _restore_momentum_if_needed(self):
if self.restore_momentum:
@@ -84,9 +84,10 @@ def allreduce(tensor, average=True, device_dense='', device_sparse='',
@_cache
def _make_broadcast_group_fn():
if _executing_eagerly():
# Eager mode requires Tensor
# Eager mode will parallelize independent control flow
def broadcast_group(variables, root_rank):
return [var.assign(broadcast(var, root_rank)) for var in variables]
for var in variables:
var.assign(broadcast(var, root_rank))

return _make_subgraph(broadcast_group)
else:

0 comments on commit 5e69053

Please sign in to comment.
You can’t perform that action at this time.