Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatic model checkpointing for pytorch-lightning training #10935

Merged
merged 54 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
7d1f3a6
init
WeichenXu123 Jan 29, 2024
7b8aa8a
update
WeichenXu123 Jan 29, 2024
fbd78d0
update
WeichenXu123 Jan 29, 2024
fc4b78c
update
WeichenXu123 Jan 30, 2024
510dce5
update
WeichenXu123 Jan 30, 2024
987c0b8
merge master
WeichenXu123 Jan 30, 2024
1176113
update
WeichenXu123 Jan 30, 2024
e5b3916
update
WeichenXu123 Feb 1, 2024
2d477c5
update test
WeichenXu123 Feb 1, 2024
9f19dc1
update
WeichenXu123 Feb 1, 2024
0586388
update
WeichenXu123 Feb 1, 2024
13ef85f
update
WeichenXu123 Feb 1, 2024
029b978
update
WeichenXu123 Feb 1, 2024
d749f75
fix
WeichenXu123 Feb 2, 2024
5dd187e
comment
WeichenXu123 Feb 2, 2024
e446073
update
WeichenXu123 Feb 2, 2024
62aa928
skipif test
WeichenXu123 Feb 2, 2024
81e30d6
Merge branch 'master' into pl-model-checkpoint
WeichenXu123 Feb 6, 2024
739a350
update
WeichenXu123 Feb 6, 2024
1c63358
update
WeichenXu123 Feb 6, 2024
4e414b3
update
WeichenXu123 Feb 6, 2024
5fd4d08
update doc
WeichenXu123 Feb 6, 2024
557091d
update
WeichenXu123 Feb 6, 2024
0bc2f9a
Merge branch 'master' into pl-model-checkpoint
WeichenXu123 Feb 7, 2024
d8c4ccc
format
WeichenXu123 Feb 7, 2024
d5b77bb
address comments
WeichenXu123 Feb 7, 2024
113be24
split test
WeichenXu123 Feb 7, 2024
d8aa855
fix doc
WeichenXu123 Feb 7, 2024
df07b7b
fix doc
WeichenXu123 Feb 7, 2024
21af0aa
format
WeichenXu123 Feb 7, 2024
250840d
validation in constructor
WeichenXu123 Feb 7, 2024
da14b5b
validate in constructor
WeichenXu123 Feb 7, 2024
9f0e0c4
Merge remote-tracking branch 'base/master' into pl-model-checkpoint
mlflow-automation Feb 7, 2024
fa7a6e1
Autoformat: https://github.com/mlflow/mlflow/actions/runs/7810748460
mlflow-automation Feb 7, 2024
c12d334
update
WeichenXu123 Feb 7, 2024
3c8f186
update tag key
WeichenXu123 Feb 7, 2024
7e23cfc
address comments
WeichenXu123 Feb 7, 2024
b9653e6
address comment
WeichenXu123 Feb 7, 2024
ce752ac
update doc
WeichenXu123 Feb 7, 2024
d562550
remove mock from tests
WeichenXu123 Feb 8, 2024
3305894
nit
WeichenXu123 Feb 8, 2024
54f14e5
format
WeichenXu123 Feb 8, 2024
91ea6a0
Merge remote-tracking branch 'base/master' into pl-model-checkpoint
mlflow-automation Feb 9, 2024
9d1da54
Autoformat: https://github.com/mlflow/mlflow/actions/runs/7838319186
mlflow-automation Feb 9, 2024
d1dfdb6
update
WeichenXu123 Feb 9, 2024
9f27ca6
improve test
WeichenXu123 Feb 9, 2024
0f33c52
update tests
WeichenXu123 Feb 14, 2024
4891c91
address comments
WeichenXu123 Feb 14, 2024
afc7846
format
WeichenXu123 Feb 14, 2024
e0bfbfb
Merge remote-tracking branch 'base/master' into pl-model-checkpoint
mlflow-automation Feb 14, 2024
01c5b69
Autoformat: https://github.com/mlflow/mlflow/actions/runs/7897092879
mlflow-automation Feb 14, 2024
4d5521b
format
Feb 14, 2024
27dc379
format
Feb 14, 2024
7a8527c
format
Feb 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,10 @@
("py:class", "keras.src.callbacks.callback.Callback"),
("py:class", "keras.callbacks.Callback"),
("py:class", "keras.src.callbacks.Callback"),
("py:class", "pytorch_lightning.callbacks.callback.Callback"),
("py:class", "pytorch_lightning.trainer.trainer.Trainer"),
("py:class", "pytorch_lightning.core.module.LightningModule"),
("py:class", "pytorch_lightning.core.LightningModule"),
]


Expand Down
2 changes: 1 addition & 1 deletion mlflow/pyfunc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2171,7 +2171,7 @@ def predict(self, context, model_input: List[str], params=None) -> List[str]:


with mlflow.start_run():
model_info = mlflow.pyfunc.log_model(artifact_path="model", python_model=MyModel()) # pylint: disable=line-too-long
model_info = mlflow.pyfunc.log_model(artifact_path="model", python_model=MyModel()) # noqa # pylint: disable=line-too-long


loaded_model = mlflow.pyfunc.load_model(model_uri=model_info.model_uri)
Expand Down
128 changes: 128 additions & 0 deletions mlflow/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import mlflow
from mlflow import pyfunc
from mlflow.client import MlflowClient
from mlflow.environment_variables import MLFLOW_DEFAULT_PREDICTION_DEVICE
from mlflow.exceptions import MlflowException
from mlflow.ml_package_versions import _ML_PACKAGE_VERSIONS
Expand Down Expand Up @@ -902,6 +903,12 @@ def autolog(
silent=False,
registered_model_name=None,
extra_tags=None,
checkpoint=True,
checkpoint_monitor="val_loss",
checkpoint_mode="min",
checkpoint_save_best_only=True,
checkpoint_save_weights_only=False,
checkpoint_save_freq="epoch",
WeichenXu123 marked this conversation as resolved.
Show resolved Hide resolved
): # pylint: disable=unused-argument
"""
Enables (or disables) and configures autologging from `PyTorch Lightning
Expand Down Expand Up @@ -956,6 +963,25 @@ def autolog(
new model version of the registered model with this name. The registered model is
created if it does not already exist.
extra_tags: A dictionary of extra tags to set on each managed run created by autologging.
checkpoint: Enable automatic model checkpointing, this feature only supports
pytorch-lightning >= 1.6.0.
checkpoint_monitor: In automatic model checkpointing, the metric name to monitor if
you set `model_checkpoint_save_best_only` to True.
checkpoint_save_best_only: If True, automatic model checkpointing only saves when
the model is considered the "best" model according to the quantity
monitored and previous checkpoint model is overwritten.
checkpoint_mode: one of {"min", "max"}. In automatic model checkpointing,
if save_best_only=True, the decision to overwrite the current save file is made based on
either the maximization or the minimization of the monitored quantity.
checkpoint_save_weights_only: In automatic model checkpointing, if True, then
only the model’s weights will be saved. Otherwise, the optimizer states,
lr-scheduler states, etc are added in the checkpoint too.
checkpoint_save_freq: `"epoch"` or integer. When using `"epoch"`, the callback
saves the model after each epoch. When using integer, the callback
saves the model at end of this many batches. Note that if the saving isn't aligned to
epochs, the monitored metric may potentially be less reliable (it
could reflect as little as 1 batch, since the metrics get reset
every epoch). Defaults to `"epoch"`.

.. code-block:: python
:test:
Expand Down Expand Up @@ -1101,3 +1127,105 @@ def print_auto_logged_info(r):
autolog.__doc__ = autolog.__doc__.replace("MIN_REQ_VERSION", str(MIN_REQ_VERSION)).replace(
"MAX_REQ_VERSION", str(MAX_REQ_VERSION)
)


def load_checkpoint(model_class, run_id=None, epoch=None, global_step=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a public API?

Copy link
Collaborator Author

@WeichenXu123 WeichenXu123 Feb 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I will add it to __all__

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do I use this API? Any example?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a demo notebook https://e2-dogfood.staging.cloud.databricks.com/?o=6051921418418893#notebook/2173893049403456

After model checkpoint is logged, you can call load_checkpoint API to get the checkpoint model

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add an example in the docstring?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a Returns section in the docstring?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Returns section and example code are added

"""
If you enable "checkpoint" in autologging, during pytorch-lightning model
training execution, checkpointed models are logged as MLflow artifacts.
Using this API, you can load the checkpointed model.

If you want to load the latest checkpoint, set both `epoch` and `global_step` to None.
If "checkpoint_save_freq" is set to "epoch" in autologging,
you can set `epoch` param to the epoch of the checkpoint to load specific epoch checkpoint.
If "checkpoint_save_freq" is set to an integer in autologging,
you can set `global_step` param to the global step of the checkpoint to load specific
global step checkpoint.
`epoch` param and `global_step` can't be set together.

Args:
model_class: The class of the training model, the class should inherit
'pytorch_lightning.LightningModule'.
run_id: The id of the run which model is logged to. If not provided,
current active run is used.
epoch: The epoch of the checkpoint to be loaded, if you set
"checkpoint_save_freq" to "epoch".
global_step: The global step of the checkpoint to be loaded, if
you set "checkpoint_save_freq" to an integer.

Returns:
The instance of a pytorch-lightning model restored from the specified checkpoint.

.. code-block:: python
:caption: Example

import mlflow.pytorch

mlflow.pytorch.autolog(checkpoint=True)

model = MyLightningModuleNet() # A custom-pytorch lightning model
trainer = Trainer()

with mlflow.start_run() as run:
trainer.fit(net)

run_id = run.info.run_id

# load latest checkpoint model
latest_checkpoint_model = mlflow.pytorch.load_checkpoint(MyLightningModuleNet, run_id)

# load history checkpoint model logged in second epoch
checkpoint_model = mlflow.pytorch.load_checkpoint(MyLightningModuleNet, run_id, epoch=2)
"""
from mlflow.utils.mlflow_tags import LATEST_CHECKPOINT_ARTIFACT_TAG_KEY

client = MlflowClient()

if run_id is None:
run = mlflow.active_run()
if run is None:
raise MlflowException(
"There is no active run, please provide the 'run_id' for " "'load_checkpoint' call."
)
run_id = run.info.run_id
else:
run = client.get_run(run_id)

latest_checkpoint_artifact_path = run.data.tags.get(LATEST_CHECKPOINT_ARTIFACT_TAG_KEY)
if latest_checkpoint_artifact_path is None:
raise MlflowException("There is no logged checkpoint artifact in the current run.")

checkpoint_filename = os.path.basename(latest_checkpoint_artifact_path)

if epoch is not None and global_step is not None:
raise MlflowException(
"Only one of 'epoch' and 'global_step' can be set for 'load_checkpoint'."
)
elif global_step is not None:
checkpoint_artifact_path = f"checkpoints/global_step_{global_step}/{checkpoint_filename}"
elif epoch is not None:
checkpoint_artifact_path = f"checkpoints/epoch_{epoch}/{checkpoint_filename}"
else:
checkpoint_artifact_path = latest_checkpoint_artifact_path

downloaded_checkpoint_filepath = client.download_artifacts(run_id, checkpoint_artifact_path)
return model_class.load_from_checkpoint(downloaded_checkpoint_filepath)


__all__ = [
"autolog",
"load_model",
"save_model",
"log_model",
"get_default_pip_requirements",
"get_default_conda_env",
"load_checkpoint",
]

try:
from mlflow.pytorch._lightning_autolog import MlflowModelCheckpointCallback # noqa: F401

__all__.append("MLflowModelCheckpointCallback")
except ImportError:
# Swallow exception if pytorch-lightning is not installed.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can skip this try catch block because we are using lazy loading.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seemingly lazy loading doesn't address it.

lightning_autolog = LazyLoader("mlflow. pytorch. _lightning_autolog", globals(), "mlflow. pytorch. _lightning_autolog")  # no exception
lightning_autolog.MLflowModelCheckpointCallback  # exception is raised if pytorch-lightning is not installed.

Copy link
Collaborator Author

@WeichenXu123 WeichenXu123 Feb 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and if we don't swallow the exception here,

when initializaing the mlflow.pytorch module (even if it is lazy-loading, it needs initialization when it is loaded), exception is raised.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right, I forgot that this is called pytorch module. In fact this autologging part should go tolightning or pytorch_lightning, but it's independent of this PR.

pass
Loading
Loading