Skip to content

Commit

Permalink
[feat] OAuth support (#307) (#327)
Browse files Browse the repository at this point in the history
Signed-off-by: Andre Furlan <andre.furlan@databricks.com>
Signed-off-by: Jesse Whitehouse <jesse.whitehouse@databricks.com>
Co-authored-by: Jesse <jesse.whitehouse@databricks.com>
  • Loading branch information
andrefurlan-db and susodapop committed May 2, 2023
1 parent 418401f commit f1671ed
Show file tree
Hide file tree
Showing 12 changed files with 407 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ test.env
.vscode
*.log
logs/
.venv
.venv*
78 changes: 78 additions & 0 deletions dbt/adapters/databricks/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from typing import Any, Dict, Optional
from databricks.sdk.oauth import ClientCredentials, Token, TokenSource
from databricks.sdk.core import CredentialsProvider, HeaderFactory, Config, credentials_provider


class token_auth(CredentialsProvider):
_token: str

def __init__(self, token: str) -> None:
self._token = token

def auth_type(self) -> str:
return "token"

def as_dict(self) -> dict:
return {"token": self._token}

@staticmethod
def from_dict(raw: Optional[dict]) -> CredentialsProvider:
if not raw:
return None
return token_auth(raw["token"])

def __call__(self, *args: tuple, **kwargs: Dict[str, Any]) -> HeaderFactory:
static_credentials = {"Authorization": f"Bearer {self._token}"}

def inner() -> Dict[str, str]:
return static_credentials

return inner


class m2m_auth(CredentialsProvider):
_token_source: TokenSource = None

def __init__(self, host: str, client_id: str, client_secret: str) -> None:
@credentials_provider("noop", [])
def noop_credentials(_: Any): # type: ignore
return lambda: {}

config = Config(host=host, credentials_provider=noop_credentials)
oidc = config.oidc_endpoints
scopes = ["offline_access", "all-apis"]
if not oidc:
raise ValueError(f"{host} does not support OAuth")
if config.is_azure:
# Azure AD only supports full access to Azure Databricks.
scopes = [f"{config.effective_azure_login_app_id}/.default", "offline_access"]
self._token_source = ClientCredentials(
client_id=client_id,
client_secret=client_secret,
token_url=oidc.token_endpoint,
scopes=scopes,
use_header="microsoft" not in oidc.token_endpoint,
use_params="microsoft" in oidc.token_endpoint,
)

def auth_type(self) -> str:
return "oauth"

def as_dict(self) -> dict:
if self._token_source:
return {"token": self._token_source.token().as_dict()}
else:
return {"token": {}}

@staticmethod
def from_dict(host: str, client_id: str, client_secret: str, raw: dict) -> CredentialsProvider:
c = m2m_auth(host=host, client_id=client_id, client_secret=client_secret)
c._token_source._token = Token.from_dict(raw["token"])
return c

def __call__(self, *args: tuple, **kwargs: Dict[str, Any]) -> HeaderFactory:
def inner() -> Dict[str, str]:
token = self._token_source.token()
return {"Authorization": f"{token.token_type} {token.access_token}"}

return inner
133 changes: 130 additions & 3 deletions dbt/adapters/databricks/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
import re
import sys
import threading
import time
from typing import (
Any,
Expand Down Expand Up @@ -49,6 +50,12 @@
from dbt.adapters.databricks.__version__ import version as __version__
from dbt.adapters.databricks.utils import redact_credentials

from databricks.sdk.core import CredentialsProvider
from databricks.sdk.oauth import OAuthClient, RefreshableCredentials
from dbt.adapters.databricks.auth import token_auth, m2m_auth

import keyring

logger = AdapterLogger("Databricks")

CATALOG_KEY_IN_SESSION_PROPERTIES = "databricks.catalog"
Expand All @@ -58,20 +65,30 @@
EXTRACT_CLUSTER_ID_FROM_HTTP_PATH_REGEX = re.compile(r"/?sql/protocolv1/o/\d+/(.*)")
DBT_DATABRICKS_HTTP_SESSION_HEADERS = "DBT_DATABRICKS_HTTP_SESSION_HEADERS"

REDIRECT_URL = "http://localhost:8020"
CLIENT_ID = "dbt-databricks"
SCOPES = ["all-apis", "offline_access"]


@dataclass
class DatabricksCredentials(Credentials):
database: Optional[str] # type: ignore[assignment]
host: Optional[str] = None
http_path: Optional[str] = None
token: Optional[str] = None
client_id: Optional[str] = None
client_secret: Optional[str] = None
session_properties: Optional[Dict[str, Any]] = None
connection_parameters: Optional[Dict[str, Any]] = None
auth_type: Optional[str] = None

connect_retries: int = 1
connect_timeout: Optional[int] = None
retry_all: bool = False

_credentials_provider: Optional[Dict[str, Any]] = None
_lock = threading.Lock() # to avoid concurrent auth

_ALIASES = {
"catalog": "database",
"target_catalog": "target_database",
Expand Down Expand Up @@ -116,6 +133,8 @@ def __post_init__(self) -> None:
"server_hostname",
"http_path",
"access_token",
"client_id",
"client_secret",
"session_configuration",
"catalog",
"schema",
Expand All @@ -138,11 +157,23 @@ def __post_init__(self) -> None:
self.connection_parameters = connection_parameters

def validate_creds(self) -> None:
for key in ["host", "http_path", "token"]:
for key in ["host", "http_path"]:
if not getattr(self, key):
raise dbt.exceptions.DbtProfileError(
"The config '{}' is required to connect to Databricks".format(key)
)
if not self.token and self.auth_type != "oauth":
raise dbt.exceptions.DbtProfileError(
("The config `auth_type: oauth` is required when not using access token")
)

if not self.client_id and self.client_secret:
raise dbt.exceptions.DbtProfileError(
(
"The config 'client_id' is required to connect "
"to Databricks when 'client_secret' is present"
)
)

@classmethod
def get_invocation_env(cls) -> Optional[str]:
Expand Down Expand Up @@ -232,6 +263,100 @@ def extract_cluster_id(cls, http_path: str) -> Optional[str]:
def cluster_id(self) -> Optional[str]:
return self.extract_cluster_id(self.http_path) # type: ignore[arg-type]

def authenticate(self, in_provider: CredentialsProvider) -> CredentialsProvider:
self.validate_creds()
host: str = self.host or ""
if self._credentials_provider:
return self._provider_from_dict()
if in_provider:
self._credentials_provider = in_provider.as_dict()
return in_provider

# dbt will spin up multiple threads. This has to be sync. So lock here
self._lock.acquire()
try:
if self.token:
provider = token_auth(self.token)
self._credentials_provider = provider.as_dict()
return provider

if self.client_id and self.client_secret:
provider = m2m_auth(
host=host,
client_id=self.client_id or "",
client_secret=self.client_secret or "",
)
self._credentials_provider = provider.as_dict()
return provider

oauth_client = OAuthClient(
host=host,
client_id=self.client_id if self.client_id else CLIENT_ID,
client_secret=None,
redirect_url=REDIRECT_URL,
scopes=SCOPES,
)
# optional branch. Try and keep going if it does not work
try:
# try to get cached credentials
credsdict = keyring.get_password("dbt-databricks", host)

if credsdict:
provider = RefreshableCredentials.from_dict(oauth_client, json.loads(credsdict))
# if refresh token is expired, this will throw
try:
if provider.token().valid:
return provider
except Exception as e:
logger.debug(e)
# whatever it is, get rid of the cache
keyring.delete_password("dbt-databricks", host)

# error with keyring. Maybe machine has no password persistency
except Exception as e:
logger.debug(e)
logger.info("could not retrieved saved token")

# no token, go fetch one
consent = oauth_client.initiate_consent()

provider = consent.launch_external_browser()
# save for later
self._credentials_provider = provider.as_dict()
try:
keyring.set_password("dbt-databricks", host, json.dumps(self._credentials_provider))
# error with keyring. Maybe machine has no password persistency
except Exception as e:
logger.debug(e)
logger.info("could not save token")

return provider

finally:
self._lock.release()

def _provider_from_dict(self) -> CredentialsProvider:
if self.token:
return token_auth.from_dict(self._credentials_provider)

if self.client_id and self.client_secret:
return m2m_auth.from_dict(
host=self.host or "",
client_id=self.client_id or "",
client_secret=self.client_secret or "",
raw=self._credentials_provider or {"token": {}},
)

oauth_client = OAuthClient(
host=self.host,
client_id=CLIENT_ID,
client_secret=None,
redirect_url=REDIRECT_URL,
scopes=SCOPES,
)

return RefreshableCredentials.from_dict(client=oauth_client, raw=self._credentials_provider)


class DatabricksSQLConnectionWrapper:
"""Wrap a Databricks SQL connector in a way that no-ops transactions"""
Expand Down Expand Up @@ -404,6 +529,7 @@ def _get_comment_macro(self) -> Optional[str]:

class DatabricksConnectionManager(SparkConnectionManager):
TYPE: str = "databricks"
credentials_provider: CredentialsProvider = None

def compare_dbr_version(self, major: int, minor: int) -> int:
version = (major, minor)
Expand Down Expand Up @@ -549,7 +675,8 @@ def open(cls, connection: Connection) -> Connection:
creds: DatabricksCredentials = connection.credentials
timeout = creds.connect_timeout

creds.validate_creds()
# gotta keep this so we don't prompt users many times
cls.credentials_provider = creds.authenticate(cls.credentials_provider)

user_agent_entry = f"dbt-databricks/{__version__}"

Expand All @@ -569,7 +696,7 @@ def connect() -> DatabricksSQLConnectionWrapper:
conn: DatabricksSQLConnection = dbsql.connect(
server_hostname=creds.host,
http_path=creds.http_path,
access_token=creds.token,
credentials_provider=cls.credentials_provider,
http_headers=http_headers if http_headers else None,
session_configuration=creds.session_properties,
catalog=creds.database,
Expand Down
7 changes: 6 additions & 1 deletion dbt/adapters/databricks/python_submissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import dbt.exceptions
from dbt.adapters.base import PythonJobHelper
from dbt.adapters.spark import __version__
from databricks.sdk.core import CredentialsProvider

logger = AdapterLogger("Databricks")

Expand Down Expand Up @@ -381,6 +382,7 @@ def submit(self, compiled_code: str) -> None:

class DbtDatabricksBasePythonJobHelper(BaseDatabricksHelper):
credentials: DatabricksCredentials # type: ignore[assignment]
_credentials_provider: CredentialsProvider = None

def __init__(self, parsed_model: Dict, credentials: DatabricksCredentials) -> None:
super().__init__(
Expand All @@ -400,8 +402,11 @@ def __init__(self, parsed_model: Dict, credentials: DatabricksCredentials) -> No
http_headers: Dict[str, str] = credentials.get_all_http_headers(
connection_parameters.pop("http_headers", {})
)
self._credentials_provider = credentials.authenticate(self._credentials_provider)
header_factory = self._credentials_provider()
headers = header_factory()

self.auth_header.update({"User-Agent": user_agent, **http_headers})
self.auth_header.update({"User-Agent": user_agent, **http_headers, **headers})

@property
def cluster_id(self) -> Optional[str]: # type: ignore[override]
Expand Down
8 changes: 5 additions & 3 deletions dbt/include/databricks/profile_template.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ prompts:
hint: yourorg.databricks.com
http_path:
hint: 'HTTP Path'
token:
hint: 'dapiXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX'
hide_input: true
_choose_access_token:
'use access token':
token:
hint: 'dapiXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX'
hide_input: true
_choose_unity_catalog:
'use Unity Catalog':
catalog:
Expand Down
4 changes: 2 additions & 2 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,5 @@ tox>=3.2.0
types-requests

dbt-spark==1.4.*
dbt-core==1.4.*
dbt-tests-adapter==1.4.*
# dbt-core==1.4.*
dbt-tests-adapter>=1.4.0
26 changes: 26 additions & 0 deletions docs/oauth.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Configure OAuth for DBT Databricks

This feature is in [Public Preview](https://docs.databricks.com/release-notes/release-types.html).

Databricks DBT adapter now supports authentication via OAuth in AWS and Azure. This is a much safer method as it enables you to generate short-lived (one hour) OAuth access tokens, which eliminates the risk of accidentally exposing longer-lived tokens such as Databricks personal access tokens through version control checkins or other means. OAuth also enables better server-side session invalidation and scoping.

Once an admin correctly configured OAuth in Databricks, you can simply add the config `auth_type` and set it to `oauth`. Config `token` is no longer necessary.

For Azure, you admin needs to create a Public AD application for dbt and provide you with its client_id.

``` YAML
jaffle_shop:
outputs:
dev:
host: <databricks host name>
http_path: <http path for warehouse or cluster>
catalog: <UC catalog name>
schema: <schema name>
auth_type: oauth # new
client_id: <azure application ID> # only necessary for Azure
type: databricks
target: dev
```



4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
databricks-sql-connector>=2.5.0
dbt-spark==1.4.*
dbt-spark>=1.4.0
databricks-sdk>=0.1.1
keyring>=23.13.*
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def _get_plugin_version():
install_requires=[
"dbt-spark~={}".format(dbt_spark_version),
"databricks-sql-connector>=2.5.0",
"databricks-sdk>=0.1.1",
"keyring>=23.13.0"
],
zip_safe=False,
classifiers=[
Expand Down
Loading

0 comments on commit f1671ed

Please sign in to comment.