Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] OAuth support (#307) #327

Merged
merged 2 commits into from
May 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ test.env
.vscode
*.log
logs/
.venv
.venv*
78 changes: 78 additions & 0 deletions dbt/adapters/databricks/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from typing import Any, Dict, Optional
from databricks.sdk.oauth import ClientCredentials, Token, TokenSource
from databricks.sdk.core import CredentialsProvider, HeaderFactory, Config, credentials_provider


class token_auth(CredentialsProvider):
_token: str

def __init__(self, token: str) -> None:
self._token = token

def auth_type(self) -> str:
return "token"

def as_dict(self) -> dict:
return {"token": self._token}

@staticmethod
def from_dict(raw: Optional[dict]) -> CredentialsProvider:
if not raw:
return None
return token_auth(raw["token"])

def __call__(self, *args: tuple, **kwargs: Dict[str, Any]) -> HeaderFactory:
static_credentials = {"Authorization": f"Bearer {self._token}"}

def inner() -> Dict[str, str]:
return static_credentials

return inner


class m2m_auth(CredentialsProvider):
_token_source: TokenSource = None

def __init__(self, host: str, client_id: str, client_secret: str) -> None:
@credentials_provider("noop", [])
def noop_credentials(_: Any): # type: ignore
return lambda: {}

config = Config(host=host, credentials_provider=noop_credentials)
oidc = config.oidc_endpoints
scopes = ["offline_access", "all-apis"]
if not oidc:
raise ValueError(f"{host} does not support OAuth")
if config.is_azure:
# Azure AD only supports full access to Azure Databricks.
scopes = [f"{config.effective_azure_login_app_id}/.default", "offline_access"]
self._token_source = ClientCredentials(
client_id=client_id,
client_secret=client_secret,
token_url=oidc.token_endpoint,
scopes=scopes,
use_header="microsoft" not in oidc.token_endpoint,
use_params="microsoft" in oidc.token_endpoint,
)

def auth_type(self) -> str:
return "oauth"

def as_dict(self) -> dict:
if self._token_source:
return {"token": self._token_source.token().as_dict()}
else:
return {"token": {}}

@staticmethod
def from_dict(host: str, client_id: str, client_secret: str, raw: dict) -> CredentialsProvider:
c = m2m_auth(host=host, client_id=client_id, client_secret=client_secret)
c._token_source._token = Token.from_dict(raw["token"])
return c

def __call__(self, *args: tuple, **kwargs: Dict[str, Any]) -> HeaderFactory:
def inner() -> Dict[str, str]:
token = self._token_source.token()
return {"Authorization": f"{token.token_type} {token.access_token}"}

return inner
133 changes: 130 additions & 3 deletions dbt/adapters/databricks/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
import re
import sys
import threading
import time
from typing import (
Any,
Expand Down Expand Up @@ -49,6 +50,12 @@
from dbt.adapters.databricks.__version__ import version as __version__
from dbt.adapters.databricks.utils import redact_credentials

from databricks.sdk.core import CredentialsProvider
from databricks.sdk.oauth import OAuthClient, RefreshableCredentials
from dbt.adapters.databricks.auth import token_auth, m2m_auth

import keyring

logger = AdapterLogger("Databricks")

CATALOG_KEY_IN_SESSION_PROPERTIES = "databricks.catalog"
Expand All @@ -58,20 +65,30 @@
EXTRACT_CLUSTER_ID_FROM_HTTP_PATH_REGEX = re.compile(r"/?sql/protocolv1/o/\d+/(.*)")
DBT_DATABRICKS_HTTP_SESSION_HEADERS = "DBT_DATABRICKS_HTTP_SESSION_HEADERS"

REDIRECT_URL = "http://localhost:8020"
CLIENT_ID = "dbt-databricks"
SCOPES = ["all-apis", "offline_access"]


@dataclass
class DatabricksCredentials(Credentials):
database: Optional[str] # type: ignore[assignment]
host: Optional[str] = None
http_path: Optional[str] = None
token: Optional[str] = None
client_id: Optional[str] = None
client_secret: Optional[str] = None
session_properties: Optional[Dict[str, Any]] = None
connection_parameters: Optional[Dict[str, Any]] = None
auth_type: Optional[str] = None

connect_retries: int = 1
connect_timeout: Optional[int] = None
retry_all: bool = False

_credentials_provider: Optional[Dict[str, Any]] = None
_lock = threading.Lock() # to avoid concurrent auth

_ALIASES = {
"catalog": "database",
"target_catalog": "target_database",
Expand Down Expand Up @@ -116,6 +133,8 @@ def __post_init__(self) -> None:
"server_hostname",
"http_path",
"access_token",
"client_id",
"client_secret",
"session_configuration",
"catalog",
"schema",
Expand All @@ -138,11 +157,23 @@ def __post_init__(self) -> None:
self.connection_parameters = connection_parameters

def validate_creds(self) -> None:
for key in ["host", "http_path", "token"]:
for key in ["host", "http_path"]:
if not getattr(self, key):
raise dbt.exceptions.DbtProfileError(
"The config '{}' is required to connect to Databricks".format(key)
)
if not self.token and self.auth_type != "oauth":
raise dbt.exceptions.DbtProfileError(
("The config `auth_type: oauth` is required when not using access token")
)

if not self.client_id and self.client_secret:
raise dbt.exceptions.DbtProfileError(
(
"The config 'client_id' is required to connect "
"to Databricks when 'client_secret' is present"
)
)

@classmethod
def get_invocation_env(cls) -> Optional[str]:
Expand Down Expand Up @@ -232,6 +263,100 @@ def extract_cluster_id(cls, http_path: str) -> Optional[str]:
def cluster_id(self) -> Optional[str]:
return self.extract_cluster_id(self.http_path) # type: ignore[arg-type]

def authenticate(self, in_provider: CredentialsProvider) -> CredentialsProvider:
self.validate_creds()
host: str = self.host or ""
if self._credentials_provider:
return self._provider_from_dict()
if in_provider:
self._credentials_provider = in_provider.as_dict()
return in_provider

# dbt will spin up multiple threads. This has to be sync. So lock here
self._lock.acquire()
try:
if self.token:
provider = token_auth(self.token)
self._credentials_provider = provider.as_dict()
return provider

if self.client_id and self.client_secret:
provider = m2m_auth(
host=host,
client_id=self.client_id or "",
client_secret=self.client_secret or "",
)
self._credentials_provider = provider.as_dict()
return provider

oauth_client = OAuthClient(
host=host,
client_id=self.client_id if self.client_id else CLIENT_ID,
client_secret=None,
redirect_url=REDIRECT_URL,
scopes=SCOPES,
)
# optional branch. Try and keep going if it does not work
try:
# try to get cached credentials
credsdict = keyring.get_password("dbt-databricks", host)

if credsdict:
provider = RefreshableCredentials.from_dict(oauth_client, json.loads(credsdict))
# if refresh token is expired, this will throw
try:
if provider.token().valid:
return provider
except Exception as e:
logger.debug(e)
# whatever it is, get rid of the cache
keyring.delete_password("dbt-databricks", host)

# error with keyring. Maybe machine has no password persistency
except Exception as e:
logger.debug(e)
logger.info("could not retrieved saved token")

# no token, go fetch one
consent = oauth_client.initiate_consent()

provider = consent.launch_external_browser()
# save for later
self._credentials_provider = provider.as_dict()
try:
keyring.set_password("dbt-databricks", host, json.dumps(self._credentials_provider))
# error with keyring. Maybe machine has no password persistency
except Exception as e:
logger.debug(e)
logger.info("could not save token")

return provider

finally:
self._lock.release()

def _provider_from_dict(self) -> CredentialsProvider:
if self.token:
return token_auth.from_dict(self._credentials_provider)

if self.client_id and self.client_secret:
return m2m_auth.from_dict(
host=self.host or "",
client_id=self.client_id or "",
client_secret=self.client_secret or "",
raw=self._credentials_provider or {"token": {}},
)

oauth_client = OAuthClient(
host=self.host,
client_id=CLIENT_ID,
client_secret=None,
redirect_url=REDIRECT_URL,
scopes=SCOPES,
)

return RefreshableCredentials.from_dict(client=oauth_client, raw=self._credentials_provider)


class DatabricksSQLConnectionWrapper:
"""Wrap a Databricks SQL connector in a way that no-ops transactions"""
Expand Down Expand Up @@ -404,6 +529,7 @@ def _get_comment_macro(self) -> Optional[str]:

class DatabricksConnectionManager(SparkConnectionManager):
TYPE: str = "databricks"
credentials_provider: CredentialsProvider = None

def compare_dbr_version(self, major: int, minor: int) -> int:
version = (major, minor)
Expand Down Expand Up @@ -549,7 +675,8 @@ def open(cls, connection: Connection) -> Connection:
creds: DatabricksCredentials = connection.credentials
timeout = creds.connect_timeout

creds.validate_creds()
# gotta keep this so we don't prompt users many times
cls.credentials_provider = creds.authenticate(cls.credentials_provider)

user_agent_entry = f"dbt-databricks/{__version__}"

Expand All @@ -569,7 +696,7 @@ def connect() -> DatabricksSQLConnectionWrapper:
conn: DatabricksSQLConnection = dbsql.connect(
server_hostname=creds.host,
http_path=creds.http_path,
access_token=creds.token,
credentials_provider=cls.credentials_provider,
http_headers=http_headers if http_headers else None,
session_configuration=creds.session_properties,
catalog=creds.database,
Expand Down
7 changes: 6 additions & 1 deletion dbt/adapters/databricks/python_submissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import dbt.exceptions
from dbt.adapters.base import PythonJobHelper
from dbt.adapters.spark import __version__
from databricks.sdk.core import CredentialsProvider

logger = AdapterLogger("Databricks")

Expand Down Expand Up @@ -381,6 +382,7 @@ def submit(self, compiled_code: str) -> None:

class DbtDatabricksBasePythonJobHelper(BaseDatabricksHelper):
credentials: DatabricksCredentials # type: ignore[assignment]
_credentials_provider: CredentialsProvider = None

def __init__(self, parsed_model: Dict, credentials: DatabricksCredentials) -> None:
super().__init__(
Expand All @@ -400,8 +402,11 @@ def __init__(self, parsed_model: Dict, credentials: DatabricksCredentials) -> No
http_headers: Dict[str, str] = credentials.get_all_http_headers(
connection_parameters.pop("http_headers", {})
)
self._credentials_provider = credentials.authenticate(self._credentials_provider)
header_factory = self._credentials_provider()
headers = header_factory()

self.auth_header.update({"User-Agent": user_agent, **http_headers})
self.auth_header.update({"User-Agent": user_agent, **http_headers, **headers})

@property
def cluster_id(self) -> Optional[str]: # type: ignore[override]
Expand Down
8 changes: 5 additions & 3 deletions dbt/include/databricks/profile_template.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ prompts:
hint: yourorg.databricks.com
http_path:
hint: 'HTTP Path'
token:
hint: 'dapiXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX'
hide_input: true
_choose_access_token:
'use access token':
token:
hint: 'dapiXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX'
hide_input: true
_choose_unity_catalog:
'use Unity Catalog':
catalog:
Expand Down
4 changes: 2 additions & 2 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,5 @@ tox>=3.2.0
types-requests

dbt-spark==1.4.*
dbt-core==1.4.*
dbt-tests-adapter==1.4.*
# dbt-core==1.4.*
dbt-tests-adapter>=1.4.0
26 changes: 26 additions & 0 deletions docs/oauth.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Configure OAuth for DBT Databricks

This feature is in [Public Preview](https://docs.databricks.com/release-notes/release-types.html).

Databricks DBT adapter now supports authentication via OAuth in AWS and Azure. This is a much safer method as it enables you to generate short-lived (one hour) OAuth access tokens, which eliminates the risk of accidentally exposing longer-lived tokens such as Databricks personal access tokens through version control checkins or other means. OAuth also enables better server-side session invalidation and scoping.

Once an admin correctly configured OAuth in Databricks, you can simply add the config `auth_type` and set it to `oauth`. Config `token` is no longer necessary.

For Azure, you admin needs to create a Public AD application for dbt and provide you with its client_id.

``` YAML
jaffle_shop:
outputs:
dev:
host: <databricks host name>
http_path: <http path for warehouse or cluster>
catalog: <UC catalog name>
schema: <schema name>
auth_type: oauth # new
client_id: <azure application ID> # only necessary for Azure
type: databricks
target: dev
```



4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
databricks-sql-connector>=2.5.0
dbt-spark==1.4.*
dbt-spark>=1.4.0
databricks-sdk>=0.1.1
keyring>=23.13.*
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def _get_plugin_version():
install_requires=[
"dbt-spark~={}".format(dbt_spark_version),
"databricks-sql-connector>=2.5.0",
"databricks-sdk>=0.1.1",
"keyring>=23.13.0"
],
zip_safe=False,
classifiers=[
Expand Down
Loading