Skip to content

Commit

Permalink
Correct tenant eval within async logic of DbtCloudHook
Browse files Browse the repository at this point in the history
Related: apache#28890 apache#29014

There was a recent enhancement of DbtCloudRunJobOperator to include deferrable/async functionality. Unfortunately the `tenant` evaluation in the DbtCloudHook was outdated and didn't include the most recent change to properly handle domain specification.

This PR consolidates the tenant eval logic to a common method to be used by both sync and async methods in the hook.
  • Loading branch information
josh-fell committed Feb 2, 2023
1 parent 6ec97dc commit b25c544
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 20 deletions.
41 changes: 21 additions & 20 deletions airflow/providers/dbt/cloud/hooks/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,33 +181,46 @@ def __init__(self, dbt_cloud_conn_id: str = default_conn_name, *args, **kwargs)
super().__init__(auth_type=TokenAuth)
self.dbt_cloud_conn_id = dbt_cloud_conn_id

@staticmethod
def _get_tenant_domain(conn: Connection) -> str:
if conn.schema:
warnings.warn(
"The `schema` parameter is deprecated and use within a dbt Cloud connection will be removed "
"in a future version. Please use `host` instead and specify the entire tenant domain name.",
category=DeprecationWarning,
stacklevel=2,
)
# Prior to deprecation, the connection.schema value could _only_ modify the third-level
# domain value while '.getdbt.com' was always used as the remainder of the domain name.
tenant = f"{conn.schema}.getdbt.com"
else:
tenant = conn.host or "cloud.getdbt.com"

return tenant

@staticmethod
def get_request_url_params(
tenant: str, endpoint: str, include_related: list[str] | None = None
) -> tuple[str, dict[str, Any]]:
"""
Form URL from base url and endpoint url
:param tenant: The tenant name which is need to be replaced in base url.
:param tenant: The tenant domain name which is need to be replaced in base url.
:param endpoint: Endpoint url to be requested.
:param include_related: Optional. List of related fields to pull with the run.
Valid values are "trigger", "job", "repository", and "environment".
"""
data: dict[str, Any] = {}
base_url = f"https://{tenant}.getdbt.com/api/v2/accounts/"
if include_related:
data = {"include_related": include_related}
if base_url and not base_url.endswith("/") and endpoint and not endpoint.startswith("/"):
url = base_url + "/" + endpoint
else:
url = (base_url or "") + (endpoint or "")
url = f"https://{tenant}/api/v2/accounts/{endpoint or ''}"
return url, data

async def get_headers_tenants_from_connection(self) -> tuple[dict[str, Any], str]:
"""Get Headers, tenants from the connection details"""
headers: dict[str, Any] = {}
connection: Connection = await sync_to_async(self.get_connection)(self.dbt_cloud_conn_id)
tenant: str = connection.schema if connection.schema else "cloud"
tenant = self._get_tenant_domain(conn=connection)
package_name, provider_version = _get_provider_info()
headers["User-Agent"] = f"{package_name}-v{provider_version}"
headers["Content-Type"] = "application/json"
Expand Down Expand Up @@ -267,19 +280,7 @@ def connection(self) -> Connection:
return _connection

def get_conn(self, *args, **kwargs) -> Session:
if self.connection.schema:
warnings.warn(
"The `schema` parameter is deprecated and use within a dbt Cloud connection will be removed "
"in a future version. Please use `host` instead and specify the entire tenant domain name.",
category=DeprecationWarning,
stacklevel=2,
)
# Prior to deprecation, the connection.schema value could _only_ modify the third-level
# domain value while '.getdbt.com' was always used as the remainder of the domain name.
tenant = f"{self.connection.schema}.getdbt.com"
else:
tenant = self.connection.host or "cloud.getdbt.com"

tenant = self._get_tenant_domain(conn=self.connection)
self.base_url = f"https://{tenant}/api/v2/accounts/"

session = Session()
Expand Down
15 changes: 15 additions & 0 deletions tests/providers/dbt/cloud/hooks/test_dbt_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,21 @@ def test_tenant_base_url(self, conn_id, url):
hook.get_conn()
assert hook.base_url == url

@patch("airflow.hooks.base.BaseHook.get_connection")
def test_tenant_via_conn_schema_base_url(self, mock_conn):
"""TODO: This test can be removed once using `Connection.schema` is removed to set the tenant."""
mock_conn.return_value = Connection(
conn_id="single_tenant_conn_with_schema",
conn_type=DbtCloudHook.conn_type,
login=DEFAULT_ACCOUNT_ID,
password=TOKEN,
schema="single_tenant_domain",
)
hook = DbtCloudHook()
with pytest.warns(DeprecationWarning): # DeprecationWarning for using `Connection.schema`.
hook.get_conn()
assert hook.base_url == "https://single_tenant_domain.getdbt.com/api/v2/accounts/"

@pytest.mark.parametrize(
argnames="conn_id, account_id",
argvalues=[(ACCOUNT_ID_CONN, None), (NO_ACCOUNT_ID_CONN, ACCOUNT_ID)],
Expand Down

0 comments on commit b25c544

Please sign in to comment.