Skip to content
This repository has been archived by the owner on May 17, 2024. It is now read-only.

Make local dbt data diffs concurrent #776

Merged
merged 27 commits into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 10 additions & 5 deletions data_diff/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,7 +931,7 @@ def name(self):
def compile(self, sql_ast):
return self.dialect.compile(Compiler(self), sql_ast)

def query(self, sql_ast: Union[Expr, Generator], res_type: type = None):
def query(self, sql_ast: Union[Expr, Generator], res_type: type = None, log_message: Optional[str] = None):
"""Query the given SQL code/AST, and attempt to convert the result to type 'res_type'

If given a generator, it will execute all the yielded sql queries with the same thread and cursor.
Expand All @@ -956,7 +956,10 @@ def query(self, sql_ast: Union[Expr, Generator], res_type: type = None):
if sql_code is SKIP:
return SKIP

logger.debug("Running SQL (%s):\n%s", self.name, sql_code)
if log_message:
logger.debug("Running SQL (%s): %s \n%s", self.name, log_message, sql_code)
else:
logger.debug("Running SQL (%s):\n%s", self.name, sql_code)

if self._interactive and isinstance(sql_ast, Select):
explained_sql = self.compile(Explain(sql_ast))
Expand Down Expand Up @@ -1022,7 +1025,7 @@ def query_table_schema(self, path: DbPath) -> Dict[str, tuple]:
Note: This method exists instead of select_table_schema(), just because not all databases support
accessing the schema using a SQL query.
"""
rows = self.query(self.select_table_schema(path), list)
rows = self.query(self.select_table_schema(path), list, log_message=path)
if not rows:
raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns")

Expand All @@ -1044,7 +1047,7 @@ def query_table_unique_columns(self, path: DbPath) -> List[str]:
"""Query the table for its unique columns for table in 'path', and return {column}"""
if not self.SUPPORTS_UNIQUE_CONSTAINT:
raise NotImplementedError("This database doesn't support 'unique' constraints")
res = self.query(self.select_table_unique_columns(path), List[str])
res = self.query(self.select_table_unique_columns(path), List[str], log_message=path)
return list(res)

def _process_table_schema(
Expand Down Expand Up @@ -1086,7 +1089,9 @@ def _refine_coltypes(
fields = [Code(self.dialect.normalize_uuid(self.dialect.quote(c), String_UUID())) for c in text_columns]

samples_by_row = self.query(
table(*table_path).select(*fields).where(Code(where) if where else SKIP).limit(sample_size), list
table(*table_path).select(*fields).where(Code(where) if where else SKIP).limit(sample_size),
list,
log_message=table_path,
)
if not samples_by_row:
raise ValueError(f"Table {table_path} is empty.")
Expand Down
43 changes: 27 additions & 16 deletions data_diff/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pydantic
import rich
from rich.prompt import Prompt
from concurrent.futures import ThreadPoolExecutor, as_completed

from data_diff.errors import (
DataDiffCustomSchemaNoConfigError,
Expand Down Expand Up @@ -80,7 +81,6 @@ def dbt_diff(
production_schema_flag: Optional[str] = None,
) -> None:
print_version_info()
diff_threads = []
set_entrypoint_name(os.getenv("DATAFOLD_TRIGGERED_BY", "CLI-dbt"))
dbt_parser = DbtParser(profiles_dir_override, project_dir_override, state)
models = dbt_parser.get_models(dbt_selection)
Expand Down Expand Up @@ -112,7 +112,11 @@ def dbt_diff(
else:
dbt_parser.set_connection()

with log_status_handler.status if log_status_handler else nullcontext():
futures = {}

with log_status_handler.status if log_status_handler else nullcontext(), ThreadPoolExecutor(
max_workers=dbt_parser.threads
) as executor:
for model in models:
if log_status_handler:
log_status_handler.set_prefix(f"Diffing {model.alias} \n")
Expand Down Expand Up @@ -140,12 +144,12 @@ def dbt_diff(

if diff_vars.primary_keys:
if is_cloud:
diff_thread = run_as_daemon(
future = executor.submit(
_cloud_diff, diff_vars, config.datasource_id, api, org_meta, log_status_handler
)
diff_threads.append(diff_thread)
else:
_local_diff(diff_vars, json_output)
future = executor.submit(_local_diff, diff_vars, json_output, log_status_handler)
futures[future] = model
else:
if json_output:
print(
Expand All @@ -165,10 +169,12 @@ def dbt_diff(
+ "Skipped due to unknown primary key. Add uniqueness tests, meta, or tags.\n"
)

# wait for all threads
if diff_threads:
for thread in diff_threads:
thread.join()
for future in as_completed(futures):
dlawin marked this conversation as resolved.
Show resolved Hide resolved
model = futures[future]
sungchun12 marked this conversation as resolved.
Show resolved Hide resolved
try:
future.result() # if error occurred, it will be raised here
except Exception as e:
logger.error(f"An error occurred during the execution of a diff task: {model.unique_id} - {e}")

_extension_notification()

Expand Down Expand Up @@ -265,15 +271,17 @@ def _get_prod_path_from_manifest(model, prod_manifest) -> Union[Tuple[str, str,
return prod_database, prod_schema, prod_alias


def _local_diff(diff_vars: TDiffVars, json_output: bool = False) -> None:
def _local_diff(
diff_vars: TDiffVars, json_output: bool = False, log_status_handler: Optional[LogStatusHandler] = None
) -> None:
if log_status_handler:
log_status_handler.diff_started(diff_vars.dev_path[-1])
dlawin marked this conversation as resolved.
Show resolved Hide resolved
dev_qualified_str = ".".join(diff_vars.dev_path)
prod_qualified_str = ".".join(diff_vars.prod_path)
diff_output_str = _diff_output_base(dev_qualified_str, prod_qualified_str)

table1 = connect_to_table(
diff_vars.connection, prod_qualified_str, tuple(diff_vars.primary_keys), diff_vars.threads
)
table2 = connect_to_table(diff_vars.connection, dev_qualified_str, tuple(diff_vars.primary_keys), diff_vars.threads)
table1 = connect_to_table(diff_vars.connection, prod_qualified_str, tuple(diff_vars.primary_keys))
table2 = connect_to_table(diff_vars.connection, dev_qualified_str, tuple(diff_vars.primary_keys))

try:
table1_columns = table1.get_schema()
Expand Down Expand Up @@ -373,6 +381,9 @@ def _local_diff(diff_vars: TDiffVars, json_output: bool = False) -> None:
diff_output_str += no_differences_template()
rich.print(diff_output_str)

if log_status_handler:
log_status_handler.diff_finished(diff_vars.dev_path[-1])


def _initialize_api() -> Optional[DatafoldAPI]:
datafold_host = os.environ.get("DATAFOLD_HOST")
Expand Down Expand Up @@ -406,7 +417,7 @@ def _cloud_diff(
log_status_handler: Optional[LogStatusHandler] = None,
) -> None:
if log_status_handler:
log_status_handler.cloud_diff_started(diff_vars.dev_path[-1])
log_status_handler.diff_started(diff_vars.dev_path[-1])
diff_output_str = _diff_output_base(".".join(diff_vars.dev_path), ".".join(diff_vars.prod_path))
payload = TCloudApiDataDiff(
data_source1_id=datasource_id,
Expand Down Expand Up @@ -476,7 +487,7 @@ def _cloud_diff(
rich.print(diff_output_str)

if log_status_handler:
log_status_handler.cloud_diff_finished(diff_vars.dev_path[-1])
log_status_handler.diff_finished(diff_vars.dev_path[-1])
except BaseException as ex: # Catch KeyboardInterrupt too
error = ex
finally:
Expand Down
6 changes: 3 additions & 3 deletions data_diff/dbt_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,17 +446,17 @@ def get_pk_from_model(self, node, unique_columns: dict, pk_tag: str) -> List[str

from_meta = [name for name, params in node.columns.items() if pk_tag in params.meta] or None
if from_meta:
logger.debug("Found PKs via META: " + str(from_meta))
logger.debug(f"Found PKs via META [{node.name}]: " + str(from_meta))
return from_meta

from_tags = [name for name, params in node.columns.items() if pk_tag in params.tags] or None
if from_tags:
logger.debug("Found PKs via Tags: " + str(from_tags))
logger.debug(f"Found PKs via Tags [{node.name}]: " + str(from_tags))
return from_tags
if node.unique_id in unique_columns:
from_uniq = unique_columns.get(node.unique_id)
if from_uniq is not None:
logger.debug("Found PKs via Uniqueness tests: " + str(from_uniq))
logger.debug(f"Found PKs via Uniqueness tests [{node.name}]: {str(from_uniq)}")
return list(from_uniq)

except (KeyError, IndexError, TypeError) as e:
Expand Down
61 changes: 42 additions & 19 deletions data_diff/joindiff_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def _diff_tables_root(self, table1: TableSegment, table2: TableSegment, info_tre
yield from self._diff_segments(None, table1, table2, info_tree, None)
else:
yield from self._bisect_and_diff_tables(table1, table2, info_tree)
logger.info("Diffing complete")
logger.info(f"Diffing complete: {table1.table_path} <> {table2.table_path}")
if self.materialize_to_table:
logger.info("Materialized diff to table '%s'.", ".".join(self.materialize_to_table))

Expand Down Expand Up @@ -193,8 +193,8 @@ def _diff_segments(
partial(self._collect_stats, 1, table1, info_tree),
partial(self._collect_stats, 2, table2, info_tree),
partial(self._test_null_keys, table1, table2),
partial(self._sample_and_count_exclusive, db, diff_rows, a_cols, b_cols),
partial(self._count_diff_per_column, db, diff_rows, list(a_cols), is_diff_cols),
partial(self._sample_and_count_exclusive, db, diff_rows, a_cols, b_cols, table1, table2),
partial(self._count_diff_per_column, db, diff_rows, list(a_cols), is_diff_cols, table1, table2),
partial(
self._materialize_diff,
db,
Expand All @@ -205,8 +205,8 @@ def _diff_segments(
else None,
):
assert len(a_cols) == len(b_cols)
logger.debug("Querying for different rows")
diff = db.query(diff_rows, list)
logger.debug(f"Querying for different rows: {table1.table_path}")
diff = db.query(diff_rows, list, log_message=table1.table_path)
info_tree.info.set_diff(diff, schema=tuple(diff_rows.schema.items()))
for is_xa, is_xb, *x in diff:
if is_xa and is_xb:
Expand All @@ -227,7 +227,7 @@ def _diff_segments(
yield "+", tuple(b_row)

def _test_duplicate_keys(self, table1: TableSegment, table2: TableSegment):
logger.debug("Testing for duplicate keys")
logger.debug(f"Testing for duplicate keys: {table1.table_path} <> {table2.table_path}")

# Test duplicate keys
for ts in [table1, table2]:
Expand All @@ -240,24 +240,24 @@ def _test_duplicate_keys(self, table1: TableSegment, table2: TableSegment):

unvalidated = list(set(key_columns) - set(unique))
if unvalidated:
logger.info(f"Validating that the are no duplicate keys in columns: {unvalidated}")
logger.info(f"Validating that the are no duplicate keys in columns: {unvalidated} for {ts.table_path}")
# Validate that there are no duplicate keys
self.stats["validated_unique_keys"] = self.stats.get("validated_unique_keys", []) + [unvalidated]
q = t.select(total=Count(), total_distinct=Count(Concat(this[unvalidated]), distinct=True))
total, total_distinct = ts.database.query(q, tuple)
total, total_distinct = ts.database.query(q, tuple, log_message=ts.table_path)
if total != total_distinct:
raise ValueError("Duplicate primary keys")

def _test_null_keys(self, table1, table2):
logger.debug("Testing for null keys")
logger.debug(f"Testing for null keys: {table1.table_path} <> {table2.table_path}")

# Test null keys
for ts in [table1, table2]:
t = ts.make_select()
key_columns = ts.key_columns

q = t.select(*this[key_columns]).where(or_(this[k] == None for k in key_columns))
nulls = ts.database.query(q, list)
nulls = ts.database.query(q, list, log_message=ts.table_path)
if nulls:
if self.skip_null_keys:
logger.warning(
Expand All @@ -267,7 +267,7 @@ def _test_null_keys(self, table1, table2):
raise ValueError(f"NULL values in one or more primary keys of {ts.table_path}")

def _collect_stats(self, i, table_seg: TableSegment, info_tree: InfoTree):
logger.debug(f"Collecting stats for table #{i}")
logger.debug(f"Collecting stats for table #{i}: {table_seg.table_path}")
db = table_seg.database

# Metrics
Expand All @@ -288,7 +288,7 @@ def _collect_stats(self, i, table_seg: TableSegment, info_tree: InfoTree):
)
col_exprs["count"] = Count()

res = db.query(table_seg.make_select().select(**col_exprs), tuple)
res = db.query(table_seg.make_select().select(**col_exprs), tuple, log_message=table_seg.table_path)

for col_name, value in safezip(col_exprs, res):
if value is not None:
Expand All @@ -303,7 +303,7 @@ def _collect_stats(self, i, table_seg: TableSegment, info_tree: InfoTree):
else:
self.stats[stat_name] = value

logger.debug("Done collecting stats for table #%s", i)
logger.debug("Done collecting stats for table #%s: %s", i, table_seg.table_path)

def _create_outer_join(self, table1, table2):
db = table1.database
Expand Down Expand Up @@ -334,23 +334,46 @@ def _create_outer_join(self, table1, table2):
diff_rows = all_rows.where(or_(this[c] == 1 for c in is_diff_cols))
return diff_rows, a_cols, b_cols, is_diff_cols, all_rows

def _count_diff_per_column(self, db, diff_rows, cols, is_diff_cols):
logger.debug("Counting differences per column")
is_diff_cols_counts = db.query(diff_rows.select(sum_(this[c]) for c in is_diff_cols), tuple)
def _count_diff_per_column(
sungchun12 marked this conversation as resolved.
Show resolved Hide resolved
self,
db,
diff_rows,
cols,
is_diff_cols,
table1: Optional[TableSegment] = None,
table2: Optional[TableSegment] = None,
):
logger.info(type(table1))
logger.debug(f"Counting differences per column: {table1.table_path} <> {table2.table_path}")
is_diff_cols_counts = db.query(
diff_rows.select(sum_(this[c]) for c in is_diff_cols),
tuple,
log_message=f"{table1.table_path} <> {table2.table_path}",
)
diff_counts = {}
for name, count in safezip(cols, is_diff_cols_counts):
diff_counts[name] = diff_counts.get(name, 0) + (count or 0)
self.stats["diff_counts"] = diff_counts

def _sample_and_count_exclusive(self, db, diff_rows, a_cols, b_cols):
def _sample_and_count_exclusive(
sungchun12 marked this conversation as resolved.
Show resolved Hide resolved
self,
db,
diff_rows,
a_cols,
b_cols,
table1: Optional[TableSegment] = None,
table2: Optional[TableSegment] = None,
):
if isinstance(db, (Oracle, MsSQL)):
exclusive_rows_query = diff_rows.where((this.is_exclusive_a == 1) | (this.is_exclusive_b == 1))
else:
exclusive_rows_query = diff_rows.where(this.is_exclusive_a | this.is_exclusive_b)

if not self.sample_exclusive_rows:
logger.debug("Counting exclusive rows")
self.stats["exclusive_count"] = db.query(exclusive_rows_query.count(), int)
logger.debug(f"Counting exclusive rows: {table1.table_path} <> {table2.table_path}")
self.stats["exclusive_count"] = db.query(
exclusive_rows_query.count(), int, log_message=f"{table1.table_path} <> {table2.table_path}"
)
return

logger.info("Counting and sampling exclusive rows")
Expand Down
28 changes: 14 additions & 14 deletions data_diff/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,31 +485,31 @@ def __init__(self):
super().__init__()
self.status = Status("")
self.prefix = ""
self.cloud_diff_status = {}
self.diff_status = {}

def emit(self, record):
log_entry = self.format(record)
if self.cloud_diff_status:
self._update_cloud_status(log_entry)
if self.diff_status:
self._update_diff_status(log_entry)
else:
self.status.update(self.prefix + log_entry)

def set_prefix(self, prefix_string):
self.prefix = prefix_string

def cloud_diff_started(self, model_name):
self.cloud_diff_status[model_name] = "[yellow]In Progress[/]"
self._update_cloud_status()
def diff_started(self, model_name):
self.diff_status[model_name] = "[yellow]In Progress[/]"
self._update_diff_status()

def cloud_diff_finished(self, model_name):
self.cloud_diff_status[model_name] = "[green]Finished [/]"
self._update_cloud_status()
def diff_finished(self, model_name):
self.diff_status[model_name] = "[green]Finished [/]"
self._update_diff_status()

def _update_cloud_status(self, log=None):
cloud_status_string = "\n"
for model_name, status in self.cloud_diff_status.items():
cloud_status_string += f"{status} {model_name}\n"
self.status.update(f"{cloud_status_string}{log or ''}")
def _update_diff_status(self, log=None):
status_string = "\n"
for model_name, status in self.diff_status.items():
status_string += f"{status} {model_name}\n"
self.status.update(f"{status_string}{log or ''}")


class UnknownMeta(type):
Expand Down