Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an API to download and save model weights and use it for UC registration #11157

Merged
merged 3 commits into from Feb 16, 2024

Conversation

B-Step62
Copy link
Collaborator

@B-Step62 B-Step62 commented Feb 15, 2024

🛠 DevTools 🛠

Open in GitHub Codespaces

Install mlflow from this PR

pip install git+https://github.com/mlflow/mlflow.git@refs/pull/11157/merge

Checkout with GitHub CLI

gh pr checkout 11157

What changes are proposed in this pull request?

This PR includes two (related) changes:

1. Introduce new persist_pretrained_model API.
Transformer "weigth-less" models cannot be deployed to Databricks model-serving as it is. This is because otherwise we need to download the weights from HuggingFace Hub in our control plane, which might impact service availability. Therefore, we block customers to register a model to DB Workspace Model Registry, but there should be a way to convert such weigth-less models to be the ones that can be registered.

Thus, this PR adds a new API mlflow.transformers.persist_pretrained_model(model_uri), with which users can download the model weight from the HF Hub and save into the existing weight-less model.

2. Handle UC registration
Initially, we also planned to block customers from registering weight-less models to UC. However, it turns out that the UC registration process is always done in the users notebook not our control plane, meaning that the download from the HF Hub is not harmful. We can simply use the new API to make the model registrable on behalf of users for avoid frustrating experience.

Note: For OSS Model Registry, we need no change i.e. not blocking registration. This is because the model weight is downloaded (on users' side) when they deploy the model to deployment target. This provides consistent non-blocking experience across OSS and UC, and since the Workspace Model Registry is considered as legacy, I think it's reasonable to have limitation there.

Tracker

This PR is filed for the PEFT feature branch. More changes are needed be done before merging the feature branch to master:

  • Introduce save_pretrained flag and implement saving/loading logic (PR).
  • Block registering model to Databricks Workspace Model Registry (WIP)
  • [This PR] Implement an API to download the weight files to the existing weight-less model (so it can be registered without re-logging).
  • Support PEFT model. (PR)
  • Update documentation and examples

How is this PR tested?

  • Existing unit/integration tests
  • New unit/integration tests
  • Manual tests
Transformers.Unity.Catalog.Registration.Test.mov

Steps:

  1. Log Transformers model with save_pretrained=False
  2. Check the logged model doesn't contain model weight.
  3. Click "Register Model" button and see code snippet for UC registration
  4. Go back to notebook and run registration code.
  5. Check the model weight is persisted in the model artifact and MLModel file is updated properly.

Does this PR require documentation update?

  • No. You can skip the rest of this section.
  • Yes. I've updated:
    • Examples
    • API references
    • Instructions

Documentation update will be done in a follow-up PR planned before the release.

Release Notes

Is this a user-facing change?

  • No. You can skip the rest of this section.
  • Yes. Give a description of this change to be included in the release notes for MLflow users.

Introduce a new API mlflow.transformers.persist_pretrained_model to download and save model weights, so that the Transformers "weight-less" model can be registered to Databricks Workspace Model Registry.

What component(s), interfaces, languages, and integrations does this PR affect?

Components

  • area/artifacts: Artifact stores and artifact logging
  • area/build: Build and test infrastructure for MLflow
  • area/deployments: MLflow Deployments client APIs, server, and third-party Deployments integrations
  • area/docs: MLflow documentation pages
  • area/examples: Example code
  • area/model-registry: Model Registry service, APIs, and the fluent client calls for Model Registry
  • area/models: MLmodel format, model serialization/deserialization, flavors
  • area/recipes: Recipes, Recipe APIs, Recipe configs, Recipe Templates
  • area/projects: MLproject format, project running backends
  • area/scoring: MLflow Model server, model deployment tools, Spark UDFs
  • area/server-infra: MLflow Tracking server backend
  • area/tracking: Tracking Service, tracking client APIs, autologging

Interface

  • area/uiux: Front-end, user experience, plotting, JavaScript, JavaScript dev server
  • area/docker: Docker use across MLflow's components, such as MLflow Projects and MLflow Models
  • area/sqlalchemy: Use of SQLAlchemy in the Tracking Service or Model Registry
  • area/windows: Windows support

Language

  • language/r: R APIs and clients
  • language/java: Java APIs and clients
  • language/new: Proposals for new client languages

Integrations

  • integrations/azure: Azure and Azure ML integrations
  • integrations/sagemaker: SageMaker integrations
  • integrations/databricks: Databricks integrations

How should the PR be classified in the release notes? Choose one:

  • rn/none - No description will be included. The PR will be mentioned only by the PR number in the "Small Bugfixes and Documentation Updates" section
  • rn/breaking-change - The PR will be mentioned in the "Breaking Changes" section
  • rn/feature - A new user-facing feature worth mentioning in the release notes
  • rn/bug-fix - A user-facing bug fix worth mentioning in the release notes
  • rn/documentation - A user-facing documentation change worth mentioning in the release notes

The API is primary for updating an MLflow model that was logged/saved
with setting save_pretrained=False, so it can be registered to the
Model Registry.

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
…del registration

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
@github-actions github-actions bot added area/model-registry Model registry, model registry APIs, and the fluent client calls for model registry area/models MLmodel format, model serialization/deserialization, flavors rn/feature Mention under Features in Changelogs. labels Feb 15, 2024
Copy link

github-actions bot commented Feb 15, 2024

Documentation preview for 68d6d8c will be available when this CircleCI job completes successfully.

More info

# Now the model can be registered to the Model Registry
mlflow.register_model(f"runs:/{run.info.run_id}/pipeline", "qa_pipeline")
"""
with tempfile.TemporaryDirectory() as tmp_dir:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use mlflow.utils.file_utils.TempDir :)

Copy link
Member

@harupy harupy Feb 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might want to rename TempDir to a name like SafeTempDir or LargeTempDir to imply it's different from tempfile.TemporaryDirectory. Out of scope of this PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch, thanks! Yeah agree that the utility name can more explicitly speak that it is for handling large files in Databricks.


# Now the model can be registered to the Model Registry
mlflow.register_model(f"runs:/{run.info.run_id}/pipeline", "qa_pipeline")
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this a no-op if the weights have already been loaded and are present in the directory? Since this is a public API, a user might call this function multiple times by accident and we only want a single copy of the weights. We can inform the user if the model weights are already loaded and let them know in a pleasant way that they've already been loaded and that they're "good to go" :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that we do that in update_flavor_conf_to_persist_pretrained_model, but is there a quick way to validate by listing the artifacts in the artifact store before downloading anything from the artifact store to local?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea! I think we can do that by checking model file exists or not in the artifact list. One concern is that the reliance of that directory name is not super robust (like in older version of MLflow we didn't have that instead pipeline directory), but I think it's fine as long as the worst case scenario is just unnecessary download.

transformers_model=small_seq2seq_pipeline,
artifact_path="model",
save_pretrained=False,
pip_requirements=["mlflow"], # For speed up logging
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

smart idea :) we should probably do this for more tests!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haha yea we still need keep some to test the requirement inference logic, but it doesn't have to repeat that many times like now:)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think as part of refactoring of the test modules for this flavor we should explicitly create an 'integration test suite' that tests happy path of 4-5 pipelines / components declarations and no more. Everything else can be migrated to other suites where we're not creating full pipelines and are mostly just mocking the downloads / instantiations of the pipelines themselves.
What are your thoughts on that?

Copy link
Collaborator Author

@B-Step62 B-Step62 Feb 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally agree, current tests feel too much for dev inner-loop. We should have more "true unit" tests by isolating each functionality and mocking integration points. And port heavy e2e tests to daily/weekly job like docker tests.

@BenWilson2
Copy link
Member

AWESOME idea on the pre-fetch prior to registering. LOVE IT!

Copy link
Member

@BenWilson2 BenWilson2 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM once the artifact listing check is introduced for a short-circuit avoiding local downloads to do validation :) Great work!

Copy link
Member

@harupy harupy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
@B-Step62 B-Step62 merged commit 5d46f7b into mlflow:peft Feb 16, 2024
61 checks passed
@B-Step62 B-Step62 deleted the uc-register branch February 16, 2024 05:10
B-Step62 added a commit to B-Step62/mlflow that referenced this pull request Feb 19, 2024
…tration (mlflow#11157)

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
B-Step62 added a commit that referenced this pull request Feb 23, 2024
…tration (#11157)

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
B-Step62 added a commit to B-Step62/mlflow that referenced this pull request Feb 26, 2024
…tration (mlflow#11157)

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
B-Step62 added a commit that referenced this pull request Feb 27, 2024
…tration (#11157)

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
B-Step62 added a commit that referenced this pull request Feb 28, 2024
…tration (#11157)

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
area/model-registry Model registry, model registry APIs, and the fluent client calls for model registry area/models MLmodel format, model serialization/deserialization, flavors rn/feature Mention under Features in Changelogs.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants