Skip to content

Commit

Permalink
Migrate Mlflow API request to databricks sdk authentication way and s…
Browse files Browse the repository at this point in the history
…upport OAuth (#12011)

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
  • Loading branch information
WeichenXu123 committed Jun 1, 2024
1 parent a39aec8 commit 315506d
Show file tree
Hide file tree
Showing 22 changed files with 393 additions and 158 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/deployments.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
run: |
pip install --no-dependencies tests/resources/mlflow-test-plugin
pip install .[gateway] \
pytest pytest-timeout pytest-asyncio httpx psutil sentence-transformers transformers
pytest pytest-timeout pytest-asyncio httpx psutil sentence-transformers transformers databricks-sdk
- name: Run tests
run: |
pytest tests/deployments
5 changes: 5 additions & 0 deletions mlflow/environment_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,3 +612,8 @@ def get(self):
MLFLOW_BOTO_CLIENT_ADDRESSING_STYLE = _EnvironmentVariable(
"MLFLOW_BOTO_CLIENT_ADDRESSING_STYLE", str, "auto"
)

#: Specify the timeout in seconds for Databricks endpoint HTTP request retries.
MLFLOW_DATABRICKS_ENDPOINT_HTTP_RETRY_TIMEOUT = _EnvironmentVariable(
"MLFLOW_DATABRICKS_ENDPOINT_HTTP_RETRY_TIMEOUT", int, 500
)
56 changes: 49 additions & 7 deletions mlflow/legacy_databricks_cli/configure/provider.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
# This module is copied from legacy databricks CLI python library
# module `databricks_cli.configure.provider`, see
# module `databricks_cli.configure.provider`,
# but with some modification to make `EnvironmentVariableConfigProvider` supporting
# 'DATABRICKS_CLIENT_ID' and 'DATABRICKS_CLIENT_SECRET' environmental variables.
#
# This is the original legacy databricks CLI python library provider module code:
# https://github.com/databricks/databricks-cli/blob/0.18.0/databricks_cli/configure/provider.py
# because the latest Databricks Runtime does not contain legacy databricks CLI
#
# The latest Databricks Runtime does not contain legacy databricks CLI
# but MLflow still depends on it.

import logging
Expand All @@ -24,6 +29,8 @@
INSECURE = "insecure"
JOBS_API_VERSION = "jobs-api-version"
DEFAULT_SECTION = "DEFAULT"
CLIENT_ID = "client_id"
CLIENT_SECRET = "client_secret"

# User-provided override for the DatabricksConfigProvider
_config_provider = None
Expand Down Expand Up @@ -265,8 +272,19 @@ def get_config(self):
refresh_token = os.environ.get("DATABRICKS_REFRESH_TOKEN")
insecure = os.environ.get("DATABRICKS_INSECURE")
jobs_api_version = os.environ.get("DATABRICKS_JOBS_API_VERSION")
client_id = os.environ.get("DATABRICKS_CLIENT_ID")
client_secret = os.environ.get("DATABRICKS_CLIENT_SECRET")

config = DatabricksConfig(
host, username, password, token, refresh_token, insecure, jobs_api_version
host,
username,
password,
token,
refresh_token,
insecure,
jobs_api_version,
client_id=client_id,
client_secret=client_secret,
)
if config.is_valid:
return config
Expand All @@ -276,8 +294,8 @@ def get_config(self):
class ProfileConfigProvider(DatabricksConfigProvider):
"""Loads from the databrickscfg file."""

def __init__(self, profile=DEFAULT_SECTION):
self.profile = profile
def __init__(self, profile=None):
self.profile = profile or DEFAULT_SECTION

def get_config(self):
raw_config = _fetch_from_fs()
Expand All @@ -288,8 +306,18 @@ def get_config(self):
refresh_token = _get_option_if_exists(raw_config, self.profile, REFRESH_TOKEN)
insecure = _get_option_if_exists(raw_config, self.profile, INSECURE)
jobs_api_version = _get_option_if_exists(raw_config, self.profile, JOBS_API_VERSION)
client_id = _get_option_if_exists(raw_config, self.profile, CLIENT_ID)
client_secret = _get_option_if_exists(raw_config, self.profile, CLIENT_SECRET)
config = DatabricksConfig(
host, username, password, token, refresh_token, insecure, jobs_api_version
host,
username,
password,
token,
refresh_token,
insecure,
jobs_api_version,
client_id=client_id,
client_secret=client_secret,
)
if config.is_valid:
return config
Expand Down Expand Up @@ -362,6 +390,8 @@ def __init__(
refresh_token=None,
insecure=None,
jobs_api_version=None,
client_id=None,
client_secret=None,
):
self.host = host
self.username = username
Expand All @@ -370,6 +400,8 @@ def __init__(
self.refresh_token = refresh_token
self.insecure = insecure
self.jobs_api_version = jobs_api_version
self.client_id = client_id
self.client_secret = client_secret

@classmethod
def from_token(cls, host, token, refresh_token=None, insecure=None, jobs_api_version=None):
Expand Down Expand Up @@ -415,6 +447,16 @@ def is_valid_with_token(self):
def is_valid_with_password(self):
return self.host is not None and self.username is not None and self.password is not None

@property
def is_valid_with_client_id_secret(self):
return (
self.host is not None and self.client_id is not None and self.client_secret is not None
)

@property
def is_valid(self):
return self.is_valid_with_token or self.is_valid_with_password
return (
self.is_valid_with_token
or self.is_valid_with_password
or self.is_valid_with_client_id_secret
)
Loading

0 comments on commit 315506d

Please sign in to comment.