Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatic model checkpointing for pytorch-lightning training #10935

Merged
merged 54 commits into from
Feb 14, 2024

Conversation

WeichenXu123
Copy link
Collaborator

@WeichenXu123 WeichenXu123 commented Jan 29, 2024

🛠 DevTools 🛠

Open in GitHub Codespaces

Install mlflow from this PR

pip install git+https://github.com/mlflow/mlflow.git@refs/pull/10935/merge

Checkout with GitHub CLI

gh pr checkout 10935

Related Issues/PRs

#xxx

What changes are proposed in this pull request?

Automatic model checkpointing for pytorch-lightning training,

Design doc: https://docs.google.com/document/d/1Ke7-8og_KzV3WE5xOS4XKSLZsISdTsVLlol14vJ3S70/edit

Demo notebook: https://e2-dogfood.staging.cloud.databricks.com/?o=6051921418418893#notebook/2173893049403456

How is this PR tested?

  • Existing unit/integration tests
  • New unit/integration tests
  • Manual tests

Does this PR require documentation update?

  • No. You can skip the rest of this section.
  • Yes. I've updated:
    • Examples
    • API references
    • Instructions

Release Notes

Is this a user-facing change?

  • No. You can skip the rest of this section.
  • Yes. Give a description of this change to be included in the release notes for MLflow users.

What component(s), interfaces, languages, and integrations does this PR affect?

Components

  • area/artifacts: Artifact stores and artifact logging
  • area/build: Build and test infrastructure for MLflow
  • area/deployments: MLflow Deployments client APIs, server, and third-party Deployments integrations
  • area/docs: MLflow documentation pages
  • area/examples: Example code
  • area/model-registry: Model Registry service, APIs, and the fluent client calls for Model Registry
  • area/models: MLmodel format, model serialization/deserialization, flavors
  • area/recipes: Recipes, Recipe APIs, Recipe configs, Recipe Templates
  • area/projects: MLproject format, project running backends
  • area/scoring: MLflow Model server, model deployment tools, Spark UDFs
  • area/server-infra: MLflow Tracking server backend
  • area/tracking: Tracking Service, tracking client APIs, autologging

Interface

  • area/uiux: Front-end, user experience, plotting, JavaScript, JavaScript dev server
  • area/docker: Docker use across MLflow's components, such as MLflow Projects and MLflow Models
  • area/sqlalchemy: Use of SQLAlchemy in the Tracking Service or Model Registry
  • area/windows: Windows support

Language

  • language/r: R APIs and clients
  • language/java: Java APIs and clients
  • language/new: Proposals for new client languages

Integrations

  • integrations/azure: Azure and Azure ML integrations
  • integrations/sagemaker: SageMaker integrations
  • integrations/databricks: Databricks integrations

How should the PR be classified in the release notes? Choose one:

  • rn/none - No description will be included. The PR will be mentioned only by the PR number in the "Small Bugfixes and Documentation Updates" section
  • rn/breaking-change - The PR will be mentioned in the "Breaking Changes" section
  • rn/feature - A new user-facing feature worth mentioning in the release notes
  • rn/bug-fix - A user-facing bug fix worth mentioning in the release notes
  • rn/documentation - A user-facing documentation change worth mentioning in the release notes

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Copy link

github-actions bot commented Jan 29, 2024

Documentation preview for 7a8527c will be available when this CircleCI job completes successfully.

More info

@WeichenXu123 WeichenXu123 marked this pull request as draft January 29, 2024 15:37
Copy link
Collaborator

@chenmoneygithub chenmoneygithub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Weichen for the PR!

I took a high-level pass, and the logic looks good to me. Left one comment for argument name choice. Can you share a reproducible notebook or github gist with how to use this new feature? Would be easier to review and spot errors from there, thank you!

model_checkpoint_save_best_only=True,
model_checkpoint_save_weights_only=True,
model_checkpoint_every_n_epochs=None,
model_checkpoint_train_time_interval_S=600,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this variable name is a bit strange. can we combine model_checkpoint_every_n_epochs and model_checkpoint_train_time_interval_S tosave_freq? I find that's pretty clean and easy to use:
https://github.com/keras-team/keras/blob/fe2f54aa5bc42fb23a96449cf90434ab9bb6a2cd/keras/callbacks/model_checkpoint.py#L112

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One question:

The save_freq in keras checkpointing callback supports "do checkpoint after N batches (steps) execution", do we want to support this ? Or we only want to support epoch based checkpointing ?

My suggestion is we only support epoch based checkpointing, for batch based checkpointing, per-batch metric validation result is less accurate , and in pytorch-lightning, per-batch validation is not available except you log the metric with on_step=True i.e. LightningModule.log(metric_name, value, on_epoch=True, on_step=True)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Demo notebook is attached in the PR description.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

save_freq supports both per-epoch saving and per-N-steps saving. Both scenarios are commonly used based on my experience with the production team.

in pytorch-lightning, per-batch validation is not available

I think we can just use the callback hook? It should have the training stats available: https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.Callback.html#lightning.pytorch.callbacks.Callback.on_after_backward

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

summary:

In ModelCheckpointCallback, if we use per-N-steps saving way, it needs to check "monitor" metric and the metric must be updated on per step execution, we need like log(metric_name, value, on_step=True) to enable per-step updated metric.

We have 2 options:
(1) In Mlflow ModelCheckpointCallback, we update the monitor metric in on_after_backward.
(2) Document this to tell user to add code in their module class train_step method: log(metric_name, value, on_step=True) for the metric used as "monitor"

Option (1) has one issue: In on_after_backward we can't get the data for metric computation So we might have to choose option (2)

Option (2) is the way that current built-in pytorch-lightning ModelCheckpointCallback uses.

@chenmoneygithub

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea if 2) is lightning's behavior, we can proceed with that one.

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
tmp_dir = create_tmp_dir()
try:
tmp_model_save_path = os.path.join(tmp_dir, checkpoint_model_filename)
trainer.save_checkpoint(tmp_model_save_path, weights_only=self.save_weights_only)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note:

save_checkpoint can save both model weights and trainer states.

We can restore the "trainer" by https://pytorch-lightning.readthedocs.io/en/0.8.5/weights_loading.html#restoring-training-state

So shall we provide a helper function to restore the trainer ? (similar to the load_latest_checkpoint)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, resuming training from a saved checkpoint is a common use case. Also users are not necessarily loading the latest checkpoint, so I thnk we can provide a public API mlflow.pytorch.load_checkpoint, which can load pytorch checkpoint, which includes model weights and optimizer states. This can also be used for vanilla pytorch workflow, not necessarily lightning workflows.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, the new version API becomes:

trainer.fit(model, ckpt_path="some/path/to/my_checkpoint.ckpt")

So we can't return a trainer directly without fit invocation.

https://lightning.ai/docs/pytorch/stable/common/checkpointing_basic.html#resume-training-state

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. My previous comment was like we can also load the model weights specifically without relying on the trainer instance, which means:

model = mlflow.pytorch.load_checkpoint(model_class, mlflow_uri)
trainer = trainer(model)

For example, if we are sharing this model, and someone using vanilla pytorch wants to finetune it with a custom training loop, then they can load the checkpoints.

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
@@ -1099,3 +1127,36 @@ def print_auto_logged_info(r):
autolog.__doc__ = autolog.__doc__.replace("MIN_REQ_VERSION", str(MIN_REQ_VERSION)).replace(
"MAX_REQ_VERSION", str(MAX_REQ_VERSION)
)


def load_latest_checkpoint(model_class, run_id=None):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note:

No matter "save_weights_only" true or false, when loading back, it always requires the "model_class", and it does not need the model object.

This is different with Keras side API:
https://github.com/mlflow/mlflow/pull/10955/files#r1471374884

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense to me!

@chenmoneygithub
Copy link
Collaborator

Thanks Weichen! I played with the notebook, the experience is pretty smooth. Two things about the file we save with checkpoints:

  • Can we put checkpoints along with the metrics into one directory? Now it's flat files in a parent directory. With a large number of checkpoints logged, it's hard to navigate.
  • We don't necessarily have eval metrics for every checkpoint we save, e.g., if we save checkpoints per 5000 steps, we may not have the eval metrics. We can make "eval at checkpoint saving" optional when saving per N steps.

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
@WeichenXu123 WeichenXu123 marked this pull request as ready for review February 1, 2024 14:08
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Copy link
Collaborator

@chenmoneygithub chenmoneygithub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Weichen for the PR, looks pretty good! Left some comments on nit and argument choice.

mlflow/pytorch/__init__.py Outdated Show resolved Hide resolved
mlflow/pytorch/__init__.py Outdated Show resolved Hide resolved
mlflow/pytorch/__init__.py Outdated Show resolved Hide resolved
mlflow/pytorch/__init__.py Outdated Show resolved Hide resolved
mlflow/pytorch/__init__.py Outdated Show resolved Hide resolved
mlflow/pytorch/_lightning_autolog.py Show resolved Hide resolved
mlflow/pytorch/_lightning_autolog.py Outdated Show resolved Hide resolved

assert False, "Illegal __MLflowModelCheckpoint config."

def _save_checkpoint_rank_zero_only(self, trainer: "pl.Trainer", filepath: str):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is train: "pl.Trainer" a legal type annotation? Shall we import the class?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is legal if pl is available.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohh I mean shall we remove the double quotes?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

double quotes is legal syntax, and other mlflow code also uses it. :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gotcha, I learned something new, thank you!


if self.save_best_only:
if self.monitor not in metric_dict:
# "save-best-only" requires comparing the monitor metric value,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be a better UX to raise an explicit error than silently failing.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to use:

                _logger.error(
                    "If MLflowModelCheckpoint 'save_best_only' config is True, it requires to "
                    "compare the monitored metric value, but the provided monitored metric value "
                    "is not available."
                )

instead of raise exception,

because the default checkpoint configs are:

    checkpoint=True,
    checkpoint_monitor="val_loss",
    checkpoint_mode="min",
    checkpoint_save_best_only=True,

but sometimes the model might not log "val_loss" metric, so that if it raise exception here, it breaks the whole autologging , but we don't need to break other part autologging (e.g. log params / metrics)

@chenmoneygithub

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense to me!

On top of that, may we call out "checkpoint logging is skipped" in the error message?

mlflow/pytorch/_lightning_autolog.py Outdated Show resolved Hide resolved
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: mlflow-automation <mlflow-automation@users.noreply.github.com>
mlflow/pytorch/__init__.py Outdated Show resolved Hide resolved
Comment on lines 658 to 659
trainer.fit_loop.epoch_progress.current.completed += 1
trainer._logger_connector._callback_metrics["val_loss"] -= 0.2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to set this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For testing "save_best_only" related logic. e.g. "mode" is "min" , then only loss decreases, the new checkpoint is a better one

Copy link
Member

@harupy harupy Feb 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like this will be broken in the future once the pytorch-lighting team removes _callback_metrics or renames it because they don't know this. Relying on private attributes in third-party packages is too risky and would increase maintenance burden.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated! :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated again.

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
@WeichenXu123
Copy link
Collaborator Author

@mlflow-automation autoformat

@WeichenXu123
Copy link
Collaborator Author

@mlflow-automation autoformat

Ubuntu added 3 commits February 14, 2024 06:19
Signed-off-by: Ubuntu <weichen.xu@ip-10-110-25-111.us-west-2.compute.internal>
Signed-off-by: Ubuntu <weichen.xu@ip-10-110-25-111.us-west-2.compute.internal>
Signed-off-by: Ubuntu <weichen.xu@ip-10-110-25-111.us-west-2.compute.internal>
@WeichenXu123 WeichenXu123 enabled auto-merge (squash) February 14, 2024 06:35
@WeichenXu123 WeichenXu123 merged commit 6f6bf85 into mlflow:master Feb 14, 2024
62 checks passed
annzhang-db pushed a commit to annzhang-db/mlflow that referenced this pull request Feb 14, 2024
…10935)

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: mlflow-automation <mlflow-automation@users.noreply.github.com>
Signed-off-by: Ubuntu <weichen.xu@ip-10-110-25-111.us-west-2.compute.internal>
Co-authored-by: mlflow-automation <mlflow-automation@users.noreply.github.com>
Co-authored-by: Ubuntu <weichen.xu@ip-10-110-25-111.us-west-2.compute.internal>
@WeichenXu123 WeichenXu123 self-assigned this Feb 15, 2024
sateeshmannar pushed a commit to StateFarmIns/mlflow that referenced this pull request Feb 20, 2024
…10935)

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: mlflow-automation <mlflow-automation@users.noreply.github.com>
Signed-off-by: Ubuntu <weichen.xu@ip-10-110-25-111.us-west-2.compute.internal>
Co-authored-by: mlflow-automation <mlflow-automation@users.noreply.github.com>
Co-authored-by: Ubuntu <weichen.xu@ip-10-110-25-111.us-west-2.compute.internal>
artjen pushed a commit to artjen/mlflow that referenced this pull request Mar 26, 2024
…10935)

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: mlflow-automation <mlflow-automation@users.noreply.github.com>
Signed-off-by: Ubuntu <weichen.xu@ip-10-110-25-111.us-west-2.compute.internal>
Co-authored-by: mlflow-automation <mlflow-automation@users.noreply.github.com>
Co-authored-by: Ubuntu <weichen.xu@ip-10-110-25-111.us-west-2.compute.internal>
Signed-off-by: Arthur Jenoudet <arthur.jenoudet@databricks.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
area/tracking Tracking service, tracking client APIs, autologging rn/feature Mention under Features in Changelogs.
Projects
Status: Merged
Development

Successfully merging this pull request may close these issues.

None yet

4 participants