New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add an API to download and save model weights and use it for UC registration #11157
Conversation
The API is primary for updating an MLflow model that was logged/saved with setting save_pretrained=False, so it can be registered to the Model Registry. Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
…del registration Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
Documentation preview for 68d6d8c will be available when this CircleCI job completes successfully. More info
|
mlflow/transformers/__init__.py
Outdated
# Now the model can be registered to the Model Registry | ||
mlflow.register_model(f"runs:/{run.info.run_id}/pipeline", "qa_pipeline") | ||
""" | ||
with tempfile.TemporaryDirectory() as tmp_dir: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use mlflow.utils.file_utils.TempDir
:)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might want to rename TempDir
to a name like SafeTempDir
or LargeTempDir
to imply it's different from tempfile.TemporaryDirectory
. Out of scope of this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great catch, thanks! Yeah agree that the utility name can more explicitly speak that it is for handling large files in Databricks.
tests/store/_unity_catalog/model_registry/test_unity_catalog_rest_store.py
Outdated
Show resolved
Hide resolved
tests/store/_unity_catalog/model_registry/test_unity_catalog_rest_store.py
Outdated
Show resolved
Hide resolved
|
||
# Now the model can be registered to the Model Registry | ||
mlflow.register_model(f"runs:/{run.info.run_id}/pipeline", "qa_pipeline") | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make this a no-op if the weights have already been loaded and are present in the directory? Since this is a public API, a user might call this function multiple times by accident and we only want a single copy of the weights. We can inform the user if the model weights are already loaded and let them know in a pleasant way that they've already been loaded and that they're "good to go" :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see that we do that in update_flavor_conf_to_persist_pretrained_model
, but is there a quick way to validate by listing the artifacts in the artifact store before downloading anything from the artifact store to local?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea! I think we can do that by checking model
file exists or not in the artifact list. One concern is that the reliance of that directory name is not super robust (like in older version of MLflow we didn't have that instead pipeline
directory), but I think it's fine as long as the worst case scenario is just unnecessary download.
transformers_model=small_seq2seq_pipeline, | ||
artifact_path="model", | ||
save_pretrained=False, | ||
pip_requirements=["mlflow"], # For speed up logging |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
smart idea :) we should probably do this for more tests!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
haha yea we still need keep some to test the requirement inference logic, but it doesn't have to repeat that many times like now:)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think as part of refactoring of the test modules for this flavor we should explicitly create an 'integration test suite' that tests happy path of 4-5 pipelines / components declarations and no more. Everything else can be migrated to other suites where we're not creating full pipelines and are mostly just mocking the downloads / instantiations of the pipelines themselves.
What are your thoughts on that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Totally agree, current tests feel too much for dev inner-loop. We should have more "true unit" tests by isolating each functionality and mocking integration points. And port heavy e2e tests to daily/weekly job like docker tests.
AWESOME idea on the pre-fetch prior to registering. LOVE IT! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM once the artifact listing check is introduced for a short-circuit avoiding local downloads to do validation :) Great work!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
…tration (mlflow#11157) Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
…tration (#11157) Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
…tration (mlflow#11157) Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
…tration (#11157) Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
…tration (#11157) Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
🛠 DevTools 🛠
Install mlflow from this PR
Checkout with GitHub CLI
What changes are proposed in this pull request?
This PR includes two (related) changes:
1. Introduce new
persist_pretrained_model
API.Transformer "weigth-less" models cannot be deployed to Databricks model-serving as it is. This is because otherwise we need to download the weights from HuggingFace Hub in our control plane, which might impact service availability. Therefore, we block customers to register a model to DB Workspace Model Registry, but there should be a way to convert such weigth-less models to be the ones that can be registered.
Thus, this PR adds a new API
mlflow.transformers.persist_pretrained_model(model_uri)
, with which users can download the model weight from the HF Hub and save into the existing weight-less model.2. Handle UC registration
Initially, we also planned to block customers from registering weight-less models to UC. However, it turns out that the UC registration process is always done in the users notebook not our control plane, meaning that the download from the HF Hub is not harmful. We can simply use the new API to make the model registrable on behalf of users for avoid frustrating experience.
Note: For OSS Model Registry, we need no change i.e. not blocking registration. This is because the model weight is downloaded (on users' side) when they deploy the model to deployment target. This provides consistent non-blocking experience across OSS and UC, and since the Workspace Model Registry is considered as legacy, I think it's reasonable to have limitation there.
Tracker
This PR is filed for the PEFT feature branch. More changes are needed be done before merging the feature branch to master:
save_pretrained
flag and implement saving/loading logic (PR).How is this PR tested?
Transformers.Unity.Catalog.Registration.Test.mov
Steps:
save_pretrained=False
Does this PR require documentation update?
Documentation update will be done in a follow-up PR planned before the release.
Release Notes
Is this a user-facing change?
Introduce a new API
mlflow.transformers.persist_pretrained_model
to download and save model weights, so that the Transformers "weight-less" model can be registered to Databricks Workspace Model Registry.What component(s), interfaces, languages, and integrations does this PR affect?
Components
area/artifacts
: Artifact stores and artifact loggingarea/build
: Build and test infrastructure for MLflowarea/deployments
: MLflow Deployments client APIs, server, and third-party Deployments integrationsarea/docs
: MLflow documentation pagesarea/examples
: Example codearea/model-registry
: Model Registry service, APIs, and the fluent client calls for Model Registryarea/models
: MLmodel format, model serialization/deserialization, flavorsarea/recipes
: Recipes, Recipe APIs, Recipe configs, Recipe Templatesarea/projects
: MLproject format, project running backendsarea/scoring
: MLflow Model server, model deployment tools, Spark UDFsarea/server-infra
: MLflow Tracking server backendarea/tracking
: Tracking Service, tracking client APIs, autologgingInterface
area/uiux
: Front-end, user experience, plotting, JavaScript, JavaScript dev serverarea/docker
: Docker use across MLflow's components, such as MLflow Projects and MLflow Modelsarea/sqlalchemy
: Use of SQLAlchemy in the Tracking Service or Model Registryarea/windows
: Windows supportLanguage
language/r
: R APIs and clientslanguage/java
: Java APIs and clientslanguage/new
: Proposals for new client languagesIntegrations
integrations/azure
: Azure and Azure ML integrationsintegrations/sagemaker
: SageMaker integrationsintegrations/databricks
: Databricks integrationsHow should the PR be classified in the release notes? Choose one:
rn/none
- No description will be included. The PR will be mentioned only by the PR number in the "Small Bugfixes and Documentation Updates" sectionrn/breaking-change
- The PR will be mentioned in the "Breaking Changes" sectionrn/feature
- A new user-facing feature worth mentioning in the release notesrn/bug-fix
- A user-facing bug fix worth mentioning in the release notesrn/documentation
- A user-facing documentation change worth mentioning in the release notes