Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move saving checkpoints from model to trainer #262

Merged
merged 7 commits into from
Jun 21, 2024

Conversation

frostedoyster
Copy link
Collaborator

@frostedoyster frostedoyster commented Jun 14, 2024

Closes #214, #89, #203. It involves a refactor where the trainer now saves checkpoints instead of the model.

This solves #203, which asks for saving optimizer and scheduler states for much improved restarting of training. The slightly changed interface allows PET checkpoints to work both for restarting and exporting (#214). Finally, we document what a checkpoint (e.g. model.ckpt) should be (#89).

Contributor (creator of pull-request) checklist

  • Tests updated (for new features and bugfixes)?
  • Documentation updated (for new features)?
  • Issue referenced (for PRs that solve an issue)?

📚 Documentation preview 📚: https://metatrain--262.org.readthedocs.build/en/262/

@frostedoyster frostedoyster marked this pull request as ready for review June 14, 2024 18:52
@PicoCentauri
Copy link
Contributor

Does this work with the design to support lightning?

@frostedoyster
Copy link
Collaborator Author

Yes, by inheriting from the lightning classes and making the necessary adaptations (as we were planning to do before)

Copy link
Contributor

@PicoCentauri PicoCentauri left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good. Some minor changes in the docs should reflect the changes in the code. Still the question remains if this will work with a lightning model where the model load the checkpoint.

But I can imagine it works because our Trainer knows about the model can just call lightning_model.load_from_checkpoint

``save_checkpoint()``, ``load_checkpoint()`` as well as a ``restart()`` and
``export()`` method.
The ``ModelInterface`` is the main model class and must implement the
``load_checkpoint()``, ``restart()`` and ``export()`` methods.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

load_checkpoint is also a trainer method now.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, and it's documented in the TrainerInterface just below

Comment on lines 73 to 75
@classmethod
def load_checkpoint(cls, path: Union[str, Path]) -> "ModelInterface":
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@classmethod
def load_checkpoint(cls, path: Union[str, Path]) -> "ModelInterface":
pass

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is still present for all architectures (we need a model.load_checkpoint for export, where there is no Trainer)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But then I think we should also provide for consistency a save_checkpoint for the model.

Which means basically this PR adds a save_checkpoint() and load_checkpoint for the trainer.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so. We should only have what's necessary. If people want to have it (and call it inside the Trainer.save_checkpoint) then it's up to them

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah makes sense

@@ -80,6 +80,9 @@ def export_model(model: Any, output: Union[Path, str] = "exported-model.pt") ->
torch.jit.save(model, path)
else:
extensions_path = "extensions/"
logger.info(f"Exporting model to {path} and extensions to {extensions_path}")
logger.info(
f"Exporting model to `{path}` and extensions to `{extensions_path}`"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I usally prefer quotes for path and back ticks for variables. But feel free

Suggested change
f"Exporting model to `{path}` and extensions to `{extensions_path}`"
f"Exporting model to {path!r} and extensions to {extensions_path!r}"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure that !r isn't going to print some other stuff if path and/or extension_path are path objects?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh probably you are right. But then I would go for single ticks instead of backticks

@Luthaf
Copy link
Contributor

Luthaf commented Jun 17, 2024

So the idea is that we need to save checkpoint in the trainer to be able to also save the optimizer state, but we want to load checkpoints from the model for export/restart?

@frostedoyster
Copy link
Collaborator Author

frostedoyster commented Jun 17, 2024

@Luthaf yes, that's the idea. The same checkpoint can be read by the trainer (for restarting training, reads everything: model, optimizer, scheduler, etc) or the model (just loads the model, for exporting, can also be called by the trainer's loader)

Copy link
Contributor

@PicoCentauri PicoCentauri left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is almost ready just my minor comments with the ticks.

@frostedoyster frostedoyster merged commit 7943c9c into main Jun 21, 2024
13 checks passed
@frostedoyster frostedoyster deleted the trainer-checkpoints branch June 21, 2024 12:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
3 participants