Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keras automatic checkpoint #11197

Merged
merged 23 commits into from Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
44 changes: 8 additions & 36 deletions mlflow/pytorch/__init__.py
Expand Up @@ -24,7 +24,6 @@

import mlflow
from mlflow import pyfunc
from mlflow.client import MlflowClient
from mlflow.environment_variables import MLFLOW_DEFAULT_PREDICTION_DEVICE
from mlflow.exceptions import MlflowException
from mlflow.ml_package_versions import _ML_PACKAGE_VERSIONS
Expand All @@ -37,6 +36,7 @@
from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS
from mlflow.tracking.artifact_utils import _download_artifact_from_uri
from mlflow.utils.autologging_utils import autologging_integration, safe_patch
from mlflow.utils.checkpoint_utils import download_checkpoint_artifact
from mlflow.utils.docstring_utils import LOG_MODEL_PARAM_DOCS, format_docstring
from mlflow.utils.environment import (
_CONDA_ENV_FILE_NAME,
Expand Down Expand Up @@ -108,7 +108,7 @@ def get_default_conda_env():
.. code-block:: python
:caption: Example

import mlflow.pytorch
import mlflow

# Log PyTorch model
with mlflow.start_run() as run:
Expand Down Expand Up @@ -1159,15 +1159,16 @@ def load_checkpoint(model_class, run_id=None, epoch=None, global_step=None):
.. code-block:: python
:caption: Example

import mlflow.pytorch
import mlflow

mlflow.pytorch.autolog(checkpoint=True)

model = MyLightningModuleNet() # A custom-pytorch lightning model
train_loader = create_train_dataset_loader()
trainer = Trainer()

with mlflow.start_run() as run:
trainer.fit(net)
trainer.fit(model, train_loader)

run_id = run.info.run_id

Expand All @@ -1177,38 +1178,9 @@ def load_checkpoint(model_class, run_id=None, epoch=None, global_step=None):
# load history checkpoint model logged in second epoch
checkpoint_model = mlflow.pytorch.load_checkpoint(MyLightningModuleNet, run_id, epoch=2)
"""
from mlflow.utils.mlflow_tags import LATEST_CHECKPOINT_ARTIFACT_TAG_KEY

client = MlflowClient()

if run_id is None:
run = mlflow.active_run()
if run is None:
raise MlflowException(
"There is no active run, please provide the 'run_id' for " "'load_checkpoint' call."
)
run_id = run.info.run_id
else:
run = client.get_run(run_id)

latest_checkpoint_artifact_path = run.data.tags.get(LATEST_CHECKPOINT_ARTIFACT_TAG_KEY)
if latest_checkpoint_artifact_path is None:
raise MlflowException("There is no logged checkpoint artifact in the current run.")

checkpoint_filename = os.path.basename(latest_checkpoint_artifact_path)

if epoch is not None and global_step is not None:
raise MlflowException(
"Only one of 'epoch' and 'global_step' can be set for 'load_checkpoint'."
)
elif global_step is not None:
checkpoint_artifact_path = f"checkpoints/global_step_{global_step}/{checkpoint_filename}"
elif epoch is not None:
checkpoint_artifact_path = f"checkpoints/epoch_{epoch}/{checkpoint_filename}"
else:
checkpoint_artifact_path = latest_checkpoint_artifact_path

downloaded_checkpoint_filepath = client.download_artifacts(run_id, checkpoint_artifact_path)
downloaded_checkpoint_filepath = download_checkpoint_artifact(
run_id=run_id, epoch=epoch, global_step=global_step
)
return model_class.load_from_checkpoint(downloaded_checkpoint_filepath)


Expand Down
208 changes: 75 additions & 133 deletions mlflow/pytorch/_lightning_autolog.py
@@ -1,4 +1,4 @@
import logging

Check failure on line 1 in mlflow/pytorch/_lightning_autolog.py

View workflow job for this annotation

GitHub Actions / lint

Unformatted file. Run `ruff format .` or comment `@mlflow-automation autoformat` to format.
import os
import shutil

Check failure on line 3 in mlflow/pytorch/_lightning_autolog.py

View workflow job for this annotation

GitHub Actions / lint

[*] `shutil` imported but unused. Run `ruff --fix .` or comment `@mlflow-automation autoformat` to fix this error.
import tempfile
Expand All @@ -7,7 +7,7 @@
from packaging.version import Version

import mlflow.pytorch
from mlflow.client import MlflowClient

Check failure on line 10 in mlflow/pytorch/_lightning_autolog.py

View workflow job for this annotation

GitHub Actions / lint

[*] `mlflow.client.MlflowClient` imported but unused. Run `ruff --fix .` or comment `@mlflow-automation autoformat` to fix this error.
from mlflow.exceptions import MlflowException
from mlflow.ml_package_versions import _ML_PACKAGE_VERSIONS
from mlflow.pytorch import _pytorch_autolog
Expand All @@ -17,8 +17,7 @@
MlflowAutologgingQueueingClient,
get_autologging_config,
)
from mlflow.utils.file_utils import create_tmp_dir
from mlflow.utils.mlflow_tags import LATEST_CHECKPOINT_ARTIFACT_TAG_KEY
from mlflow.utils.checkpoint_utils import _MlflowModelCheckpointCallbackBase

logging.basicConfig(level=logging.ERROR)
MIN_REQ_VERSION = Version(_ML_PACKAGE_VERSIONS["pytorch-lightning"]["autologging"]["minimum"])
Expand Down Expand Up @@ -283,148 +282,85 @@
self.metrics_logger.flush()


class MlflowModelCheckpointCallback(pl.Callback, metaclass=ExceptionSafeAbstractClass):
class MlflowModelCheckpointCallback(pl.Callback, _MlflowModelCheckpointCallbackBase):
"""Callback for auto-logging pytorch-lightning model checkpoints to MLflow.
This callback implementation only supports pytorch-lightning >= 1.6.0.

Args:
monitor: In automatic model checkpointing, the metric name to monitor if
you set `model_checkpoint_save_best_only` to True.
save_best_only: If True, automatic model checkpointing only saves when
the model is considered the "best" model according to the quantity
monitored and previous checkpoint model is overwritten.
mode: one of {"min", "max"}. In automatic model checkpointing,
if save_best_only=True, the decision to overwrite the current save file is made
based on either the maximization or the minimization of the monitored quantity.
save_weights_only: In automatic model checkpointing, if True, then
only the model鈥檚 weights will be saved. Otherwise, the optimizer states,
lr-scheduler states, etc are added in the checkpoint too.
save_freq: `"epoch"` or integer. When using `"epoch"`, the callback
saves the model after each epoch. When using integer, the callback
saves the model at end of this many batches. Note that if the saving isn't
aligned to epochs, the monitored metric may potentially be less reliable (it
could reflect as little as 1 batch, since the metrics get reset
every epoch). Defaults to `"epoch"`.

.. code-block:: python
:caption: Example

import mlflow
from mlflow.pytorch import MLflowModelCheckpointCallback
from pytorch_lightning import Trainer

mlflow.pytorch.autolog(checkpoint=True)

model = MyLightningModuleNet() # A custom-pytorch lightning model
train_loader = create_train_dataset_loader()

mlflow_checkpoint_callback = MLflowModelCheckpointCallback()

trainer = Trainer(
callbacks=[mlflow_checkpoint_callback]
)

with mlflow.start_run() as run:
trainer.fit(model, train_loader)

"""

def __init__(
self,
client,
run_id,
monitor="val_loss",
mode="min",
save_best_only=True,
save_weights_only=False,
save_freq="epoch",
):
"""
Args:
client: An instance of `MlflowClient`.
run_id: The id of the MLflow run which you want to log checkpoints to.
monitor: In automatic model checkpointing, the metric name to monitor if
you set `model_checkpoint_save_best_only` to True.
save_best_only: If True, automatic model checkpointing only saves when
the model is considered the "best" model according to the quantity
monitored and previous checkpoint model is overwritten.
mode: one of {"min", "max"}. In automatic model checkpointing,
if save_best_only=True, the decision to overwrite the current save file is made
based on either the maximization or the minimization of the monitored quantity.
save_weights_only: In automatic model checkpointing, if True, then
only the model鈥檚 weights will be saved. Otherwise, the optimizer states,
lr-scheduler states, etc are added in the checkpoint too.
save_freq: `"epoch"` or integer. When using `"epoch"`, the callback
saves the model after each epoch. When using integer, the callback
saves the model at end of this many batches. Note that if the saving isn't
aligned to epochs, the monitored metric may potentially be less reliable (it
could reflect as little as 1 batch, since the metrics get reset
every epoch). Defaults to `"epoch"`.
"""
self.client = client
self.run_id = run_id
self.monitor = monitor
self.mode = mode
self.save_best_only = save_best_only
self.save_weights_only = save_weights_only
self.save_freq = save_freq
self.last_monitor_value = None

if self.save_best_only:
if self.monitor is None:
raise MlflowException(
"If checkpoint 'save_best_only' config is set to True, you need to set "
"'monitor' config as well."
)
if self.mode not in ["min", "max"]:
raise MlflowException(
"If checkpoint 'save_best_only' config is set to True, you need to set "
"'mode' config and available modes includes 'min' and 'max', but you set "
f"'mode' to '{self.mode}'."
)

def _is_new_checkpoint_better(self, new_monitor_value):
if self.last_monitor_value is None:
return True

if self.mode == "min":
return new_monitor_value <= self.last_monitor_value

return new_monitor_value >= self.last_monitor_value

def _save_checkpoint_rank_zero_only(self, trainer: "pl.Trainer", filepath: str):
checkpoint = trainer._checkpoint_connector.dump_checkpoint(self.save_weights_only)
trainer.strategy.save_checkpoint(checkpoint, filepath)

def _check_and_save_checkpoint_if_needed(self, trainer: "pl.Trainer"):
current_epoch = trainer.current_epoch
metric_dict = {k: float(v) for k, v in trainer.callback_metrics.items()}

if self.save_best_only:
if self.monitor not in metric_dict:
_logger.warning(
"Checkpoint logging is skipped, because checkpoint 'save_best_only' config is "
"True, it requires to compare the monitored metric value, but the provided "
"monitored metric value is not available."
)
return

new_monitor_value = metric_dict[self.monitor]
if not self._is_new_checkpoint_better(new_monitor_value):
# Current checkpoint is worse than last saved checkpoint,
# so skip checkpointing.
self.last_monitor_value = new_monitor_value
return

self.last_monitor_value = new_monitor_value

if self.save_best_only:
if self.save_weights_only:
checkpoint_model_filename = "latest_checkpoint.weights.pth"
else:
checkpoint_model_filename = "latest_checkpoint.pth"
checkpoint_metrics_filename = "latest_checkpoint_metrics.json"
checkpoint_artifact_dir = "checkpoints"
else:
if self.save_freq == "epoch":
sub_dir_name = f"epoch_{current_epoch}"
else:
sub_dir_name = f"global_step_{trainer.global_step}"

if self.save_weights_only:
checkpoint_model_filename = "checkpoint.weights.pth"
else:
checkpoint_model_filename = "checkpoint.pth"
checkpoint_metrics_filename = "checkpoint_metrics.json"
checkpoint_artifact_dir = f"checkpoints/{sub_dir_name}"

self.client.set_tag(
self.run_id,
LATEST_CHECKPOINT_ARTIFACT_TAG_KEY,
f"{checkpoint_artifact_dir}/{checkpoint_model_filename}",
super().__init__(
checkpoint_file_suffix="pth",
monitor=monitor,
mode=mode,
save_best_only=save_best_only,
save_weights_only=save_weights_only,
save_freq=save_freq,
)

self.client.log_dict(
self.run_id,
{**metric_dict, "epoch": current_epoch, "global_step": trainer.global_step},
f"{checkpoint_artifact_dir}/{checkpoint_metrics_filename}",
self.trainer = None

def save_checkpoint(self, filepath: str):
# Note: `trainer.save_checkpoint` implementation contains invocation of
# `self.strategy.barrier("Trainer.save_checkpoint")`,
# in DDP training, this callback is only invoked in rank 0 process,
# the `barrier` invocation causes deadlock,
# so I implement `save_checkpoint` instead of
# calling `trainer.save_checkpoint`.
checkpoint = self.trainer._checkpoint_connector.dump_checkpoint(
self.save_weights_only
)
self.trainer.strategy.save_checkpoint(checkpoint, filepath)

tmp_dir = create_tmp_dir()
try:
tmp_model_save_path = os.path.join(tmp_dir, checkpoint_model_filename)
# Note: `trainer.save_checkpoint` implementation contains invocation of
# `self.strategy.barrier("Trainer.save_checkpoint")`,
# in DDP training, this callback is only invoked in rank 0 process,
# the `barrier` invocation causes deadlock,
# so I implement `_save_checkpoint_rank_zero_only` instead of
# `trainer.save_checkpoint`.
self._save_checkpoint_rank_zero_only(
trainer,
tmp_model_save_path,
)
self.client.log_artifact(self.run_id, tmp_model_save_path, checkpoint_artifact_dir)
finally:
shutil.rmtree(tmp_dir, ignore_errors=True)
@rank_zero_only
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.trainer = trainer

@rank_zero_only
def on_train_batch_end(
Expand All @@ -438,12 +374,20 @@
if isinstance(self.save_freq, int) and (
trainer.global_step > 0 and trainer.global_step % self.save_freq == 0
):
self._check_and_save_checkpoint_if_needed(trainer)
self.check_and_save_checkpoint_if_needed(
current_epoch=trainer.current_epoch,
global_step=trainer.global_step,
metric_dict={k: float(v) for k, v in trainer.callback_metrics.items()}
)

@rank_zero_only
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if self.save_freq == "epoch":
self._check_and_save_checkpoint_if_needed(trainer)
self.check_and_save_checkpoint_if_needed(
current_epoch=trainer.current_epoch,
global_step=trainer.global_step,
metric_dict={k: float(v) for k, v in trainer.callback_metrics.items()}
)


# PyTorch-Lightning refactored the LoggerConnector class in version 1.4.0 and made metrics
Expand Down Expand Up @@ -581,8 +525,6 @@
):
self.callbacks += [
MlflowModelCheckpointCallback(
client=MlflowClient(tracking_uri),
run_id=run_id,
monitor=checkpoint_monitor,
mode=checkpoint_mode,
save_best_only=checkpoint_save_best_only,
Expand Down