Skip to content

Commit

Permalink
Added args
Browse files Browse the repository at this point in the history
Signed-off-by: Travis Addair <taddair@uber.com>
  • Loading branch information
tgaddair committed Sep 18, 2020
1 parent 06ab66b commit 323f0cc
Showing 1 changed file with 23 additions and 5 deletions.
28 changes: 23 additions & 5 deletions examples/elastic/tensorflow_keras_mnist_elastic.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import argparse

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Sequential
Expand All @@ -7,6 +9,22 @@

import horovod.tensorflow.keras as hvd

parser = argparse.ArgumentParser(description='TensorFlow Keras MNIST Elastic',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)

parser.add_argument('--batches-per-epoch', type=int, default=500,
help='number of batches per epoch scaled by world size')
parser.add_argument('--batches-per-commit', type=int, default=1,
help='number of batches per commit of the elastic state object')
parser.add_argument('--epochs', type=int, default=24,
help='number of epochs')
parser.add_argument('--learning-rate', type=float, default=1.0,
help='learning rate')
parser.add_argument('--batch-size', type=int, default=128,
help='batch size')

args = parser.parse_args()

# Horovod: initialize Horovod.
hvd.init()

Expand All @@ -16,9 +34,9 @@
config.gpu_options.visible_device_list = str(hvd.local_rank())
K.set_session(tf.Session(config=config))

lr = 1.0
batch_size = 128
epochs = 24
lr = args.learning_rate
batch_size = args.batch_size
epochs = args.epochs
num_classes = 10

(mnist_images, mnist_labels), _ = \
Expand Down Expand Up @@ -77,7 +95,7 @@ def on_state_reset():
# It is important that this callback comes last to ensure that all
# other state is fully up to date before we commit, so we do not lose
# any progress.
hvd.elastic.CommitStateCallback(state),
hvd.elastic.CommitStateCallback(state, batches_per_commit=args.batches_per_commit),
]

# Horovod: save checkpoints only on worker 0 to prevent other workers from corrupting them.
Expand All @@ -90,7 +108,7 @@ def train(state):
# Horovod: adjust number of steps based on number of GPUs and number of epochs
# based on the number of previously completed epochs.
state.model.fit(dataset,
steps_per_epoch=500 // hvd.size(),
steps_per_epoch=args.batches_per_epoch // hvd.size(),
callbacks=callbacks,
epochs=epochs - state.epoch,
verbose=1 if hvd.rank() == 0 else 0)
Expand Down

0 comments on commit 323f0cc

Please sign in to comment.