Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support remote Databricks model registries in mlflow.<flavor>.load_model #3330

Merged
merged 3 commits into from
Aug 26, 2020

Conversation

sueann
Copy link
Contributor

@sueann sueann commented Aug 26, 2020

Signed-off-by: Sue Ann Hong sueann@databricks.com

What changes are proposed in this pull request?

Support remote Databricks model registries in mlflow.<flavor>.load_model when called with a models:/ URI.

Currently, the following two ways to call mlflow..load_model from a remote Databricks model registry work:

mlflow.pyfunc.load_model('models://profile@databricks/model_name/Staging')
mlflow.set_tracking_uri(registry_uri)
mlflow.pyfunc.load_model('models:/model_name/Staging')

But the following, more natural, way does not:

mlflow.set_registry_uri(registry_uri)
mlflow.pyfunc.load_model('models:/model_name/Staging')

This is because each flavor’s load_model
calls _download_artifact_from_uri
which calls get_artifact_repository(artifact_uri=root_uri).download_artifacts
and no registry server information is passed into the eventual DBFS artifact repository instance. So then the DBFS artifact repository uses the tracking URI from the context (i.e. set via the global variable or the environment variable).

To fix this, here we have ModelsArtifactRepository automatically use the Databricks model registry server information if specified in the context.

How is this patch tested?

  • Unit tests
  • Manual tests
mlflow.set_registry_uri('databricks://blah')  # invalid but ignored
mlflow.set_tracking_uri('databricks://blurgh')  # invalid but ignored
model = mlflow.pyfunc.load_model(f'models://{profile}@databricks/{model_name}/Staging')
model.predict(1)
mlflow.set_registry_uri(f'databricks://{profile}')  # valid registry URI
mlflow.set_tracking_uri('databricks://blurgh')  # invalid but ignored
model = mlflow.pyfunc.load_model(f'models:/{model_name}/Staging')
model.predict(1)
mlflow.set_registry_uri(None)
mlflow.set_tracking_uri(f'databricks://{profile}')  # valid registry URI
model = mlflow.pyfunc.load_model(f'models:/{model_name}/Staging')
model.predict(1)

Release Notes

Is this a user-facing change?

  • No. You can skip the rest of this section.
  • Yes. Give a description of this change to be included in the release notes for MLflow users.

mlflow.<flavor>.load_model methods will fetch the model from the Databricks model registry specified by mlflow.set_registry_uri if it is set to one.

What component(s), interfaces, languages, and integrations does this PR affect?

Components

  • area/artifacts: Artifact stores and artifact logging
  • area/build: Build and test infrastructure for MLflow
  • area/docs: MLflow documentation pages
  • area/examples: Example code
  • area/model-registry: Model Registry service, APIs, and the fluent client calls for Model Registry
  • area/models: MLmodel format, model serialization/deserialization, flavors
  • area/projects: MLproject format, project running backends
  • area/scoring: Local serving, model deployment tools, spark UDFs
  • area/server-infra: MLflow server, JavaScript dev server
  • area/tracking: Tracking Service, tracking client APIs, autologging

Interface

  • area/uiux: Front-end, user experience, JavaScript, plotting
  • area/docker: Docker use across MLflow's components, such as MLflow Projects and MLflow Models
  • area/sqlalchemy: Use of SQLAlchemy in the Tracking Service or Model Registry
  • area/windows: Windows support

Language

  • language/r: R APIs and clients
  • language/java: Java APIs and clients
  • language/new: Proposals for new client languages

Integrations

  • integrations/azure: Azure and Azure ML integrations
  • integrations/sagemaker: SageMaker integrations
  • integrations/databricks: Databricks integrations

How should the PR be classified in the release notes? Choose one:

  • rn/breaking-change - The PR will be mentioned in the "Breaking Changes" section
  • rn/none - No description will be included. The PR will be mentioned only by the PR number in the "Small Bugfixes and Documentation Updates" section
  • rn/feature - A new user-facing feature worth mentioning in the release notes
  • rn/bug-fix - A user-facing bug fix worth mentioning in the release notes
  • rn/documentation - A user-facing documentation change worth mentioning in the release notes

Signed-off-by: Sue Ann Hong <sueann@databricks.com>
@github-actions github-actions bot added area/artifacts Artifact stores and artifact logging area/model-registry Model registry, model registry APIs, and the fluent client calls for model registry rn/feature Mention under Features in Changelogs. labels Aug 26, 2020
Copy link
Collaborator

@ankit-db ankit-db left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried it out, and it works! Small note about adding a test - I tried a few of the cases, but would be good to codify in the test

assert models_repo.artifact_uri == model_uri
assert isinstance(models_repo.repo, DbfsRestArtifactRepository)
mock_repo.assert_called_once_with(
"dbfs://scope:key@databricks/databricks/mlflow-registry/12345/models/keras-model"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we also just verify that the right error is thrown if the registry URI is malformed?

Signed-off-by: Sue Ann Hong <sueann@databricks.com>
Signed-off-by: Sue Ann Hong <sueann@databricks.com>
with mock.patch("mlflow.get_registry_uri", return_value="databricks://scope:key:invalid"):
with pytest.raises(MlflowException) as ex:
ModelsArtifactRepository(model_uri)
assert "Key prefixes cannot contain" in ex.value.message
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ankitmathur-db is something like this what you were thinking? please do let me know if there are other cases you were thinking of I didn't cover here. thanks!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is what I was thinking. Have we tested stuff like databricks://scope:key. (with the extra spaces at the end)?

Copy link
Contributor Author

@sueann sueann Aug 26, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah we don't test it - but that should be handled at a lower level. Created a PR to address that: #3338.

@ankit-db
Copy link
Collaborator

Looks great!

@sueann sueann merged commit 948ddab into mlflow:master Aug 26, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
area/artifacts Artifact stores and artifact logging area/model-registry Model registry, model registry APIs, and the fluent client calls for model registry rn/feature Mention under Features in Changelogs.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants