diff --git a/horovod/spark/lightning/estimator.py b/horovod/spark/lightning/estimator.py index d832300699..a0baf8c7c0 100644 --- a/horovod/spark/lightning/estimator.py +++ b/horovod/spark/lightning/estimator.py @@ -439,7 +439,7 @@ def _read_checkpoint(self, run_id): return store.read(last_ckpt_path) def _create_model(self, run_results, run_id, metadata): - serialized_checkpoint, history = run_results[0] + serialized_checkpoint = run_results[0] serialized_checkpoint.seek(0) best_checkpoint = torch.load(serialized_checkpoint, map_location=torch.device('cpu')) @@ -447,6 +447,8 @@ def _create_model(self, run_results, run_id, metadata): model.load_state_dict(best_checkpoint['model']) model.eval() + history = best_checkpoint['logged_metrics'] + # Optimizer is part of the model no need to return it to transformer. # TODO: (Pengz) Update the latest state of the optimizer in the model for retraining. optimizer = None diff --git a/horovod/spark/lightning/remote.py b/horovod/spark/lightning/remote.py index f506a875e9..f4f180955c 100644 --- a/horovod/spark/lightning/remote.py +++ b/horovod/spark/lightning/remote.py @@ -248,15 +248,11 @@ def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - serialized_checkpoint = io.BytesIO() module = best_model if not is_legacy else best_model._model - # TODO: find a way to pass trainer.logged_metrics out. - output = {'model': module.state_dict()} + output = {'model': module.state_dict(), 'logged_metrics': trainer.logged_metrics} torch.save(output, serialized_checkpoint) - # Save logged metrics as history, which will saved in transformer. - history = trainer.logged_metrics - - return serialized_checkpoint, history + return serialized_checkpoint return train