Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keras automatic checkpoint #11197

Merged
merged 23 commits into from Feb 26, 2024

Conversation

WeichenXu123
Copy link
Collaborator

@WeichenXu123 WeichenXu123 commented Feb 20, 2024

馃洜 DevTools 馃洜

Open in GitHub Codespaces

Install mlflow from this PR

pip install git+https://github.com/mlflow/mlflow.git@refs/pull/11197/merge

Checkout with GitHub CLI

gh pr checkout 11197

Related Issues/PRs

#xxx

What changes are proposed in this pull request?

Keras automatic checkpoint implementation.

This PR replaces the old PR: #10955

How is this PR tested?

  • Existing unit/integration tests
  • New unit/integration tests
  • Manual tests

Does this PR require documentation update?

  • No. You can skip the rest of this section.
  • Yes. I've updated:
    • Examples
    • API references
    • Instructions

Release Notes

Is this a user-facing change?

  • No. You can skip the rest of this section.
  • Yes. Give a description of this change to be included in the release notes for MLflow users.

Several automatic checkpoint arguments are added into keras autologging.

What component(s), interfaces, languages, and integrations does this PR affect?

Components

  • area/artifacts: Artifact stores and artifact logging
  • area/build: Build and test infrastructure for MLflow
  • area/deployments: MLflow Deployments client APIs, server, and third-party Deployments integrations
  • area/docs: MLflow documentation pages
  • area/examples: Example code
  • area/model-registry: Model Registry service, APIs, and the fluent client calls for Model Registry
  • area/models: MLmodel format, model serialization/deserialization, flavors
  • area/recipes: Recipes, Recipe APIs, Recipe configs, Recipe Templates
  • area/projects: MLproject format, project running backends
  • area/scoring: MLflow Model server, model deployment tools, Spark UDFs
  • area/server-infra: MLflow Tracking server backend
  • area/tracking: Tracking Service, tracking client APIs, autologging

Interface

  • area/uiux: Front-end, user experience, plotting, JavaScript, JavaScript dev server
  • area/docker: Docker use across MLflow's components, such as MLflow Projects and MLflow Models
  • area/sqlalchemy: Use of SQLAlchemy in the Tracking Service or Model Registry
  • area/windows: Windows support

Language

  • language/r: R APIs and clients
  • language/java: Java APIs and clients
  • language/new: Proposals for new client languages

Integrations

  • integrations/azure: Azure and Azure ML integrations
  • integrations/sagemaker: SageMaker integrations
  • integrations/databricks: Databricks integrations

How should the PR be classified in the release notes? Choose one:

  • rn/none - No description will be included. The PR will be mentioned only by the PR number in the "Small Bugfixes and Documentation Updates" section
  • rn/breaking-change - The PR will be mentioned in the "Breaking Changes" section
  • rn/feature - A new user-facing feature worth mentioning in the release notes
  • rn/bug-fix - A user-facing bug fix worth mentioning in the release notes
  • rn/documentation - A user-facing documentation change worth mentioning in the release notes

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
@github-actions github-actions bot added area/tracking Tracking service, tracking client APIs, autologging rn/feature Mention under Features in Changelogs. labels Feb 20, 2024
Copy link

github-actions bot commented Feb 20, 2024

Documentation preview for 6d9ad09 will be available when this CircleCI job
completes successfully.

More info

Copy link
Collaborator

@chenmoneygithub chenmoneygithub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Weichen, great work! Left some comments.

could reflect as little as 1 batch, since the metrics get reset
every epoch). Defaults to `"epoch"`.
"""
self.client = client
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use the fluent API instead of the client API?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to use client API, it can set tracking uri explicitly (we need to set tracking uri here, for distributed training case). Any reason to use fluent API ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think fluent api is our standard user interface, so it is more robust, also we can shorten the arglist. For distributed training scenario, is mlflow.set_tracking_uri() enough?

But anyway, I don't think this is a blocking issue, just want to open the discussion to align the API design.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, we can remove client arguements from MlflowModelCheckpointCallback constructor.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also removed run_id from MlflowModelCheckpointCallback constructor and use fluent apis instead

mlflow/pytorch/_lightning_autolog.py Outdated Show resolved Hide resolved
@@ -438,15 +359,25 @@ def on_train_batch_end(
batch,
batch_idx,
) -> None:
self.trainer = trainer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little concerned about this line - we are setting a class attribute in on_train_batch_end, which is called multiple times. Do we need this line?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can make it only set it at the first time.

The purpose of setting this attribute is for usage in save_checkpoint, otherwise we need to figure out a way to pass trainer object to save_checkpoint function. Note save_checkpoint is called in base class but its implementation is in sub-class.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we set it in the constructor or in on_train_start hook?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved it into

    @rank_zero_only
    def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        self.trainer = trainer

We can't put it in the constructor, because the callbacks argument is also defined in Trainer constructor, and we need to support usage like:

        mlflow_checkpoint_callback = MLflowModelCheckpointCallback()  # we can't pass trainer object here.

        trainer = Trainer(
            callbacks=[mlflow_checkpoint_callback]
        )

mlflow/tensorflow/__init__.py Show resolved Hide resolved
mlflow/tensorflow/_autolog.py Outdated Show resolved Hide resolved
tests/tensorflow/test_tensorflow2_autolog.py Show resolved Hide resolved
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
@harupy harupy added the enable-dev-tests Enables cross-version tests for dev versions label Feb 22, 2024
Copy link
Collaborator

@chenmoneygithub chenmoneygithub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Weichen! We are pretty close, left a few non-blocking comments.

@@ -286,7 +287,7 @@ def on_test_end(self, trainer, pl_module):
self.metrics_logger.flush()


class MlflowModelCheckpointCallback(pl.Callback, metaclass=ExceptionSafeAbstractClass):
class MlflowModelCheckpointCallback(pl.Callback, MlflowModelCheckpointCallbackBase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although at this time we don't expect users to directly use this class, since this is a public class, can we add a code example on how to use it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to make it a private class _MlflowModelCheckpointCallbackBase

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohh we cannot make it a private class because we are importing it in a different file.

mlflow/pytorch/_lightning_autolog.py Outdated Show resolved Hide resolved
mlflow/pytorch/_lightning_autolog.py Outdated Show resolved Hide resolved
mlflow/tensorflow/__init__.py Outdated Show resolved Hide resolved
mlflow/tensorflow/__init__.py Outdated Show resolved Hide resolved
mlflow/tensorflow/callback.py Show resolved Hide resolved
mlflow/tensorflow/callback.py Outdated Show resolved Hide resolved
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Copy link
Collaborator

@chenmoneygithub chenmoneygithub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Weichen! Approved with a comment

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
@WeichenXu123 WeichenXu123 self-assigned this Feb 26, 2024
Copy link
Member

@harupy harupy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

run_id=run_id, epoch=epoch, global_step=global_step, dst_path=tmp_dir.path()
)

if os.path.basename(downloaded_checkpoint_filepath).split(".")[-2] == "weights":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line looks a bit scary. .split(".")[-2] fails if the result of os.path.basename(downloaded_checkpoint_filepath) doesn't contain a dot.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Improved error handling for this case:

        artifact_name = os.path.basename(downloaded_checkpoint_filepath)
        artifact_name_splits = artifact_name.split(".")
        if len(artifact_name_splits) < 2:
            raise MlflowException(
                f"The model checkpoint artifact file name '{artifact_name}' is malformed."
            )

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if a str doesn't contain dot, 'xx'.split(".") gets ['xx']

Copy link
Member

@harupy harupy Feb 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd use endswith or regex so we don't need os.path.basename or split.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally I prefer split than regex, if we incautiously write a bad regex it might cause performance issue when matching a long string, split is safer to me.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
@WeichenXu123
Copy link
Collaborator Author

Test failures in tensorflow / dev are not relevent.

@WeichenXu123 WeichenXu123 merged commit b58a12b into mlflow:master Feb 26, 2024
85 of 88 checks passed
serena-ruan pushed a commit to serena-ruan/mlflow that referenced this pull request Feb 28, 2024
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
artjen pushed a commit to artjen/mlflow that referenced this pull request Mar 26, 2024
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Arthur Jenoudet <arthur.jenoudet@databricks.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
area/tracking Tracking service, tracking client APIs, autologging enable-dev-tests Enables cross-version tests for dev versions rn/feature Mention under Features in Changelogs.
Projects
Status: Merged
Development

Successfully merging this pull request may close these issues.

None yet

3 participants