Skip to content

Commit

Permalink
[PECO-1386] Python submissions: ignore .netrc when making REST API ca…
Browse files Browse the repository at this point in the history
…lls (#338)

Signed-off-by: Jesse Whitehouse <jesse.whitehouse@databricks.com>
  • Loading branch information
susodapop committed Jan 18, 2024
1 parent d03a733 commit b2a77d9
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 24 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
- Allow schema to be specified in testing (thanks @case-k-git!) ([538](https://github.com/databricks/dbt-databricks/pull/538))
- Fix dbt incremental_strategy behavior by fixing schema table existing check (thanks @case-k-git!) ([530](https://github.com/databricks/dbt-databricks/pull/530))
- Fixed bug that was causing streaming tables to be dropped and recreated instead of refreshed. ([552](https://github.com/databricks/dbt-databricks/pull/552))
- Fix: Python models authentication could be overridden by a `.netrc` file in the user's home directory ([338](https://github.com/databricks/dbt-databricks/pull/338))

### Under the Hood

Expand Down
20 changes: 20 additions & 0 deletions dbt/adapters/databricks/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@
from databricks.sdk.oauth import OAuthClient, SessionCredentials
from dbt.adapters.databricks.auth import token_auth, m2m_auth

from requests.auth import AuthBase
from requests import PreparedRequest
from databricks.sdk.core import HeaderFactory

import keyring

logger = AdapterLogger("Databricks")
Expand Down Expand Up @@ -119,6 +123,22 @@ def emit(self, record: logging.LogRecord) -> None:
DEFAULT_MAX_IDLE_TIME = 600


class BearerAuth(AuthBase):
"""See issue #337.
We use this mix-in to stop requests from implicitly reading .netrc
Solution taken from SO post in issue description.
"""

def __init__(self, headers: HeaderFactory):
self.headers = headers()

def __call__(self, r: PreparedRequest) -> PreparedRequest:
r.headers.update(**self.headers)
return r


@dataclass
class DatabricksCredentials(Credentials):
database: Optional[str] = None # type: ignore[assignment]
Expand Down
58 changes: 34 additions & 24 deletions dbt/adapters/databricks/python_submissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,17 @@
from dbt.events import AdapterLogger
import dbt.exceptions
from dbt.adapters.base import PythonJobHelper
from dbt.adapters.spark import __version__

from databricks.sdk.core import CredentialsProvider
from requests.adapters import HTTPAdapter
from dbt.adapters.databricks.connections import BearerAuth


logger = AdapterLogger("Databricks")

DEFAULT_POLLING_INTERVAL = 10
SUBMISSION_LANGUAGE = "python"
DEFAULT_TIMEOUT = 60 * 60 * 24
DBT_SPARK_VERSION = __version__.version


class BaseDatabricksHelper(PythonJobHelper):
Expand All @@ -43,9 +44,8 @@ def __init__(self, parsed_model: Dict, credentials: DatabricksCredentials) -> No
self.session.mount("https://", adapter)

self.check_credentials()
self.auth_header = {
"Authorization": f"Bearer {self.credentials.token}",
"User-Agent": f"dbt-labs-dbt-spark/{DBT_SPARK_VERSION} (Databricks)",
self.extra_headers = {
"User-Agent": f"dbt-databricks/{version}",
}

@property
Expand All @@ -66,7 +66,7 @@ def check_credentials(self) -> None:
def _create_work_dir(self, path: str) -> None:
response = self.session.post(
f"https://{self.credentials.host}/api/2.0/workspace/mkdirs",
headers=self.auth_header,
headers=self.extra_headers,
json={
"path": path,
},
Expand All @@ -86,7 +86,7 @@ def _upload_notebook(self, path: str, compiled_code: str) -> None:
b64_encoded_content = base64.b64encode(compiled_code.encode()).decode()
response = self.session.post(
f"https://{self.credentials.host}/api/2.0/workspace/import",
headers=self.auth_header,
headers=self.extra_headers,
json={
"path": path,
"content": b64_encoded_content,
Expand Down Expand Up @@ -131,7 +131,7 @@ def _submit_job(self, path: str, cluster_spec: dict) -> str:
job_spec.update({"libraries": libraries}) # type: ignore
submit_response = self.session.post(
f"https://{self.credentials.host}/api/2.1/jobs/runs/submit",
headers=self.auth_header,
headers=self.extra_headers,
json=job_spec,
)
if submit_response.status_code != 200:
Expand All @@ -157,7 +157,7 @@ def _submit_through_notebook(self, compiled_code: str, cluster_spec: dict) -> No
status_func=self.session.get,
status_func_kwargs={
"url": f"https://{self.credentials.host}/api/2.1/jobs/runs/get?run_id={run_id}",
"headers": self.auth_header,
"headers": self.extra_headers,
},
get_state_func=lambda response: response.json()["state"]["life_cycle_state"],
terminal_states=("TERMINATED", "SKIPPED", "INTERNAL_ERROR"),
Expand All @@ -168,7 +168,7 @@ def _submit_through_notebook(self, compiled_code: str, cluster_spec: dict) -> No
# get end state to return to user
run_output = self.session.get(
f"https://{self.credentials.host}" f"/api/2.1/jobs/runs/get-output?run_id={run_id}",
headers=self.auth_header,
headers=self.extra_headers,
)
json_run_output = run_output.json()
result_state = json_run_output["metadata"]["state"]["result_state"]
Expand Down Expand Up @@ -231,10 +231,10 @@ def __init__(
self,
credentials: DatabricksCredentials,
cluster_id: str,
auth_header: dict,
extra_headers: dict,
session: Session,
) -> None:
self.auth_header = auth_header
self.extra_headers = extra_headers
self.cluster_id = cluster_id
self.host = credentials.host
self.session = session
Expand All @@ -253,7 +253,7 @@ def create(self) -> str:

response = self.session.post(
f"https://{self.host}/api/1.2/contexts/create",
headers=self.auth_header,
headers=self.extra_headers,
json={
"clusterId": self.cluster_id,
"language": SUBMISSION_LANGUAGE,
Expand All @@ -269,7 +269,7 @@ def destroy(self, context_id: str) -> str:
# https://docs.databricks.com/dev-tools/api/1.2/index.html#delete-an-execution-context
response = self.session.post(
f"https://{self.host}/api/1.2/contexts/destroy",
headers=self.auth_header,
headers=self.extra_headers,
json={
"clusterId": self.cluster_id,
"contextId": context_id,
Expand All @@ -286,7 +286,7 @@ def get_cluster_status(self) -> Dict:

response = self.session.get(
f"https://{self.host}/api/2.0/clusters/get",
headers=self.auth_header,
headers=self.extra_headers,
json={"cluster_id": self.cluster_id},
)
if response.status_code != 200:
Expand All @@ -309,7 +309,7 @@ def start_cluster(self) -> None:

response = self.session.post(
f"https://{self.host}/api/2.0/clusters/start",
headers=self.auth_header,
headers=self.extra_headers,
json={"cluster_id": self.cluster_id},
)
if response.status_code != 200:
Expand Down Expand Up @@ -346,10 +346,10 @@ def __init__(
self,
credentials: DatabricksCredentials,
cluster_id: str,
auth_header: dict,
extra_headers: dict,
session: Session,
) -> None:
self.auth_header = auth_header
self.extra_headers = extra_headers
self.cluster_id = cluster_id
self.host = credentials.host
self.session = session
Expand All @@ -358,7 +358,7 @@ def execute(self, context_id: str, command: str) -> str:
# https://docs.databricks.com/dev-tools/api/1.2/index.html#run-a-command
response = self.session.post(
f"https://{self.host}/api/1.2/commands/execute",
headers=self.auth_header,
headers=self.extra_headers,
json={
"clusterId": self.cluster_id,
"contextId": context_id,
Expand All @@ -377,7 +377,7 @@ def status(self, context_id: str, command_id: str) -> Dict[str, Any]:
# https://docs.databricks.com/dev-tools/api/1.2/index.html#get-information-about-a-command
response = self.session.get(
f"https://{self.host}/api/1.2/commands/status",
headers=self.auth_header,
headers=self.extra_headers,
params={
"clusterId": self.cluster_id,
"contextId": context_id,
Expand All @@ -404,8 +404,18 @@ def submit(self, compiled_code: str) -> None:
config = {"existing_cluster_id": self.cluster_id}
self._submit_through_notebook(compiled_code, self._update_with_acls(config))
else:
context = DBContext(self.credentials, self.cluster_id, self.auth_header, self.session)
command = DBCommand(self.credentials, self.cluster_id, self.auth_header, self.session)
context = DBContext(
self.credentials,
self.cluster_id,
self.extra_headers,
self.session,
)
command = DBCommand(
self.credentials,
self.cluster_id,
self.extra_headers,
self.session,
)
context_id = context.create()
try:
command_id = command.execute(context_id, compiled_code)
Expand Down Expand Up @@ -454,9 +464,9 @@ def __init__(self, parsed_model: Dict, credentials: DatabricksCredentials) -> No
)
self._credentials_provider = credentials.authenticate(self._credentials_provider)
header_factory = self._credentials_provider()
headers = header_factory()
self.session.auth = BearerAuth(header_factory)

self.auth_header.update({"User-Agent": user_agent, **http_headers, **headers})
self.extra_headers.update({"User-Agent": user_agent, **http_headers})

@property
def cluster_id(self) -> Optional[str]: # type: ignore[override]
Expand Down

0 comments on commit b2a77d9

Please sign in to comment.