-
Notifications
You must be signed in to change notification settings - Fork 4.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Automatic model checkpointing for pytorch-lightning training #10935
Changes from 7 commits
7d1f3a6
7b8aa8a
fbd78d0
fc4b78c
510dce5
987c0b8
1176113
e5b3916
2d477c5
9f19dc1
0586388
13ef85f
029b978
d749f75
5dd187e
e446073
62aa928
81e30d6
739a350
1c63358
4e414b3
5fd4d08
557091d
0bc2f9a
d8c4ccc
d5b77bb
113be24
d8aa855
df07b7b
21af0aa
250840d
da14b5b
9f0e0c4
fa7a6e1
c12d334
3c8f186
7e23cfc
b9653e6
ce752ac
d562550
3305894
54f14e5
91ea6a0
9d1da54
d1dfdb6
9f27ca6
0f33c52
4891c91
afc7846
e0bfbfb
01c5b69
4d5521b
27dc379
7a8527c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,6 +24,7 @@ | |
|
||
import mlflow | ||
from mlflow import pyfunc | ||
from mlflow.client import MlflowClient | ||
from mlflow.environment_variables import MLFLOW_DEFAULT_PREDICTION_DEVICE | ||
from mlflow.exceptions import MlflowException | ||
from mlflow.ml_package_versions import _ML_PACKAGE_VERSIONS | ||
|
@@ -901,6 +902,13 @@ def autolog( | |
silent=False, | ||
registered_model_name=None, | ||
extra_tags=None, | ||
model_checkpoint=True, | ||
model_checkpoint_monitor="val_loss", | ||
model_checkpoint_mode="min", | ||
model_checkpoint_save_best_only=True, | ||
model_checkpoint_save_weights_only=True, | ||
model_checkpoint_every_n_epochs=None, | ||
model_checkpoint_train_time_interval_S=600, | ||
): # pylint: disable=unused-argument | ||
""" | ||
Enables (or disables) and configures autologging from `PyTorch Lightning | ||
|
@@ -955,6 +963,26 @@ def autolog( | |
new model version of the registered model with this name. The registered model is | ||
created if it does not already exist. | ||
extra_tags: A dictionary of extra tags to set on each managed run created by autologging. | ||
:param model_checkpoint: Enable automatic model checkpointing, this feature only supports | ||
pytorch-lightning >= 1.4.0 | ||
WeichenXu123 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
:param model_checkpoint_monitor: In automatic model checkpointing, the metric name to monitor if | ||
you set `model_checkpoint_save_best_only` to True. | ||
:param model_checkpoint_save_best_only: If True, automatic model checkpointing only saves when | ||
the model is considered the "best" and the latest best model according to the quantity | ||
monitored will not be overwritten. | ||
WeichenXu123 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
:param model_checkpoint_mode: one of {"min", "max"}. In automatic model checkpointing, | ||
if save_best_only=True, the decision to overwrite the current save file is made based on | ||
either the maximization or the minimization of the monitored quantity. | ||
:param model_checkpoint_save_weights_only: In automatic model checkpointing, if True, then | ||
only the model’s weights will be saved. Otherwise, the optimizer states, | ||
lr-scheduler states, etc are added in the checkpoint too. | ||
:param model_checkpoint_every_n_epochs: Number of epochs between checkpoints for automatic | ||
model checkpointing. | ||
:param model_checkpoint_train_time_interval_S: Automatic model checkpoints are monitored | ||
at the specified time interval in seconds. For all practical purposes, this cannot be | ||
smaller than the amount of time it takes to process a single training batch. This is | ||
not guaranteed to execute at the exact time specified, but should be close. | ||
This must be mutually exclusive with `model_checkpoint_every_n_epochs`. | ||
|
||
.. testcode:: python | ||
:caption: Example | ||
|
@@ -1099,3 +1127,36 @@ def print_auto_logged_info(r): | |
autolog.__doc__ = autolog.__doc__.replace("MIN_REQ_VERSION", str(MIN_REQ_VERSION)).replace( | ||
"MAX_REQ_VERSION", str(MAX_REQ_VERSION) | ||
) | ||
|
||
|
||
def load_latest_checkpoint(model_class, run_id=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: No matter "save_weights_only" true or false, when loading back, it always requires the "model_class", and it does not need the model object. This is different with Keras side API: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. makes sense to me! |
||
""" | ||
If you enable model_checkpoint in autologging, during pytorch-lightning model | ||
training execution, checkpointed models are logged as MLflow artifacts. | ||
Using this API, you can load the latest checkpointed model. | ||
|
||
:param model_class: The class of the training model | ||
WeichenXu123 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
:param run_id: The id of the run which model is logged to. If not provided, | ||
current active run is used. | ||
""" | ||
from mlflow.pytorch._lightning_autolog import _LATEST_CHECKPOINT_ARTIFACT_TAG_KEY | ||
|
||
client = MlflowClient() | ||
|
||
if run_id is None: | ||
run = mlflow.active_run() | ||
if run is None: | ||
raise MlflowException( | ||
"There is no active run, please provide the 'run_id' for " | ||
"'load_best_checkpoint' call." | ||
WeichenXu123 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
run_id = run.info.run_id | ||
else: | ||
run = client.get_run(run_id) | ||
|
||
best_checkpoint_artifact = run.data.tags.get(_LATEST_CHECKPOINT_ARTIFACT_TAG_KEY) | ||
if best_checkpoint_artifact is None: | ||
raise MlflowException("There is no logged checkpoint artifact in current run.") | ||
WeichenXu123 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
downloaded_checkpoint_filepath = client.download_artifacts(run_id, best_checkpoint_artifact) | ||
return model_class.load_from_checkpoint(downloaded_checkpoint_filepath) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,13 @@ | ||
import logging | ||
import os | ||
import shutil | ||
import tempfile | ||
import time | ||
import warnings | ||
|
||
from packaging.version import Version | ||
|
||
from mlflow.utils.file_utils import create_tmp_dir | ||
import mlflow.pytorch | ||
from mlflow.exceptions import MlflowException | ||
from mlflow.ml_package_versions import _ML_PACKAGE_VERSIONS | ||
|
@@ -287,6 +290,109 @@ def on_test_end(self, trainer, pl_module): | |
self.metrics_logger.flush() | ||
|
||
|
||
_LATEST_CHECKPOINT_ARTIFACT_TAG_KEY = "_latest_checkpoint_artifact" | ||
WeichenXu123 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
class __MLflowModelCheckpointCallback(pl.Callback, metaclass=ExceptionSafeAbstractClass): | ||
WeichenXu123 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def __init__( | ||
WeichenXu123 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self, | ||
monitor, | ||
mode, | ||
save_best_only, | ||
save_weights_only, | ||
every_n_epochs, | ||
train_time_interval_S, | ||
): | ||
self.monitor = monitor | ||
self.mode = mode | ||
self.save_best_only = save_best_only | ||
self.save_weights_only = save_weights_only | ||
self.every_n_epochs = every_n_epochs | ||
self.train_time_interval_S = train_time_interval_S | ||
self.latest_checkpoint_timestamp = time.time() | ||
WeichenXu123 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.last_monitor_value = None | ||
|
||
def _is_new_checkpoint_better(self, new_monitor_value): | ||
if self.last_monitor_value is None: | ||
return True | ||
|
||
if self.mode == "min": | ||
return new_monitor_value <= self.last_monitor_value | ||
|
||
if self.mode == "max": | ||
return new_monitor_value >= self.last_monitor_value | ||
|
||
assert False, "Illegal __MLflowModelCheckpoint config." | ||
WeichenXu123 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: | ||
current_epoch = trainer.current_epoch | ||
metric_dict = {k: float(v) for k, v in trainer.callback_metrics.items()} | ||
|
||
should_checkpoint = False | ||
if self.every_n_epochs and (current_epoch % self.every_n_epochs == 0): | ||
should_checkpoint = True | ||
elif ( | ||
self.train_time_interval_S and | ||
time.time() - self.latest_checkpoint_timestamp > self.train_time_interval_S | ||
): | ||
should_checkpoint = True | ||
|
||
if not should_checkpoint: | ||
return | ||
|
||
if self.save_best_only: | ||
if self.monitor not in metric_dict: | ||
# "save-best-only" requires comparing the monitor metric value, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It might be a better UX to raise an explicit error than silently failing. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I prefer to use:
instead of raise exception, because the default checkpoint configs are:
but sometimes the model might not log "val_loss" metric, so that if it raise exception here, it breaks the whole autologging , but we don't need to break other part autologging (e.g. log params / metrics) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. makes sense to me! On top of that, may we call out "checkpoint logging is skipped" in the error message? |
||
# but the provided monitor metric is not available, | ||
# skip model checkpoint autologging | ||
return | ||
|
||
new_monitor_value = metric_dict[self.monitor] | ||
if not self._is_new_checkpoint_better(new_monitor_value): | ||
# Current checkpoint is worse than last saved checkpoint, | ||
# so skip checkpointing. | ||
return | ||
|
||
self.last_monitor_value = new_monitor_value | ||
|
||
if self.save_best_only: | ||
if self.save_weights_only: | ||
checkpoint_model_filename = "latest_checkpoint_model.weights.pth" | ||
else: | ||
checkpoint_model_filename = "latest_checkpoint_model.pth" | ||
checkpoint_metrics_filename = "latest_checkpoint_metrics.json" | ||
checkpoint_artifact_dir = "" | ||
else: | ||
if self.save_weights_only: | ||
checkpoint_model_filename = f"checkpoint_model_epoch_{current_epoch}.weights.pth" | ||
else: | ||
checkpoint_model_filename = f"checkpoint_model_epoch_{current_epoch}.pth" | ||
checkpoint_metrics_filename = f"checkpoint_metrics_epoch_{current_epoch}.json" | ||
checkpoint_artifact_dir = "checkpoints" | ||
|
||
mlflow.set_tag( | ||
_LATEST_CHECKPOINT_ARTIFACT_TAG_KEY, | ||
os.path.join(checkpoint_artifact_dir, checkpoint_model_filename) | ||
) | ||
|
||
mlflow.log_dict( | ||
{**metric_dict, "epoch": current_epoch}, | ||
os.path.join(checkpoint_artifact_dir, checkpoint_metrics_filename) | ||
) | ||
|
||
tmp_dir = create_tmp_dir() | ||
try: | ||
tmp_model_save_path = os.path.join(tmp_dir, checkpoint_model_filename) | ||
trainer.save_checkpoint(tmp_model_save_path, weights_only=self.save_weights_only) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note:
We can restore the "trainer" by https://pytorch-lightning.readthedocs.io/en/0.8.5/weights_loading.html#restoring-training-state So shall we provide a helper function to restore the trainer ? (similar to the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, resuming training from a saved checkpoint is a common use case. Also users are not necessarily loading the latest checkpoint, so I thnk we can provide a public API There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh, the new version API becomes:
So we can't return a trainer directly without https://lightning.ai/docs/pytorch/stable/common/checkpointing_basic.html#resume-training-state There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got it. My previous comment was like we can also load the model weights specifically without relying on the
For example, if we are sharing this model, and someone using vanilla pytorch wants to finetune it with a custom training loop, then they can load the checkpoints. |
||
|
||
mlflow.log_artifact(tmp_model_save_path, checkpoint_artifact_dir) | ||
finally: | ||
shutil.rmtree(tmp_dir, ignore_errors=True) | ||
|
||
self.latest_checkpoint_timestamp = time.time() | ||
|
||
|
||
# PyTorch-Lightning refactored the LoggerConnector class in version 1.4.0 and made metrics | ||
# update on demand. Prior to this, the metrics from the current step were not available to | ||
# callbacks immediately, so the view of metrics was off by one step. | ||
|
@@ -396,6 +502,47 @@ def patched_fit(original, self, *args, **kwargs): | |
) | ||
] | ||
|
||
model_checkpoint = get_autologging_config( | ||
mlflow.pytorch.FLAVOR_NAME, "model_checkpoint", True | ||
) | ||
if model_checkpoint: | ||
if _pl_version >= Version("1.4.0"): | ||
model_checkpoint_monitor = get_autologging_config( | ||
mlflow.pytorch.FLAVOR_NAME, "model_checkpoint_monitor", "val_loss" | ||
) | ||
model_checkpoint_mode = get_autologging_config( | ||
mlflow.pytorch.FLAVOR_NAME, "model_checkpoint_mode", "min" | ||
) | ||
model_checkpoint_save_best_only = get_autologging_config( | ||
mlflow.pytorch.FLAVOR_NAME, "model_checkpoint_save_best_only", True | ||
) | ||
model_checkpoint_save_weights_only = get_autologging_config( | ||
mlflow.pytorch.FLAVOR_NAME, "model_checkpoint_save_weights_only", True | ||
) | ||
model_checkpoint_every_n_epochs = get_autologging_config( | ||
mlflow.pytorch.FLAVOR_NAME, "model_checkpoint_every_n_epochs", None | ||
) | ||
model_checkpoint_train_time_interval_S = get_autologging_config( | ||
mlflow.pytorch.FLAVOR_NAME, "model_checkpoint_train_time_interval_S", None | ||
) | ||
|
||
# __MLflowModelCheckpoint only supports pytorch-lightning >- 1.4.0 | ||
if not any(isinstance(callbacks, __MLflowModelCheckpointCallback) for callbacks in self.callbacks): | ||
self.callbacks += [ | ||
__MLflowModelCheckpointCallback( | ||
monitor=model_checkpoint_monitor, | ||
mode=model_checkpoint_mode, | ||
save_best_only=model_checkpoint_save_best_only, | ||
save_weights_only=model_checkpoint_save_weights_only, | ||
every_n_epochs=model_checkpoint_every_n_epochs, | ||
train_time_interval_S=model_checkpoint_train_time_interval_S, | ||
) | ||
] | ||
else: | ||
warnings.warn( | ||
"Automatic model checkpointing is disabled because this feature only " | ||
"supports pytorch-lightning >= 1.4.0.") | ||
|
||
client.flush(synchronous=False) | ||
|
||
result = original(self, *args, **kwargs) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this variable name is a bit strange. can we combine
model_checkpoint_every_n_epochs
andmodel_checkpoint_train_time_interval_S
tosave_freq
? I find that's pretty clean and easy to use:https://github.com/keras-team/keras/blob/fe2f54aa5bc42fb23a96449cf90434ab9bb6a2cd/keras/callbacks/model_checkpoint.py#L112
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One question:
The
save_freq
in keras checkpointing callback supports "do checkpoint after N batches (steps) execution", do we want to support this ? Or we only want to support epoch based checkpointing ?My suggestion is we only support epoch based checkpointing, for batch based checkpointing, per-batch metric validation result is less accurate , and in pytorch-lightning, per-batch validation is not available except you log the metric with
on_step=True
i.e.LightningModule.log(metric_name, value, on_epoch=True, on_step=True)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Demo notebook is attached in the PR description.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
save_freq
supports both per-epoch saving and per-N-steps saving. Both scenarios are commonly used based on my experience with the production team.I think we can just use the callback hook? It should have the training stats available: https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.Callback.html#lightning.pytorch.callbacks.Callback.on_after_backward
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
summary:
In ModelCheckpointCallback, if we use per-N-steps saving way, it needs to check "monitor" metric and the metric must be updated on per step execution, we need like
log(metric_name, value, on_step=True)
to enable per-step updated metric.We have 2 options:
(1) In Mlflow ModelCheckpointCallback, we update the monitor metric in
on_after_backward
.(2) Document this to tell user to add code in their module class
train_step
method:log(metric_name, value, on_step=True)
for the metric used as "monitor"Option (1) has one issue: In
on_after_backward
we can't get the data for metric computation So we might have to choose option (2)Option (2) is the way that current built-in pytorch-lightning ModelCheckpointCallback uses.
@chenmoneygithub
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea if 2) is lightning's behavior, we can proceed with that one.