Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix the issue when passing history from GPU to CPU #3216

Merged
merged 1 commit into from Oct 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion horovod/spark/lightning/estimator.py
Expand Up @@ -439,14 +439,16 @@ def _read_checkpoint(self, run_id):
return store.read(last_ckpt_path)

def _create_model(self, run_results, run_id, metadata):
serialized_checkpoint, history = run_results[0]
serialized_checkpoint = run_results[0]
serialized_checkpoint.seek(0)
best_checkpoint = torch.load(serialized_checkpoint, map_location=torch.device('cpu'))

model = copy.deepcopy(self.getModel())
model.load_state_dict(best_checkpoint['model'])
model.eval()

history = best_checkpoint['logged_metrics']
chongxiaoc marked this conversation as resolved.
Show resolved Hide resolved

# Optimizer is part of the model no need to return it to transformer.
# TODO: (Pengz) Update the latest state of the optimizer in the model for retraining.
optimizer = None
Expand Down
8 changes: 2 additions & 6 deletions horovod/spark/lightning/remote.py
Expand Up @@ -248,15 +248,11 @@ def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
serialized_checkpoint = io.BytesIO()
module = best_model if not is_legacy else best_model._model

# TODO: find a way to pass trainer.logged_metrics out.
output = {'model': module.state_dict()}
output = {'model': module.state_dict(), 'logged_metrics': trainer.logged_metrics}

torch.save(output, serialized_checkpoint)

# Save logged metrics as history, which will saved in transformer.
history = trainer.logged_metrics

return serialized_checkpoint, history
return serialized_checkpoint
return train


Expand Down