Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Infer Azure tenant ID if not set #638

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
20 changes: 20 additions & 0 deletions databricks/sdk/azure.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import logging
from dataclasses import dataclass
from typing import Dict
from urllib import parse

import requests

from .oauth import TokenSource
from .service.provisioning import Workspace


logger = logging.getLogger(__name__)

@dataclass
class AzureEnvironment:
name: str
Expand Down Expand Up @@ -52,3 +58,17 @@ def get_azure_resource_id(workspace: Workspace):
return (f'/subscriptions/{workspace.azure_workspace_info.subscription_id}'
f'/resourceGroups/{workspace.azure_workspace_info.resource_group}'
f'/providers/Microsoft.Databricks/workspaces/{workspace.workspace_name}')


def _load_azure_tenant_id(cfg: 'Config'):
if not cfg.is_azure or cfg.azure_tenant_id is not None or cfg.host is None:
return
logging.debug(f'Loading tenant ID from {cfg.host}/aad/auth')
resp = requests.get(f'{cfg.host}/aad/auth', allow_redirects=False)
entra_id_endpoint = resp.headers.get('Location')
if entra_id_endpoint is None:
logging.debug(f'No Location header in response from {cfg.host}/aad/auth')
return
url = parse.urlparse(entra_id_endpoint)
cfg.azure_tenant_id = url.path.split('/')[1]
logging.debug(f'Loaded tenant ID: {cfg.azure_tenant_id}')
20 changes: 9 additions & 11 deletions databricks/sdk/credentials_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from google.auth.transport.requests import Request
from google.oauth2 import service_account

from .azure import add_sp_management_token, add_workspace_id_header
from .azure import add_sp_management_token, add_workspace_id_header, _load_azure_tenant_id
from .oauth import (ClientCredentials, OAuthClient, Refreshable, Token,
TokenCache, TokenSource)

Expand Down Expand Up @@ -179,11 +179,10 @@ def _ensure_host_present(cfg: 'Config', token_source_for: Callable[[str], TokenS


@credentials_provider('azure-client-secret',
['is_azure', 'azure_client_id', 'azure_client_secret', 'azure_tenant_id'])
['is_azure', 'azure_client_id', 'azure_client_secret'])
def azure_service_principal(cfg: 'Config') -> HeaderFactory:
""" Adds refreshed Azure Active Directory (AAD) Service Principal OAuth tokens
to every request, while automatically resolving different Azure environment endpoints. """

def token_source_for(resource: str) -> TokenSource:
aad_endpoint = cfg.arm_environment.active_directory_endpoint
return ClientCredentials(client_id=cfg.azure_client_id,
Expand All @@ -192,6 +191,7 @@ def token_source_for(resource: str) -> TokenSource:
endpoint_params={"resource": resource},
use_params=True)

_load_azure_tenant_id(cfg)
_ensure_host_present(cfg, token_source_for)
logger.info("Configured AAD token for Service Principal (%s)", cfg.azure_client_id)
inner = token_source_for(cfg.effective_azure_login_app_id)
Expand Down Expand Up @@ -363,11 +363,13 @@ def refresh(self) -> Token:
class AzureCliTokenSource(CliTokenSource):
""" Obtain the token granted by `az login` CLI command """

def __init__(self, resource: str, subscription: str = ""):
def __init__(self, resource: str, subscription: str = "", tenant: str = None):
cmd = ["az", "account", "get-access-token", "--resource", resource, "--output", "json"]
if subscription != "":
cmd.append("--subscription")
cmd.append(subscription)
if tenant:
cmd.extend(["--tenant", tenant])
super().__init__(cmd=cmd,
token_type_field='tokenType',
access_token_field='accessToken',
Expand Down Expand Up @@ -395,7 +397,7 @@ def is_human_user(self) -> bool:
@staticmethod
def for_resource(cfg: 'Config', resource: str) -> 'AzureCliTokenSource':
subscription = AzureCliTokenSource.get_subscription(cfg)
if subscription != "":
if cfg.azure_tenant_id == "" and subscription != "":
token_source = AzureCliTokenSource(resource, subscription)
try:
# This will fail if the user has access to the workspace, but not to the subscription
Expand All @@ -406,7 +408,7 @@ def for_resource(cfg: 'Config', resource: str) -> 'AzureCliTokenSource':
except OSError:
logger.warning("Failed to get token for subscription. Using resource only token.")

token_source = AzureCliTokenSource(resource)
token_source = AzureCliTokenSource(resource, cfg.azure_tenant_id)
token_source.token()
return token_source

Expand All @@ -425,6 +427,7 @@ def get_subscription(cfg: 'Config') -> str:
@credentials_provider('azure-cli', ['is_azure'])
def azure_cli(cfg: 'Config') -> Optional[HeaderFactory]:
""" Adds refreshed OAuth token granted by `az login` command to every request. """
_load_azure_tenant_id(cfg)
token_source = None
mgmt_token_source = None
try:
Expand All @@ -448,11 +451,6 @@ def azure_cli(cfg: 'Config') -> Optional[HeaderFactory]:

_ensure_host_present(cfg, lambda resource: AzureCliTokenSource.for_resource(cfg, resource))
logger.info("Using Azure CLI authentication with AAD tokens")
if not cfg.is_account_client and AzureCliTokenSource.get_subscription(cfg) == "":
logger.warning(
"azure_workspace_resource_id field not provided. "
"It is recommended to specify this field in the Databricks configuration to avoid authentication errors."
)

def inner() -> Dict[str, str]:
token = token_source.token()
Expand Down
20 changes: 20 additions & 0 deletions tests/test_azure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from databricks.sdk.config import Config
import os

__tests__ = os.path.dirname(__file__)


def test_load_azure_tenant_id(requests_mock, monkeypatch):
monkeypatch.setenv('PATH', __tests__ + '/testdata:/bin')
mock = requests_mock.get('https://abc123.azuredatabricks.net/aad/auth', status_code=302, headers={'Location': 'https://login.microsoftonline.com/abc123xyz/oauth2/authorize'})
cfg = Config(host="https://abc123.azuredatabricks.net")
assert cfg.azure_tenant_id == 'abc123xyz'
assert mock.called_once


def test_load_azure_tenant_id_tenant_id_set(requests_mock, monkeypatch):
monkeypatch.setenv('PATH', __tests__ + '/testdata:/bin')
mock = requests_mock.get('https://abc123.azuredatabricks.net/aad/auth', status_code=302, headers={'Location': 'https://login.microsoftonline.com/abc123xyz/oauth2/authorize'})
cfg = Config(host="https://abc123.azuredatabricks.net", azure_tenant_id="123456789")
assert cfg.azure_tenant_id == '123456789'
assert mock.call_count == 0