From 7b95cfc118c1b31d577f47f5c5a7bc84e7366249 Mon Sep 17 00:00:00 2001 From: Sergey Vasilyev Date: Wed, 27 Dec 2023 12:44:48 +0100 Subject: [PATCH 1/7] Cease detecting MD5 hashes as UUIDs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit It fails the comparison anyway — because of casing & dashes not fitting into alphanumeric ranges/slices. --- data_diff/utils.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/data_diff/utils.py b/data_diff/utils.py index b9045cc1..a93bdfe6 100644 --- a/data_diff/utils.py +++ b/data_diff/utils.py @@ -43,7 +43,14 @@ def safezip(*args): return zip(*args) -def is_uuid(u): +UUID_PATTERN = re.compile(r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}", re.I) + + +def is_uuid(u: str) -> bool: + # E.g., hashlib.md5(b'hello') is a 32-letter hex number, but not an UUID. + # It would fail UUID-like comparison (< & >) because of casing and dashes. + if not UUID_PATTERN.fullmatch(u): + return False try: UUID(u) except ValueError: From 7802b05216cc821106fedba08f2e7fc6ff5fbcbc Mon Sep 17 00:00:00 2001 From: Sergey Vasilyev Date: Wed, 27 Dec 2023 12:47:06 +0100 Subject: [PATCH 2/7] Refactor UUID & ArithUUID from inheritance into composition --- data_diff/databases/base.py | 9 ++++-- data_diff/utils.py | 63 ++++++++++++++++++++++++++++++++----- 2 files changed, 62 insertions(+), 10 deletions(-) diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index 059854a5..5c43a664 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -19,7 +19,7 @@ from data_diff.abcs.compiler import AbstractCompiler, Compilable from data_diff.queries.extras import ApplyFuncAndNormalizeAsString, Checksum, NormalizeAsString -from data_diff.utils import ArithString, is_uuid, join_iter, safezip +from data_diff.utils import ArithString, ArithUUID, is_uuid, join_iter, safezip from data_diff.queries.api import Expr, table, Select, SKIP, Explain, Code, this from data_diff.queries.ast_classes import ( Alias, @@ -247,6 +247,9 @@ def _compile(self, compiler: Compiler, elem) -> str: return self.timestamp_value(elem) elif isinstance(elem, bytes): return f"b'{elem.decode()}'" + elif isinstance(elem, ArithUUID): + s = f"'{elem.uuid}'" + return s elif isinstance(elem, ArithString): return f"'{elem}'" assert False, elem @@ -680,8 +683,10 @@ def _constant_value(self, v): return f"'{v}'" elif isinstance(v, datetime): return self.timestamp_value(v) - elif isinstance(v, UUID): + elif isinstance(v, UUID): # probably unused anymore in favour of ArithUUID return f"'{v}'" + elif isinstance(v, ArithUUID): + return f"'{v.uuid}'" elif isinstance(v, decimal.Decimal): return str(v) elif isinstance(v, bytearray): diff --git a/data_diff/utils.py b/data_diff/utils.py index a93bdfe6..e52702db 100644 --- a/data_diff/utils.py +++ b/data_diff/utils.py @@ -135,23 +135,70 @@ def range(self, other: "ArithString", count: int) -> List[Self]: return [self.new(int=i) for i in checkpoints] -# @attrs.define # not as long as it inherits from UUID -class ArithUUID(UUID, ArithString): +def _any_to_uuid(v: Union[str, int, UUID]) -> UUID: + if isinstance(v, UUID): + return v + elif isinstance(v, str): + return UUID(v) + elif isinstance(v, int): + return UUID(int=v) + else: + raise ValueError(f"Cannot convert a value to UUID: {v!r}") + + +@attrs.define(frozen=True, eq=False, order=False) +class ArithUUID(ArithString): "A UUID that supports basic arithmetic (add, sub)" + uuid: UUID = attrs.field(converter=_any_to_uuid) + def range(self, other: "ArithUUID", count: int) -> List[Self]: + assert isinstance(other, ArithUUID) + checkpoints = split_space(self.uuid.int, other.uuid.int, count) + return [attrs.evolve(self, uuid=i) for i in checkpoints] + def __int__(self): - return self.int + return self.uuid.int def __add__(self, other: int) -> Self: if isinstance(other, int): - return self.new(int=self.int + other) + return attrs.evolve(self, uuid=self.uuid.int + other) return NotImplemented - def __sub__(self, other: Union[UUID, int]): + def __sub__(self, other: Union["ArithUUID", int]): if isinstance(other, int): - return self.new(int=self.int - other) - elif isinstance(other, UUID): - return self.int - other.int + return attrs.evolve(self, uuid=self.uuid.int - other) + elif isinstance(other, ArithUUID): + return self.uuid.int - other.uuid.int + return NotImplemented + + def __eq__(self, other: object) -> bool: + if isinstance(other, ArithUUID): + return self.uuid == other.uuid + return NotImplemented + + def __ne__(self, other: object) -> bool: + if isinstance(other, ArithUUID): + return self.uuid != other.uuid + return NotImplemented + + def __gt__(self, other: object) -> bool: + if isinstance(other, ArithUUID): + return self.uuid > other.uuid + return NotImplemented + + def __lt__(self, other: object) -> bool: + if isinstance(other, ArithUUID): + return self.uuid < other.uuid + return NotImplemented + + def __ge__(self, other: object) -> bool: + if isinstance(other, ArithUUID): + return self.uuid >= other.uuid + return NotImplemented + + def __le__(self, other: object) -> bool: + if isinstance(other, ArithUUID): + return self.uuid <= other.uuid return NotImplemented From 9ff05b56699ac57d140bfe728be4b7d9f07d5355 Mon Sep 17 00:00:00 2001 From: Sergey Vasilyev Date: Wed, 27 Dec 2023 12:47:50 +0100 Subject: [PATCH 3/7] Preserve lower-/upper-case mode of UUIDs and render them back accordingly --- data_diff/abcs/database_types.py | 9 ++++++++- data_diff/databases/base.py | 7 +++++-- data_diff/utils.py | 3 +++ 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/data_diff/abcs/database_types.py b/data_diff/abcs/database_types.py index f3c6381a..894a017c 100644 --- a/data_diff/abcs/database_types.py +++ b/data_diff/abcs/database_types.py @@ -131,7 +131,14 @@ class Native_UUID(ColType_UUID): @attrs.define(frozen=True) class String_UUID(ColType_UUID, StringType): - pass + # Case is important for UUIDs stored as regular string, not native UUIDs stored as numbers. + # We slice them internally as numbers, but render them back to SQL as lower/upper case. + # None means we do not know for sure, behave as with False, but it might be unreliable. + lowercase: Optional[bool] = None + uppercase: Optional[bool] = None + + def make_value(self, v: str) -> ArithUUID: + return self.python_type(v, lowercase=self.lowercase, uppercase=self.uppercase) @attrs.define(frozen=True) diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index 5c43a664..1ebdc395 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -249,7 +249,7 @@ def _compile(self, compiler: Compiler, elem) -> str: return f"b'{elem.decode()}'" elif isinstance(elem, ArithUUID): s = f"'{elem.uuid}'" - return s + return s.upper() if elem.uppercase else s.lower() if elem.lowercase else s elif isinstance(elem, ArithString): return f"'{elem}'" assert False, elem @@ -1109,7 +1109,10 @@ def _refine_coltypes( ) else: assert col_name in col_dict - col_dict[col_name] = String_UUID() + col_dict[col_name] = String_UUID( + lowercase=all(s == s.lower() for s in uuid_samples), + uppercase=all(s == s.upper() for s in uuid_samples), + ) continue if self.SUPPORTS_ALPHANUMS: # Anything but MySQL (so far) diff --git a/data_diff/utils.py b/data_diff/utils.py index e52702db..6a538813 100644 --- a/data_diff/utils.py +++ b/data_diff/utils.py @@ -151,6 +151,9 @@ class ArithUUID(ArithString): "A UUID that supports basic arithmetic (add, sub)" uuid: UUID = attrs.field(converter=_any_to_uuid) + lowercase: Optional[bool] = None + uppercase: Optional[bool] = None + def range(self, other: "ArithUUID", count: int) -> List[Self]: assert isinstance(other, ArithUUID) checkpoints = split_space(self.uuid.int, other.uuid.int, count) From b9a858c654c5d1031cf61f98259fbc1e093f7fbc Mon Sep 17 00:00:00 2001 From: Sergey Vasilyev Date: Mon, 18 Dec 2023 15:39:36 +0100 Subject: [PATCH 4/7] Reformat the code as Black insists --- data_diff/databases/redshift.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/data_diff/databases/redshift.py b/data_diff/databases/redshift.py index 44e86b17..c9d54a5f 100644 --- a/data_diff/databases/redshift.py +++ b/data_diff/databases/redshift.py @@ -110,7 +110,9 @@ def select_view_columns(self, path: DbPath) -> str: return """select * from pg_get_cols('{}.{}') cols(col_name name, col_type varchar) - """.format(schema, table) + """.format( + schema, table + ) def query_pg_get_cols(self, path: DbPath) -> Dict[str, tuple]: rows = self.query(self.select_view_columns(path), list) From efff59156fec044cb3f2186ee2efcf7559194f29 Mon Sep 17 00:00:00 2001 From: Sergey Vasilyev Date: Mon, 18 Dec 2023 15:34:22 +0100 Subject: [PATCH 5/7] Group raw column info from rows to structures for schema parsing --- data_diff/__main__.py | 6 ++-- data_diff/databases/base.py | 60 +++++++++++++++++-------------- data_diff/databases/bigquery.py | 19 ++++------ data_diff/databases/clickhouse.py | 22 ++++++------ data_diff/databases/databricks.py | 10 ++++-- data_diff/databases/duckdb.py | 15 +++----- data_diff/databases/oracle.py | 15 +++----- data_diff/databases/presto.py | 19 ++++------ data_diff/databases/redshift.py | 40 ++++++++++++++------- data_diff/databases/vertica.py | 19 ++++------ data_diff/schema.py | 32 +++++++++++++++++ data_diff/table_segment.py | 8 ++--- 12 files changed, 147 insertions(+), 118 deletions(-) diff --git a/data_diff/__main__.py b/data_diff/__main__.py index 06140ad1..c4e698f5 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -12,8 +12,8 @@ from rich.logging import RichHandler import click -from data_diff import Database -from data_diff.schema import create_schema +from data_diff import Database, DbPath +from data_diff.schema import RawColumnInfo, create_schema from data_diff.queries.api import current_timestamp from data_diff.dbt import dbt_diff @@ -72,7 +72,7 @@ def _remove_passwords_in_dict(d: dict) -> None: d[k] = remove_password_from_url(v) -def _get_schema(pair): +def _get_schema(pair: Tuple[Database, DbPath]) -> Dict[str, RawColumnInfo]: db, table_path = pair return db.query_table_schema(table_path) diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index 1ebdc395..39cc4db0 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -19,6 +19,7 @@ from data_diff.abcs.compiler import AbstractCompiler, Compilable from data_diff.queries.extras import ApplyFuncAndNormalizeAsString, Checksum, NormalizeAsString +from data_diff.schema import RawColumnInfo from data_diff.utils import ArithString, ArithUUID, is_uuid, join_iter, safezip from data_diff.queries.api import Expr, table, Select, SKIP, Explain, Code, this from data_diff.queries.ast_classes import ( @@ -712,27 +713,18 @@ def type_repr(self, t) -> str: datetime: "TIMESTAMP", }[t] - def _parse_type_repr(self, type_repr: str) -> Optional[Type[ColType]]: - return self.TYPE_CLASSES.get(type_repr) - - def parse_type( - self, - table_path: DbPath, - col_name: str, - type_repr: str, - datetime_precision: int = None, - numeric_precision: int = None, - numeric_scale: int = None, - ) -> ColType: + def parse_type(self, table_path: DbPath, info: RawColumnInfo) -> ColType: "Parse type info as returned by the database" - cls = self._parse_type_repr(type_repr) + cls = self.TYPE_CLASSES.get(info.data_type) if cls is None: - return UnknownColType(type_repr) + return UnknownColType(info.data_type) if issubclass(cls, TemporalType): return cls( - precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION, + precision=info.datetime_precision + if info.datetime_precision is not None + else DEFAULT_DATETIME_PRECISION, rounds=self.ROUNDS_ON_PREC_LOSS, ) @@ -743,22 +735,22 @@ def parse_type( return cls() elif issubclass(cls, Decimal): - if numeric_scale is None: - numeric_scale = 0 # Needed for Oracle. - return cls(precision=numeric_scale) + if info.numeric_scale is None: + return cls(precision=0) # Needed for Oracle. + return cls(precision=info.numeric_scale) elif issubclass(cls, Float): # assert numeric_scale is None return cls( precision=self._convert_db_precision_to_digits( - numeric_precision if numeric_precision is not None else DEFAULT_NUMERIC_PRECISION + info.numeric_precision if info.numeric_precision is not None else DEFAULT_NUMERIC_PRECISION ) ) elif issubclass(cls, (JSON, Array, Struct, Text, Native_UUID)): return cls() - raise TypeError(f"Parsing {type_repr} returned an unknown type '{cls}'.") + raise TypeError(f"Parsing {info.data_type} returned an unknown type {cls!r}.") def _convert_db_precision_to_digits(self, p: int) -> int: """Convert from binary precision, used by floats, to decimal precision.""" @@ -1023,7 +1015,7 @@ def select_table_schema(self, path: DbPath) -> str: f"WHERE table_name = '{name}' AND table_schema = '{schema}'" ) - def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: + def query_table_schema(self, path: DbPath) -> Dict[str, RawColumnInfo]: """Query the table for its schema for table in 'path', and return {column: tuple} where the tuple is (table_name, col_name, type_repr, datetime_precision?, numeric_precision?, numeric_scale?) @@ -1034,7 +1026,17 @@ def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: if not rows: raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns") - d = {r[0]: r for r in rows} + d = { + r[0]: RawColumnInfo( + column_name=r[0], + data_type=r[1], + datetime_precision=r[2], + numeric_precision=r[3], + numeric_scale=r[4], + collation_name=r[5] if len(r) > 5 else None, + ) + for r in rows + } assert len(d) == len(rows) return d @@ -1056,7 +1058,11 @@ def query_table_unique_columns(self, path: DbPath) -> List[str]: return list(res) def _process_table_schema( - self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str] = None, where: str = None + self, + path: DbPath, + raw_schema: Dict[str, RawColumnInfo], + filter_columns: Sequence[str] = None, + where: str = None, ): """Process the result of query_table_schema(). @@ -1072,7 +1078,7 @@ def _process_table_schema( accept = {i.lower() for i in filter_columns} filtered_schema = {name: row for name, row in raw_schema.items() if name.lower() in accept} - col_dict = {row[0]: self.dialect.parse_type(path, *row) for _name, row in filtered_schema.items()} + col_dict = {info.column_name: self.dialect.parse_type(path, info) for info in filtered_schema.values()} self._refine_coltypes(path, col_dict, where) @@ -1081,7 +1087,7 @@ def _process_table_schema( def _refine_coltypes( self, table_path: DbPath, col_dict: Dict[str, ColType], where: Optional[str] = None, sample_size=64 - ): + ) -> Dict[str, ColType]: """Refine the types in the column dict, by querying the database for a sample of their values 'where' restricts the rows to be sampled. @@ -1089,7 +1095,7 @@ def _refine_coltypes( text_columns = [k for k, v in col_dict.items() if isinstance(v, Text)] if not text_columns: - return + return col_dict fields = [Code(self.dialect.normalize_uuid(self.dialect.quote(c), String_UUID())) for c in text_columns] @@ -1126,6 +1132,8 @@ def _refine_coltypes( assert col_name in col_dict col_dict[col_name] = String_VaryingAlphanum() + return col_dict + def _normalize_table_path(self, path: DbPath) -> DbPath: if len(path) == 1: return self.default_schema, path[0] diff --git a/data_diff/databases/bigquery.py b/data_diff/databases/bigquery.py index 26d8aec3..02ee4d33 100644 --- a/data_diff/databases/bigquery.py +++ b/data_diff/databases/bigquery.py @@ -33,6 +33,7 @@ MD5_HEXDIGITS, ) from data_diff.databases.base import TIMESTAMP_PRECISION_POS, ThreadLocalInterpreter +from data_diff.schema import RawColumnInfo @import_helper(text="Please install BigQuery and configure your google-cloud access.") @@ -91,19 +92,13 @@ def type_repr(self, t) -> str: except KeyError: return super().type_repr(t) - def parse_type( - self, - table_path: DbPath, - col_name: str, - type_repr: str, - *args: Any, # pass-through args - **kwargs: Any, # pass-through args - ) -> ColType: - col_type = super().parse_type(table_path, col_name, type_repr, *args, **kwargs) + def parse_type(self, table_path: DbPath, info: RawColumnInfo) -> ColType: + col_type = super().parse_type(table_path, info) if isinstance(col_type, UnknownColType): - m = self.TYPE_ARRAY_RE.fullmatch(type_repr) + m = self.TYPE_ARRAY_RE.fullmatch(info.type_repr) if m: - item_type = self.parse_type(table_path, col_name, m.group(1), *args, **kwargs) + item_info = attrs.evolve(info, data_type=m.group(1)) + item_type = self.parse_type(table_path, item_info) col_type = Array(item_type=item_type) # We currently ignore structs' structure, but later can parse it too. Examples: @@ -111,7 +106,7 @@ def parse_type( # - STRUCT (named) # - STRUCT> (with complex fields) # - STRUCT> (nested) - m = self.TYPE_STRUCT_RE.fullmatch(type_repr) + m = self.TYPE_STRUCT_RE.fullmatch(info.type_repr) if m: col_type = Struct() diff --git a/data_diff/databases/clickhouse.py b/data_diff/databases/clickhouse.py index 13082504..7bbc156f 100644 --- a/data_diff/databases/clickhouse.py +++ b/data_diff/databases/clickhouse.py @@ -14,6 +14,7 @@ ) from data_diff.abcs.database_types import ( ColType, + DbPath, Decimal, Float, Integer, @@ -24,6 +25,7 @@ Timestamp, Boolean, ) +from data_diff.schema import RawColumnInfo # https://clickhouse.com/docs/en/operations/server-configuration-parameters/settings/#default-database DEFAULT_DATABASE = "default" @@ -75,19 +77,19 @@ def _convert_db_precision_to_digits(self, p: int) -> int: # because it does not help for float with a big integer part. return super()._convert_db_precision_to_digits(p) - 2 - def _parse_type_repr(self, type_repr: str) -> Optional[Type[ColType]]: + def parse_type(self, table_path: DbPath, info: RawColumnInfo) -> ColType: nullable_prefix = "Nullable(" - if type_repr.startswith(nullable_prefix): - type_repr = type_repr[len(nullable_prefix) :].rstrip(")") + if info.data_type.startswith(nullable_prefix): + info = attrs.evolve(info, data_type=info.data_type[len(nullable_prefix) :].rstrip(")")) - if type_repr.startswith("Decimal"): - type_repr = "Decimal" - elif type_repr.startswith("FixedString"): - type_repr = "FixedString" - elif type_repr.startswith("DateTime64"): - type_repr = "DateTime64" + if info.data_type.startswith("Decimal"): + info = attrs.evolve(info, data_type="Decimal") + elif info.data_type.startswith("FixedString"): + info = attrs.evolve(info, data_type="FixedString") + elif info.data_type.startswith("DateTime64"): + info = attrs.evolve(info, data_type="DateTime64") - return self.TYPE_CLASSES.get(type_repr) + return super().parse_type(table_path, info) # def timestamp_value(self, t: DbTime) -> str: # # return f"'{t}'" diff --git a/data_diff/databases/databricks.py b/data_diff/databases/databricks.py index c755cfa9..b8ff21df 100644 --- a/data_diff/databases/databricks.py +++ b/data_diff/databases/databricks.py @@ -26,6 +26,7 @@ import_helper, parse_table_name, ) +from data_diff.schema import RawColumnInfo @import_helper(text="You can install it using 'pip install databricks-sql-connector'") @@ -138,7 +139,7 @@ def create_connection(self): except databricks.sql.exc.Error as e: raise ConnectionError(*e.args) from e - def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: + def query_table_schema(self, path: DbPath) -> Dict[str, RawColumnInfo]: # Databricks has INFORMATION_SCHEMA only for Databricks Runtime, not for Databricks SQL. # https://docs.databricks.com/spark/latest/spark-sql/language-manual/information-schema/columns.html # So, to obtain information about schema, we should use another approach. @@ -155,7 +156,12 @@ def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: if not rows: raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns") - d = {r.COLUMN_NAME: (r.COLUMN_NAME, r.TYPE_NAME, r.DECIMAL_DIGITS, None, None) for r in rows} + d = { + r.COLUMN_NAME: RawColumnInfo( + column_name=r.COLUMN_NAME, data_type=r.TYPE_NAME, datetime_precision=r.DECIMAL_DIGITS + ) + for r in rows + } assert len(d) == len(rows) return d diff --git a/data_diff/databases/duckdb.py b/data_diff/databases/duckdb.py index d057e0d9..31536854 100644 --- a/data_diff/databases/duckdb.py +++ b/data_diff/databases/duckdb.py @@ -3,6 +3,7 @@ import attrs from packaging.version import parse as parse_version +from data_diff.schema import RawColumnInfo from data_diff.utils import match_regexps from data_diff.abcs.database_types import ( Timestamp, @@ -74,24 +75,16 @@ def _convert_db_precision_to_digits(self, p: int) -> int: # Subtracting 2 due to wierd precision issues in PostgreSQL return super()._convert_db_precision_to_digits(p) - 2 - def parse_type( - self, - table_path: DbPath, - col_name: str, - type_repr: str, - datetime_precision: int = None, - numeric_precision: int = None, - numeric_scale: int = None, - ) -> ColType: + def parse_type(self, table_path: DbPath, info: RawColumnInfo) -> ColType: regexps = { r"DECIMAL\((\d+),(\d+)\)": Decimal, } - for m, t_cls in match_regexps(regexps, type_repr): + for m, t_cls in match_regexps(regexps, info.type_repr): precision = int(m.group(2)) return t_cls(precision=precision) - return super().parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision, numeric_scale) + return super().parse_type(table_path, info) def set_timezone_to_utc(self) -> str: return "SET GLOBAL TimeZone='UTC'" diff --git a/data_diff/databases/oracle.py b/data_diff/databases/oracle.py index ab84f0b6..0383b1f1 100644 --- a/data_diff/databases/oracle.py +++ b/data_diff/databases/oracle.py @@ -2,6 +2,7 @@ import attrs +from data_diff.schema import RawColumnInfo from data_diff.utils import match_regexps from data_diff.abcs.database_types import ( Decimal, @@ -105,26 +106,18 @@ def constant_values(self, rows) -> str: def explain_as_text(self, query: str) -> str: raise NotImplementedError("Explain not yet implemented in Oracle") - def parse_type( - self, - table_path: DbPath, - col_name: str, - type_repr: str, - datetime_precision: int = None, - numeric_precision: int = None, - numeric_scale: int = None, - ) -> ColType: + def parse_type(self, table_path: DbPath, info: RawColumnInfo) -> ColType: regexps = { r"TIMESTAMP\((\d)\) WITH LOCAL TIME ZONE": Timestamp, r"TIMESTAMP\((\d)\) WITH TIME ZONE": TimestampTZ, r"TIMESTAMP\((\d)\)": Timestamp, } - for m, t_cls in match_regexps(regexps, type_repr): + for m, t_cls in match_regexps(regexps, info.type_repr): precision = int(m.group(1)) return t_cls(precision=precision, rounds=self.ROUNDS_ON_PREC_LOSS) - return super().parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision, numeric_scale) + return super().parse_type(table_path, info) def set_timezone_to_utc(self) -> str: return "ALTER SESSION SET TIME_ZONE = 'UTC'" diff --git a/data_diff/databases/presto.py b/data_diff/databases/presto.py index ba1c7360..fa8c1f6d 100644 --- a/data_diff/databases/presto.py +++ b/data_diff/databases/presto.py @@ -4,6 +4,7 @@ import attrs +from data_diff.schema import RawColumnInfo from data_diff.utils import match_regexps from data_diff.abcs.database_types import ( @@ -91,33 +92,25 @@ def quote(self, s: str): def to_string(self, s: str): return f"cast({s} as varchar)" - def parse_type( - self, - table_path: DbPath, - col_name: str, - type_repr: str, - datetime_precision: int = None, - numeric_precision: int = None, - _numeric_scale: int = None, - ) -> ColType: + def parse_type(self, table_path: DbPath, info: RawColumnInfo) -> ColType: timestamp_regexps = { r"timestamp\((\d)\)": Timestamp, r"timestamp\((\d)\) with time zone": TimestampTZ, } - for m, t_cls in match_regexps(timestamp_regexps, type_repr): + for m, t_cls in match_regexps(timestamp_regexps, info.type_repr): precision = int(m.group(1)) return t_cls(precision=precision, rounds=self.ROUNDS_ON_PREC_LOSS) number_regexps = {r"decimal\((\d+),(\d+)\)": Decimal} - for m, n_cls in match_regexps(number_regexps, type_repr): + for m, n_cls in match_regexps(number_regexps, info.type_repr): _prec, scale = map(int, m.groups()) return n_cls(scale) string_regexps = {r"varchar\((\d+)\)": Text, r"char\((\d+)\)": Text} - for m, n_cls in match_regexps(string_regexps, type_repr): + for m, n_cls in match_regexps(string_regexps, info.type_repr): return n_cls() - return super().parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision) + return super().parse_type(table_path, info) def set_timezone_to_utc(self) -> str: return "SET TIME ZONE '+00:00'" diff --git a/data_diff/databases/redshift.py b/data_diff/databases/redshift.py index c9d54a5f..924204d0 100644 --- a/data_diff/databases/redshift.py +++ b/data_diff/databases/redshift.py @@ -1,4 +1,4 @@ -from typing import ClassVar, List, Dict, Type +from typing import Any, ClassVar, Iterable, List, Dict, Tuple, Type import attrs @@ -20,6 +20,7 @@ TIMESTAMP_PRECISION_POS, PostgresqlDialect, ) +from data_diff.schema import RawColumnInfo @attrs.define(frozen=False) @@ -96,13 +97,12 @@ def select_external_table_schema(self, path: DbPath) -> str: + db_clause ) - def query_external_table_schema(self, path: DbPath) -> Dict[str, tuple]: + def query_external_table_schema(self, path: DbPath) -> Dict[str, RawColumnInfo]: rows = self.query(self.select_external_table_schema(path), list) if not rows: raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns") schema_dict = self._normalize_schema_info(rows) - return schema_dict def select_view_columns(self, path: DbPath) -> str: @@ -114,14 +114,12 @@ def select_view_columns(self, path: DbPath) -> str: schema, table ) - def query_pg_get_cols(self, path: DbPath) -> Dict[str, tuple]: + def query_pg_get_cols(self, path: DbPath) -> Dict[str, RawColumnInfo]: rows = self.query(self.select_view_columns(path), list) - if not rows: raise RuntimeError(f"{self.name}: View '{'.'.join(path)}' does not exist, or has no columns") schema_dict = self._normalize_schema_info(rows) - return schema_dict def select_svv_columns_schema(self, path: DbPath) -> Dict[str, tuple]: @@ -147,19 +145,29 @@ def select_svv_columns_schema(self, path: DbPath) -> Dict[str, tuple]: + db_clause ) - def query_svv_columns(self, path: DbPath) -> Dict[str, tuple]: + def query_svv_columns(self, path: DbPath) -> Dict[str, RawColumnInfo]: rows = self.query(self.select_svv_columns_schema(path), list) if not rows: raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns") - d = {r[0]: r for r in rows} + d = { + r[0]: RawColumnInfo( + column_name=r[0], + data_type=r[1], + datetime_precision=r[2], + numeric_precision=r[3], + numeric_scale=r[4], + collation_name=r[5] if len(r) > 5 else None, + ) + for r in rows + } assert len(d) == len(rows) return d # when using a non-information_schema source, strip (N) from type(N) etc. to match # typical information_schema output - def _normalize_schema_info(self, rows) -> Dict[str, tuple]: - schema_dict = {} + def _normalize_schema_info(self, rows: Iterable[Tuple[Any]]) -> Dict[str, RawColumnInfo]: + schema_dict: Dict[str, RawColumnInfo] = {} for r in rows: col_name = r[0] type_info = r[1].split("(") @@ -173,11 +181,17 @@ def _normalize_schema_info(self, rows) -> Dict[str, tuple]: precision = int(precision) scale = int(scale) - out = [col_name, base_type, None, precision, scale] - schema_dict[col_name] = tuple(out) + schema_dict[col_name] = RawColumnInfo( + column_name=col_name, + data_type=col_name, + datetime_precision=None, + numeric_precision=precision, + numeric_scale=scale, + collation_name=None, + ) return schema_dict - def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: + def query_table_schema(self, path: DbPath) -> Dict[str, RawColumnInfo]: try: return super().query_table_schema(path) except RuntimeError: diff --git a/data_diff/databases/vertica.py b/data_diff/databases/vertica.py index 23f63acc..8a0e329e 100644 --- a/data_diff/databases/vertica.py +++ b/data_diff/databases/vertica.py @@ -2,6 +2,7 @@ import attrs +from data_diff.schema import RawColumnInfo from data_diff.utils import match_regexps from data_diff.databases.base import ( CHECKSUM_HEXDIGITS, @@ -68,27 +69,19 @@ def to_string(self, s: str) -> str: def is_distinct_from(self, a: str, b: str) -> str: return f"not ({a} <=> {b})" - def parse_type( - self, - table_path: DbPath, - col_name: str, - type_repr: str, - datetime_precision: int = None, - numeric_precision: int = None, - numeric_scale: int = None, - ) -> ColType: + def parse_type(self, table_path: DbPath, info: RawColumnInfo) -> ColType: timestamp_regexps = { r"timestamp\(?(\d?)\)?": Timestamp, r"timestamptz\(?(\d?)\)?": TimestampTZ, } - for m, t_cls in match_regexps(timestamp_regexps, type_repr): + for m, t_cls in match_regexps(timestamp_regexps, info.type_repr): precision = int(m.group(1)) if m.group(1) else 6 return t_cls(precision=precision, rounds=self.ROUNDS_ON_PREC_LOSS) number_regexps = { r"numeric\((\d+),(\d+)\)": Decimal, } - for m, n_cls in match_regexps(number_regexps, type_repr): + for m, n_cls in match_regexps(number_regexps, info.type_repr): _prec, scale = map(int, m.groups()) return n_cls(scale) @@ -96,10 +89,10 @@ def parse_type( r"varchar\((\d+)\)": Text, r"char\((\d+)\)": Text, } - for m, n_cls in match_regexps(string_regexps, type_repr): + for m, n_cls in match_regexps(string_regexps, info.type_repr): return n_cls() - return super().parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision) + return super().parse_type(table_path, info) def set_timezone_to_utc(self) -> str: return "SET TIME ZONE TO 'UTC'" diff --git a/data_diff/schema.py b/data_diff/schema.py index 67b4261f..f0408935 100644 --- a/data_diff/schema.py +++ b/data_diff/schema.py @@ -1,4 +1,7 @@ import logging +from typing import Any, Collection, Iterable, Optional + +import attrs from data_diff.utils import CaseAwareMapping, CaseInsensitiveDict, CaseSensitiveDict from data_diff.abcs.database_types import DbPath @@ -8,6 +11,35 @@ Schema = CaseAwareMapping +@attrs.frozen(kw_only=True) +class RawColumnInfo(Collection[Any]): + """ + A raw row representing the schema info about a column. + + Do not rely on this class too much, it will be removed soon when the schema + selecting & parsing methods are united into one overrideable method. + """ + + column_name: str + data_type: str + datetime_precision: Optional[int] = None + numeric_precision: Optional[int] = None + numeric_scale: Optional[int] = None + collation_name: Optional[str] = None + + # It was a tuple once, so we keep it backward compatible temporarily, until remade to classes. + def __iter__(self) -> Iterable[Any]: + return iter( + (self.column_name, self.data_type, self.datetime_precision, self.numeric_precision, self.numeric_scale) + ) + + def __len__(self) -> int: + return 5 + + def __contains__(self, item: Any) -> bool: + return False # that was not used + + def create_schema(db_name: str, table_path: DbPath, schema: dict, case_sensitive: bool) -> CaseAwareMapping: logger.info(f"[{db_name}] Schema = {schema}") diff --git a/data_diff/table_segment.py b/data_diff/table_segment.py index d8f84231..180f0aad 100644 --- a/data_diff/table_segment.py +++ b/data_diff/table_segment.py @@ -1,5 +1,5 @@ import time -from typing import List, Optional, Tuple +from typing import Dict, List, Optional, Tuple import logging from itertools import product @@ -10,7 +10,7 @@ from data_diff.utils import ArithString, split_space from data_diff.databases.base import Database from data_diff.abcs.database_types import DbPath, DbKey, DbTime -from data_diff.schema import Schema, create_schema +from data_diff.schema import RawColumnInfo, Schema, create_schema from data_diff.queries.extras import Checksum from data_diff.queries.api import Count, SKIP, table, this, Expr, min_, max_, Code from data_diff.queries.extras import ApplyFuncAndNormalizeAsString, NormalizeAsString @@ -140,7 +140,7 @@ def __attrs_post_init__(self): def _where(self): return f"({self.where})" if self.where else None - def _with_raw_schema(self, raw_schema: dict) -> Self: + def _with_raw_schema(self, raw_schema: Dict[str, RawColumnInfo]) -> Self: schema = self.database._process_table_schema(self.table_path, raw_schema, self.relevant_columns, self._where()) return self.new(schema=create_schema(self.database.name, self.table_path, schema, self.case_sensitive)) @@ -151,7 +151,7 @@ def with_schema(self) -> Self: return self._with_raw_schema(self.database.query_table_schema(self.table_path)) - def get_schema(self): + def get_schema(self) -> Dict[str, RawColumnInfo]: return self.database.query_table_schema(self.table_path) def _make_key_range(self): From 53a65c09dc9c8ccf993b65fcd60c13ae9bff25ab Mon Sep 17 00:00:00 2001 From: Sergey Vasilyev Date: Mon, 18 Dec 2023 15:36:28 +0100 Subject: [PATCH 6/7] Retrieve collations for selected databases (SQL Server & Snowflake) --- data_diff/abcs/database_types.py | 88 +++++++++++++++++++++++++++++++- data_diff/databases/base.py | 2 +- data_diff/databases/mssql.py | 2 +- data_diff/databases/snowflake.py | 3 +- 4 files changed, 91 insertions(+), 4 deletions(-) diff --git a/data_diff/abcs/database_types.py b/data_diff/abcs/database_types.py index 894a017c..14436438 100644 --- a/data_diff/abcs/database_types.py +++ b/data_diff/abcs/database_types.py @@ -1,6 +1,6 @@ import decimal from abc import ABC, abstractmethod -from typing import List, Optional, Tuple, Type, TypeVar, Union +from typing import Collection, List, Optional, Tuple, Type, TypeVar, Union from datetime import datetime import attrs @@ -15,6 +15,91 @@ N = TypeVar("N") +@attrs.frozen(kw_only=True, eq=False, order=False, unsafe_hash=True) +class Collation: + """ + A pre-parsed or pre-known record about db collation, per column. + + The "greater" collation should be used as a target collation for textual PKs + on both sides of the diff — by coverting the "lesser" collation to self. + + Snowflake easily absorbs the performance losses, so it has a boost to always + be greater than any other collation in non-Snowflake databases. + Other databases need to negotiate which side absorbs the performance impact. + """ + + # A boost for special databases that are known to absorb the performance dmaage well. + absorbs_damage: bool = False + + # Ordinal soring by ASCII/UTF8 (True), or alphabetic as per locale/country/etc (False). + ordinal: Optional[bool] = None + + # Lowercase first (aAbBcC or abcABC). Otherwise, uppercase first (AaBbCc or ABCabc). + lower_first: Optional[bool] = None + + # 2-letter lower-case locale and upper-case country codes, e.g. en_US. Ignored for ordinals. + language: Optional[str] = None + country: Optional[str] = None + + # There are also space-, punctuation-, width-, kana-(in)sensitivity, so on. + # Ignore everything not related to xdb alignment. Only case- & accent-sensitivity are common. + case_sensitive: Optional[bool] = None + accent_sensitive: Optional[bool] = None + + # Purely informational, for debugging: + _source: Union[None, str, Collection[str]] = None + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Collation): + return NotImplemented + if self.ordinal and other.ordinal: + # TODO: does it depend on language? what does Albanic_BIN mean in MS SQL? + return True + return ( + self.language == other.language + and (self.country is None or other.country is None or self.country == other.country) + and self.case_sensitive == other.case_sensitive + and self.accent_sensitive == other.accent_sensitive + and self.lower_first == other.lower_first + ) + + def __ne__(self, other: object) -> bool: + if not isinstance(other, Collation): + return NotImplemented + return not self.__eq__(other) + + def __gt__(self, other: object) -> bool: + if not isinstance(other, Collation): + return NotImplemented + if self == other: + return False + if self.absorbs_damage and not other.absorbs_damage: + return False + if other.absorbs_damage and not self.absorbs_damage: + return True # this one is preferred if it cannot absorb damage as its counterpart can + if self.ordinal and not other.ordinal: + return True + if other.ordinal and not self.ordinal: + return False + # TODO: try to align the languages & countries? + return False + + def __ge__(self, other: object) -> bool: + if not isinstance(other, Collation): + return NotImplemented + return self == other or self.__gt__(other) + + def __lt__(self, other: object) -> bool: + if not isinstance(other, Collation): + return NotImplemented + return self != other and not self.__gt__(other) + + def __le__(self, other: object) -> bool: + if not isinstance(other, Collation): + return NotImplemented + return self == other or not self.__gt__(other) + + @attrs.define(frozen=True, kw_only=True) class ColType: # Arbitrary metadata added and fetched at runtime. @@ -112,6 +197,7 @@ def python_type(self) -> type: @attrs.define(frozen=True) class StringType(ColType): python_type = str + collation: Optional[Collation] = attrs.field(default=None, kw_only=True) @attrs.define(frozen=True) diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index 39cc4db0..21c8d0e6 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -1130,7 +1130,7 @@ def _refine_coltypes( ) else: assert col_name in col_dict - col_dict[col_name] = String_VaryingAlphanum() + col_dict[col_name] = String_VaryingAlphanum(collation=col_dict[col_name].collation) return col_dict diff --git a/data_diff/databases/mssql.py b/data_diff/databases/mssql.py index 834ed9cd..758ac3e8 100644 --- a/data_diff/databases/mssql.py +++ b/data_diff/databases/mssql.py @@ -201,7 +201,7 @@ def select_table_schema(self, path: DbPath) -> str: info_schema_path.insert(0, self.dialect.quote(database)) return ( - "SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale " + "SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale, collation_name " f"FROM {'.'.join(info_schema_path)} " f"WHERE table_name = '{name}' AND table_schema = '{schema}'" ) diff --git a/data_diff/databases/snowflake.py b/data_diff/databases/snowflake.py index 4152a407..1b70085a 100644 --- a/data_diff/databases/snowflake.py +++ b/data_diff/databases/snowflake.py @@ -164,7 +164,8 @@ def select_table_schema(self, path: DbPath) -> str: info_schema_path.insert(0, database) return ( - "SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale " + "SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale" + " , coalesce(collation_name, 'utf8') " f"FROM {'.'.join(info_schema_path)} " f"WHERE table_name = '{name}' AND table_schema = '{schema}'" ) From ec8453129844aa3abb21488971596c8d0cb08fcf Mon Sep 17 00:00:00 2001 From: nolar Date: Wed, 27 Dec 2023 11:50:04 +0000 Subject: [PATCH 7/7] style fixes by ruff --- data_diff/databases/redshift.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/data_diff/databases/redshift.py b/data_diff/databases/redshift.py index 924204d0..be9ec0fb 100644 --- a/data_diff/databases/redshift.py +++ b/data_diff/databases/redshift.py @@ -110,9 +110,7 @@ def select_view_columns(self, path: DbPath) -> str: return """select * from pg_get_cols('{}.{}') cols(col_name name, col_type varchar) - """.format( - schema, table - ) + """.format(schema, table) def query_pg_get_cols(self, path: DbPath) -> Dict[str, RawColumnInfo]: rows = self.query(self.select_view_columns(path), list)