Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Json matching & tests for sqeleton PR #15 #383

Merged
merged 4 commits into from
May 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions data_diff/hashdiff_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

from runtype import dataclass

from data_diff.sqeleton.abcs import ColType_UUID, NumericType, PrecisionType, StringType, Boolean
from data_diff.sqeleton.abcs import ColType_UUID, NumericType, PrecisionType, StringType, Boolean, JSONType

from .info_tree import InfoTree
from .utils import safezip
from .utils import safezip, diffs_are_equiv_jsons
from .thread_utils import ThreadedYielder
from .table_segment import TableSegment

Expand All @@ -24,7 +24,7 @@
logger = logging.getLogger("hashdiff_tables")


def diff_sets(a: set, b: set) -> Iterator:
def diff_sets(a: list, b: list, json_cols: dict = None) -> Iterator:
sa = set(a)
sb = set(b)

Expand All @@ -38,7 +38,17 @@ def diff_sets(a: set, b: set) -> Iterator:
if row not in sa:
d[row[0]].append(("+", row))

warned_diff_cols = set()
for _k, v in sorted(d.items(), key=lambda i: i[0]):
if json_cols:
parsed_match, overriden_diff_cols = diffs_are_equiv_jsons(v, json_cols)
if parsed_match:
to_warn = overriden_diff_cols - warned_diff_cols
for w in to_warn:
logger.warning(f"Equivalent JSON objects with different string representations detected "
f"in column '{w}'. These cases are NOT reported as differences.")
warned_diff_cols.add(w)
continue
yield from v


Expand Down Expand Up @@ -194,7 +204,9 @@ def _bisect_and_diff_segments(
# This saves time, as bisection speed is limited by ping and query performance.
if max_rows < self.bisection_threshold or max_space_size < self.bisection_factor * 2:
rows1, rows2 = self._threaded_call("get_values", [table1, table2])
diff = list(diff_sets(rows1, rows2))
json_cols = {i: colname for i, colname in enumerate(table1.extra_columns)
if isinstance(table1._schema[colname], JSONType)}
diff = list(diff_sets(rows1, rows2, json_cols))

info_tree.info.set_diff(diff)
info_tree.info.rowcounts = {1: len(rows1), 2: len(rows2)}
Expand Down
25 changes: 25 additions & 0 deletions data_diff/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import logging
import re
from typing import Dict, Iterable, Sequence
Expand Down Expand Up @@ -144,3 +145,27 @@ def dbt_diff_string_template(
string_output += f"\n{k}: {v}"

return string_output


def _jsons_equiv(a: str, b: str):
try:
return json.loads(a) == json.loads(b)
except (ValueError, TypeError, json.decoder.JSONDecodeError): # not valid jsons
return False


def diffs_are_equiv_jsons(diff: list, json_cols: dict):
if (len(diff) != 2) or ({diff[0][0], diff[1][0]} != {'+', '-'}):
return False
match = True
overriden_diff_cols = set()
for i, (col_a, col_b) in enumerate(safezip(diff[0][1][1:], diff[1][1][1:])): # index 0 is extra_columns first elem
# we only attempt to parse columns of JSONType, but we still need to check if non-json columns don't match
match = col_a == col_b
if not match and (i in json_cols):
if _jsons_equiv(col_a, col_b):
overriden_diff_cols.add(json_cols[i])
match = True
if not match:
break
return match, overriden_diff_cols
44 changes: 35 additions & 9 deletions tests/test_database_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ def init_conns():
"boolean": [
"boolean",
],
"json": [
"json",
"jsonb"
]
},
db.MySQL: {
# https://dev.mysql.com/doc/refman/8.0/en/integer-types.html
Expand Down Expand Up @@ -199,6 +203,9 @@ def init_conns():
"boolean": [
"boolean",
],
"json": [
"super",
]
},
db.Oracle: {
"int": [
Expand Down Expand Up @@ -469,12 +476,28 @@ def __iter__(self):
return (uuid.uuid1(i) for i in range(self.max))


class JsonFaker:
MANUAL_FAKES = [
'{"keyText": "text", "keyInt": 3, "keyFloat": 5.4445, "keyBoolean": true}',
]

def __init__(self, max):
self.max = max

def __iter__(self):
return iter(self.MANUAL_FAKES[: self.max])

def __len__(self):
return min(self.max, len(self.MANUAL_FAKES))


TYPE_SAMPLES = {
"int": IntFaker(N_SAMPLES),
"datetime": DateTimeFaker(N_SAMPLES),
"float": FloatFaker(N_SAMPLES),
"uuid": UUID_Faker(N_SAMPLES),
"boolean": BooleanFaker(N_SAMPLES),
"json": JsonFaker(N_SAMPLES)
}


Expand Down Expand Up @@ -546,7 +569,7 @@ def expand_params(testcase_func, param_num, param):
return name


def _insert_to_table(conn, table_path, values, type):
def _insert_to_table(conn, table_path, values, coltype):
tbl = table(table_path)

current_n_rows = conn.query(tbl.count(), int)
Expand All @@ -555,31 +578,34 @@ def _insert_to_table(conn, table_path, values, type):
return
elif current_n_rows > 0:
conn.query(drop_table(table_name))
_create_table_with_indexes(conn, table_path, type)
_create_table_with_indexes(conn, table_path, coltype)

# if BENCHMARK and N_SAMPLES > 10_000:
# description = f"{conn.name}: {table}"
# values = rich.progress.track(values, total=N_SAMPLES, description=description)

if type == "boolean":
if coltype == "boolean":
values = [(i, bool(sample)) for i, sample in values]
elif re.search(r"(time zone|tz)", type):
elif re.search(r"(time zone|tz)", coltype):
values = [(i, sample.replace(tzinfo=timezone.utc)) for i, sample in values]

if isinstance(conn, db.Clickhouse):
if type.startswith("DateTime64"):
if coltype.startswith("DateTime64"):
values = [(i, f"{sample.replace(tzinfo=None)}") for i, sample in values]

elif type == "DateTime":
elif coltype == "DateTime":
# Clickhouse's DateTime does not allow to store micro/milli/nano seconds
values = [(i, str(sample)[:19]) for i, sample in values]

elif type.startswith("Decimal("):
precision = int(type[8:].rstrip(")").split(",")[1])
elif coltype.startswith("Decimal("):
precision = int(coltype[8:].rstrip(")").split(",")[1])
values = [(i, round(sample, precision)) for i, sample in values]
elif isinstance(conn, db.BigQuery) and type == "datetime":
elif isinstance(conn, db.BigQuery) and coltype == "datetime":
values = [(i, Code(f"cast(timestamp '{sample}' as datetime)")) for i, sample in values]

if isinstance(conn, db.Redshift) and coltype == "json":
values = [(i, Code(f"JSON_PARSE('{sample}')")) for i, sample in values]

insert_rows_in_batches(conn, tbl, values, columns=["id", "col"])
conn.query(commit)

Expand Down