Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move saving checkpoints from model to trainer #262

Merged
merged 7 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 21 additions & 13 deletions docs/src/dev-docs/new-architecture.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ to these lines
hypers = {}
dataset_info = DatasetInfo()

if "continue_from":
model = Model.load_checkpoint("path")
if continue_from is not None:
trainer = Trainer.load_checkpoint(continue_from, hypers["training"])
model = Model.load_checkpoint(continue_from)
model = model.restart(dataset_info)
else:
model = Model(hypers["architecture"], dataset_info)

trainer = Trainer(hypers["training"])
model = Model(hypers["model"], dataset_info)
trainer = Trainer(hypers["training"])

trainer.train(
model=model,
Expand Down Expand Up @@ -56,9 +56,8 @@ In order to follow this, a new architectures has two define two classes
val_datasets passed to the Trainer, as well as the dataset_info passed to the
model.

The ``ModelInterface`` is the main model class and must implement a
``save_checkpoint()``, ``load_checkpoint()`` as well as a ``restart()`` and
``export()`` method.
The ``ModelInterface`` is the main model class and must implement the
``load_checkpoint()``, ``restart()`` and ``export()`` methods.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

load_checkpoint is also a trainer method now.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, and it's documented in the TrainerInterface just below


.. code-block:: python

Expand All @@ -71,9 +70,6 @@ The ``ModelInterface`` is the main model class and must implement a
self.hypers = model_hypers
self.dataset_info = dataset_info

def save_checkpoint(self, path: Union[str, Path]):
pass

@classmethod
def load_checkpoint(cls, path: Union[str, Path]) -> "ModelInterface":
pass
Comment on lines 73 to 75
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@classmethod
def load_checkpoint(cls, path: Union[str, Path]) -> "ModelInterface":
pass

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is still present for all architectures (we need a model.load_checkpoint for export, where there is no Trainer)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But then I think we should also provide for consistency a save_checkpoint for the model.

Which means basically this PR adds a save_checkpoint() and load_checkpoint for the trainer.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so. We should only have what's necessary. If people want to have it (and call it inside the Trainer.save_checkpoint) then it's up to them

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah makes sense

Expand Down Expand Up @@ -105,8 +101,8 @@ a helper function :py:func:`metatrain.utils.export.export` to export a torch
model to an :py:class:`MetatensorAtomisticModel
<metatensor.torch.atomistic.MetatensorAtomisticModel>`.

The ``TrainerInterface`` class should have the following signature with a required
methods for ``train()``.
The ``TrainerInterface`` class should have the following signature with required
methods for ``train()``, ``save_checkpoint()`` and ``load_checkpoint()``.

.. code-block:: python

Expand All @@ -123,6 +119,18 @@ methods for ``train()``.
checkpoint_dir: str,
) -> None: ...

def save_checkpoint(self, path: Union[str, Path]) -> None: ...

@classmethod
def load_checkpoint(
cls, path: Union[str, Path], train_hypers: Dict
) -> "TrainerInterface":
pass

The format of checkpoints is not defined by `metatrain` and can be any format that
can be loaded by the trainer (to restart training) and by the model (to export the
checkpoint).

The names of the ``ModelInterface`` and the ``TrainerInterface`` are free to choose but
should be linked to constants in the ``__init__.py`` of each architecture. On top of
these two constants the ``__init__.py`` must contain constants for the original
Expand Down
5 changes: 4 additions & 1 deletion src/metatrain/cli/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ def export_model(model: Any, output: Union[Path, str] = "exported-model.pt") ->
torch.jit.save(model, path)
else:
extensions_path = "extensions/"
logger.info(f"Exporting model to {path} and extensions to {extensions_path}")
logger.info(
f"Exporting model to `{path}` and extensions to `{extensions_path}`"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I usally prefer quotes for path and back ticks for variables. But feel free

Suggested change
f"Exporting model to `{path}` and extensions to `{extensions_path}`"
f"Exporting model to {path!r} and extensions to {extensions_path!r}"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure that !r isn't going to print some other stuff if path and/or extension_path are path objects?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh probably you are right. But then I would go for single ticks instead of backticks

)
mts_atomistic_model = model.export()
mts_atomistic_model.export(path, collect_extensions=extensions_path)
logger.info("Model exported successfully")
11 changes: 6 additions & 5 deletions src/metatrain/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,20 +371,21 @@ def train_model(
try:
if continue_from is not None:
logger.info(f"Loading checkpoint from `{continue_from}`")
trainer = Trainer.load_checkpoint(continue_from, hypers["training"])
model = Model.load_checkpoint(continue_from)
model = model.restart(dataset_info)
else:
model = Model(hypers["model"], dataset_info)
trainer = Trainer(hypers["training"])
except Exception as e:
raise ArchitectureError(e)

###########################
# TRAIN MODEL #############
###########################

logger.info("Start training")
logger.info("Calling trainer")
try:
trainer = Trainer(hypers["training"])
trainer.train(
model=model,
devices=devices,
Expand All @@ -405,18 +406,18 @@ def train_model(
output_checked = check_suffix(filename=output, suffix=".pt")
logger.info(
"Training finished, saving final checkpoint "
f"to {str(Path(output_checked).stem)}.ckpt"
f"to `{str(Path(output_checked).stem)}.ckpt`"
)
try:
model.save_checkpoint(f"{Path(output_checked).stem}.ckpt")
trainer.save_checkpoint(model, f"{Path(output_checked).stem}.ckpt")
except Exception as e:
raise ArchitectureError(e)

mts_atomistic_model = model.export()
extensions_path = "extensions/"

logger.info(
f"Exporting model to {output_checked} and extensions to {extensions_path}"
f"Exporting model to `{output_checked}` and extensions to `{extensions_path}`"
)
mts_atomistic_model.export(str(output_checked), collect_extensions=extensions_path)

Expand Down
25 changes: 6 additions & 19 deletions src/metatrain/experimental/alchemical_model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from ...utils.data.dataset import DatasetInfo
from ...utils.dtype import dtype_to_str
from ...utils.export import export
from ...utils.io import check_suffix
from .utils import systems_to_torch_alchemical_batch


Expand Down Expand Up @@ -126,29 +125,17 @@ def forward(
)
return total_energies

def save_checkpoint(self, path: Union[str, Path]):
torch.save(
{
"model_hypers": {
"model_hypers": self.hypers,
"dataset_info": self.dataset_info,
},
"model_state_dict": self.state_dict(),
},
check_suffix(path, ".ckpt"),
)

@classmethod
def load_checkpoint(cls, path: Union[str, Path]) -> "AlchemicalModel":

# Load the model and the metadata
model_dict = torch.load(path)
# Load the checkpoint
checkpoint = torch.load(path)
model_hypers = checkpoint["model_hypers"]
model_state_dict = checkpoint["model_state_dict"]

# Create the model
model = cls(**model_dict["model_hypers"])

# Load the model weights
model.load_state_dict(model_dict["model_state_dict"])
model = cls(**model_hypers)
model.load_state_dict(model_state_dict)

return model

Expand Down
61 changes: 58 additions & 3 deletions src/metatrain/experimental/alchemical_model/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
from ...utils.evaluate_model import evaluate_model
from ...utils.external_naming import to_external_name
from ...utils.io import check_suffix
from ...utils.logging import MetricLogger
from ...utils.loss import TensorMapDictLoss
from ...utils.metrics import RMSEAccumulator
Expand All @@ -35,6 +36,9 @@
class Trainer:
def __init__(self, train_hypers):
self.hypers = train_hypers
self.optimizer_state_dict = None
self.scheduler_state_dict = None
self.epoch = None

def train(
self,
Expand Down Expand Up @@ -178,6 +182,8 @@ def train(
optimizer = torch.optim.Adam(
model.parameters(), lr=self.hypers["learning_rate"]
)
if self.optimizer_state_dict is not None:
optimizer.load_state_dict(self.optimizer_state_dict)

# Create a scheduler:
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
Expand All @@ -186,6 +192,8 @@ def train(
factor=self.hypers["scheduler_factor"],
patience=self.hypers["scheduler_patience"],
)
if self.scheduler_state_dict is not None:
lr_scheduler.load_state_dict(self.scheduler_state_dict)

# counters for early stopping:
best_val_loss = float("inf")
Expand All @@ -194,9 +202,11 @@ def train(
# per-atom targets:
per_structure_targets = self.hypers["per_structure_targets"]

start_epoch = 0 if self.epoch is None else self.epoch + 1

# Train the model:
logger.info("Starting training")
for epoch in range(self.hypers["num_epochs"]):
for epoch in range(start_epoch, start_epoch + self.hypers["num_epochs"]):
train_rmse_calculator = RMSEAccumulator()
val_rmse_calculator = RMSEAccumulator()

Expand Down Expand Up @@ -279,7 +289,7 @@ def train(
**finalized_val_info,
}

if epoch == 0:
if epoch == start_epoch:
metric_logger = MetricLogger(
logobj=logger,
dataset_info=model.dataset_info,
Expand All @@ -293,7 +303,12 @@ def train(
)

if epoch % self.hypers["checkpoint_interval"] == 0:
model.save_checkpoint(Path(checkpoint_dir) / f"model_{epoch}.ckpt")
self.optimizer_state_dict = optimizer.state_dict()
self.scheduler_state_dict = lr_scheduler.state_dict()
self.epoch = epoch
self.save_checkpoint(
model, Path(checkpoint_dir) / f"model_{epoch}.ckpt"
)

# early stopping criterion:
if val_loss < best_val_loss:
Expand All @@ -308,3 +323,43 @@ def train(
"without improvement."
)
break

def save_checkpoint(self, model, path: Union[str, Path]):
checkpoint = {
"model_hypers": {
"model_hypers": model.hypers,
"dataset_info": model.dataset_info,
},
"model_state_dict": model.state_dict(),
"train_hypers": self.hypers,
"epoch": self.epoch,
"optimizer_state_dict": self.optimizer_state_dict,
"scheduler_state_dict": self.scheduler_state_dict,
}
torch.save(
checkpoint,
check_suffix(path, ".ckpt"),
)

@classmethod
def load_checkpoint(cls, path: Union[str, Path], train_hypers) -> "Trainer":

# Load the checkpoint
checkpoint = torch.load(path)
model_hypers = checkpoint["model_hypers"]
model_state_dict = checkpoint["model_state_dict"]
epoch = checkpoint["epoch"]
optimizer_state_dict = checkpoint["optimizer_state_dict"]
scheduler_state_dict = checkpoint["scheduler_state_dict"]

# Create the trainer
trainer = cls(train_hypers)
trainer.optimizer_state_dict = optimizer_state_dict
trainer.scheduler_state_dict = scheduler_state_dict
trainer.epoch = epoch

# Create the model
model = AlchemicalModel(**model_hypers)
model.load_state_dict(model_state_dict)

return trainer
10 changes: 0 additions & 10 deletions src/metatrain/experimental/gap/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Type, Union

import metatensor.torch
Expand Down Expand Up @@ -197,15 +196,6 @@ def forward(
out_tensor = self.apply_composition_weights(systems, energies)
return {output_key: out_tensor}

def save_checkpoint(self, path: Union[str, Path]):
# GAP will not save checkpoints, as it does not allow
# restarting training
return

@classmethod
def load_checkpoint(cls, path: Union[str, Path]) -> "GAP":
raise ValueError("GAP does not allow restarting training")

def export(self) -> MetatensorAtomisticModel:
capabilities = ModelCapabilities(
outputs=self.outputs,
Expand Down
10 changes: 10 additions & 0 deletions src/metatrain/experimental/gap/trainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from pathlib import Path
from typing import List, Union

import metatensor
Expand Down Expand Up @@ -140,3 +141,12 @@ def train(
model._subset_of_regressors_torch = (
model._subset_of_regressors.export_torch_script_model()
)

def save_checkpoint(self, model, checkpoint_dir: str):
# GAP won't save a checkpoint since it
# doesn't support restarting training
return

@classmethod
def load_checkpoint(cls, path: Union[str, Path], hypers_train) -> "GAP":
raise ValueError("GAP does not allow restarting training")
43 changes: 24 additions & 19 deletions src/metatrain/experimental/pet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
NeighborListOptions,
System,
)
from pet.hypers import Hypers
from pet.pet import PET as RawPET
from pet.pet import SelfContributionsWrapper

from metatrain.utils.data import DatasetInfo

from ...utils.dtype import dtype_to_str
from ...utils.export import export
from ...utils.io import check_suffix
from .utils import systems_to_batch_dict


Expand Down Expand Up @@ -110,29 +111,33 @@ def forward(
output_quantities[output_name] = output_tmap
return output_quantities

def save_checkpoint(self, path: Union[str, Path]):
torch.save(
{
"model_hypers": {
"model_hypers": self.hypers,
"dataset_info": self.dataset_info,
},
"model_state_dict": self.state_dict(),
},
check_suffix(path, ".ckpt"),
)

@classmethod
def load_checkpoint(cls, path: Union[str, Path]) -> "PET":

# Load the model and the metadata
model_dict = torch.load(path)
checkpoint = torch.load(path)
hypers = checkpoint["hypers"]
dataset_info = checkpoint["dataset_info"]
model = cls(
model_hypers=hypers["ARCHITECTURAL_HYPERS"], dataset_info=dataset_info
)

checkpoint = torch.load(path)
state_dict = checkpoint["checkpoint"]["model_state_dict"]

ARCHITECTURAL_HYPERS = Hypers(model.hypers)
raw_pet = RawPET(ARCHITECTURAL_HYPERS, 0.0, len(model.atomic_types))

new_state_dict = {}
for name, value in state_dict.items():
name = name.replace("model.pet_model.", "")
new_state_dict[name] = value

raw_pet.load_state_dict(new_state_dict)

# Create the model
model = cls(**model_dict["model_hypers"])
self_contributions = checkpoint["self_contributions"]
wrapper = SelfContributionsWrapper(raw_pet, self_contributions)

# Load the model weights
model.load_state_dict(model_dict["model_state_dict"])
model.set_trained_model(wrapper)

return model

Expand Down
Loading
Loading