New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
Keras automatic checkpoint #11197
Keras automatic checkpoint #11197
Conversation
Documentation preview for 6d9ad09 will be available when this CircleCI job More info
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Weichen, great work! Left some comments.
mlflow/utils/checkpoint_utils.py
Outdated
could reflect as little as 1 batch, since the metrics get reset | ||
every epoch). Defaults to `"epoch"`. | ||
""" | ||
self.client = client |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we use the fluent API instead of the client API?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer to use client API, it can set tracking uri explicitly (we need to set tracking uri here, for distributed training case). Any reason to use fluent API ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think fluent api is our standard user interface, so it is more robust, also we can shorten the arglist. For distributed training scenario, is mlflow.set_tracking_uri()
enough?
But anyway, I don't think this is a blocking issue, just want to open the discussion to align the API design.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, we can remove client
arguements from MlflowModelCheckpointCallback
constructor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also removed run_id from MlflowModelCheckpointCallback
constructor and use fluent apis instead
mlflow/pytorch/_lightning_autolog.py
Outdated
@@ -438,15 +359,25 @@ def on_train_batch_end( | |||
batch, | |||
batch_idx, | |||
) -> None: | |||
self.trainer = trainer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a little concerned about this line - we are setting a class attribute in on_train_batch_end
, which is called multiple times. Do we need this line?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can make it only set it at the first time.
The purpose of setting this attribute is for usage in save_checkpoint
, otherwise we need to figure out a way to pass trainer
object to save_checkpoint
function. Note save_checkpoint
is called in base class but its implementation is in sub-class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we set it in the constructor or in on_train_start
hook?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved it into
@rank_zero_only
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.trainer = trainer
We can't put it in the constructor, because the callbacks
argument is also defined in Trainer constructor, and we need to support usage like:
mlflow_checkpoint_callback = MLflowModelCheckpointCallback() # we can't pass trainer object here.
trainer = Trainer(
callbacks=[mlflow_checkpoint_callback]
)
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Weichen! We are pretty close, left a few non-blocking comments.
@@ -286,7 +287,7 @@ def on_test_end(self, trainer, pl_module): | |||
self.metrics_logger.flush() | |||
|
|||
|
|||
class MlflowModelCheckpointCallback(pl.Callback, metaclass=ExceptionSafeAbstractClass): | |||
class MlflowModelCheckpointCallback(pl.Callback, MlflowModelCheckpointCallbackBase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although at this time we don't expect users to directly use this class, since this is a public class, can we add a code example on how to use it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer to make it a private class _MlflowModelCheckpointCallbackBase
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ohh we cannot make it a private class because we are importing it in a different file.
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Weichen! Approved with a comment
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
mlflow/tensorflow/__init__.py
Outdated
run_id=run_id, epoch=epoch, global_step=global_step, dst_path=tmp_dir.path() | ||
) | ||
|
||
if os.path.basename(downloaded_checkpoint_filepath).split(".")[-2] == "weights": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line looks a bit scary. .split(".")[-2]
fails if the result of os.path.basename(downloaded_checkpoint_filepath)
doesn't contain a dot.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Improved error handling for this case:
artifact_name = os.path.basename(downloaded_checkpoint_filepath)
artifact_name_splits = artifact_name.split(".")
if len(artifact_name_splits) < 2:
raise MlflowException(
f"The model checkpoint artifact file name '{artifact_name}' is malformed."
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if a str doesn't contain dot, 'xx'.split(".")
gets ['xx']
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd use endswith
or regex so we don't need os.path.basename
or split.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Personally I prefer split
than regex
, if we incautiously write a bad regex it might cause performance issue when matching a long string, split
is safer to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Test failures in tensorflow / dev are not relevent. |
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com> Signed-off-by: Arthur Jenoudet <arthur.jenoudet@databricks.com>
馃洜 DevTools 馃洜
Install mlflow from this PR
Checkout with GitHub CLI
Related Issues/PRs
#xxxWhat changes are proposed in this pull request?
Keras automatic checkpoint implementation.
This PR replaces the old PR: #10955
How is this PR tested?
Does this PR require documentation update?
Release Notes
Is this a user-facing change?
Several automatic checkpoint arguments are added into keras autologging.
What component(s), interfaces, languages, and integrations does this PR affect?
Components
area/artifacts
: Artifact stores and artifact loggingarea/build
: Build and test infrastructure for MLflowarea/deployments
: MLflow Deployments client APIs, server, and third-party Deployments integrationsarea/docs
: MLflow documentation pagesarea/examples
: Example codearea/model-registry
: Model Registry service, APIs, and the fluent client calls for Model Registryarea/models
: MLmodel format, model serialization/deserialization, flavorsarea/recipes
: Recipes, Recipe APIs, Recipe configs, Recipe Templatesarea/projects
: MLproject format, project running backendsarea/scoring
: MLflow Model server, model deployment tools, Spark UDFsarea/server-infra
: MLflow Tracking server backendarea/tracking
: Tracking Service, tracking client APIs, autologgingInterface
area/uiux
: Front-end, user experience, plotting, JavaScript, JavaScript dev serverarea/docker
: Docker use across MLflow's components, such as MLflow Projects and MLflow Modelsarea/sqlalchemy
: Use of SQLAlchemy in the Tracking Service or Model Registryarea/windows
: Windows supportLanguage
language/r
: R APIs and clientslanguage/java
: Java APIs and clientslanguage/new
: Proposals for new client languagesIntegrations
integrations/azure
: Azure and Azure ML integrationsintegrations/sagemaker
: SageMaker integrationsintegrations/databricks
: Databricks integrationsHow should the PR be classified in the release notes? Choose one:
rn/none
- No description will be included. The PR will be mentioned only by the PR number in the "Small Bugfixes and Documentation Updates" sectionrn/breaking-change
- The PR will be mentioned in the "Breaking Changes" sectionrn/feature
- A new user-facing feature worth mentioning in the release notesrn/bug-fix
- A user-facing bug fix worth mentioning in the release notesrn/documentation
- A user-facing documentation change worth mentioning in the release notes