Skip to content

Commit

Permalink
Add a helper function to get MlflowCredentialContext by run_id (#8323)
Browse files Browse the repository at this point in the history
Signed-off-by: Liang Zhang <liang.zhang@databricks.com>
  • Loading branch information
liangz1 committed Apr 26, 2023
1 parent 1c074cf commit d540a72
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 0 deletions.
9 changes: 9 additions & 0 deletions mlflow/utils/databricks_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,15 @@ def wrapper(*args, **kwargs):
return decorator


def get_mlflow_credential_context_by_run_id(run_id):
from mlflow.tracking.artifact_utils import get_artifact_uri
from mlflow.utils.uri import get_databricks_profile_uri_from_artifact_uri

run_root_artifact_uri = get_artifact_uri(run_id=run_id)
profile = get_databricks_profile_uri_from_artifact_uri(run_root_artifact_uri)
return MlflowCredentialContext(profile)


class MlflowCredentialContext:
"""Sets and clears credentials on a context using the provided profile URL."""

Expand Down
16 changes: 16 additions & 0 deletions tests/utils/test_databricks_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from mlflow.exceptions import MlflowException
from mlflow.utils import databricks_utils
from mlflow.utils.databricks_utils import (
get_mlflow_credential_context_by_run_id,
get_workspace_info_from_dbutils,
get_workspace_info_from_databricks_secrets,
is_databricks_default_tracking_uri,
Expand Down Expand Up @@ -336,3 +337,18 @@ def test_is_running_in_ipython_environment_works(get_ipython):

with mock.patch("IPython.get_ipython", return_value=get_ipython):
assert is_running_in_ipython_environment() == (get_ipython is not None)


def test_get_mlflow_credential_context_by_run_id():
with mock.patch(
"mlflow.tracking.artifact_utils.get_artifact_uri", return_value="dbfs:/path/to/artifact"
) as mock_get_artifact_uri, mock.patch(
"mlflow.utils.uri.get_databricks_profile_uri_from_artifact_uri",
return_value="databricks://path/to/profile",
) as mock_get_databricks_profile, mock.patch(
"mlflow.utils.databricks_utils.MlflowCredentialContext"
) as mock_credential_context:
get_mlflow_credential_context_by_run_id(run_id="abc")
mock_get_artifact_uri.assert_called_once_with(run_id="abc")
mock_get_databricks_profile.assert_called_once_with("dbfs:/path/to/artifact")
mock_credential_context.assert_called_once_with("databricks://path/to/profile")

0 comments on commit d540a72

Please sign in to comment.