-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move saving checkpoints from model to trainer #262
Changes from 5 commits
37bc0ef
15dfc0c
dce3af4
9591224
8d5c656
9693d66
943a5e8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -16,13 +16,13 @@ to these lines | |||||||
hypers = {} | ||||||||
dataset_info = DatasetInfo() | ||||||||
|
||||||||
if "continue_from": | ||||||||
model = Model.load_checkpoint("path") | ||||||||
if continue_from is not None: | ||||||||
trainer = Trainer.load_checkpoint(continue_from, hypers["training"]) | ||||||||
model = Model.load_checkpoint(continue_from) | ||||||||
model = model.restart(dataset_info) | ||||||||
else: | ||||||||
model = Model(hypers["architecture"], dataset_info) | ||||||||
|
||||||||
trainer = Trainer(hypers["training"]) | ||||||||
model = Model(hypers["model"], dataset_info) | ||||||||
trainer = Trainer(hypers["training"]) | ||||||||
|
||||||||
trainer.train( | ||||||||
model=model, | ||||||||
|
@@ -56,9 +56,8 @@ In order to follow this, a new architectures has two define two classes | |||||||
val_datasets passed to the Trainer, as well as the dataset_info passed to the | ||||||||
model. | ||||||||
|
||||||||
The ``ModelInterface`` is the main model class and must implement a | ||||||||
``save_checkpoint()``, ``load_checkpoint()`` as well as a ``restart()`` and | ||||||||
``export()`` method. | ||||||||
The ``ModelInterface`` is the main model class and must implement the | ||||||||
``load_checkpoint()``, ``restart()`` and ``export()`` methods. | ||||||||
|
||||||||
.. code-block:: python | ||||||||
|
||||||||
|
@@ -71,9 +70,6 @@ The ``ModelInterface`` is the main model class and must implement a | |||||||
self.hypers = model_hypers | ||||||||
self.dataset_info = dataset_info | ||||||||
|
||||||||
def save_checkpoint(self, path: Union[str, Path]): | ||||||||
pass | ||||||||
|
||||||||
@classmethod | ||||||||
def load_checkpoint(cls, path: Union[str, Path]) -> "ModelInterface": | ||||||||
pass | ||||||||
Comment on lines
73
to
75
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is still present for all architectures (we need a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But then I think we should also provide for consistency a Which means basically this PR adds a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think so. We should only have what's necessary. If people want to have it (and call it inside the Trainer.save_checkpoint) then it's up to them There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah makes sense |
||||||||
|
@@ -105,8 +101,8 @@ a helper function :py:func:`metatrain.utils.export.export` to export a torch | |||||||
model to an :py:class:`MetatensorAtomisticModel | ||||||||
<metatensor.torch.atomistic.MetatensorAtomisticModel>`. | ||||||||
|
||||||||
The ``TrainerInterface`` class should have the following signature with a required | ||||||||
methods for ``train()``. | ||||||||
The ``TrainerInterface`` class should have the following signature with required | ||||||||
methods for ``train()``, ``save_checkpoint()`` and ``load_checkpoint()``. | ||||||||
|
||||||||
.. code-block:: python | ||||||||
|
||||||||
|
@@ -123,6 +119,18 @@ methods for ``train()``. | |||||||
checkpoint_dir: str, | ||||||||
) -> None: ... | ||||||||
|
||||||||
def save_checkpoint(self, path: Union[str, Path]) -> None: ... | ||||||||
|
||||||||
@classmethod | ||||||||
def load_checkpoint( | ||||||||
cls, path: Union[str, Path], train_hypers: Dict | ||||||||
) -> "TrainerInterface": | ||||||||
pass | ||||||||
|
||||||||
The format of checkpoints is not defined by `metatrain` and can be any format that | ||||||||
can be loaded by the trainer (to restart training) and by the model (to export the | ||||||||
checkpoint). | ||||||||
|
||||||||
The names of the ``ModelInterface`` and the ``TrainerInterface`` are free to choose but | ||||||||
should be linked to constants in the ``__init__.py`` of each architecture. On top of | ||||||||
these two constants the ``__init__.py`` must contain constants for the original | ||||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -80,6 +80,9 @@ def export_model(model: Any, output: Union[Path, str] = "exported-model.pt") -> | |||||
torch.jit.save(model, path) | ||||||
else: | ||||||
extensions_path = "extensions/" | ||||||
logger.info(f"Exporting model to {path} and extensions to {extensions_path}") | ||||||
logger.info( | ||||||
f"Exporting model to `{path}` and extensions to `{extensions_path}`" | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I usally prefer quotes for path and back ticks for variables. But feel free
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you sure that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ahh probably you are right. But then I would go for single ticks instead of backticks |
||||||
) | ||||||
mts_atomistic_model = model.export() | ||||||
mts_atomistic_model.export(path, collect_extensions=extensions_path) | ||||||
logger.info("Model exported successfully") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
load_checkpoint
is also a trainer method now.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, and it's documented in the
TrainerInterface
just below