Skip to content

Commit

Permalink
Reinitialize callbacks on_train_begin
Browse files Browse the repository at this point in the history
Signed-off-by: Travis Addair <taddair@uber.com>
  • Loading branch information
tgaddair committed Sep 18, 2020
1 parent 913c4c2 commit 7121295
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
2 changes: 1 addition & 1 deletion examples/elastic/tensorflow_keras_mnist_elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def on_state_reset():
tf.keras.backend.set_value(model.optimizer.lr, lr * hvd.size())


state = hvd.elastic.KerasState(model, batch=100, epoch=0)
state = hvd.elastic.KerasState(model, batch=0, epoch=0)
state.register_reset_callbacks([on_state_reset])

callbacks = [
Expand Down
8 changes: 8 additions & 0 deletions horovod/_keras/elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ def __init__(self, backend, state, batches_per_commit, *args):
self.batches_per_commit = batches_per_commit
self.batches_remaining = batches_per_commit

def on_train_begin(self, logs=None):
# Reset this for every sync event to ensure consistency across ranks
self.batches_remaining = self.batches_per_commit

def on_batch_end(self, batch, logs=None):
self.batches_remaining -= 1
if self.batches_remaining == 0:
Expand All @@ -42,6 +46,10 @@ def __init__(self, backend, state, *args):
self.state = state
self.steps_per_epoch = None

def on_train_begin(self, logs=None):
# Reset this for every sync event to ensure consistency across ranks
self.steps_per_epoch = None

def on_epoch_begin(self, epoch, logs=None):
if self.params.get('steps'):
if self.steps_per_epoch is None:
Expand Down

0 comments on commit 7121295

Please sign in to comment.