diff --git a/.codegen/__init__.py.tmpl b/.codegen/__init__.py.tmpl index 55c2c24c..c8177175 100644 --- a/.codegen/__init__.py.tmpl +++ b/.codegen/__init__.py.tmpl @@ -1,6 +1,6 @@ import databricks.sdk.core as client import databricks.sdk.dbutils as dbutils -from databricks.sdk.credentials_provider import CredentialsStrategy +from databricks.sdk.credentials_provider import CredentialsProvider from databricks.sdk.mixins.files import DbfsExt from databricks.sdk.mixins.compute import ClustersExt @@ -46,12 +46,10 @@ class WorkspaceClient: debug_headers: bool = None, product="unknown", product_version="0.0.0", - credentials_strategy: CredentialsStrategy = None, - credentials_provider: CredentialsStrategy = None, + credentials_provider: CredentialsProvider = None, config: client.Config = None): if not config: config = client.Config({{range $args}}{{.}}={{.}}, {{end}} - credentials_strategy=credentials_strategy, credentials_provider=credentials_provider, debug_truncate_bytes=debug_truncate_bytes, debug_headers=debug_headers, @@ -103,12 +101,10 @@ class AccountClient: debug_headers: bool = None, product="unknown", product_version="0.0.0", - credentials_strategy: CredentialsStrategy = None, - credentials_provider: CredentialsStrategy = None, + credentials_provider: CredentialsProvider = None, config: client.Config = None): if not config: config = client.Config({{range $args}}{{.}}={{.}}, {{end}} - credentials_strategy=credentials_strategy, credentials_provider=credentials_provider, debug_truncate_bytes=debug_truncate_bytes, debug_headers=debug_headers, diff --git a/databricks/sdk/__init__.py b/databricks/sdk/__init__.py index fbe09693..312d538b 100755 --- a/databricks/sdk/__init__.py +++ b/databricks/sdk/__init__.py @@ -1,7 +1,7 @@ import databricks.sdk.core as client import databricks.sdk.dbutils as dbutils from databricks.sdk import azure -from databricks.sdk.credentials_provider import CredentialsStrategy +from databricks.sdk.credentials_provider import CredentialsProvider from databricks.sdk.mixins.compute import ClustersExt from databricks.sdk.mixins.files import DbfsExt from databricks.sdk.mixins.workspace import WorkspaceExt @@ -131,8 +131,7 @@ def __init__(self, debug_headers: bool = None, product="unknown", product_version="0.0.0", - credentials_strategy: CredentialsStrategy = None, - credentials_provider: CredentialsStrategy = None, + credentials_provider: CredentialsProvider = None, config: client.Config = None): if not config: config = client.Config(host=host, @@ -153,7 +152,6 @@ def __init__(self, cluster_id=cluster_id, google_credentials=google_credentials, google_service_account=google_service_account, - credentials_strategy=credentials_strategy, credentials_provider=credentials_provider, debug_truncate_bytes=debug_truncate_bytes, debug_headers=debug_headers, @@ -702,8 +700,7 @@ def __init__(self, debug_headers: bool = None, product="unknown", product_version="0.0.0", - credentials_strategy: CredentialsStrategy = None, - credentials_provider: CredentialsStrategy = None, + credentials_provider: CredentialsProvider = None, config: client.Config = None): if not config: config = client.Config(host=host, @@ -724,7 +721,6 @@ def __init__(self, cluster_id=cluster_id, google_credentials=google_credentials, google_service_account=google_service_account, - credentials_strategy=credentials_strategy, credentials_provider=credentials_provider, debug_truncate_bytes=debug_truncate_bytes, debug_headers=debug_headers, diff --git a/databricks/sdk/config.py b/databricks/sdk/config.py index f2d62ffa..d7d45c20 100644 --- a/databricks/sdk/config.py +++ b/databricks/sdk/config.py @@ -11,10 +11,10 @@ import requests from .clock import Clock, RealClock -from .credentials_provider import CredentialsStrategy, DefaultCredentials +from .credentials_provider import CredentialsProvider, DefaultCredentials from .environments import (ALL_ENVS, AzureEnvironment, Cloud, DatabricksEnvironment, get_environment_for_hostname) -from .oauth import OidcEndpoints, Token +from .oauth import OidcEndpoints from .version import __version__ logger = logging.getLogger('databricks.sdk') @@ -81,9 +81,7 @@ class Config: def __init__(self, *, - # Deprecated. Use credentials_strategy instead. - credentials_provider: CredentialsStrategy = None, - credentials_strategy: CredentialsStrategy = None, + credentials_provider: CredentialsProvider = None, product="unknown", product_version="0.0.0", clock: Clock = None, @@ -91,15 +89,7 @@ def __init__(self, self._header_factory = None self._inner = {} self._user_agent_other_info = [] - if credentials_strategy and credentials_provider: - raise ValueError( - "When providing `credentials_strategy` field, `credential_provider` cannot be specified.") - if credentials_provider: - logger.warning( - "parameter 'credentials_provider' is deprecated. Use 'credentials_strategy' instead.") - self._credentials_strategy = next( - s for s in [credentials_strategy, credentials_provider, - DefaultCredentials()] if s is not None) + self._credentials_provider = credentials_provider if credentials_provider else DefaultCredentials() if 'databricks_environment' in kwargs: self.databricks_environment = kwargs['databricks_environment'] del kwargs['databricks_environment'] @@ -117,9 +107,6 @@ def __init__(self, message = self.wrap_debug_info(str(e)) raise ValueError(message) from e - def oauth_token(self) -> Token: - return self._credentials_strategy.oauth_token(self) - def wrap_debug_info(self, message: str) -> str: debug_string = self.debug_string() if debug_string: @@ -449,12 +436,12 @@ def _validate(self): def init_auth(self): try: - self._header_factory = self._credentials_strategy(self) - self.auth_type = self._credentials_strategy.auth_type() + self._header_factory = self._credentials_provider(self) + self.auth_type = self._credentials_provider.auth_type() if not self._header_factory: raise ValueError('not configured') except ValueError as e: - raise ValueError(f'{self._credentials_strategy.auth_type()} auth: {e}') from e + raise ValueError(f'{self._credentials_provider.auth_type()} auth: {e}') from e def __repr__(self): return f'<{self.debug_string()}>' diff --git a/databricks/sdk/core.py b/databricks/sdk/core.py index cacbad90..658ed219 100644 --- a/databricks/sdk/core.py +++ b/databricks/sdk/core.py @@ -4,7 +4,6 @@ from json import JSONDecodeError from types import TracebackType from typing import Any, BinaryIO, Iterator, Type -from urllib.parse import urlencode from requests.adapters import HTTPAdapter @@ -14,17 +13,12 @@ from .credentials_provider import * from .errors import DatabricksError, error_mapper from .errors.private_link import _is_private_link_redirect -from .oauth import retrieve_token from .retries import retried __all__ = ['Config', 'DatabricksError'] logger = logging.getLogger('databricks.sdk') -URL_ENCODED_CONTENT_TYPE = "application/x-www-form-urlencoded" -JWT_BEARER_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:jwt-bearer" -OIDC_TOKEN_PATH = "/oidc/v1/token" - class ApiClient: _cfg: Config @@ -115,22 +109,6 @@ def flatten_dict(d: Dict[str, Any]) -> Dict[str, Any]: flattened = dict(flatten_dict(with_fixed_bools)) return flattened - def get_oauth_token(self, auth_details: str) -> Token: - if not self._cfg.auth_type: - self._cfg.authenticate() - original_token = self._cfg.oauth_token() - headers = {"Content-Type": URL_ENCODED_CONTENT_TYPE} - params = urlencode({ - "grant_type": JWT_BEARER_GRANT_TYPE, - "authorization_details": auth_details, - "assertion": original_token.access_token - }) - return retrieve_token(client_id=self._cfg.client_id, - client_secret=self._cfg.client_secret, - token_url=self._cfg.host + OIDC_TOKEN_PATH, - params=params, - headers=headers) - def do(self, method: str, path: str, diff --git a/databricks/sdk/credentials_provider.py b/databricks/sdk/credentials_provider.py index d7173ab3..02727599 100644 --- a/databricks/sdk/credentials_provider.py +++ b/databricks/sdk/credentials_provider.py @@ -22,26 +22,12 @@ from .oauth import (ClientCredentials, OAuthClient, Refreshable, Token, TokenCache, TokenSource) -CredentialsProvider = Callable[[], Dict[str, str]] +HeaderFactory = Callable[[], Dict[str, str]] logger = logging.getLogger('databricks.sdk') -class OAuthCredentialsProvider: - """ OAuthCredentialsProvider is a type of CredentialsProvider which exposes OAuth tokens. """ - - def __init__(self, credentials_provider: CredentialsProvider, token_provider: Callable[[], Token]): - self._credentials_provider = credentials_provider - self._token_provider = token_provider - - def __call__(self) -> Dict[str, str]: - return self._credentials_provider() - - def oauth_token(self) -> Token: - return self._token_provider() - - -class CredentialsStrategy(abc.ABC): +class CredentialsProvider(abc.ABC): """ CredentialsProvider is the protocol (call-side interface) for authenticating requests to Databricks REST APIs""" @@ -50,39 +36,20 @@ def auth_type(self) -> str: ... @abc.abstractmethod - def __call__(self, cfg: 'Config') -> CredentialsProvider: + def __call__(self, cfg: 'Config') -> HeaderFactory: ... -class OauthCredentialsStrategy(CredentialsStrategy): - """ OauthCredentialsProvider is a CredentialsProvider which - supports Oauth tokens""" - - def __init__(self, auth_type: str, headers_provider: Callable[['Config'], OAuthCredentialsProvider]): - self._headers_provider = headers_provider - self._auth_type = auth_type - - def auth_type(self) -> str: - return self._auth_type - - def __call__(self, cfg: 'Config') -> OAuthCredentialsProvider: - return self._headers_provider(cfg) - - def oauth_token(self, cfg: 'Config') -> Token: - return self._headers_provider(cfg).oauth_token() - - -def credentials_strategy(name: str, require: List[str]): +def credentials_provider(name: str, require: List[str]): """ Given the function that receives a Config and returns RequestVisitor, create CredentialsProvider with a given name and required configuration attribute names to be present for this function to be called. """ - def inner(func: Callable[['Config'], CredentialsProvider]) -> CredentialsStrategy: + def inner(func: Callable[['Config'], HeaderFactory]) -> CredentialsProvider: @functools.wraps(func) - def wrapper(cfg: 'Config') -> Optional[CredentialsProvider]: + def wrapper(cfg: 'Config') -> Optional[HeaderFactory]: for attr in require: - getattr(cfg, attr) if not getattr(cfg, attr): return None return func(cfg) @@ -93,27 +60,8 @@ def wrapper(cfg: 'Config') -> Optional[CredentialsProvider]: return inner -def oauth_credentials_strategy(name: str, require: List[str]): - """ Given the function that receives a Config and returns an OauthHeaderFactory, - create an OauthCredentialsProvider with a given name and required configuration - attribute names to be present for this function to be called. """ - - def inner(func: Callable[['Config'], OAuthCredentialsProvider]) -> OauthCredentialsStrategy: - - @functools.wraps(func) - def wrapper(cfg: 'Config') -> Optional[OAuthCredentialsProvider]: - for attr in require: - if not getattr(cfg, attr): - return None - return func(cfg) - - return OauthCredentialsStrategy(name, wrapper) - - return inner - - -@credentials_strategy('basic', ['host', 'username', 'password']) -def basic_auth(cfg: 'Config') -> CredentialsProvider: +@credentials_provider('basic', ['host', 'username', 'password']) +def basic_auth(cfg: 'Config') -> HeaderFactory: """ Given username and password, add base64-encoded Basic credentials """ encoded = base64.b64encode(f'{cfg.username}:{cfg.password}'.encode()).decode() static_credentials = {'Authorization': f'Basic {encoded}'} @@ -124,8 +72,8 @@ def inner() -> Dict[str, str]: return inner -@credentials_strategy('pat', ['host', 'token']) -def pat_auth(cfg: 'Config') -> CredentialsProvider: +@credentials_provider('pat', ['host', 'token']) +def pat_auth(cfg: 'Config') -> HeaderFactory: """ Adds Databricks Personal Access Token to every request """ static_credentials = {'Authorization': f'Bearer {cfg.token}'} @@ -135,8 +83,8 @@ def inner() -> Dict[str, str]: return inner -@credentials_strategy('runtime', []) -def runtime_native_auth(cfg: 'Config') -> Optional[CredentialsProvider]: +@credentials_provider('runtime', []) +def runtime_native_auth(cfg: 'Config') -> Optional[HeaderFactory]: if 'DATABRICKS_RUNTIME_VERSION' not in os.environ: return None @@ -159,8 +107,8 @@ def runtime_native_auth(cfg: 'Config') -> Optional[CredentialsProvider]: return None -@oauth_credentials_strategy('oauth-m2m', ['host', 'client_id', 'client_secret']) -def oauth_service_principal(cfg: 'Config') -> Optional[CredentialsProvider]: +@credentials_provider('oauth-m2m', ['host', 'client_id', 'client_secret']) +def oauth_service_principal(cfg: 'Config') -> Optional[HeaderFactory]: """ Adds refreshed Databricks machine-to-machine OAuth Bearer token to every request, if /oidc/.well-known/oauth-authorization-server is available on the given host. """ oidc = cfg.oidc_endpoints @@ -176,11 +124,11 @@ def inner() -> Dict[str, str]: token = token_source.token() return {'Authorization': f'{token.token_type} {token.access_token}'} - return OAuthCredentialsProvider(inner, token_source.token) + return inner -@credentials_strategy('external-browser', ['host', 'auth_type']) -def external_browser(cfg: 'Config') -> Optional[CredentialsProvider]: +@credentials_provider('external-browser', ['host', 'auth_type']) +def external_browser(cfg: 'Config') -> Optional[HeaderFactory]: if cfg.auth_type != 'external-browser': return None if cfg.client_id: @@ -230,9 +178,9 @@ def _ensure_host_present(cfg: 'Config', token_source_for: Callable[[str], TokenS cfg.host = f"https://{resp.json()['properties']['workspaceUrl']}" -@oauth_credentials_strategy('azure-client-secret', - ['is_azure', 'azure_client_id', 'azure_client_secret', 'azure_tenant_id']) -def azure_service_principal(cfg: 'Config') -> CredentialsProvider: +@credentials_provider('azure-client-secret', + ['is_azure', 'azure_client_id', 'azure_client_secret', 'azure_tenant_id']) +def azure_service_principal(cfg: 'Config') -> HeaderFactory: """ Adds refreshed Azure Active Directory (AAD) Service Principal OAuth tokens to every request, while automatically resolving different Azure environment endpoints. """ @@ -255,11 +203,11 @@ def refreshed_headers() -> Dict[str, str]: add_sp_management_token(cloud, headers) return headers - return OAuthCredentialsProvider(refreshed_headers, inner.token) + return refreshed_headers -@oauth_credentials_strategy('github-oidc-azure', ['host', 'azure_client_id']) -def github_oidc_azure(cfg: 'Config') -> Optional[CredentialsProvider]: +@credentials_provider('github-oidc-azure', ['host', 'azure_client_id']) +def github_oidc_azure(cfg: 'Config') -> Optional[HeaderFactory]: if 'ACTIONS_ID_TOKEN_REQUEST_TOKEN' not in os.environ: # not in GitHub actions return None @@ -302,14 +250,14 @@ def refreshed_headers() -> Dict[str, str]: token = inner.token() return {'Authorization': f'{token.token_type} {token.access_token}'} - return OAuthCredentialsProvider(refreshed_headers, inner.token) + return refreshed_headers GcpScopes = ["https://www.googleapis.com/auth/cloud-platform", "https://www.googleapis.com/auth/compute"] -@oauth_credentials_strategy('google-credentials', ['host', 'google_credentials']) -def google_credentials(cfg: 'Config') -> Optional[CredentialsProvider]: +@credentials_provider('google-credentials', ['host', 'google_credentials']) +def google_credentials(cfg: 'Config') -> Optional[HeaderFactory]: if not cfg.is_gcp: return None # Reads credentials as JSON. Credentials can be either a path to JSON file, or actual JSON string. @@ -329,10 +277,6 @@ def google_credentials(cfg: 'Config') -> Optional[CredentialsProvider]: gcp_credentials = service_account.Credentials.from_service_account_info(info=account_info, scopes=GcpScopes) - def token() -> Token: - credentials.refresh(request) - return credentials.token - def refreshed_headers() -> Dict[str, str]: credentials.refresh(request) headers = {'Authorization': f'Bearer {credentials.token}'} @@ -341,11 +285,11 @@ def refreshed_headers() -> Dict[str, str]: headers["X-Databricks-GCP-SA-Access-Token"] = gcp_credentials.token return headers - return OAuthCredentialsProvider(refreshed_headers, token) + return refreshed_headers -@oauth_credentials_strategy('google-id', ['host', 'google_service_account']) -def google_id(cfg: 'Config') -> Optional[CredentialsProvider]: +@credentials_provider('google-id', ['host', 'google_service_account']) +def google_id(cfg: 'Config') -> Optional[HeaderFactory]: if not cfg.is_gcp: return None credentials, _project_id = google.auth.default() @@ -365,10 +309,6 @@ def google_id(cfg: 'Config') -> Optional[CredentialsProvider]: request = Request() - def token() -> Token: - id_creds.refresh(request) - return id_creds.token - def refreshed_headers() -> Dict[str, str]: id_creds.refresh(request) headers = {'Authorization': f'Bearer {id_creds.token}'} @@ -377,7 +317,7 @@ def refreshed_headers() -> Dict[str, str]: headers["X-Databricks-GCP-SA-Access-Token"] = gcp_impersonated_credentials.token return headers - return OAuthCredentialsProvider(refreshed_headers, token) + return refreshed_headers class CliTokenSource(Refreshable): @@ -482,8 +422,8 @@ def get_subscription(cfg: 'Config') -> str: return components[2] -@credentials_strategy('azure-cli', ['is_azure']) -def azure_cli(cfg: 'Config') -> Optional[CredentialsProvider]: +@credentials_provider('azure-cli', ['is_azure']) +def azure_cli(cfg: 'Config') -> Optional[HeaderFactory]: """ Adds refreshed OAuth token granted by `az login` command to every request. """ token_source = None mgmt_token_source = None @@ -576,8 +516,8 @@ def _find_executable(name) -> str: raise err -@oauth_credentials_strategy('databricks-cli', ['host']) -def databricks_cli(cfg: 'Config') -> Optional[CredentialsProvider]: +@credentials_provider('databricks-cli', ['host']) +def databricks_cli(cfg: 'Config') -> Optional[HeaderFactory]: try: token_source = DatabricksCliTokenSource(cfg) except FileNotFoundError as e: @@ -598,7 +538,7 @@ def inner() -> Dict[str, str]: token = token_source.token() return {'Authorization': f'{token.token_type} {token.access_token}'} - return OAuthCredentialsProvider(inner, token_source.token) + return inner class MetadataServiceTokenSource(Refreshable): @@ -637,8 +577,8 @@ def refresh(self) -> Token: return Token(access_token=access_token, token_type=token_type, expiry=expiry) -@credentials_strategy('metadata-service', ['host', 'metadata_service_url']) -def metadata_service(cfg: 'Config') -> Optional[CredentialsProvider]: +@credentials_provider('metadata-service', ['host', 'metadata_service_url']) +def metadata_service(cfg: 'Config') -> Optional[HeaderFactory]: """ Adds refreshed token granted by Databricks Metadata Service to every request. """ token_source = MetadataServiceTokenSource(cfg) @@ -657,25 +597,17 @@ class DefaultCredentials: def __init__(self) -> None: self._auth_type = 'default' - self._auth_providers = [ - pat_auth, basic_auth, metadata_service, oauth_service_principal, azure_service_principal, - github_oidc_azure, azure_cli, external_browser, databricks_cli, runtime_native_auth, - google_credentials, google_id - ] def auth_type(self) -> str: return self._auth_type - def oauth_token(self, cfg: 'Config') -> Token: - for provider in self._auth_providers: - auth_type = provider.auth_type() - if auth_type != self._auth_type: - # ignore other auth types if they don't match the selected one - continue - return provider.oauth_token(cfg) - - def __call__(self, cfg: 'Config') -> CredentialsProvider: - for provider in self._auth_providers: + def __call__(self, cfg: 'Config') -> HeaderFactory: + auth_providers = [ + pat_auth, basic_auth, metadata_service, oauth_service_principal, azure_service_principal, + github_oidc_azure, azure_cli, external_browser, databricks_cli, runtime_native_auth, + google_credentials, google_id + ] + for provider in auth_providers: auth_type = provider.auth_type() if cfg.auth_type and auth_type != cfg.auth_type: # ignore other auth types if one is explicitly enforced diff --git a/databricks/sdk/oauth.py b/databricks/sdk/oauth.py index 38da6c03..07bc5551 100644 --- a/databricks/sdk/oauth.py +++ b/databricks/sdk/oauth.py @@ -21,10 +21,6 @@ # See https://stackoverflow.com/a/75466778/277035 for more info NO_ORIGIN_FOR_SPA_CLIENT_ERROR = 'AADSTS9002327' -URL_ENCODED_CONTENT_TYPE = "application/x-www-form-urlencoded" -JWT_BEARER_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:jwt-bearer" -OIDC_TOKEN_PATH = "/oidc/v1/token" - logger = logging.getLogger(__name__) @@ -362,13 +358,13 @@ def __init__(self, client_secret: str = None): # TODO: is it a circular dependency?.. from .core import Config - from .credentials_provider import credentials_strategy + from .credentials_provider import credentials_provider - @credentials_strategy('noop', []) + @credentials_provider('noop', []) def noop_credentials(_: any): return lambda: {} - config = Config(host=host, credentials_strategy=noop_credentials) + config = Config(host=host, credentials_provider=noop_credentials) if not scopes: scopes = ['all-apis'] if config.is_azure: diff --git a/databricks/sdk/service/sql.py b/databricks/sdk/service/sql.py index 557b53e3..bfef44af 100755 --- a/databricks/sdk/service/sql.py +++ b/databricks/sdk/service/sql.py @@ -360,6 +360,7 @@ def from_dict(cls, d: Dict[str, any]) -> ChannelInfo: class ChannelName(Enum): + """Name of the channel""" CHANNEL_NAME_CURRENT = 'CHANNEL_NAME_CURRENT' CHANNEL_NAME_CUSTOM = 'CHANNEL_NAME_CUSTOM' diff --git a/docs/dbdataclasses/sql.rst b/docs/dbdataclasses/sql.rst index adf3ced5..fe1469a3 100644 --- a/docs/dbdataclasses/sql.rst +++ b/docs/dbdataclasses/sql.rst @@ -64,6 +64,8 @@ These dataclasses are used in the SDK to represent API requests and responses fo .. py:class:: ChannelName + Name of the channel + .. py:attribute:: CHANNEL_NAME_CURRENT :value: "CHANNEL_NAME_CURRENT" diff --git a/docs/gen-client-docs.py b/docs/gen-client-docs.py index 79396152..5322a25b 100644 --- a/docs/gen-client-docs.py +++ b/docs/gen-client-docs.py @@ -13,7 +13,7 @@ from typing import Optional, Any, get_args from databricks.sdk import AccountClient, WorkspaceClient -from databricks.sdk.core import credentials_strategy +from databricks.sdk.core import credentials_provider __dir__ = os.path.dirname(__file__) __examples__ = Path(f'{__dir__}/../examples').absolute() @@ -451,7 +451,7 @@ def _write_client_package_doc(self, folder: str, pkg: Package, services: list[st if __name__ == '__main__': - @credentials_strategy('noop', []) + @credentials_provider('noop', []) def noop_credentials(_: any): return lambda: {} diff --git a/examples/custom_auth.py b/examples/custom_auth.py index 4d335164..02747011 100644 --- a/examples/custom_auth.py +++ b/examples/custom_auth.py @@ -1,10 +1,10 @@ -from databricks.sdk.core import (ApiClient, Config, CredentialsProvider, - credentials_strategy) +from databricks.sdk.core import (ApiClient, Config, HeaderFactory, + credentials_provider) from databricks.sdk.service.iam import CurrentUserAPI -@credentials_strategy("custom", ["host"]) -def user_input_token(cfg: Config) -> CredentialsProvider: +@credentials_provider("custom", ["host"]) +def user_input_token(cfg: Config) -> HeaderFactory: pat = input("Enter Databricks PAT: ") def inner() -> dict[str, str]: diff --git a/tests/conftest.py b/tests/conftest.py index ef740662..80753ae9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,17 +5,17 @@ from pyfakefs.fake_filesystem_unittest import Patcher from databricks.sdk.core import Config -from databricks.sdk.credentials_provider import credentials_strategy +from databricks.sdk.credentials_provider import credentials_provider -@credentials_strategy('noop', []) +@credentials_provider('noop', []) def noop_credentials(_: any): return lambda: {} @pytest.fixture def config(): - return Config(host='http://localhost', credentials_strategy=noop_credentials) + return Config(host='http://localhost', credentials_provider=noop_credentials) @pytest.fixture diff --git a/tests/test_config.py b/tests/test_config.py index 22a4b71c..1411e5be 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -4,13 +4,6 @@ def test_config_copy_preserves_product_and_product_version(): - c = Config(credentials_strategy=noop_credentials, product='foo', product_version='1.2.3') - c2 = c.copy() - assert c2._product == 'foo' - assert c2._product_version == '1.2.3' - - -def test_config_supports_legacy_credentials_provider(): c = Config(credentials_provider=noop_credentials, product='foo', product_version='1.2.3') c2 = c.copy() assert c2._product == 'foo' diff --git a/tests/test_core.py b/tests/test_core.py index 74d38bd3..fd2f12ad 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -18,9 +18,8 @@ StreamingResponse) from databricks.sdk.credentials_provider import (CliTokenSource, CredentialsProvider, - CredentialsStrategy, DatabricksCliTokenSource, - databricks_cli) + HeaderFactory, databricks_cli) from databricks.sdk.environments import (ENVIRONMENTS, AzureEnvironment, Cloud, DatabricksEnvironment) from databricks.sdk.service.catalog import PermissionsChange @@ -208,7 +207,7 @@ def system(self): def test_config_copy_shallow_copies_credential_provider(): - class TestCredentialsStrategy(CredentialsStrategy): + class TestCredentialsProvider(CredentialsProvider): def __init__(self): super().__init__() @@ -217,24 +216,24 @@ def __init__(self): def auth_type(self) -> str: return "test" - def __call__(self, cfg: 'Config') -> CredentialsProvider: + def __call__(self, cfg: 'Config') -> HeaderFactory: return lambda: {"token": self._token} def refresh(self): self._token = "token2" - credentials_strategy = TestCredentialsStrategy() - config = Config(credentials_strategy=credentials_strategy) + credential_provider = TestCredentialsProvider() + config = Config(credentials_provider=credential_provider) config_copy = config.copy() assert config.authenticate()["token"] == "token1" assert config_copy.authenticate()["token"] == "token1" - credentials_strategy.refresh() + credential_provider.refresh() assert config.authenticate()["token"] == "token2" assert config_copy.authenticate()["token"] == "token2" - assert config._credentials_strategy == config_copy._credentials_strategy + assert config._credentials_provider == config_copy._credentials_provider def test_config_copy_deep_copies_user_agent_other_info(config): @@ -278,7 +277,7 @@ def __init__(self): def test_config_parsing_non_string_env_vars(monkeypatch): monkeypatch.setenv('DATABRICKS_DEBUG_TRUNCATE_BYTES', '100') - c = Config(host='http://localhost', credentials_strategy=noop_credentials) + c = Config(host='http://localhost', credentials_provider=noop_credentials) assert c.debug_truncate_bytes == 100 @@ -593,7 +592,7 @@ def inner(h: BaseHTTPRequestHandler): ('AzureUSGovernment', ENVIRONMENTS['USGOVERNMENT']), ('AzureChinaCloud', ENVIRONMENTS['CHINA']), ]) def test_azure_environment(azure_environment, expected): - c = Config(credentials_strategy=noop_credentials, + c = Config(credentials_provider=noop_credentials, azure_workspace_resource_id='...', azure_environment=azure_environment) assert c.arm_environment == expected diff --git a/tests/test_dbutils.py b/tests/test_dbutils.py index 15b26a54..4f80e91b 100644 --- a/tests/test_dbutils.py +++ b/tests/test_dbutils.py @@ -140,7 +140,7 @@ def assertions(): command=expect_command) dbutils = RemoteDbUtils( - Config(host='http://localhost', cluster_id='x', credentials_strategy=noop_credentials)) + Config(host='http://localhost', cluster_id='x', credentials_provider=noop_credentials)) return dbutils, assertions return inner