From db1f4ae9d055ff62f79374cbdfedf0e1da87cdf6 Mon Sep 17 00:00:00 2001 From: hectorcast-db Date: Tue, 21 May 2024 09:28:46 +0200 Subject: [PATCH] Create a method to generate OAuth tokens (#644) ## Changes Add method to get OAuth tokens ## Tests - [X] `make test` run locally - [X] `make fmt` applied - [ ] relevant integration tests applied - [X] Manual test (cannot be run as integration tests due to limitations in the current infrastructure setup) ``` def test(): w = WorkspaceClient(profile='DEFAULT') auth_details = f'"type":"workspace_permission","object_type":"serving-endpoints","object_path":"/serving-endpoints/REDACTED","actions":["query_inference_endpoint"]' auth_details = "[{" + auth_details + "}]" t = w.api_client.get_oauth_token(auth_details) print(t) ``` Result: ``` Token(access_token='REDACTED', token_type='Bearer', refresh_token=None, expiry=datetime.datetime(2024, 5, 16, 11, 9, 8, 221008)) ``` --- .codegen/__init__.py.tmpl | 10 +- databricks/sdk/__init__.py | 10 +- databricks/sdk/config.py | 27 +++-- databricks/sdk/core.py | 22 ++++ databricks/sdk/credentials_provider.py | 156 ++++++++++++++++++------- databricks/sdk/oauth.py | 10 +- databricks/sdk/service/sql.py | 1 - docs/dbdataclasses/sql.rst | 2 - docs/gen-client-docs.py | 4 +- examples/custom_auth.py | 8 +- tests/conftest.py | 6 +- tests/test_config.py | 7 ++ tests/test_core.py | 19 +-- tests/test_dbutils.py | 2 +- 14 files changed, 202 insertions(+), 82 deletions(-) diff --git a/.codegen/__init__.py.tmpl b/.codegen/__init__.py.tmpl index c8177175..55c2c24c 100644 --- a/.codegen/__init__.py.tmpl +++ b/.codegen/__init__.py.tmpl @@ -1,6 +1,6 @@ import databricks.sdk.core as client import databricks.sdk.dbutils as dbutils -from databricks.sdk.credentials_provider import CredentialsProvider +from databricks.sdk.credentials_provider import CredentialsStrategy from databricks.sdk.mixins.files import DbfsExt from databricks.sdk.mixins.compute import ClustersExt @@ -46,10 +46,12 @@ class WorkspaceClient: debug_headers: bool = None, product="unknown", product_version="0.0.0", - credentials_provider: CredentialsProvider = None, + credentials_strategy: CredentialsStrategy = None, + credentials_provider: CredentialsStrategy = None, config: client.Config = None): if not config: config = client.Config({{range $args}}{{.}}={{.}}, {{end}} + credentials_strategy=credentials_strategy, credentials_provider=credentials_provider, debug_truncate_bytes=debug_truncate_bytes, debug_headers=debug_headers, @@ -101,10 +103,12 @@ class AccountClient: debug_headers: bool = None, product="unknown", product_version="0.0.0", - credentials_provider: CredentialsProvider = None, + credentials_strategy: CredentialsStrategy = None, + credentials_provider: CredentialsStrategy = None, config: client.Config = None): if not config: config = client.Config({{range $args}}{{.}}={{.}}, {{end}} + credentials_strategy=credentials_strategy, credentials_provider=credentials_provider, debug_truncate_bytes=debug_truncate_bytes, debug_headers=debug_headers, diff --git a/databricks/sdk/__init__.py b/databricks/sdk/__init__.py index 312d538b..fbe09693 100755 --- a/databricks/sdk/__init__.py +++ b/databricks/sdk/__init__.py @@ -1,7 +1,7 @@ import databricks.sdk.core as client import databricks.sdk.dbutils as dbutils from databricks.sdk import azure -from databricks.sdk.credentials_provider import CredentialsProvider +from databricks.sdk.credentials_provider import CredentialsStrategy from databricks.sdk.mixins.compute import ClustersExt from databricks.sdk.mixins.files import DbfsExt from databricks.sdk.mixins.workspace import WorkspaceExt @@ -131,7 +131,8 @@ def __init__(self, debug_headers: bool = None, product="unknown", product_version="0.0.0", - credentials_provider: CredentialsProvider = None, + credentials_strategy: CredentialsStrategy = None, + credentials_provider: CredentialsStrategy = None, config: client.Config = None): if not config: config = client.Config(host=host, @@ -152,6 +153,7 @@ def __init__(self, cluster_id=cluster_id, google_credentials=google_credentials, google_service_account=google_service_account, + credentials_strategy=credentials_strategy, credentials_provider=credentials_provider, debug_truncate_bytes=debug_truncate_bytes, debug_headers=debug_headers, @@ -700,7 +702,8 @@ def __init__(self, debug_headers: bool = None, product="unknown", product_version="0.0.0", - credentials_provider: CredentialsProvider = None, + credentials_strategy: CredentialsStrategy = None, + credentials_provider: CredentialsStrategy = None, config: client.Config = None): if not config: config = client.Config(host=host, @@ -721,6 +724,7 @@ def __init__(self, cluster_id=cluster_id, google_credentials=google_credentials, google_service_account=google_service_account, + credentials_strategy=credentials_strategy, credentials_provider=credentials_provider, debug_truncate_bytes=debug_truncate_bytes, debug_headers=debug_headers, diff --git a/databricks/sdk/config.py b/databricks/sdk/config.py index d7d45c20..f2d62ffa 100644 --- a/databricks/sdk/config.py +++ b/databricks/sdk/config.py @@ -11,10 +11,10 @@ import requests from .clock import Clock, RealClock -from .credentials_provider import CredentialsProvider, DefaultCredentials +from .credentials_provider import CredentialsStrategy, DefaultCredentials from .environments import (ALL_ENVS, AzureEnvironment, Cloud, DatabricksEnvironment, get_environment_for_hostname) -from .oauth import OidcEndpoints +from .oauth import OidcEndpoints, Token from .version import __version__ logger = logging.getLogger('databricks.sdk') @@ -81,7 +81,9 @@ class Config: def __init__(self, *, - credentials_provider: CredentialsProvider = None, + # Deprecated. Use credentials_strategy instead. + credentials_provider: CredentialsStrategy = None, + credentials_strategy: CredentialsStrategy = None, product="unknown", product_version="0.0.0", clock: Clock = None, @@ -89,7 +91,15 @@ def __init__(self, self._header_factory = None self._inner = {} self._user_agent_other_info = [] - self._credentials_provider = credentials_provider if credentials_provider else DefaultCredentials() + if credentials_strategy and credentials_provider: + raise ValueError( + "When providing `credentials_strategy` field, `credential_provider` cannot be specified.") + if credentials_provider: + logger.warning( + "parameter 'credentials_provider' is deprecated. Use 'credentials_strategy' instead.") + self._credentials_strategy = next( + s for s in [credentials_strategy, credentials_provider, + DefaultCredentials()] if s is not None) if 'databricks_environment' in kwargs: self.databricks_environment = kwargs['databricks_environment'] del kwargs['databricks_environment'] @@ -107,6 +117,9 @@ def __init__(self, message = self.wrap_debug_info(str(e)) raise ValueError(message) from e + def oauth_token(self) -> Token: + return self._credentials_strategy.oauth_token(self) + def wrap_debug_info(self, message: str) -> str: debug_string = self.debug_string() if debug_string: @@ -436,12 +449,12 @@ def _validate(self): def init_auth(self): try: - self._header_factory = self._credentials_provider(self) - self.auth_type = self._credentials_provider.auth_type() + self._header_factory = self._credentials_strategy(self) + self.auth_type = self._credentials_strategy.auth_type() if not self._header_factory: raise ValueError('not configured') except ValueError as e: - raise ValueError(f'{self._credentials_provider.auth_type()} auth: {e}') from e + raise ValueError(f'{self._credentials_strategy.auth_type()} auth: {e}') from e def __repr__(self): return f'<{self.debug_string()}>' diff --git a/databricks/sdk/core.py b/databricks/sdk/core.py index 658ed219..cacbad90 100644 --- a/databricks/sdk/core.py +++ b/databricks/sdk/core.py @@ -4,6 +4,7 @@ from json import JSONDecodeError from types import TracebackType from typing import Any, BinaryIO, Iterator, Type +from urllib.parse import urlencode from requests.adapters import HTTPAdapter @@ -13,12 +14,17 @@ from .credentials_provider import * from .errors import DatabricksError, error_mapper from .errors.private_link import _is_private_link_redirect +from .oauth import retrieve_token from .retries import retried __all__ = ['Config', 'DatabricksError'] logger = logging.getLogger('databricks.sdk') +URL_ENCODED_CONTENT_TYPE = "application/x-www-form-urlencoded" +JWT_BEARER_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:jwt-bearer" +OIDC_TOKEN_PATH = "/oidc/v1/token" + class ApiClient: _cfg: Config @@ -109,6 +115,22 @@ def flatten_dict(d: Dict[str, Any]) -> Dict[str, Any]: flattened = dict(flatten_dict(with_fixed_bools)) return flattened + def get_oauth_token(self, auth_details: str) -> Token: + if not self._cfg.auth_type: + self._cfg.authenticate() + original_token = self._cfg.oauth_token() + headers = {"Content-Type": URL_ENCODED_CONTENT_TYPE} + params = urlencode({ + "grant_type": JWT_BEARER_GRANT_TYPE, + "authorization_details": auth_details, + "assertion": original_token.access_token + }) + return retrieve_token(client_id=self._cfg.client_id, + client_secret=self._cfg.client_secret, + token_url=self._cfg.host + OIDC_TOKEN_PATH, + params=params, + headers=headers) + def do(self, method: str, path: str, diff --git a/databricks/sdk/credentials_provider.py b/databricks/sdk/credentials_provider.py index 02727599..d7173ab3 100644 --- a/databricks/sdk/credentials_provider.py +++ b/databricks/sdk/credentials_provider.py @@ -22,12 +22,26 @@ from .oauth import (ClientCredentials, OAuthClient, Refreshable, Token, TokenCache, TokenSource) -HeaderFactory = Callable[[], Dict[str, str]] +CredentialsProvider = Callable[[], Dict[str, str]] logger = logging.getLogger('databricks.sdk') -class CredentialsProvider(abc.ABC): +class OAuthCredentialsProvider: + """ OAuthCredentialsProvider is a type of CredentialsProvider which exposes OAuth tokens. """ + + def __init__(self, credentials_provider: CredentialsProvider, token_provider: Callable[[], Token]): + self._credentials_provider = credentials_provider + self._token_provider = token_provider + + def __call__(self) -> Dict[str, str]: + return self._credentials_provider() + + def oauth_token(self) -> Token: + return self._token_provider() + + +class CredentialsStrategy(abc.ABC): """ CredentialsProvider is the protocol (call-side interface) for authenticating requests to Databricks REST APIs""" @@ -36,20 +50,39 @@ def auth_type(self) -> str: ... @abc.abstractmethod - def __call__(self, cfg: 'Config') -> HeaderFactory: + def __call__(self, cfg: 'Config') -> CredentialsProvider: ... -def credentials_provider(name: str, require: List[str]): +class OauthCredentialsStrategy(CredentialsStrategy): + """ OauthCredentialsProvider is a CredentialsProvider which + supports Oauth tokens""" + + def __init__(self, auth_type: str, headers_provider: Callable[['Config'], OAuthCredentialsProvider]): + self._headers_provider = headers_provider + self._auth_type = auth_type + + def auth_type(self) -> str: + return self._auth_type + + def __call__(self, cfg: 'Config') -> OAuthCredentialsProvider: + return self._headers_provider(cfg) + + def oauth_token(self, cfg: 'Config') -> Token: + return self._headers_provider(cfg).oauth_token() + + +def credentials_strategy(name: str, require: List[str]): """ Given the function that receives a Config and returns RequestVisitor, create CredentialsProvider with a given name and required configuration attribute names to be present for this function to be called. """ - def inner(func: Callable[['Config'], HeaderFactory]) -> CredentialsProvider: + def inner(func: Callable[['Config'], CredentialsProvider]) -> CredentialsStrategy: @functools.wraps(func) - def wrapper(cfg: 'Config') -> Optional[HeaderFactory]: + def wrapper(cfg: 'Config') -> Optional[CredentialsProvider]: for attr in require: + getattr(cfg, attr) if not getattr(cfg, attr): return None return func(cfg) @@ -60,8 +93,27 @@ def wrapper(cfg: 'Config') -> Optional[HeaderFactory]: return inner -@credentials_provider('basic', ['host', 'username', 'password']) -def basic_auth(cfg: 'Config') -> HeaderFactory: +def oauth_credentials_strategy(name: str, require: List[str]): + """ Given the function that receives a Config and returns an OauthHeaderFactory, + create an OauthCredentialsProvider with a given name and required configuration + attribute names to be present for this function to be called. """ + + def inner(func: Callable[['Config'], OAuthCredentialsProvider]) -> OauthCredentialsStrategy: + + @functools.wraps(func) + def wrapper(cfg: 'Config') -> Optional[OAuthCredentialsProvider]: + for attr in require: + if not getattr(cfg, attr): + return None + return func(cfg) + + return OauthCredentialsStrategy(name, wrapper) + + return inner + + +@credentials_strategy('basic', ['host', 'username', 'password']) +def basic_auth(cfg: 'Config') -> CredentialsProvider: """ Given username and password, add base64-encoded Basic credentials """ encoded = base64.b64encode(f'{cfg.username}:{cfg.password}'.encode()).decode() static_credentials = {'Authorization': f'Basic {encoded}'} @@ -72,8 +124,8 @@ def inner() -> Dict[str, str]: return inner -@credentials_provider('pat', ['host', 'token']) -def pat_auth(cfg: 'Config') -> HeaderFactory: +@credentials_strategy('pat', ['host', 'token']) +def pat_auth(cfg: 'Config') -> CredentialsProvider: """ Adds Databricks Personal Access Token to every request """ static_credentials = {'Authorization': f'Bearer {cfg.token}'} @@ -83,8 +135,8 @@ def inner() -> Dict[str, str]: return inner -@credentials_provider('runtime', []) -def runtime_native_auth(cfg: 'Config') -> Optional[HeaderFactory]: +@credentials_strategy('runtime', []) +def runtime_native_auth(cfg: 'Config') -> Optional[CredentialsProvider]: if 'DATABRICKS_RUNTIME_VERSION' not in os.environ: return None @@ -107,8 +159,8 @@ def runtime_native_auth(cfg: 'Config') -> Optional[HeaderFactory]: return None -@credentials_provider('oauth-m2m', ['host', 'client_id', 'client_secret']) -def oauth_service_principal(cfg: 'Config') -> Optional[HeaderFactory]: +@oauth_credentials_strategy('oauth-m2m', ['host', 'client_id', 'client_secret']) +def oauth_service_principal(cfg: 'Config') -> Optional[CredentialsProvider]: """ Adds refreshed Databricks machine-to-machine OAuth Bearer token to every request, if /oidc/.well-known/oauth-authorization-server is available on the given host. """ oidc = cfg.oidc_endpoints @@ -124,11 +176,11 @@ def inner() -> Dict[str, str]: token = token_source.token() return {'Authorization': f'{token.token_type} {token.access_token}'} - return inner + return OAuthCredentialsProvider(inner, token_source.token) -@credentials_provider('external-browser', ['host', 'auth_type']) -def external_browser(cfg: 'Config') -> Optional[HeaderFactory]: +@credentials_strategy('external-browser', ['host', 'auth_type']) +def external_browser(cfg: 'Config') -> Optional[CredentialsProvider]: if cfg.auth_type != 'external-browser': return None if cfg.client_id: @@ -178,9 +230,9 @@ def _ensure_host_present(cfg: 'Config', token_source_for: Callable[[str], TokenS cfg.host = f"https://{resp.json()['properties']['workspaceUrl']}" -@credentials_provider('azure-client-secret', - ['is_azure', 'azure_client_id', 'azure_client_secret', 'azure_tenant_id']) -def azure_service_principal(cfg: 'Config') -> HeaderFactory: +@oauth_credentials_strategy('azure-client-secret', + ['is_azure', 'azure_client_id', 'azure_client_secret', 'azure_tenant_id']) +def azure_service_principal(cfg: 'Config') -> CredentialsProvider: """ Adds refreshed Azure Active Directory (AAD) Service Principal OAuth tokens to every request, while automatically resolving different Azure environment endpoints. """ @@ -203,11 +255,11 @@ def refreshed_headers() -> Dict[str, str]: add_sp_management_token(cloud, headers) return headers - return refreshed_headers + return OAuthCredentialsProvider(refreshed_headers, inner.token) -@credentials_provider('github-oidc-azure', ['host', 'azure_client_id']) -def github_oidc_azure(cfg: 'Config') -> Optional[HeaderFactory]: +@oauth_credentials_strategy('github-oidc-azure', ['host', 'azure_client_id']) +def github_oidc_azure(cfg: 'Config') -> Optional[CredentialsProvider]: if 'ACTIONS_ID_TOKEN_REQUEST_TOKEN' not in os.environ: # not in GitHub actions return None @@ -250,14 +302,14 @@ def refreshed_headers() -> Dict[str, str]: token = inner.token() return {'Authorization': f'{token.token_type} {token.access_token}'} - return refreshed_headers + return OAuthCredentialsProvider(refreshed_headers, inner.token) GcpScopes = ["https://www.googleapis.com/auth/cloud-platform", "https://www.googleapis.com/auth/compute"] -@credentials_provider('google-credentials', ['host', 'google_credentials']) -def google_credentials(cfg: 'Config') -> Optional[HeaderFactory]: +@oauth_credentials_strategy('google-credentials', ['host', 'google_credentials']) +def google_credentials(cfg: 'Config') -> Optional[CredentialsProvider]: if not cfg.is_gcp: return None # Reads credentials as JSON. Credentials can be either a path to JSON file, or actual JSON string. @@ -277,6 +329,10 @@ def google_credentials(cfg: 'Config') -> Optional[HeaderFactory]: gcp_credentials = service_account.Credentials.from_service_account_info(info=account_info, scopes=GcpScopes) + def token() -> Token: + credentials.refresh(request) + return credentials.token + def refreshed_headers() -> Dict[str, str]: credentials.refresh(request) headers = {'Authorization': f'Bearer {credentials.token}'} @@ -285,11 +341,11 @@ def refreshed_headers() -> Dict[str, str]: headers["X-Databricks-GCP-SA-Access-Token"] = gcp_credentials.token return headers - return refreshed_headers + return OAuthCredentialsProvider(refreshed_headers, token) -@credentials_provider('google-id', ['host', 'google_service_account']) -def google_id(cfg: 'Config') -> Optional[HeaderFactory]: +@oauth_credentials_strategy('google-id', ['host', 'google_service_account']) +def google_id(cfg: 'Config') -> Optional[CredentialsProvider]: if not cfg.is_gcp: return None credentials, _project_id = google.auth.default() @@ -309,6 +365,10 @@ def google_id(cfg: 'Config') -> Optional[HeaderFactory]: request = Request() + def token() -> Token: + id_creds.refresh(request) + return id_creds.token + def refreshed_headers() -> Dict[str, str]: id_creds.refresh(request) headers = {'Authorization': f'Bearer {id_creds.token}'} @@ -317,7 +377,7 @@ def refreshed_headers() -> Dict[str, str]: headers["X-Databricks-GCP-SA-Access-Token"] = gcp_impersonated_credentials.token return headers - return refreshed_headers + return OAuthCredentialsProvider(refreshed_headers, token) class CliTokenSource(Refreshable): @@ -422,8 +482,8 @@ def get_subscription(cfg: 'Config') -> str: return components[2] -@credentials_provider('azure-cli', ['is_azure']) -def azure_cli(cfg: 'Config') -> Optional[HeaderFactory]: +@credentials_strategy('azure-cli', ['is_azure']) +def azure_cli(cfg: 'Config') -> Optional[CredentialsProvider]: """ Adds refreshed OAuth token granted by `az login` command to every request. """ token_source = None mgmt_token_source = None @@ -516,8 +576,8 @@ def _find_executable(name) -> str: raise err -@credentials_provider('databricks-cli', ['host']) -def databricks_cli(cfg: 'Config') -> Optional[HeaderFactory]: +@oauth_credentials_strategy('databricks-cli', ['host']) +def databricks_cli(cfg: 'Config') -> Optional[CredentialsProvider]: try: token_source = DatabricksCliTokenSource(cfg) except FileNotFoundError as e: @@ -538,7 +598,7 @@ def inner() -> Dict[str, str]: token = token_source.token() return {'Authorization': f'{token.token_type} {token.access_token}'} - return inner + return OAuthCredentialsProvider(inner, token_source.token) class MetadataServiceTokenSource(Refreshable): @@ -577,8 +637,8 @@ def refresh(self) -> Token: return Token(access_token=access_token, token_type=token_type, expiry=expiry) -@credentials_provider('metadata-service', ['host', 'metadata_service_url']) -def metadata_service(cfg: 'Config') -> Optional[HeaderFactory]: +@credentials_strategy('metadata-service', ['host', 'metadata_service_url']) +def metadata_service(cfg: 'Config') -> Optional[CredentialsProvider]: """ Adds refreshed token granted by Databricks Metadata Service to every request. """ token_source = MetadataServiceTokenSource(cfg) @@ -597,17 +657,25 @@ class DefaultCredentials: def __init__(self) -> None: self._auth_type = 'default' - - def auth_type(self) -> str: - return self._auth_type - - def __call__(self, cfg: 'Config') -> HeaderFactory: - auth_providers = [ + self._auth_providers = [ pat_auth, basic_auth, metadata_service, oauth_service_principal, azure_service_principal, github_oidc_azure, azure_cli, external_browser, databricks_cli, runtime_native_auth, google_credentials, google_id ] - for provider in auth_providers: + + def auth_type(self) -> str: + return self._auth_type + + def oauth_token(self, cfg: 'Config') -> Token: + for provider in self._auth_providers: + auth_type = provider.auth_type() + if auth_type != self._auth_type: + # ignore other auth types if they don't match the selected one + continue + return provider.oauth_token(cfg) + + def __call__(self, cfg: 'Config') -> CredentialsProvider: + for provider in self._auth_providers: auth_type = provider.auth_type() if cfg.auth_type and auth_type != cfg.auth_type: # ignore other auth types if one is explicitly enforced diff --git a/databricks/sdk/oauth.py b/databricks/sdk/oauth.py index 07bc5551..38da6c03 100644 --- a/databricks/sdk/oauth.py +++ b/databricks/sdk/oauth.py @@ -21,6 +21,10 @@ # See https://stackoverflow.com/a/75466778/277035 for more info NO_ORIGIN_FOR_SPA_CLIENT_ERROR = 'AADSTS9002327' +URL_ENCODED_CONTENT_TYPE = "application/x-www-form-urlencoded" +JWT_BEARER_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:jwt-bearer" +OIDC_TOKEN_PATH = "/oidc/v1/token" + logger = logging.getLogger(__name__) @@ -358,13 +362,13 @@ def __init__(self, client_secret: str = None): # TODO: is it a circular dependency?.. from .core import Config - from .credentials_provider import credentials_provider + from .credentials_provider import credentials_strategy - @credentials_provider('noop', []) + @credentials_strategy('noop', []) def noop_credentials(_: any): return lambda: {} - config = Config(host=host, credentials_provider=noop_credentials) + config = Config(host=host, credentials_strategy=noop_credentials) if not scopes: scopes = ['all-apis'] if config.is_azure: diff --git a/databricks/sdk/service/sql.py b/databricks/sdk/service/sql.py index bfef44af..557b53e3 100755 --- a/databricks/sdk/service/sql.py +++ b/databricks/sdk/service/sql.py @@ -360,7 +360,6 @@ def from_dict(cls, d: Dict[str, any]) -> ChannelInfo: class ChannelName(Enum): - """Name of the channel""" CHANNEL_NAME_CURRENT = 'CHANNEL_NAME_CURRENT' CHANNEL_NAME_CUSTOM = 'CHANNEL_NAME_CUSTOM' diff --git a/docs/dbdataclasses/sql.rst b/docs/dbdataclasses/sql.rst index fe1469a3..adf3ced5 100644 --- a/docs/dbdataclasses/sql.rst +++ b/docs/dbdataclasses/sql.rst @@ -64,8 +64,6 @@ These dataclasses are used in the SDK to represent API requests and responses fo .. py:class:: ChannelName - Name of the channel - .. py:attribute:: CHANNEL_NAME_CURRENT :value: "CHANNEL_NAME_CURRENT" diff --git a/docs/gen-client-docs.py b/docs/gen-client-docs.py index 5322a25b..79396152 100644 --- a/docs/gen-client-docs.py +++ b/docs/gen-client-docs.py @@ -13,7 +13,7 @@ from typing import Optional, Any, get_args from databricks.sdk import AccountClient, WorkspaceClient -from databricks.sdk.core import credentials_provider +from databricks.sdk.core import credentials_strategy __dir__ = os.path.dirname(__file__) __examples__ = Path(f'{__dir__}/../examples').absolute() @@ -451,7 +451,7 @@ def _write_client_package_doc(self, folder: str, pkg: Package, services: list[st if __name__ == '__main__': - @credentials_provider('noop', []) + @credentials_strategy('noop', []) def noop_credentials(_: any): return lambda: {} diff --git a/examples/custom_auth.py b/examples/custom_auth.py index 02747011..4d335164 100644 --- a/examples/custom_auth.py +++ b/examples/custom_auth.py @@ -1,10 +1,10 @@ -from databricks.sdk.core import (ApiClient, Config, HeaderFactory, - credentials_provider) +from databricks.sdk.core import (ApiClient, Config, CredentialsProvider, + credentials_strategy) from databricks.sdk.service.iam import CurrentUserAPI -@credentials_provider("custom", ["host"]) -def user_input_token(cfg: Config) -> HeaderFactory: +@credentials_strategy("custom", ["host"]) +def user_input_token(cfg: Config) -> CredentialsProvider: pat = input("Enter Databricks PAT: ") def inner() -> dict[str, str]: diff --git a/tests/conftest.py b/tests/conftest.py index 80753ae9..ef740662 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,17 +5,17 @@ from pyfakefs.fake_filesystem_unittest import Patcher from databricks.sdk.core import Config -from databricks.sdk.credentials_provider import credentials_provider +from databricks.sdk.credentials_provider import credentials_strategy -@credentials_provider('noop', []) +@credentials_strategy('noop', []) def noop_credentials(_: any): return lambda: {} @pytest.fixture def config(): - return Config(host='http://localhost', credentials_provider=noop_credentials) + return Config(host='http://localhost', credentials_strategy=noop_credentials) @pytest.fixture diff --git a/tests/test_config.py b/tests/test_config.py index 1411e5be..22a4b71c 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -4,6 +4,13 @@ def test_config_copy_preserves_product_and_product_version(): + c = Config(credentials_strategy=noop_credentials, product='foo', product_version='1.2.3') + c2 = c.copy() + assert c2._product == 'foo' + assert c2._product_version == '1.2.3' + + +def test_config_supports_legacy_credentials_provider(): c = Config(credentials_provider=noop_credentials, product='foo', product_version='1.2.3') c2 = c.copy() assert c2._product == 'foo' diff --git a/tests/test_core.py b/tests/test_core.py index fd2f12ad..74d38bd3 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -18,8 +18,9 @@ StreamingResponse) from databricks.sdk.credentials_provider import (CliTokenSource, CredentialsProvider, + CredentialsStrategy, DatabricksCliTokenSource, - HeaderFactory, databricks_cli) + databricks_cli) from databricks.sdk.environments import (ENVIRONMENTS, AzureEnvironment, Cloud, DatabricksEnvironment) from databricks.sdk.service.catalog import PermissionsChange @@ -207,7 +208,7 @@ def system(self): def test_config_copy_shallow_copies_credential_provider(): - class TestCredentialsProvider(CredentialsProvider): + class TestCredentialsStrategy(CredentialsStrategy): def __init__(self): super().__init__() @@ -216,24 +217,24 @@ def __init__(self): def auth_type(self) -> str: return "test" - def __call__(self, cfg: 'Config') -> HeaderFactory: + def __call__(self, cfg: 'Config') -> CredentialsProvider: return lambda: {"token": self._token} def refresh(self): self._token = "token2" - credential_provider = TestCredentialsProvider() - config = Config(credentials_provider=credential_provider) + credentials_strategy = TestCredentialsStrategy() + config = Config(credentials_strategy=credentials_strategy) config_copy = config.copy() assert config.authenticate()["token"] == "token1" assert config_copy.authenticate()["token"] == "token1" - credential_provider.refresh() + credentials_strategy.refresh() assert config.authenticate()["token"] == "token2" assert config_copy.authenticate()["token"] == "token2" - assert config._credentials_provider == config_copy._credentials_provider + assert config._credentials_strategy == config_copy._credentials_strategy def test_config_copy_deep_copies_user_agent_other_info(config): @@ -277,7 +278,7 @@ def __init__(self): def test_config_parsing_non_string_env_vars(monkeypatch): monkeypatch.setenv('DATABRICKS_DEBUG_TRUNCATE_BYTES', '100') - c = Config(host='http://localhost', credentials_provider=noop_credentials) + c = Config(host='http://localhost', credentials_strategy=noop_credentials) assert c.debug_truncate_bytes == 100 @@ -592,7 +593,7 @@ def inner(h: BaseHTTPRequestHandler): ('AzureUSGovernment', ENVIRONMENTS['USGOVERNMENT']), ('AzureChinaCloud', ENVIRONMENTS['CHINA']), ]) def test_azure_environment(azure_environment, expected): - c = Config(credentials_provider=noop_credentials, + c = Config(credentials_strategy=noop_credentials, azure_workspace_resource_id='...', azure_environment=azure_environment) assert c.arm_environment == expected diff --git a/tests/test_dbutils.py b/tests/test_dbutils.py index 4f80e91b..15b26a54 100644 --- a/tests/test_dbutils.py +++ b/tests/test_dbutils.py @@ -140,7 +140,7 @@ def assertions(): command=expect_command) dbutils = RemoteDbUtils( - Config(host='http://localhost', cluster_id='x', credentials_provider=noop_credentials)) + Config(host='http://localhost', cluster_id='x', credentials_strategy=noop_credentials)) return dbutils, assertions return inner