-
Notifications
You must be signed in to change notification settings - Fork 4.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Automatic model checkpointing for pytorch-lightning training #10935
Automatic model checkpointing for pytorch-lightning training #10935
Conversation
Documentation preview for 7a8527c will be available when this CircleCI job completes successfully. More info
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Weichen for the PR!
I took a high-level pass, and the logic looks good to me. Left one comment for argument name choice. Can you share a reproducible notebook or github gist with how to use this new feature? Would be easier to review and spot errors from there, thank you!
mlflow/pytorch/__init__.py
Outdated
model_checkpoint_save_best_only=True, | ||
model_checkpoint_save_weights_only=True, | ||
model_checkpoint_every_n_epochs=None, | ||
model_checkpoint_train_time_interval_S=600, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this variable name is a bit strange. can we combine model_checkpoint_every_n_epochs
and model_checkpoint_train_time_interval_S
tosave_freq
? I find that's pretty clean and easy to use:
https://github.com/keras-team/keras/blob/fe2f54aa5bc42fb23a96449cf90434ab9bb6a2cd/keras/callbacks/model_checkpoint.py#L112
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One question:
The save_freq
in keras checkpointing callback supports "do checkpoint after N batches (steps) execution", do we want to support this ? Or we only want to support epoch based checkpointing ?
My suggestion is we only support epoch based checkpointing, for batch based checkpointing, per-batch metric validation result is less accurate , and in pytorch-lightning, per-batch validation is not available except you log the metric with on_step=True
i.e. LightningModule.log(metric_name, value, on_epoch=True, on_step=True)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Demo notebook is attached in the PR description.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
save_freq
supports both per-epoch saving and per-N-steps saving. Both scenarios are commonly used based on my experience with the production team.
in pytorch-lightning, per-batch validation is not available
I think we can just use the callback hook? It should have the training stats available: https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.Callback.html#lightning.pytorch.callbacks.Callback.on_after_backward
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
summary:
In ModelCheckpointCallback, if we use per-N-steps saving way, it needs to check "monitor" metric and the metric must be updated on per step execution, we need like log(metric_name, value, on_step=True)
to enable per-step updated metric.
We have 2 options:
(1) In Mlflow ModelCheckpointCallback, we update the monitor metric in on_after_backward
.
(2) Document this to tell user to add code in their module class train_step
method: log(metric_name, value, on_step=True)
for the metric used as "monitor"
Option (1) has one issue: In on_after_backward
we can't get the data for metric computation So we might have to choose option (2)
Option (2) is the way that current built-in pytorch-lightning ModelCheckpointCallback uses.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea if 2) is lightning's behavior, we can proceed with that one.
mlflow/pytorch/_lightning_autolog.py
Outdated
tmp_dir = create_tmp_dir() | ||
try: | ||
tmp_model_save_path = os.path.join(tmp_dir, checkpoint_model_filename) | ||
trainer.save_checkpoint(tmp_model_save_path, weights_only=self.save_weights_only) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note:
save_checkpoint
can save both model weights and trainer states.
We can restore the "trainer" by https://pytorch-lightning.readthedocs.io/en/0.8.5/weights_loading.html#restoring-training-state
So shall we provide a helper function to restore the trainer ? (similar to the load_latest_checkpoint
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, resuming training from a saved checkpoint is a common use case. Also users are not necessarily loading the latest checkpoint, so I thnk we can provide a public API mlflow.pytorch.load_checkpoint
, which can load pytorch checkpoint, which includes model weights and optimizer states. This can also be used for vanilla pytorch workflow, not necessarily lightning workflows.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, the new version API becomes:
trainer.fit(model, ckpt_path="some/path/to/my_checkpoint.ckpt")
So we can't return a trainer directly without fit
invocation.
https://lightning.ai/docs/pytorch/stable/common/checkpointing_basic.html#resume-training-state
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it. My previous comment was like we can also load the model weights specifically without relying on the trainer
instance, which means:
model = mlflow.pytorch.load_checkpoint(model_class, mlflow_uri)
trainer = trainer(model)
For example, if we are sharing this model, and someone using vanilla pytorch wants to finetune it with a custom training loop, then they can load the checkpoints.
mlflow/pytorch/__init__.py
Outdated
@@ -1099,3 +1127,36 @@ def print_auto_logged_info(r): | |||
autolog.__doc__ = autolog.__doc__.replace("MIN_REQ_VERSION", str(MIN_REQ_VERSION)).replace( | |||
"MAX_REQ_VERSION", str(MAX_REQ_VERSION) | |||
) | |||
|
|||
|
|||
def load_latest_checkpoint(model_class, run_id=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note:
No matter "save_weights_only" true or false, when loading back, it always requires the "model_class", and it does not need the model object.
This is different with Keras side API:
https://github.com/mlflow/mlflow/pull/10955/files#r1471374884
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense to me!
Thanks Weichen! I played with the notebook, the experience is pretty smooth. Two things about the file we save with checkpoints:
|
9dbe7fa
to
0586388
Compare
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Weichen for the PR, looks pretty good! Left some comments on nit and argument choice.
|
||
assert False, "Illegal __MLflowModelCheckpoint config." | ||
|
||
def _save_checkpoint_rank_zero_only(self, trainer: "pl.Trainer", filepath: str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is train: "pl.Trainer"
a legal type annotation? Shall we import the class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is legal if pl
is available.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ohh I mean shall we remove the double quotes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
double quotes is legal syntax, and other mlflow code also uses it. :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gotcha, I learned something new, thank you!
mlflow/pytorch/_lightning_autolog.py
Outdated
|
||
if self.save_best_only: | ||
if self.monitor not in metric_dict: | ||
# "save-best-only" requires comparing the monitor metric value, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be a better UX to raise an explicit error than silently failing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer to use:
_logger.error(
"If MLflowModelCheckpoint 'save_best_only' config is True, it requires to "
"compare the monitored metric value, but the provided monitored metric value "
"is not available."
)
instead of raise exception,
because the default checkpoint configs are:
checkpoint=True,
checkpoint_monitor="val_loss",
checkpoint_mode="min",
checkpoint_save_best_only=True,
but sometimes the model might not log "val_loss" metric, so that if it raise exception here, it breaks the whole autologging , but we don't need to break other part autologging (e.g. log params / metrics)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense to me!
On top of that, may we call out "checkpoint logging is skipped" in the error message?
Signed-off-by: mlflow-automation <mlflow-automation@users.noreply.github.com>
trainer.fit_loop.epoch_progress.current.completed += 1 | ||
trainer._logger_connector._callback_metrics["val_loss"] -= 0.2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need to set this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For testing "save_best_only" related logic. e.g. "mode" is "min" , then only loss decreases, the new checkpoint is a better one
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like this will be broken in the future once the pytorch-lighting team removes _callback_metrics
or renames it because they don't know this. Relying on private attributes in third-party packages is too risky and would increase maintenance burden.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated! :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated again.
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
@mlflow-automation autoformat |
Signed-off-by: mlflow-automation <mlflow-automation@users.noreply.github.com>
@mlflow-automation autoformat |
…10935) Signed-off-by: Weichen Xu <weichen.xu@databricks.com> Signed-off-by: mlflow-automation <mlflow-automation@users.noreply.github.com> Signed-off-by: Ubuntu <weichen.xu@ip-10-110-25-111.us-west-2.compute.internal> Co-authored-by: mlflow-automation <mlflow-automation@users.noreply.github.com> Co-authored-by: Ubuntu <weichen.xu@ip-10-110-25-111.us-west-2.compute.internal>
…10935) Signed-off-by: Weichen Xu <weichen.xu@databricks.com> Signed-off-by: mlflow-automation <mlflow-automation@users.noreply.github.com> Signed-off-by: Ubuntu <weichen.xu@ip-10-110-25-111.us-west-2.compute.internal> Co-authored-by: mlflow-automation <mlflow-automation@users.noreply.github.com> Co-authored-by: Ubuntu <weichen.xu@ip-10-110-25-111.us-west-2.compute.internal>
…10935) Signed-off-by: Weichen Xu <weichen.xu@databricks.com> Signed-off-by: mlflow-automation <mlflow-automation@users.noreply.github.com> Signed-off-by: Ubuntu <weichen.xu@ip-10-110-25-111.us-west-2.compute.internal> Co-authored-by: mlflow-automation <mlflow-automation@users.noreply.github.com> Co-authored-by: Ubuntu <weichen.xu@ip-10-110-25-111.us-west-2.compute.internal> Signed-off-by: Arthur Jenoudet <arthur.jenoudet@databricks.com>
🛠 DevTools 🛠
Install mlflow from this PR
Checkout with GitHub CLI
Related Issues/PRs
#xxxWhat changes are proposed in this pull request?
Automatic model checkpointing for pytorch-lightning training,
Design doc: https://docs.google.com/document/d/1Ke7-8og_KzV3WE5xOS4XKSLZsISdTsVLlol14vJ3S70/edit
Demo notebook: https://e2-dogfood.staging.cloud.databricks.com/?o=6051921418418893#notebook/2173893049403456
How is this PR tested?
Does this PR require documentation update?
Release Notes
Is this a user-facing change?
What component(s), interfaces, languages, and integrations does this PR affect?
Components
area/artifacts
: Artifact stores and artifact loggingarea/build
: Build and test infrastructure for MLflowarea/deployments
: MLflow Deployments client APIs, server, and third-party Deployments integrationsarea/docs
: MLflow documentation pagesarea/examples
: Example codearea/model-registry
: Model Registry service, APIs, and the fluent client calls for Model Registryarea/models
: MLmodel format, model serialization/deserialization, flavorsarea/recipes
: Recipes, Recipe APIs, Recipe configs, Recipe Templatesarea/projects
: MLproject format, project running backendsarea/scoring
: MLflow Model server, model deployment tools, Spark UDFsarea/server-infra
: MLflow Tracking server backendarea/tracking
: Tracking Service, tracking client APIs, autologgingInterface
area/uiux
: Front-end, user experience, plotting, JavaScript, JavaScript dev serverarea/docker
: Docker use across MLflow's components, such as MLflow Projects and MLflow Modelsarea/sqlalchemy
: Use of SQLAlchemy in the Tracking Service or Model Registryarea/windows
: Windows supportLanguage
language/r
: R APIs and clientslanguage/java
: Java APIs and clientslanguage/new
: Proposals for new client languagesIntegrations
integrations/azure
: Azure and Azure ML integrationsintegrations/sagemaker
: SageMaker integrationsintegrations/databricks
: Databricks integrationsHow should the PR be classified in the release notes? Choose one:
rn/none
- No description will be included. The PR will be mentioned only by the PR number in the "Small Bugfixes and Documentation Updates" sectionrn/breaking-change
- The PR will be mentioned in the "Breaking Changes" sectionrn/feature
- A new user-facing feature worth mentioning in the release notesrn/bug-fix
- A user-facing bug fix worth mentioning in the release notesrn/documentation
- A user-facing documentation change worth mentioning in the release notes