Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatic model checkpointing for pytorch-lightning training #10935

Merged
merged 54 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
7d1f3a6
init
WeichenXu123 Jan 29, 2024
7b8aa8a
update
WeichenXu123 Jan 29, 2024
fbd78d0
update
WeichenXu123 Jan 29, 2024
fc4b78c
update
WeichenXu123 Jan 30, 2024
510dce5
update
WeichenXu123 Jan 30, 2024
987c0b8
merge master
WeichenXu123 Jan 30, 2024
1176113
update
WeichenXu123 Jan 30, 2024
e5b3916
update
WeichenXu123 Feb 1, 2024
2d477c5
update test
WeichenXu123 Feb 1, 2024
9f19dc1
update
WeichenXu123 Feb 1, 2024
0586388
update
WeichenXu123 Feb 1, 2024
13ef85f
update
WeichenXu123 Feb 1, 2024
029b978
update
WeichenXu123 Feb 1, 2024
d749f75
fix
WeichenXu123 Feb 2, 2024
5dd187e
comment
WeichenXu123 Feb 2, 2024
e446073
update
WeichenXu123 Feb 2, 2024
62aa928
skipif test
WeichenXu123 Feb 2, 2024
81e30d6
Merge branch 'master' into pl-model-checkpoint
WeichenXu123 Feb 6, 2024
739a350
update
WeichenXu123 Feb 6, 2024
1c63358
update
WeichenXu123 Feb 6, 2024
4e414b3
update
WeichenXu123 Feb 6, 2024
5fd4d08
update doc
WeichenXu123 Feb 6, 2024
557091d
update
WeichenXu123 Feb 6, 2024
0bc2f9a
Merge branch 'master' into pl-model-checkpoint
WeichenXu123 Feb 7, 2024
d8c4ccc
format
WeichenXu123 Feb 7, 2024
d5b77bb
address comments
WeichenXu123 Feb 7, 2024
113be24
split test
WeichenXu123 Feb 7, 2024
d8aa855
fix doc
WeichenXu123 Feb 7, 2024
df07b7b
fix doc
WeichenXu123 Feb 7, 2024
21af0aa
format
WeichenXu123 Feb 7, 2024
250840d
validation in constructor
WeichenXu123 Feb 7, 2024
da14b5b
validate in constructor
WeichenXu123 Feb 7, 2024
9f0e0c4
Merge remote-tracking branch 'base/master' into pl-model-checkpoint
mlflow-automation Feb 7, 2024
fa7a6e1
Autoformat: https://github.com/mlflow/mlflow/actions/runs/7810748460
mlflow-automation Feb 7, 2024
c12d334
update
WeichenXu123 Feb 7, 2024
3c8f186
update tag key
WeichenXu123 Feb 7, 2024
7e23cfc
address comments
WeichenXu123 Feb 7, 2024
b9653e6
address comment
WeichenXu123 Feb 7, 2024
ce752ac
update doc
WeichenXu123 Feb 7, 2024
d562550
remove mock from tests
WeichenXu123 Feb 8, 2024
3305894
nit
WeichenXu123 Feb 8, 2024
54f14e5
format
WeichenXu123 Feb 8, 2024
91ea6a0
Merge remote-tracking branch 'base/master' into pl-model-checkpoint
mlflow-automation Feb 9, 2024
9d1da54
Autoformat: https://github.com/mlflow/mlflow/actions/runs/7838319186
mlflow-automation Feb 9, 2024
d1dfdb6
update
WeichenXu123 Feb 9, 2024
9f27ca6
improve test
WeichenXu123 Feb 9, 2024
0f33c52
update tests
WeichenXu123 Feb 14, 2024
4891c91
address comments
WeichenXu123 Feb 14, 2024
afc7846
format
WeichenXu123 Feb 14, 2024
e0bfbfb
Merge remote-tracking branch 'base/master' into pl-model-checkpoint
mlflow-automation Feb 14, 2024
01c5b69
Autoformat: https://github.com/mlflow/mlflow/actions/runs/7897092879
mlflow-automation Feb 14, 2024
4d5521b
format
Feb 14, 2024
27dc379
format
Feb 14, 2024
7a8527c
format
Feb 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions mlflow/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import mlflow
from mlflow import pyfunc
from mlflow.client import MlflowClient
from mlflow.environment_variables import MLFLOW_DEFAULT_PREDICTION_DEVICE
from mlflow.exceptions import MlflowException
from mlflow.ml_package_versions import _ML_PACKAGE_VERSIONS
Expand Down Expand Up @@ -901,6 +902,13 @@ def autolog(
silent=False,
registered_model_name=None,
extra_tags=None,
model_checkpoint=True,
model_checkpoint_monitor="val_loss",
model_checkpoint_mode="min",
model_checkpoint_save_best_only=True,
model_checkpoint_save_weights_only=True,
model_checkpoint_every_n_epochs=None,
model_checkpoint_train_time_interval_S=600,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this variable name is a bit strange. can we combine model_checkpoint_every_n_epochs and model_checkpoint_train_time_interval_S tosave_freq? I find that's pretty clean and easy to use:
https://github.com/keras-team/keras/blob/fe2f54aa5bc42fb23a96449cf90434ab9bb6a2cd/keras/callbacks/model_checkpoint.py#L112

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One question:

The save_freq in keras checkpointing callback supports "do checkpoint after N batches (steps) execution", do we want to support this ? Or we only want to support epoch based checkpointing ?

My suggestion is we only support epoch based checkpointing, for batch based checkpointing, per-batch metric validation result is less accurate , and in pytorch-lightning, per-batch validation is not available except you log the metric with on_step=True i.e. LightningModule.log(metric_name, value, on_epoch=True, on_step=True)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Demo notebook is attached in the PR description.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

save_freq supports both per-epoch saving and per-N-steps saving. Both scenarios are commonly used based on my experience with the production team.

in pytorch-lightning, per-batch validation is not available

I think we can just use the callback hook? It should have the training stats available: https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.Callback.html#lightning.pytorch.callbacks.Callback.on_after_backward

Copy link
Collaborator Author

@WeichenXu123 WeichenXu123 Jan 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

summary:

In ModelCheckpointCallback, if we use per-N-steps saving way, it needs to check "monitor" metric and the metric must be updated on per step execution, we need like log(metric_name, value, on_step=True) to enable per-step updated metric.

We have 2 options:
(1) In Mlflow ModelCheckpointCallback, we update the monitor metric in on_after_backward.
(2) Document this to tell user to add code in their module class train_step method: log(metric_name, value, on_step=True) for the metric used as "monitor"

Option (1) has one issue: In on_after_backward we can't get the data for metric computation So we might have to choose option (2)

Option (2) is the way that current built-in pytorch-lightning ModelCheckpointCallback uses.

@chenmoneygithub

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea if 2) is lightning's behavior, we can proceed with that one.

): # pylint: disable=unused-argument
"""
Enables (or disables) and configures autologging from `PyTorch Lightning
Expand Down Expand Up @@ -955,6 +963,26 @@ def autolog(
new model version of the registered model with this name. The registered model is
created if it does not already exist.
extra_tags: A dictionary of extra tags to set on each managed run created by autologging.
:param model_checkpoint: Enable automatic model checkpointing, this feature only supports
pytorch-lightning >= 1.4.0
WeichenXu123 marked this conversation as resolved.
Show resolved Hide resolved
:param model_checkpoint_monitor: In automatic model checkpointing, the metric name to monitor if
you set `model_checkpoint_save_best_only` to True.
:param model_checkpoint_save_best_only: If True, automatic model checkpointing only saves when
the model is considered the "best" and the latest best model according to the quantity
monitored will not be overwritten.
WeichenXu123 marked this conversation as resolved.
Show resolved Hide resolved
:param model_checkpoint_mode: one of {"min", "max"}. In automatic model checkpointing,
if save_best_only=True, the decision to overwrite the current save file is made based on
either the maximization or the minimization of the monitored quantity.
:param model_checkpoint_save_weights_only: In automatic model checkpointing, if True, then
only the model’s weights will be saved. Otherwise, the optimizer states,
lr-scheduler states, etc are added in the checkpoint too.
:param model_checkpoint_every_n_epochs: Number of epochs between checkpoints for automatic
model checkpointing.
:param model_checkpoint_train_time_interval_S: Automatic model checkpoints are monitored
at the specified time interval in seconds. For all practical purposes, this cannot be
smaller than the amount of time it takes to process a single training batch. This is
not guaranteed to execute at the exact time specified, but should be close.
This must be mutually exclusive with `model_checkpoint_every_n_epochs`.

.. testcode:: python
:caption: Example
Expand Down Expand Up @@ -1099,3 +1127,36 @@ def print_auto_logged_info(r):
autolog.__doc__ = autolog.__doc__.replace("MIN_REQ_VERSION", str(MIN_REQ_VERSION)).replace(
"MAX_REQ_VERSION", str(MAX_REQ_VERSION)
)


def load_latest_checkpoint(model_class, run_id=None):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note:

No matter "save_weights_only" true or false, when loading back, it always requires the "model_class", and it does not need the model object.

This is different with Keras side API:
https://github.com/mlflow/mlflow/pull/10955/files#r1471374884

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense to me!

"""
If you enable model_checkpoint in autologging, during pytorch-lightning model
training execution, checkpointed models are logged as MLflow artifacts.
Using this API, you can load the latest checkpointed model.

:param model_class: The class of the training model
WeichenXu123 marked this conversation as resolved.
Show resolved Hide resolved
:param run_id: The id of the run which model is logged to. If not provided,
current active run is used.
"""
from mlflow.pytorch._lightning_autolog import _LATEST_CHECKPOINT_ARTIFACT_TAG_KEY

client = MlflowClient()

if run_id is None:
run = mlflow.active_run()
if run is None:
raise MlflowException(
"There is no active run, please provide the 'run_id' for "
"'load_best_checkpoint' call."
WeichenXu123 marked this conversation as resolved.
Show resolved Hide resolved
)
run_id = run.info.run_id
else:
run = client.get_run(run_id)

best_checkpoint_artifact = run.data.tags.get(_LATEST_CHECKPOINT_ARTIFACT_TAG_KEY)
if best_checkpoint_artifact is None:
raise MlflowException("There is no logged checkpoint artifact in current run.")
WeichenXu123 marked this conversation as resolved.
Show resolved Hide resolved

downloaded_checkpoint_filepath = client.download_artifacts(run_id, best_checkpoint_artifact)
return model_class.load_from_checkpoint(downloaded_checkpoint_filepath)
147 changes: 147 additions & 0 deletions mlflow/pytorch/_lightning_autolog.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import logging
import os
import shutil
import tempfile
import time
import warnings

from packaging.version import Version

from mlflow.utils.file_utils import create_tmp_dir
import mlflow.pytorch
from mlflow.exceptions import MlflowException
from mlflow.ml_package_versions import _ML_PACKAGE_VERSIONS
Expand Down Expand Up @@ -287,6 +290,109 @@ def on_test_end(self, trainer, pl_module):
self.metrics_logger.flush()


_LATEST_CHECKPOINT_ARTIFACT_TAG_KEY = "_latest_checkpoint_artifact"
WeichenXu123 marked this conversation as resolved.
Show resolved Hide resolved


class __MLflowModelCheckpointCallback(pl.Callback, metaclass=ExceptionSafeAbstractClass):
WeichenXu123 marked this conversation as resolved.
Show resolved Hide resolved

def __init__(
WeichenXu123 marked this conversation as resolved.
Show resolved Hide resolved
self,
monitor,
mode,
save_best_only,
save_weights_only,
every_n_epochs,
train_time_interval_S,
):
self.monitor = monitor
self.mode = mode
self.save_best_only = save_best_only
self.save_weights_only = save_weights_only
self.every_n_epochs = every_n_epochs
self.train_time_interval_S = train_time_interval_S
self.latest_checkpoint_timestamp = time.time()
WeichenXu123 marked this conversation as resolved.
Show resolved Hide resolved
self.last_monitor_value = None

def _is_new_checkpoint_better(self, new_monitor_value):
if self.last_monitor_value is None:
return True

if self.mode == "min":
return new_monitor_value <= self.last_monitor_value

if self.mode == "max":
return new_monitor_value >= self.last_monitor_value

assert False, "Illegal __MLflowModelCheckpoint config."
WeichenXu123 marked this conversation as resolved.
Show resolved Hide resolved

def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
current_epoch = trainer.current_epoch
metric_dict = {k: float(v) for k, v in trainer.callback_metrics.items()}

should_checkpoint = False
if self.every_n_epochs and (current_epoch % self.every_n_epochs == 0):
should_checkpoint = True
elif (
self.train_time_interval_S and
time.time() - self.latest_checkpoint_timestamp > self.train_time_interval_S
):
should_checkpoint = True

if not should_checkpoint:
return

if self.save_best_only:
if self.monitor not in metric_dict:
# "save-best-only" requires comparing the monitor metric value,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be a better UX to raise an explicit error than silently failing.

Copy link
Collaborator Author

@WeichenXu123 WeichenXu123 Feb 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to use:

                _logger.error(
                    "If MLflowModelCheckpoint 'save_best_only' config is True, it requires to "
                    "compare the monitored metric value, but the provided monitored metric value "
                    "is not available."
                )

instead of raise exception,

because the default checkpoint configs are:

    checkpoint=True,
    checkpoint_monitor="val_loss",
    checkpoint_mode="min",
    checkpoint_save_best_only=True,

but sometimes the model might not log "val_loss" metric, so that if it raise exception here, it breaks the whole autologging , but we don't need to break other part autologging (e.g. log params / metrics)

@chenmoneygithub

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense to me!

On top of that, may we call out "checkpoint logging is skipped" in the error message?

# but the provided monitor metric is not available,
# skip model checkpoint autologging
return

new_monitor_value = metric_dict[self.monitor]
if not self._is_new_checkpoint_better(new_monitor_value):
# Current checkpoint is worse than last saved checkpoint,
# so skip checkpointing.
return

self.last_monitor_value = new_monitor_value

if self.save_best_only:
if self.save_weights_only:
checkpoint_model_filename = "latest_checkpoint_model.weights.pth"
else:
checkpoint_model_filename = "latest_checkpoint_model.pth"
checkpoint_metrics_filename = "latest_checkpoint_metrics.json"
checkpoint_artifact_dir = ""
else:
if self.save_weights_only:
checkpoint_model_filename = f"checkpoint_model_epoch_{current_epoch}.weights.pth"
else:
checkpoint_model_filename = f"checkpoint_model_epoch_{current_epoch}.pth"
checkpoint_metrics_filename = f"checkpoint_metrics_epoch_{current_epoch}.json"
checkpoint_artifact_dir = "checkpoints"

mlflow.set_tag(
_LATEST_CHECKPOINT_ARTIFACT_TAG_KEY,
os.path.join(checkpoint_artifact_dir, checkpoint_model_filename)
)

mlflow.log_dict(
{**metric_dict, "epoch": current_epoch},
os.path.join(checkpoint_artifact_dir, checkpoint_metrics_filename)
)

tmp_dir = create_tmp_dir()
try:
tmp_model_save_path = os.path.join(tmp_dir, checkpoint_model_filename)
trainer.save_checkpoint(tmp_model_save_path, weights_only=self.save_weights_only)
Copy link
Collaborator Author

@WeichenXu123 WeichenXu123 Jan 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note:

save_checkpoint can save both model weights and trainer states.

We can restore the "trainer" by https://pytorch-lightning.readthedocs.io/en/0.8.5/weights_loading.html#restoring-training-state

So shall we provide a helper function to restore the trainer ? (similar to the load_latest_checkpoint)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, resuming training from a saved checkpoint is a common use case. Also users are not necessarily loading the latest checkpoint, so I thnk we can provide a public API mlflow.pytorch.load_checkpoint, which can load pytorch checkpoint, which includes model weights and optimizer states. This can also be used for vanilla pytorch workflow, not necessarily lightning workflows.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, the new version API becomes:

trainer.fit(model, ckpt_path="some/path/to/my_checkpoint.ckpt")

So we can't return a trainer directly without fit invocation.

https://lightning.ai/docs/pytorch/stable/common/checkpointing_basic.html#resume-training-state

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. My previous comment was like we can also load the model weights specifically without relying on the trainer instance, which means:

model = mlflow.pytorch.load_checkpoint(model_class, mlflow_uri)
trainer = trainer(model)

For example, if we are sharing this model, and someone using vanilla pytorch wants to finetune it with a custom training loop, then they can load the checkpoints.


mlflow.log_artifact(tmp_model_save_path, checkpoint_artifact_dir)
finally:
shutil.rmtree(tmp_dir, ignore_errors=True)

self.latest_checkpoint_timestamp = time.time()


# PyTorch-Lightning refactored the LoggerConnector class in version 1.4.0 and made metrics
# update on demand. Prior to this, the metrics from the current step were not available to
# callbacks immediately, so the view of metrics was off by one step.
Expand Down Expand Up @@ -396,6 +502,47 @@ def patched_fit(original, self, *args, **kwargs):
)
]

model_checkpoint = get_autologging_config(
mlflow.pytorch.FLAVOR_NAME, "model_checkpoint", True
)
if model_checkpoint:
if _pl_version >= Version("1.4.0"):
model_checkpoint_monitor = get_autologging_config(
mlflow.pytorch.FLAVOR_NAME, "model_checkpoint_monitor", "val_loss"
)
model_checkpoint_mode = get_autologging_config(
mlflow.pytorch.FLAVOR_NAME, "model_checkpoint_mode", "min"
)
model_checkpoint_save_best_only = get_autologging_config(
mlflow.pytorch.FLAVOR_NAME, "model_checkpoint_save_best_only", True
)
model_checkpoint_save_weights_only = get_autologging_config(
mlflow.pytorch.FLAVOR_NAME, "model_checkpoint_save_weights_only", True
)
model_checkpoint_every_n_epochs = get_autologging_config(
mlflow.pytorch.FLAVOR_NAME, "model_checkpoint_every_n_epochs", None
)
model_checkpoint_train_time_interval_S = get_autologging_config(
mlflow.pytorch.FLAVOR_NAME, "model_checkpoint_train_time_interval_S", None
)

# __MLflowModelCheckpoint only supports pytorch-lightning >- 1.4.0
if not any(isinstance(callbacks, __MLflowModelCheckpointCallback) for callbacks in self.callbacks):
self.callbacks += [
__MLflowModelCheckpointCallback(
monitor=model_checkpoint_monitor,
mode=model_checkpoint_mode,
save_best_only=model_checkpoint_save_best_only,
save_weights_only=model_checkpoint_save_weights_only,
every_n_epochs=model_checkpoint_every_n_epochs,
train_time_interval_S=model_checkpoint_train_time_interval_S,
)
]
else:
warnings.warn(
"Automatic model checkpointing is disabled because this feature only "
"supports pytorch-lightning >= 1.4.0.")

client.flush(synchronous=False)

result = original(self, *args, **kwargs)
Expand Down
Loading
Loading