-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move saving checkpoints from model to trainer #262
Conversation
16cb56e
to
d497062
Compare
d497062
to
dce3af4
Compare
Does this work with the design to support lightning? |
Yes, by inheriting from the lightning classes and making the necessary adaptations (as we were planning to do before) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good. Some minor changes in the docs should reflect the changes in the code. Still the question remains if this will work with a lightning
model where the model load the checkpoint.
But I can imagine it works because our Trainer
knows about the model can just call lightning_model.load_from_checkpoint
``save_checkpoint()``, ``load_checkpoint()`` as well as a ``restart()`` and | ||
``export()`` method. | ||
The ``ModelInterface`` is the main model class and must implement the | ||
``load_checkpoint()``, ``restart()`` and ``export()`` methods. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
load_checkpoint
is also a trainer method now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, and it's documented in the TrainerInterface
just below
@classmethod | ||
def load_checkpoint(cls, path: Union[str, Path]) -> "ModelInterface": | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@classmethod | |
def load_checkpoint(cls, path: Union[str, Path]) -> "ModelInterface": | |
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is still present for all architectures (we need a model.load_checkpoint
for export, where there is no Trainer
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But then I think we should also provide for consistency a save_checkpoint
for the model.
Which means basically this PR adds a save_checkpoint()
and load_checkpoint
for the trainer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so. We should only have what's necessary. If people want to have it (and call it inside the Trainer.save_checkpoint) then it's up to them
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah makes sense
src/metatrain/cli/export.py
Outdated
@@ -80,6 +80,9 @@ def export_model(model: Any, output: Union[Path, str] = "exported-model.pt") -> | |||
torch.jit.save(model, path) | |||
else: | |||
extensions_path = "extensions/" | |||
logger.info(f"Exporting model to {path} and extensions to {extensions_path}") | |||
logger.info( | |||
f"Exporting model to `{path}` and extensions to `{extensions_path}`" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I usally prefer quotes for path and back ticks for variables. But feel free
f"Exporting model to `{path}` and extensions to `{extensions_path}`" | |
f"Exporting model to {path!r} and extensions to {extensions_path!r}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you sure that !r
isn't going to print some other stuff if path and/or extension_path are path objects?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ahh probably you are right. But then I would go for single ticks instead of backticks
So the idea is that we need to save checkpoint in the trainer to be able to also save the optimizer state, but we want to load checkpoints from the model for export/restart? |
@Luthaf yes, that's the idea. The same checkpoint can be read by the trainer (for restarting training, reads everything: model, optimizer, scheduler, etc) or the model (just loads the model, for exporting, can also be called by the trainer's loader) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is almost ready just my minor comments with the ticks.
Closes #214, #89, #203. It involves a refactor where the trainer now saves checkpoints instead of the model.
This solves #203, which asks for saving optimizer and scheduler states for much improved restarting of training. The slightly changed interface allows PET checkpoints to work both for restarting and exporting (#214). Finally, we document what a checkpoint (e.g.
model.ckpt
) should be (#89).Contributor (creator of pull-request) checklist
📚 Documentation preview 📚: https://metatrain--262.org.readthedocs.build/en/262/