Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
benc-db committed Sep 5, 2023
1 parent f505b4f commit b9533a5
Showing 1 changed file with 28 additions and 84 deletions.
112 changes: 28 additions & 84 deletions dbt/adapters/databricks/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,7 @@ def emit(self, record: logging.LogRecord) -> None:
dbt_adapter_logger = AdapterLogger("databricks-sql-connector")

pysql_logger = logging.getLogger("databricks.sql")
pysql_logger_level = os.environ.get(
"DBT_DATABRICKS_CONNECTOR_LOG_LEVEL", "WARN"
).upper()
pysql_logger_level = os.environ.get("DBT_DATABRICKS_CONNECTOR_LOG_LEVEL", "WARN").upper()
pysql_logger.setLevel(pysql_logger_level)

pysql_handler = DbtCoreHandler(dbt_logger=dbt_adapter_logger, level=pysql_logger_level)
Expand Down Expand Up @@ -195,9 +193,7 @@ def validate_creds(self) -> None:
)
if not self.token and self.auth_type != "oauth":
raise dbt.exceptions.DbtProfileError(
(
"The config `auth_type: oauth` is required when not using access token"
)
("The config `auth_type: oauth` is required when not using access token")
)

if not self.client_id and self.client_secret:
Expand All @@ -221,9 +217,7 @@ def get_invocation_env(cls) -> Optional[str]:
return invocation_env

@classmethod
def get_all_http_headers(
cls, user_http_session_headers: Dict[str, str]
) -> Dict[str, str]:
def get_all_http_headers(cls, user_http_session_headers: Dict[str, str]) -> Dict[str, str]:
http_session_headers_str: Optional[str] = os.environ.get(
DBT_DATABRICKS_HTTP_SESSION_HEADERS
)
Expand Down Expand Up @@ -258,17 +252,13 @@ def type(self) -> str:
def unique_field(self) -> str:
return cast(str, self.host)

def connection_info(
self, *, with_aliases: bool = False
) -> Iterable[Tuple[str, Any]]:
def connection_info(self, *, with_aliases: bool = False) -> Iterable[Tuple[str, Any]]:
as_dict = self.to_dict(omit_none=False)
connection_keys = set(self._connection_keys(with_aliases=with_aliases))
aliases: List[str] = []
if with_aliases:
aliases = [k for k, v in self._ALIASES.items() if v in connection_keys]
for key in itertools.chain(
self._connection_keys(with_aliases=with_aliases), aliases
):
for key in itertools.chain(self._connection_keys(with_aliases=with_aliases), aliases):
if key in as_dict:
yield key, as_dict[key]

Expand Down Expand Up @@ -341,9 +331,7 @@ def authenticate(self, in_provider: CredentialsProvider) -> CredentialsProvider:
credsdict = keyring.get_password("dbt-databricks", host)

if credsdict:
provider = SessionCredentials.from_dict(
oauth_client, json.loads(credsdict)
)
provider = SessionCredentials.from_dict(oauth_client, json.loads(credsdict))
# if refresh token is expired, this will throw
try:
if provider.token().valid:
Expand All @@ -365,9 +353,7 @@ def authenticate(self, in_provider: CredentialsProvider) -> CredentialsProvider:
# save for later
self._credentials_provider = provider.as_dict()
try:
keyring.set_password(
"dbt-databricks", host, json.dumps(self._credentials_provider)
)
keyring.set_password("dbt-databricks", host, json.dumps(self._credentials_provider))
# error with keyring. Maybe machine has no password persistency
except Exception as e:
logger.debug(e)
Expand Down Expand Up @@ -398,9 +384,7 @@ def _provider_from_dict(self) -> CredentialsProvider:
scopes=SCOPES,
)

return SessionCredentials.from_dict(
client=oauth_client, raw=self._credentials_provider
)
return SessionCredentials.from_dict(client=oauth_client, raw=self._credentials_provider)


class DatabricksSQLConnectionWrapper:
Expand Down Expand Up @@ -487,9 +471,7 @@ class DatabricksSQLCursorWrapper:
_user_agent: str
_creds: DatabricksCredentials

def __init__(
self, cursor: DatabricksSQLCursor, creds: DatabricksCredentials, user_agent: str
):
def __init__(self, cursor: DatabricksSQLCursor, creds: DatabricksCredentials, user_agent: str):
self._cursor = cursor
self._creds = creds
self._user_agent = user_agent
Expand Down Expand Up @@ -542,19 +524,15 @@ def pollRefreshPipeline(

stopped_states = ("COMPLETED", "FAILED", "CANCELED")
host: str = self._creds.host or ""
headers = (
self._cursor.connection.thrift_backend._auth_provider._header_factory()
)
headers = self._cursor.connection.thrift_backend._auth_provider._header_factory()
headers["User-Agent"] = self._user_agent

pipeline_id = _get_table_view_pipeline_id(host, headers, model_name)
pipeline = _get_pipeline_state(host, headers, pipeline_id)
# get the most recently created update for the pipeline
latest_update = _find_update(pipeline)
if not latest_update:
raise dbt.exceptions.DbtRuntimeError(
f"No update created for pipeline: {pipeline_id}"
)
raise dbt.exceptions.DbtRuntimeError(f"No update created for pipeline: {pipeline_id}")

state = latest_update.get("state")
# we use update_id to retrieve the update in the polling loop
Expand Down Expand Up @@ -610,20 +588,14 @@ def pollRefreshPipeline(
state = None

if exceeded_timeout:
raise dbt.exceptions.DbtRuntimeError(
"timed out waiting for materialized view refresh"
)
raise dbt.exceptions.DbtRuntimeError("timed out waiting for materialized view refresh")

if state == "FAILED":
msg = _get_update_error_msg(host, headers, pipeline_id, update_id)
raise dbt.exceptions.DbtRuntimeError(
f"error refreshing model {model_name} {msg}"
)
raise dbt.exceptions.DbtRuntimeError(f"error refreshing model {model_name} {msg}")

if state == "CANCELED":
raise dbt.exceptions.DbtRuntimeError(
f"refreshing model {model_name} cancelled"
)
raise dbt.exceptions.DbtRuntimeError(f"refreshing model {model_name} cancelled")

return

Expand All @@ -642,9 +614,7 @@ def hex_query_id(self) -> str:
This UUID can be tied back to the Databricks query history API
"""

_as_hex = uuid.UUID(
bytes=self._cursor.active_result_set.command_id.operationId.guid
)
_as_hex = uuid.UUID(bytes=self._cursor.active_result_set.command_id.operationId.guid)

return str(_as_hex)

Expand Down Expand Up @@ -676,9 +646,7 @@ def description(
def schemas(self, catalog_name: str, schema_name: Optional[str] = None) -> None:
self._cursor.schemas(catalog_name=catalog_name, schema_name=schema_name)

def tables(
self, catalog_name: str, schema_name: str, table_name: Optional[str] = None
) -> None:
def tables(self, catalog_name: str, schema_name: str, table_name: Optional[str] = None) -> None:
self._cursor.tables(
catalog_name=catalog_name, schema_name=schema_name, table_name=table_name
)
Expand Down Expand Up @@ -778,9 +746,7 @@ def add_query(
connection = self.get_thread_connection()
if auto_begin and connection.transaction_open is False:
self.begin()
fire_event(
ConnectionUsed(conn_type=self.TYPE, conn_name=cast_to_str(connection.name))
)
fire_event(ConnectionUsed(conn_type=self.TYPE, conn_name=cast_to_str(connection.name)))

with self.exception_handler(sql):
cursor: Optional[DatabricksSQLCursorWrapper] = None
Expand All @@ -789,14 +755,10 @@ def add_query(
if abridge_sql_log:
log_sql = "{}...".format(log_sql[:512])

fire_event(
SQLQuery(conn_name=cast_to_str(connection.name), sql=log_sql)
)
fire_event(SQLQuery(conn_name=cast_to_str(connection.name), sql=log_sql))
pre = time.time()

cursor = cast(
DatabricksSQLConnectionWrapper, connection.handle
).cursor()
cursor = cast(DatabricksSQLConnectionWrapper, connection.handle).cursor()
cursor.execute(sql, bindings)

fire_event(
Expand Down Expand Up @@ -840,16 +802,12 @@ def _execute_cursor(
) -> Table:
connection = self.get_thread_connection()

fire_event(
ConnectionUsed(conn_type=self.TYPE, conn_name=cast_to_str(connection.name))
)
fire_event(ConnectionUsed(conn_type=self.TYPE, conn_name=cast_to_str(connection.name)))

with self.exception_handler(log_sql):
cursor: Optional[DatabricksSQLCursorWrapper] = None
try:
fire_event(
SQLQuery(conn_name=cast_to_str(connection.name), sql=log_sql)
)
fire_event(SQLQuery(conn_name=cast_to_str(connection.name), sql=log_sql))
pre = time.time()

handle: DatabricksSQLConnectionWrapper = connection.handle
Expand All @@ -874,9 +832,7 @@ def list_schemas(self, database: str, schema: Optional[str] = None) -> Table:
lambda cursor: cursor.schemas(catalog_name=database, schema_name=schema),
)

def list_tables(
self, database: str, schema: str, identifier: Optional[str] = None
) -> Table:
def list_tables(self, database: str, schema: str, identifier: Optional[str] = None) -> Table:
return self._execute_cursor(
f"GetTables(database={database}, schema={schema}, identifier={identifier})",
lambda cursor: cursor.tables(
Expand Down Expand Up @@ -905,9 +861,7 @@ def open(cls, connection: Connection) -> Connection:
connection_parameters = creds.connection_parameters.copy() # type: ignore[union-attr]

http_headers: List[Tuple[str, str]] = list(
creds.get_all_http_headers(
connection_parameters.pop("http_headers", {})
).items()
creds.get_all_http_headers(connection_parameters.pop("http_headers", {})).items()
)

def connect() -> DatabricksSQLConnectionWrapper:
Expand Down Expand Up @@ -952,9 +906,7 @@ def exponential_backoff(attempt: int) -> int:
)

@classmethod
def get_response(
cls, cursor: DatabricksSQLCursorWrapper
) -> DatabricksAdapterResponse:
def get_response(cls, cursor: DatabricksSQLCursorWrapper) -> DatabricksAdapterResponse:
_query_id = getattr(cursor, "hex_query_id", None)
if cursor is None:
logger.debug("No cursor was provided. Query ID not available.")
Expand All @@ -978,9 +930,7 @@ def _should_poll_refresh(sql: str) -> Tuple[bool, str]:
name = ""
refresh_search = re.search(r"refresh\s+materialized\s+view\s+([`\w.]+)", sql)
if not refresh_search:
refresh_search = re.search(
r"create\s+or\s+refresh\s+streaming\s+table\s+([`\w.]+)", sql
)
refresh_search = re.search(r"create\s+or\s+refresh\s+streaming\s+table\s+([`\w.]+)", sql)

if refresh_search:
name = refresh_search.group(1).replace("`", "")
Expand Down Expand Up @@ -1010,9 +960,7 @@ def _get_pipeline_state(host: str, headers: dict, pipeline_id: str) -> dict:

response = requests.get(pipeline_url, headers=headers)
if response.status_code != 200:
raise dbt.exceptions.DbtRuntimeError(
f"Error getting pipeline info: {pipeline_id}"
)
raise dbt.exceptions.DbtRuntimeError(f"Error getting pipeline info: {pipeline_id}")

return response.json()

Expand All @@ -1034,15 +982,11 @@ def _find_update(pipeline: dict, id: str = "") -> Optional[Dict]:
return None


def _get_update_error_msg(
host: str, headers: dict, pipeline_id: str, update_id: str
) -> str:
def _get_update_error_msg(host: str, headers: dict, pipeline_id: str, update_id: str) -> str:
events_url = f"https://{host}/api/2.0/pipelines/{pipeline_id}/events"
response = requests.get(events_url, headers=headers)
if response.status_code != 200:
raise dbt.exceptions.DbtRuntimeError(
f"Error getting pipeline event info: {pipeline_id}"
)
raise dbt.exceptions.DbtRuntimeError(f"Error getting pipeline event info: {pipeline_id}")

events = response.json().get("events", [])
update_events = [
Expand Down

0 comments on commit b9533a5

Please sign in to comment.