diff --git a/data_diff/sqeleton/abcs/database_types.py b/data_diff/sqeleton/abcs/database_types.py index 145b8452..f82e681b 100644 --- a/data_diff/sqeleton/abcs/database_types.py +++ b/data_diff/sqeleton/abcs/database_types.py @@ -13,6 +13,7 @@ DbTime = datetime +@dataclass class ColType: supported = True @@ -140,6 +141,21 @@ class JSON(ColType): pass +@dataclass +class Array(ColType): + item_type: ColType + + +# Unlike JSON, structs are not free-form and have a very specific set of fields and their types. +# We do not parse & use those fields now, but we can do this later. +# For example, in BigQuery: +# - https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types#struct_type +# - https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#struct_literals +@dataclass +class Struct(ColType): + pass + + @dataclass class Integer(NumericType, IKey): precision: int = 0 @@ -227,6 +243,10 @@ def parse_type( ) -> ColType: "Parse type info as returned by the database" + @abstractmethod + def to_comparable(self, value: str, coltype: ColType) -> str: + """Ensure that the expression is comparable in ``IS DISTINCT FROM``.""" + from typing import TypeVar, Generic diff --git a/data_diff/sqeleton/abcs/mixins.py b/data_diff/sqeleton/abcs/mixins.py index 89462dd9..b07a7315 100644 --- a/data_diff/sqeleton/abcs/mixins.py +++ b/data_diff/sqeleton/abcs/mixins.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from .database_types import TemporalType, FractionalType, ColType_UUID, Boolean, ColType, String_UUID, JSON +from .database_types import Array, TemporalType, FractionalType, ColType_UUID, Boolean, ColType, String_UUID, JSON, Struct from .compiler import Compilable @@ -8,6 +8,11 @@ class AbstractMixin(ABC): class AbstractMixin_NormalizeValue(AbstractMixin): + + @abstractmethod + def to_comparable(self, value: str, coltype: ColType) -> str: + """Ensure that the expression is comparable in ``IS DISTINCT FROM``.""" + @abstractmethod def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: """Creates an SQL expression, that converts 'value' to a normalized timestamp. @@ -51,7 +56,15 @@ def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: def normalize_json(self, value: str, _coltype: JSON) -> str: """Creates an SQL expression, that converts 'value' to its minified json string representation.""" - raise NotImplementedError() + return self.to_string(value) + + def normalize_array(self, value: str, _coltype: Array) -> str: + """Creates an SQL expression, that serialized an array into a JSON string.""" + return self.to_string(value) + + def normalize_struct(self, value: str, _coltype: Struct) -> str: + """Creates an SQL expression, that serialized a typed struct into a JSON string.""" + return self.to_string(value) def normalize_value_by_type(self, value: str, coltype: ColType) -> str: """Creates an SQL expression, that converts 'value' to a normalized representation. @@ -79,6 +92,10 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str: return self.normalize_boolean(value, coltype) elif isinstance(coltype, JSON): return self.normalize_json(value, coltype) + elif isinstance(coltype, Array): + return self.normalize_array(value, coltype) + elif isinstance(coltype, Struct): + return self.normalize_struct(value, coltype) return self.to_string(value) diff --git a/data_diff/sqeleton/databases/base.py b/data_diff/sqeleton/databases/base.py index e9e0884d..8ef01373 100644 --- a/data_diff/sqeleton/databases/base.py +++ b/data_diff/sqeleton/databases/base.py @@ -17,7 +17,8 @@ from ..queries.ast_classes import Random from ..abcs.database_types import ( AbstractDatabase, - T_Dialect, + Array, + Struct, AbstractDialect, AbstractTable, ColType, @@ -165,6 +166,10 @@ def concat(self, items: List[str]) -> str: joined_exprs = ", ".join(items) return f"concat({joined_exprs})" + def to_comparable(self, value: str, coltype: ColType) -> str: + """Ensure that the expression is comparable in ``IS DISTINCT FROM``.""" + return value + def is_distinct_from(self, a: str, b: str) -> str: return f"{a} is distinct from {b}" @@ -229,7 +234,7 @@ def parse_type( """ """ cls = self._parse_type_repr(type_repr) - if not cls: + if cls is None: return UnknownColType(type_repr) if issubclass(cls, TemporalType): @@ -257,10 +262,7 @@ def parse_type( ) ) - elif issubclass(cls, (Text, Native_UUID)): - return cls() - - elif issubclass(cls, JSON): + elif issubclass(cls, (JSON, Array, Struct, Text, Native_UUID)): return cls() raise TypeError(f"Parsing {type_repr} returned an unknown type '{cls}'.") diff --git a/data_diff/sqeleton/databases/bigquery.py b/data_diff/sqeleton/databases/bigquery.py index 0b4dc66c..c2090e5c 100644 --- a/data_diff/sqeleton/databases/bigquery.py +++ b/data_diff/sqeleton/databases/bigquery.py @@ -1,5 +1,10 @@ -from typing import List, Union +import re +from typing import Any, List, Union from ..abcs.database_types import ( + ColType, + Array, + JSON, + Struct, Timestamp, Datetime, Integer, @@ -10,6 +15,7 @@ FractionalType, TemporalType, Boolean, + UnknownColType, ) from ..abcs.mixins import ( AbstractMixin_MD5, @@ -36,6 +42,7 @@ def md5_as_int(self, s: str) -> str: class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: if coltype.rounds: timestamp = f"timestamp_micros(cast(round(unix_micros(cast({value} as timestamp))/1000000, {coltype.precision})*1000000 as int))" @@ -57,6 +64,27 @@ def normalize_number(self, value: str, coltype: FractionalType) -> str: def normalize_boolean(self, value: str, _coltype: Boolean) -> str: return self.to_string(f"cast({value} as int)") + def normalize_json(self, value: str, _coltype: JSON) -> str: + # BigQuery is unable to compare arrays & structs with ==/!=/distinct from, e.g.: + # Got error: 400 Grouping is not defined for arguments of type ARRAY at … + # So we do the best effort and compare it as strings, hoping that the JSON forms + # match on both sides: i.e. have properly ordered keys, same spacing, same quotes, etc. + return f"to_json_string({value})" + + def normalize_array(self, value: str, _coltype: Array) -> str: + # BigQuery is unable to compare arrays & structs with ==/!=/distinct from, e.g.: + # Got error: 400 Grouping is not defined for arguments of type ARRAY at … + # So we do the best effort and compare it as strings, hoping that the JSON forms + # match on both sides: i.e. have properly ordered keys, same spacing, same quotes, etc. + return f"to_json_string({value})" + + def normalize_struct(self, value: str, _coltype: Struct) -> str: + # BigQuery is unable to compare arrays & structs with ==/!=/distinct from, e.g.: + # Got error: 400 Grouping is not defined for arguments of type ARRAY at … + # So we do the best effort and compare it as strings, hoping that the JSON forms + # match on both sides: i.e. have properly ordered keys, same spacing, same quotes, etc. + return f"to_json_string({value})" + class Mixin_Schema(AbstractMixin_Schema): def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: @@ -112,11 +140,12 @@ class Dialect(BaseDialect, Mixin_Schema): "BIGNUMERIC": Decimal, "FLOAT64": Float, "FLOAT32": Float, - # Text "STRING": Text, - # Boolean "BOOL": Boolean, + "JSON": JSON, } + TYPE_ARRAY_RE = re.compile(r'ARRAY<(.+)>') + TYPE_STRUCT_RE = re.compile(r'STRUCT<(.+)>') MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_TimeTravel, Mixin_RandomSample} def random(self) -> str: @@ -134,6 +163,40 @@ def type_repr(self, t) -> str: except KeyError: return super().type_repr(t) + def parse_type( + self, + table_path: DbPath, + col_name: str, + type_repr: str, + *args: Any, # pass-through args + **kwargs: Any, # pass-through args + ) -> ColType: + col_type = super().parse_type(table_path, col_name, type_repr, *args, **kwargs) + if isinstance(col_type, UnknownColType): + + m = self.TYPE_ARRAY_RE.fullmatch(type_repr) + if m: + item_type = self.parse_type(table_path, col_name, m.group(1), *args, **kwargs) + col_type = Array(item_type=item_type) + + # We currently ignore structs' structure, but later can parse it too. Examples: + # - STRUCT (unnamed) + # - STRUCT (named) + # - STRUCT> (with complex fields) + # - STRUCT> (nested) + m = self.TYPE_STRUCT_RE.fullmatch(type_repr) + if m: + col_type = Struct() + + return col_type + + def to_comparable(self, value: str, coltype: ColType) -> str: + """Ensure that the expression is comparable in ``IS DISTINCT FROM``.""" + if isinstance(coltype, (JSON, Array, Struct)): + return self.normalize_value_by_type(value, coltype) + else: + return super().to_comparable(value, coltype) + def set_timezone_to_utc(self) -> str: raise NotImplementedError() diff --git a/data_diff/sqeleton/queries/ast_classes.py b/data_diff/sqeleton/queries/ast_classes.py index ae75c855..7975c8fa 100644 --- a/data_diff/sqeleton/queries/ast_classes.py +++ b/data_diff/sqeleton/queries/ast_classes.py @@ -352,7 +352,9 @@ class IsDistinctFrom(ExprNode, LazyOps): type = bool def compile(self, c: Compiler) -> str: - return c.dialect.is_distinct_from(c.compile(self.a), c.compile(self.b)) + a = c.dialect.to_comparable(c.compile(self.a), self.a.type) + b = c.dialect.to_comparable(c.compile(self.b), self.b.type) + return c.dialect.is_distinct_from(a, b) @dataclass(eq=False, order=False) diff --git a/tests/sqeleton/test_query.py b/tests/sqeleton/test_query.py index 3856d802..efc41c02 100644 --- a/tests/sqeleton/test_query.py +++ b/tests/sqeleton/test_query.py @@ -26,6 +26,9 @@ def concat(self, l: List[str]) -> str: s = ", ".join(l) return f"concat({s})" + def to_comparable(self, s: str) -> str: + return s + def to_string(self, s: str) -> str: return f"cast({s} as varchar)"