Skip to content

Commit

Permalink
history
Browse files Browse the repository at this point in the history
Signed-off-by: Peng Zhang <pengz@uber.com>
  • Loading branch information
irasit committed Oct 12, 2021
1 parent afa2022 commit 445f6a5
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 7 deletions.
4 changes: 3 additions & 1 deletion horovod/spark/lightning/estimator.py
Expand Up @@ -439,14 +439,16 @@ def _read_checkpoint(self, run_id):
return store.read(last_ckpt_path)

def _create_model(self, run_results, run_id, metadata):
serialized_checkpoint, history = run_results[0]
serialized_checkpoint = run_results[0]
serialized_checkpoint.seek(0)
best_checkpoint = torch.load(serialized_checkpoint, map_location=torch.device('cpu'))

model = copy.deepcopy(self.getModel())
model.load_state_dict(best_checkpoint['model'])
model.eval()

history = best_checkpoint['logged_metrics']

# Optimizer is part of the model no need to return it to transformer.
# TODO: (Pengz) Update the latest state of the optimizer in the model for retraining.
optimizer = None
Expand Down
8 changes: 2 additions & 6 deletions horovod/spark/lightning/remote.py
Expand Up @@ -248,15 +248,11 @@ def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
serialized_checkpoint = io.BytesIO()
module = best_model if not is_legacy else best_model._model

# TODO: find a way to pass trainer.logged_metrics out.
output = {'model': module.state_dict()}
output = {'model': module.state_dict(), 'logged_metrics': trainer.logged_metrics}

torch.save(output, serialized_checkpoint)

# Save logged metrics as history, which will saved in transformer.
history = trainer.logged_metrics

return serialized_checkpoint, history
return serialized_checkpoint
return train


Expand Down

0 comments on commit 445f6a5

Please sign in to comment.