diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index 9602b583..4a118bd4 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -1,14 +1,27 @@ +from uuid import UUID import math import sys import logging -from typing import Dict, Tuple, Optional, Sequence +from typing import Dict, Tuple, Optional, Sequence, Type, List from functools import lru_cache, wraps from concurrent.futures import ThreadPoolExecutor import threading from abc import abstractmethod -from .database_types import AbstractDatabase, ColType, Integer, Decimal, Float, UnknownColType -from data_diff.sql import DbPath, SqlOrStr, Compiler, Explain, Select +from data_diff.utils import is_uuid, safezip +from .database_types import ( + ColType_UUID, + AbstractDatabase, + ColType, + Integer, + Decimal, + Float, + PrecisionType, + TemporalType, + UnknownColType, + Text, +) +from data_diff.sql import DbPath, SqlOrStr, Compiler, Explain, Select, TableName logger = logging.getLogger("database") @@ -62,7 +75,7 @@ class Database(AbstractDatabase): Instanciated using :meth:`~data_diff.connect_to_uri` """ - DATETIME_TYPES: Dict[str, type] = {} + TYPE_CLASSES: Dict[str, type] = {} default_schema: str = None @property @@ -93,7 +106,7 @@ def query(self, sql_ast: SqlOrStr, res_type: type): assert len(res) == 1, (sql_code, res) return res[0] elif getattr(res_type, "__origin__", None) is list and len(res_type.__args__) == 1: - if res_type.__args__ == (int,): + if res_type.__args__ == (int,) or res_type.__args__ == (str,): return [_one(row) for row in res] elif res_type.__args__ == (Tuple,): return [tuple(row) for row in res] @@ -109,8 +122,12 @@ def _convert_db_precision_to_digits(self, p: int) -> int: # See: https://en.wikipedia.org/wiki/Single-precision_floating-point_format return math.floor(math.log(2**p, 10)) + def _parse_type_repr(self, type_repr: str) -> Optional[Type[ColType]]: + return self.TYPE_CLASSES.get(type_repr) + def _parse_type( self, + table_path: DbPath, col_name: str, type_repr: str, datetime_precision: int = None, @@ -119,28 +136,27 @@ def _parse_type( ) -> ColType: """ """ - cls = self.DATETIME_TYPES.get(type_repr) - if cls: + cls = self._parse_type_repr(type_repr) + if not cls: + return UnknownColType(type_repr) + + if issubclass(cls, TemporalType): return cls( precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION, rounds=self.ROUNDS_ON_PREC_LOSS, ) - cls = self.NUMERIC_TYPES.get(type_repr) - if cls: - if issubclass(cls, Integer): - # Some DBs have a constant numeric_scale, so they don't report it. - # We fill in the constant, so we need to ignore it for integers. - return cls(precision=0) - - elif issubclass(cls, Decimal): - if numeric_scale is None: - raise ValueError( - f"{self.name}: Unexpected numeric_scale is NULL, for column {col_name} of type {type_repr}." - ) - return cls(precision=numeric_scale) + elif issubclass(cls, Integer): + return cls() + + elif issubclass(cls, Decimal): + if numeric_scale is None: + raise ValueError( + f"{self.name}: Unexpected numeric_scale is NULL, for column {'.'.join(table_path)}.{col_name} of type {type_repr}." + ) + return cls(precision=numeric_scale) - assert issubclass(cls, Float) + elif issubclass(cls, Float): # assert numeric_scale is None return cls( precision=self._convert_db_precision_to_digits( @@ -148,7 +164,10 @@ def _parse_type( ) ) - return UnknownColType(type_repr) + elif issubclass(cls, Text): + return cls() + + raise TypeError(f"Parsing {type_repr} returned an unknown type '{cls}'.") def select_table_schema(self, path: DbPath) -> str: schema, table = self._normalize_table_path(path) @@ -167,8 +186,34 @@ def query_table_schema(self, path: DbPath, filter_columns: Optional[Sequence[str accept = {i.lower() for i in filter_columns} rows = [r for r in rows if r[0].lower() in accept] + col_dict: Dict[str, ColType] = {row[0]: self._parse_type(path, *row) for row in rows} + + self._refine_coltypes(path, col_dict) + # Return a dict of form {name: type} after normalization - return {row[0]: self._parse_type(*row) for row in rows} + return col_dict + + def _refine_coltypes(self, table_path: DbPath, col_dict: Dict[str, ColType]): + "Refine the types in the column dict, by querying the database for a sample of their values" + + text_columns = [k for k, v in col_dict.items() if isinstance(v, Text)] + if not text_columns: + return + + fields = [self.normalize_uuid(c, ColType_UUID()) for c in text_columns] + samples_by_row = self.query(Select(fields, TableName(table_path), limit=16), list) + samples_by_col = list(zip(*samples_by_row)) + for col_name, samples in safezip(text_columns, samples_by_col): + uuid_samples = list(filter(is_uuid, samples)) + + if uuid_samples: + if len(uuid_samples) != len(samples): + logger.warning( + f"Mixed UUID/Non-UUID values detected in column {'.'.join(table_path)}.{col_name}, disabling UUID support." + ) + else: + assert col_name in col_dict + col_dict[col_name] = ColType_UUID() # @lru_cache() # def get_table_schema(self, path: DbPath) -> Dict[str, ColType]: @@ -186,6 +231,15 @@ def _normalize_table_path(self, path: DbPath) -> DbPath: def parse_table_name(self, name: str) -> DbPath: return parse_table_name(name) + def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None): + if offset: + raise NotImplementedError("No support for OFFSET in query") + + return f"LIMIT {limit}" + + def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: + return f"TRIM({value})" + class ThreadedDatabase(Database): """Access the database through singleton threads. diff --git a/data_diff/databases/bigquery.py b/data_diff/databases/bigquery.py index 362526c8..411ae795 100644 --- a/data_diff/databases/bigquery.py +++ b/data_diff/databases/bigquery.py @@ -11,17 +11,19 @@ def import_bigquery(): class BigQuery(Database): - DATETIME_TYPES = { + TYPE_CLASSES = { + # Dates "TIMESTAMP": Timestamp, "DATETIME": Datetime, - } - NUMERIC_TYPES = { + # Numbers "INT64": Integer, "INT32": Integer, "NUMERIC": Decimal, "BIGNUMERIC": Decimal, "FLOAT64": Float, "FLOAT32": Float, + # Text + "STRING": Text, } ROUNDS_ON_PREC_LOSS = False # Technically BigQuery doesn't allow implicit rounding or truncation diff --git a/data_diff/databases/database_types.py b/data_diff/databases/database_types.py index 424883e3..caf38852 100644 --- a/data_diff/databases/database_types.py +++ b/data_diff/databases/database_types.py @@ -1,15 +1,20 @@ +import decimal from abc import ABC, abstractmethod -from typing import Sequence, Optional, Tuple, Union, Dict +from typing import Sequence, Optional, Tuple, Union, Dict, Any from datetime import datetime from runtype import dataclass +from data_diff.utils import ArithUUID + + DbPath = Tuple[str, ...] -DbKey = Union[int, str, bytes] +DbKey = Union[int, str, bytes, ArithUUID] DbTime = datetime class ColType: + supported = True pass @@ -50,11 +55,36 @@ class Float(FractionalType): class Decimal(FractionalType): + @property + def python_type(self) -> type: + if self.precision == 0: + return int + return decimal.Decimal + + +class StringType(ColType): pass +class IKey(ABC): + "Interface for ColType, for using a column as a key in data-diff" + python_type: type + + +class ColType_UUID(StringType, IKey): + python_type = ArithUUID + + +@dataclass +class Text(StringType): + supported = False + + @dataclass -class Integer(NumericType): +class Integer(NumericType, IKey): + precision: int = 0 + python_type: type = int + def __post_init__(self): assert self.precision == 0 @@ -63,6 +93,8 @@ def __post_init__(self): class UnknownColType(ColType): text: str + supported = False + class AbstractDatabase(ABC): @abstractmethod @@ -80,6 +112,10 @@ def md5_to_int(self, s: str) -> str: "Provide SQL for computing md5 and returning an int" ... + @abstractmethod + def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None): + ... + @abstractmethod def _query(self, sql_code: str) -> list: "Send query to database and return result" @@ -138,6 +174,14 @@ def normalize_number(self, value: str, coltype: FractionalType) -> str: """ ... + @abstractmethod + def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: + """Creates an SQL expression, that converts 'value' to a normalized uuid. + + i.e. just makes sure there is no trailing whitespace. + """ + ... + def normalize_value_by_type(self, value: str, coltype: ColType) -> str: """Creates an SQL expression, that converts 'value' to a normalized representation. @@ -158,6 +202,8 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str: return self.normalize_timestamp(value, coltype) elif isinstance(coltype, FractionalType): return self.normalize_number(value, coltype) + elif isinstance(coltype, ColType_UUID): + return self.normalize_uuid(value, coltype) return self.to_string(value) def _normalize_table_path(self, path: DbPath) -> DbPath: diff --git a/data_diff/databases/mysql.py b/data_diff/databases/mysql.py index 2ba9550f..7e4e2956 100644 --- a/data_diff/databases/mysql.py +++ b/data_diff/databases/mysql.py @@ -11,16 +11,19 @@ def import_mysql(): class MySQL(ThreadedDatabase): - DATETIME_TYPES = { + TYPE_CLASSES = { + # Dates "datetime": Datetime, "timestamp": Timestamp, - } - NUMERIC_TYPES = { + # Numbers "double": Float, "float": Float, "decimal": Decimal, "int": Integer, "bigint": Integer, + # Text + "varchar": Text, + "char": Text, } ROUNDS_ON_PREC_LOSS = True diff --git a/data_diff/databases/oracle.py b/data_diff/databases/oracle.py index 81d5dc38..55109eda 100644 --- a/data_diff/databases/oracle.py +++ b/data_diff/databases/oracle.py @@ -13,6 +13,15 @@ def import_oracle(): class Oracle(ThreadedDatabase): + TYPE_CLASSES: Dict[str, type] = { + "NUMBER": Decimal, + "FLOAT": Float, + # Text + "CHAR": Text, + "NCHAR": Text, + "NVARCHAR2": Text, + "VARCHAR2": Text, + } ROUNDS_ON_PREC_LOSS = True def __init__(self, host, port, user, password, *, database, thread_count, **kw): @@ -67,13 +76,13 @@ def normalize_number(self, value: str, coltype: FractionalType) -> str: def _parse_type( self, + table_name: DbPath, col_name: str, type_repr: str, datetime_precision: int = None, numeric_precision: int = None, numeric_scale: int = None, ) -> ColType: - """ """ regexps = { r"TIMESTAMP\((\d)\) WITH LOCAL TIME ZONE": Timestamp, r"TIMESTAMP\((\d)\) WITH TIME ZONE": TimestampTZ, @@ -87,20 +96,14 @@ def _parse_type( rounds=self.ROUNDS_ON_PREC_LOSS, ) - n_cls = { - "NUMBER": Decimal, - "FLOAT": Float, - }.get(type_repr, None) - if n_cls: - if issubclass(n_cls, Decimal): - assert numeric_scale is not None, (type_repr, numeric_precision, numeric_scale) - return n_cls(precision=numeric_scale) - - assert issubclass(n_cls, Float) - return n_cls( - precision=self._convert_db_precision_to_digits( - numeric_precision if numeric_precision is not None else DEFAULT_NUMERIC_PRECISION - ) - ) + return super()._parse_type(type_repr, col_name, type_repr, datetime_precision, numeric_precision, numeric_scale) + + def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None): + if offset: + raise NotImplementedError("No support for OFFSET in query") + + return f"FETCH NEXT {limit} ROWS ONLY" - return UnknownColType(type_repr) + def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: + # Cast is necessary for correct MD5 (trimming not enough) + return f"CAST(TRIM({value}) AS VARCHAR(36))" diff --git a/data_diff/databases/postgresql.py b/data_diff/databases/postgresql.py index 86917c80..2e6ed64f 100644 --- a/data_diff/databases/postgresql.py +++ b/data_diff/databases/postgresql.py @@ -13,19 +13,23 @@ def import_postgresql(): class PostgreSQL(ThreadedDatabase): - DATETIME_TYPES = { + TYPE_CLASSES = { + # Timestamps "timestamp with time zone": TimestampTZ, "timestamp without time zone": Timestamp, "timestamp": Timestamp, - # "datetime": Datetime, - } - NUMERIC_TYPES = { + # Numbers "double precision": Float, "real": Float, "decimal": Decimal, "integer": Integer, "numeric": Decimal, "bigint": Integer, + # Text + "character": Text, + "character varying": Text, + "varchar": Text, + "text": Text, } ROUNDS_ON_PREC_LOSS = True diff --git a/data_diff/databases/presto.py b/data_diff/databases/presto.py index 8b78b7e0..400f2971 100644 --- a/data_diff/databases/presto.py +++ b/data_diff/databases/presto.py @@ -20,16 +20,17 @@ def import_presto(): class Presto(Database): default_schema = "public" - DATETIME_TYPES = { + TYPE_CLASSES = { + # Timestamps "timestamp with time zone": TimestampTZ, "timestamp without time zone": Timestamp, "timestamp": Timestamp, - # "datetime": Datetime, - } - NUMERIC_TYPES = { + # Numbers "integer": Integer, "real": Float, "double": Float, + # Text + "varchar": Text, } ROUNDS_ON_PREC_LOSS = True @@ -82,7 +83,12 @@ def select_table_schema(self, path: DbPath) -> str: ) def _parse_type( - self, col_name: str, type_repr: str, datetime_precision: int = None, numeric_precision: int = None + self, + table_path: DbPath, + col_name: str, + type_repr: str, + datetime_precision: int = None, + numeric_precision: int = None, ) -> ColType: timestamp_regexps = { r"timestamp\((\d)\)": Timestamp, @@ -104,17 +110,14 @@ def _parse_type( prec, scale = map(int, m.groups()) return n_cls(scale) - n_cls = self.NUMERIC_TYPES.get(type_repr) - if n_cls: - if issubclass(n_cls, Integer): - assert numeric_precision is not None - return n_cls(0) + string_regexps = {r"varchar\((\d+)\)": Text, r"char\((\d+)\)": Text} + for regexp, n_cls in string_regexps.items(): + m = re.match(regexp + "$", type_repr) + if m: + return n_cls() - assert issubclass(n_cls, Float) - return n_cls( - precision=self._convert_db_precision_to_digits( - numeric_precision if numeric_precision is not None else DEFAULT_NUMERIC_PRECISION - ) - ) + return super()._parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision) - return UnknownColType(type_repr) + def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: + # Trim doesn't work on CHAR type + return f"TRIM(CAST({value} AS VARCHAR))" diff --git a/data_diff/databases/redshift.py b/data_diff/databases/redshift.py index 12904fcd..cbda7ff9 100644 --- a/data_diff/databases/redshift.py +++ b/data_diff/databases/redshift.py @@ -3,8 +3,8 @@ class Redshift(PostgreSQL): - NUMERIC_TYPES = { - **PostgreSQL.NUMERIC_TYPES, + TYPE_CLASSES = { + **PostgreSQL.TYPE_CLASSES, "double": Float, "real": Float, } diff --git a/data_diff/databases/snowflake.py b/data_diff/databases/snowflake.py index d0dbbbb7..9d3f8448 100644 --- a/data_diff/databases/snowflake.py +++ b/data_diff/databases/snowflake.py @@ -12,14 +12,16 @@ def import_snowflake(): class Snowflake(Database): - DATETIME_TYPES = { + TYPE_CLASSES = { + # Timestamps "TIMESTAMP_NTZ": Timestamp, "TIMESTAMP_LTZ": Timestamp, "TIMESTAMP_TZ": TimestampTZ, - } - NUMERIC_TYPES = { + # Numbers "NUMBER": Decimal, "FLOAT": Float, + # Text + "TEXT": Text, } ROUNDS_ON_PREC_LOSS = False diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index b9526e8d..143714f2 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -5,17 +5,20 @@ import time from operator import attrgetter, methodcaller from collections import defaultdict -from typing import List, Tuple, Iterator, Optional +from typing import List, Tuple, Iterator, Optional, Type import logging from concurrent.futures import ThreadPoolExecutor from runtype import dataclass -from .sql import Select, Checksum, Compare, DbPath, DbKey, DbTime, Count, TableName, Time, Min, Max +from .sql import Select, Checksum, Compare, DbPath, DbKey, DbTime, Count, TableName, Time, Min, Max, Value +from .utils import safezip, split_space from .databases.base import Database from .databases.database_types import ( + ArithUUID, NumericType, PrecisionType, + StringType, UnknownColType, Schema, Schema_CaseInsensitive, @@ -30,17 +33,6 @@ DEFAULT_BISECTION_FACTOR = 32 -def safezip(*args): - "zip but makes sure all sequences are the same length" - assert len(set(map(len, args))) == 1 - return zip(*args) - - -def split_space(start, end, count): - size = end - start - return list(range(start, end, (size + 1) // (count + 1)))[1 : count + 1] - - @dataclass(frozen=False) class TableSegment: """Signifies a segment of rows (and selected columns) within a table @@ -105,6 +97,7 @@ def with_schema(self) -> "TableSegment": return self schema = self.database.query_table_schema(self.table_path, self._relevant_columns) + logger.debug(f"[{self.database.name}] Schema = {schema}") schema_inst: Schema if self.case_sensitive: @@ -121,9 +114,9 @@ def with_schema(self) -> "TableSegment": def _make_key_range(self): if self.min_key is not None: - yield Compare("<=", str(self.min_key), self._key_column) + yield Compare("<=", Value(self.min_key), self._key_column) if self.max_key is not None: - yield Compare("<", self._key_column, str(self.max_key)) + yield Compare("<", self._key_column, Value(self.max_key)) def _make_update_range(self): if self.min_update is not None: @@ -152,6 +145,11 @@ def get_values(self) -> list: def choose_checkpoints(self, count: int) -> List[DbKey]: "Suggests a bunch of evenly-spaced checkpoints to split by (not including start, end)" assert self.is_bounded + if isinstance(self.min_key, ArithUUID): + checkpoints = split_space(self.min_key.int, self.max_key.int, count) + assert isinstance(self.max_key, ArithUUID) + return [ArithUUID(int=i) for i in checkpoints] + return split_space(self.min_key, self.max_key, count) def segment_by_checkpoints(self, checkpoints: List[DbKey]) -> List["TableSegment"]: @@ -297,9 +295,13 @@ def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult: key_ranges = self._threaded_call("query_key_range", [table1, table2]) mins, maxs = zip(*key_ranges) + key_type = table1._schema["id"] + key_type2 = table2._schema["id"] + assert key_type.python_type is key_type2.python_type + # We add 1 because our ranges are exclusive of the end (like in Python) - min_key = min(map(int, mins)) - max_key = max(map(int, maxs)) + 1 + min_key = min(map(key_type.python_type, mins)) + max_key = max(map(key_type.python_type, maxs)) + 1 table1 = table1.new(min_key=min_key, max_key=max_key) table2 = table2.new(min_key=min_key, max_key=max_key) @@ -324,7 +326,7 @@ def _validate_and_adjust_columns(self, table1, table2): col2 = table2._schema[c] if isinstance(col1, PrecisionType): if not isinstance(col2, PrecisionType): - raise TypeError(f"Incompatible types for column {c}: {col1} <-> {col2}") + raise TypeError(f"Incompatible types for column '{c}': {col1} <-> {col2}") lowest = min(col1, col2, key=attrgetter("precision")) @@ -336,7 +338,7 @@ def _validate_and_adjust_columns(self, table1, table2): elif isinstance(col1, NumericType): if not isinstance(col2, NumericType): - raise TypeError(f"Incompatible types for column {c}: {col1} <-> {col2}") + raise TypeError(f"Incompatible types for column '{c}': {col1} <-> {col2}") lowest = min(col1, col2, key=attrgetter("precision")) @@ -346,12 +348,16 @@ def _validate_and_adjust_columns(self, table1, table2): table1._schema[c] = col1.replace(precision=lowest.precision) table2._schema[c] = col2.replace(precision=lowest.precision) + elif isinstance(col1, StringType): + if not isinstance(col2, StringType): + raise TypeError(f"Incompatible types for column '{c}': {col1} <-> {col2}") + for t in [table1, table2]: for c in t._relevant_columns: ctype = t._schema[c] - if isinstance(ctype, UnknownColType): - logger.warn( - f"[{t.database.name}] Column '{c}' of type '{ctype.text}' has no compatibility handling. " + if not ctype.supported: + logger.warning( + f"[{t.database.name}] Column '{c}' of type '{ctype}' has no compatibility handling. " "If encoding/formatting differs between databases, it may result in false positives." ) diff --git a/data_diff/sql.py b/data_diff/sql.py index eb5a6d66..fb6c92a6 100644 --- a/data_diff/sql.py +++ b/data_diff/sql.py @@ -6,7 +6,7 @@ from runtype import dataclass -from .databases.database_types import AbstractDatabase, DbPath, DbKey, DbTime +from .databases.database_types import AbstractDatabase, DbPath, DbKey, DbTime, ArithUUID class Sql: @@ -65,6 +65,8 @@ def compile(self, c: Compiler): return "b'%s'" % self.value.decode() elif isinstance(self.value, str): return "'%s'" % self.value + elif isinstance(self.value, ArithUUID): + return "'%s'" % self.value return str(self.value) @@ -75,6 +77,7 @@ class Select(Sql): where: Sequence[SqlOrStr] = None order_by: Sequence[SqlOrStr] = None group_by: Sequence[SqlOrStr] = None + limit: int = None def compile(self, parent_c: Compiler): c = parent_c.replace(in_select=True) @@ -93,6 +96,9 @@ def compile(self, parent_c: Compiler): if self.order_by: select += " ORDER BY " + ", ".join(map(c.compile, self.order_by)) + if self.limit is not None: + select += " " + c.database.offset_limit(0, self.limit) + if parent_c.in_select: select = "(%s)" % select return select diff --git a/data_diff/utils.py b/data_diff/utils.py new file mode 100644 index 00000000..ec6c6eea --- /dev/null +++ b/data_diff/utils.py @@ -0,0 +1,40 @@ +from typing import Sequence, Optional, Tuple, Union, Dict, Any +from uuid import UUID + + +def safezip(*args): + "zip but makes sure all sequences are the same length" + assert len(set(map(len, args))) == 1 + return zip(*args) + + +def split_space(start, end, count): + size = end - start + return list(range(start, end, (size + 1) // (count + 1)))[1 : count + 1] + + +class ArithUUID(UUID): + "A UUID that supports basic arithmetic (add, sub)" + + def __add__(self, other: Union[UUID, int]): + if isinstance(other, int): + return type(self)(int=self.int + other) + return NotImplemented + + def __sub__(self, other: Union[UUID, int]): + if isinstance(other, int): + return type(self)(int=self.int - other) + elif isinstance(other, UUID): + return self.int - other.int + return NotImplemented + + def __int__(self): + return self.int + + +def is_uuid(u): + try: + UUID(u) + except ValueError: + return False + return True diff --git a/tests/test_database_types.py b/tests/test_database_types.py index a07dd675..29c814e2 100644 --- a/tests/test_database_types.py +++ b/tests/test_database_types.py @@ -3,6 +3,7 @@ import time import re import math +import uuid from datetime import datetime, timedelta from decimal import Decimal from parameterized import parameterized @@ -158,10 +159,22 @@ def __next__(self) -> float: raise StopIteration +class UUID_Faker: + def __init__(self, max): + self.max = max + + def __len__(self): + return self.max + + def __iter__(self): + return (uuid.uuid1(i) for i in range(self.max)) + + TYPE_SAMPLES = { "int": IntFaker(N_SAMPLES), "datetime_no_timezone": DateTimeFaker(N_SAMPLES), "float": FloatFaker(N_SAMPLES), + "uuid": UUID_Faker(N_SAMPLES), } DATABASE_TYPES = { @@ -185,6 +198,11 @@ def __next__(self) -> float: "double precision", "numeric(6,3)", ], + "uuid": [ + "text", + "varchar(100)", + "char(100)", + ], }, db.MySQL: { # https://dev.mysql.com/doc/refman/8.0/en/integer-types.html @@ -210,6 +228,10 @@ def __next__(self) -> float: "numeric", "numeric(65, 10)", ], + "uuid": [ + "varchar(100)", + "char(100)", + ], }, db.BigQuery: { "int": ["int"], @@ -222,6 +244,9 @@ def __next__(self) -> float: "float64", "bignumeric", ], + "uuid": [ + "STRING", + ], }, db.Snowflake: { # https://docs.snowflake.com/en/sql-reference/data-types-numeric.html#int-integer-bigint-smallint-tinyint-byteint @@ -246,6 +271,10 @@ def __next__(self) -> float: "float", "numeric", ], + "uuid": [ + "varchar", + "varchar(100)", + ], }, db.Redshift: { "int": [ @@ -260,6 +289,11 @@ def __next__(self) -> float: "float8", "numeric", ], + "uuid": [ + "text", + "varchar(100)", + "char(100)", + ], }, db.Oracle: { "int": [ @@ -273,6 +307,14 @@ def __next__(self) -> float: "float": [ "float", "numeric", + "real", + "double precision", + ], + "uuid": [ + "CHAR(100)", + "VARCHAR(100)", + "NCHAR(100)", + "NVARCHAR2(100)", ], }, db.Presto: { @@ -293,6 +335,10 @@ def __next__(self) -> float: "decimal(10,2)", "decimal(30,6)", ], + "uuid": [ + "varchar", + "char(100)", + ], }, } @@ -364,8 +410,10 @@ def _insert_to_table(conn, table, values): for j, sample in values: if isinstance(sample, (float, Decimal, int)): value = str(sample) - else: + elif isinstance(sample, datetime): value = f"timestamp '{sample}'" + else: + value = f"'{sample}'" selects.append(f"SELECT {j}, {value} FROM dual") insertion_query += " UNION ALL ".join(selects) else: @@ -394,7 +442,8 @@ def _drop_table_if_exists(conn, table): conn.query(f"DROP TABLE {table}", None) else: conn.query(f"DROP TABLE IF EXISTS {table}", None) - conn.query("COMMIT", None) + if not isinstance(conn, db.BigQuery): + conn.query("COMMIT", None) class TestDiffCrossDatabaseTables(unittest.TestCase): diff --git a/tests/test_diff_tables.py b/tests/test_diff_tables.py index a5eefcd9..c3f25679 100644 --- a/tests/test_diff_tables.py +++ b/tests/test_diff_tables.py @@ -1,5 +1,6 @@ import datetime import unittest +import uuid import preql import arrow # comes with preql @@ -239,6 +240,46 @@ def test_diff_sorted_by_key(self): self.assertEqual(expected, diff) +class TestStringKeys(TestWithConnection): + def setUp(self): + super().setUp() + + queries = [ + "DROP TABLE IF EXISTS a", + "DROP TABLE IF EXISTS b", + "CREATE TABLE a(id varchar(100), comment varchar(1000))", + "COMMIT", + ] + for i in range(100): + queries.append(f"INSERT INTO a VALUES ('{uuid.uuid1(i)}', '{i}')") + + queries += [ + "COMMIT", + "CREATE TABLE b AS SELECT * FROM a", + "COMMIT", + ] + + self.new_uuid = uuid.uuid1(32132131) + queries.append(f"INSERT INTO a VALUES ('{self.new_uuid}', 'This one is different')") + + # TODO test unexpected values? + + for query in queries: + self.connection.query(query, None) + + self.a = TableSegment(self.connection, ("a",), "id", "comment") + self.b = TableSegment(self.connection, ("b",), "id", "comment") + + def test_string_keys(self): + differ = TableDiffer() + diff = list(differ.diff_tables(self.a, self.b)) + self.assertEqual(diff, [("-", (str(self.new_uuid), "This one is different"))]) + + self.connection.query(f"INSERT INTO a VALUES ('unexpected', '<-- this bad value should not break us')", None) + + self.assertRaises(ValueError, differ.diff_tables, self.a, self.b) + + class TestTableSegment(TestWithConnection): def setUp(self) -> None: super().setUp()