diff --git a/data_diff/databases/_connect.py b/data_diff/databases/_connect.py index df63c78..e5fb15e 100644 --- a/data_diff/databases/_connect.py +++ b/data_diff/databases/_connect.py @@ -1,3 +1,4 @@ +import json import logging from typing import Hashable, MutableMapping, Type, Optional, Union, Dict from itertools import zip_longest @@ -281,7 +282,7 @@ class Connect: def __make_cache_key(self, db_conf: Union[str, dict]) -> Hashable: if isinstance(db_conf, dict): - return tuple(db_conf.items()) + return json.dumps(db_conf) return db_conf diff --git a/data_diff/databases/duckdb.py b/data_diff/databases/duckdb.py index 9537ce5..f4a3c51 100644 --- a/data_diff/databases/duckdb.py +++ b/data_diff/databases/duckdb.py @@ -1,3 +1,4 @@ +from functools import partial from typing import Any, ClassVar, Dict, Union, Type import attrs @@ -27,6 +28,7 @@ from data_diff.databases.base import ( ThreadLocalInterpreter, TIMESTAMP_PRECISION_POS, CHECKSUM_OFFSET, + apply_query, ) from data_diff.databases.base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS from data_diff.version import __version__ @@ -137,8 +139,17 @@ class DuckDB(Database): return True def _query(self, sql_code: Union[str, ThreadLocalInterpreter]): - "Uses the standard SQL cursor interface" - return self._query_conn(self._conn, sql_code) + #"Uses the standard SQL cursor interface" + #return self._query_conn(self._conn, sql_code) + + c = self._conn.cursor() + settings = self._args.get("settings", {}) + for key, value in settings.items(): + c.execute(f"SET {key} = '{value}'") + + callback = partial(self._query_cursor, c) + return apply_query(callback, sql_code) + def close(self): super().close() @@ -157,6 +168,15 @@ class DuckDB(Database): assert custom_user_agent in custom_user_agent_filtered else: connection = ddb.connect(database=self._args["filepath"]) + # Install extensions + extensions = self._args.get("extensions", {}) + for extension in extensions: + connection.install_extension(extension) + connection.load_extension(extension) + # Apply settings + #settings = self._args.get("settings", {}) + #for key, value in settings.items(): + # connection.execute(f"SET {key} = '{value}'") return connection except ddb.OperationalError as e: raise ConnectError(*e.args) from e diff --git a/data_diff/dbt.py b/data_diff/dbt.py index 6e961a2..0206164 100644 --- a/data_diff/dbt.py +++ b/data_diff/dbt.py @@ -1,9 +1,10 @@ from contextlib import nullcontext +# import traceback import json import os import re import time -from typing import List, Optional, Dict, Tuple, Union +from typing import List, Optional, Dict, Tuple, Union, Any import keyring import pydantic import rich @@ -59,7 +60,7 @@ class TDiffVars(pydantic.BaseModel): dev_path: List[str] prod_path: List[str] primary_keys: List[str] - connection: Dict[str, Optional[str]] + connection: Dict[str, Any] threads: Optional[int] = None where_filter: Optional[str] = None include_columns: List[str] @@ -176,6 +177,9 @@ def dbt_diff( try: future.result() # if error occurred, it will be raised here except Exception as e: + # logger.error("".join(traceback.TracebackException.from_exception(e).format()) == traceback.format_exc() == "".join( + # traceback.format_exception(type(e), e, e.__traceback__))) + # logger.error("".join(traceback.TracebackException.from_exception(e).format())) logger.error(f"An error occurred during the execution of a diff task: {model.unique_id} - {e}") _extension_notification() diff --git a/data_diff/dbt_parser.py b/data_diff/dbt_parser.py index eda5f6c..d43344a 100644 --- a/data_diff/dbt_parser.py +++ b/data_diff/dbt_parser.py @@ -389,6 +389,8 @@ class DbtParser: conn_info = { "driver": conn_type, "filepath": credentials.get("path"), + "settings": credentials.get("settings"), + "extensions": credentials.get("extensions"), } elif conn_type == "redshift": if (credentials.get("pass") is None and credentials.get("password") is None) or credentials.get(