From 1f17b6f21d3b8d8d8618e2b8919ccafa522189c9 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Mon, 7 Nov 2022 16:26:55 -0300 Subject: [PATCH] Added support for Boolean types --- data_diff/databases/base.py | 4 ++ data_diff/databases/database_types.py | 10 +++++ data_diff/databases/mysql.py | 3 ++ data_diff/databases/postgresql.py | 6 +++ data_diff/databases/presto.py | 6 +++ data_diff/databases/snowflake.py | 7 ++- data_diff/databases/vertica.py | 6 +++ tests/test_database_types.py | 65 +++++++++++++++++++++++++-- 8 files changed, 102 insertions(+), 5 deletions(-) diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index 672c4e0b..f31b8f8e 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -30,6 +30,7 @@ Text, DbTime, DbPath, + Boolean, ) logger = logging.getLogger("database") @@ -188,6 +189,9 @@ def parse_type( elif issubclass(cls, Integer): return cls() + elif issubclass(cls, Boolean): + return cls() + elif issubclass(cls, Decimal): if numeric_scale is None: numeric_scale = 0 # Needed for Oracle. diff --git a/data_diff/databases/database_types.py b/data_diff/databases/database_types.py index 296ad475..8bd237be 100644 --- a/data_diff/databases/database_types.py +++ b/data_diff/databases/database_types.py @@ -26,6 +26,9 @@ class PrecisionType(ColType): rounds: bool +class Boolean(ColType): + supported = True + class TemporalType(PrecisionType): pass @@ -250,6 +253,11 @@ def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: """ ... + def normalize_boolean(self, value: str, coltype: Boolean) -> str: + """Creates an SQL expression, that converts 'value' to either '0' or '1'. + """ + return self.to_string(value) + def normalize_value_by_type(self, value: str, coltype: ColType) -> str: """Creates an SQL expression, that converts 'value' to a normalized representation. @@ -272,6 +280,8 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str: return self.normalize_number(value, coltype) elif isinstance(coltype, ColType_UUID): return self.normalize_uuid(value, coltype) + elif isinstance(coltype, Boolean): + return self.normalize_boolean(value, coltype) return self.to_string(value) diff --git a/data_diff/databases/mysql.py b/data_diff/databases/mysql.py index 8f8e1730..1f4058dd 100644 --- a/data_diff/databases/mysql.py +++ b/data_diff/databases/mysql.py @@ -8,6 +8,7 @@ TemporalType, FractionalType, ColType_UUID, + Boolean, ) from .base import ThreadedDatabase, import_helper, ConnectError, BaseDialect from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, TIMESTAMP_PRECISION_POS @@ -39,6 +40,8 @@ class Dialect(BaseDialect): "char": Text, "varbinary": Text, "binary": Text, + # Boolean + "boolean": Boolean, } def quote(self, s: str): diff --git a/data_diff/databases/postgresql.py b/data_diff/databases/postgresql.py index 27df1273..0b31172a 100644 --- a/data_diff/databases/postgresql.py +++ b/data_diff/databases/postgresql.py @@ -8,6 +8,7 @@ Native_UUID, Text, FractionalType, + Boolean, ) from .base import BaseDialect, ThreadedDatabase, import_helper, ConnectError from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, _CHECKSUM_BITSIZE, TIMESTAMP_PRECISION_POS @@ -48,6 +49,8 @@ class PostgresqlDialect(BaseDialect): "text": Text, # UUID "uuid": Native_UUID, + # Boolean + "boolean": Boolean, } def quote(self, s: str): @@ -71,6 +74,9 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: def normalize_number(self, value: str, coltype: FractionalType) -> str: return self.to_string(f"{value}::decimal(38, {coltype.precision})") + def normalize_boolean(self, value: str, coltype: Boolean) -> str: + return self.to_string(f"{value}::int") + def _convert_db_precision_to_digits(self, p: int) -> int: # Subtracting 2 due to wierd precision issues in PostgreSQL return super()._convert_db_precision_to_digits(p) - 2 diff --git a/data_diff/databases/presto.py b/data_diff/databases/presto.py index de54f5b5..51a47b81 100644 --- a/data_diff/databases/presto.py +++ b/data_diff/databases/presto.py @@ -16,6 +16,7 @@ ColType, ColType_UUID, TemporalType, + Boolean, ) from .base import BaseDialect, Database, import_helper, ThreadLocalInterpreter from .base import ( @@ -56,6 +57,8 @@ class Dialect(BaseDialect): "double": Float, # Text "varchar": Text, + # Boolean + "boolean": Boolean, } def explain_as_text(self, query: str) -> str: @@ -95,6 +98,9 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: def normalize_number(self, value: str, coltype: FractionalType) -> str: return self.to_string(f"cast({value} as decimal(38,{coltype.precision}))") + def normalize_boolean(self, value: str, coltype: Boolean) -> str: + return self.to_string(f"cast ({value} as int)") + def parse_type( self, table_path: DbPath, diff --git a/data_diff/databases/snowflake.py b/data_diff/databases/snowflake.py index 5ab5705b..7b016d8d 100644 --- a/data_diff/databases/snowflake.py +++ b/data_diff/databases/snowflake.py @@ -1,7 +1,7 @@ from typing import Union, List import logging -from .database_types import Timestamp, TimestampTZ, Decimal, Float, Text, FractionalType, TemporalType, DbPath +from .database_types import Timestamp, TimestampTZ, Decimal, Float, Text, FractionalType, TemporalType, DbPath, Boolean from .base import BaseDialect, ConnectError, Database, import_helper, CHECKSUM_MASK, ThreadLocalInterpreter @@ -27,6 +27,8 @@ class Dialect(BaseDialect): "FLOAT": Float, # Text "TEXT": Text, + # Boolean + "BOOLEAN": Boolean, } def explain_as_text(self, query: str) -> str: @@ -43,6 +45,9 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: def normalize_number(self, value: str, coltype: FractionalType) -> str: return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))") + def normalize_boolean(self, value: str, coltype: Boolean) -> str: + return self.to_string(f"{value}::int") + def quote(self, s: str): return f'"{s}"' diff --git a/data_diff/databases/vertica.py b/data_diff/databases/vertica.py index 7852800a..d902455b 100644 --- a/data_diff/databases/vertica.py +++ b/data_diff/databases/vertica.py @@ -22,6 +22,7 @@ Text, Timestamp, TimestampTZ, + Boolean, ) @@ -47,6 +48,8 @@ class Dialect(BaseDialect): # Text "char": Text, "varchar": Text, + # Boolean + "boolean": Boolean, } def quote(self, s: str): @@ -77,6 +80,9 @@ def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: # Trim doesn't work on CHAR type return f"TRIM(CAST({value} AS VARCHAR))" + def normalize_boolean(self, value: str, coltype: Boolean) -> str: + return self.to_string(f"cast ({value} as int)") + def is_distinct_from(self, a: str, b: str) -> str: return f"not ({a} <=> {b})" diff --git a/tests/test_database_types.py b/tests/test_database_types.py index 250b4537..848349e3 100644 --- a/tests/test_database_types.py +++ b/tests/test_database_types.py @@ -70,6 +70,9 @@ def init_conns(): "varchar(100)", "char(100)", ], + "boolean": [ + "boolean", + ], }, db.MySQL: { # https://dev.mysql.com/doc/refman/8.0/en/integer-types.html @@ -100,6 +103,9 @@ def init_conns(): "char(100)", "varbinary(100)", ], + "boolean": [ + "boolean", + ], }, db.BigQuery: { "int": ["int"], @@ -115,6 +121,9 @@ def init_conns(): "uuid": [ "STRING", ], + "boolean": [ + "boolean", + ], }, db.Snowflake: { # https://docs.snowflake.com/en/sql-reference/data-types-numeric.html#int-integer-bigint-smallint-tinyint-byteint @@ -144,6 +153,9 @@ def init_conns(): "varchar", "varchar(100)", ], + "boolean": [ + "boolean", + ], }, db.Redshift: { "int": [ @@ -164,6 +176,9 @@ def init_conns(): "varchar(100)", "char(100)", ], + "boolean": [ + "boolean", + ], }, db.Oracle: { "int": [ @@ -187,6 +202,8 @@ def init_conns(): "NCHAR(100)", "NVARCHAR2(100)", ], + "boolean": [ # Oracle has no boolean type + ], }, db.Presto: { "int": [ @@ -210,6 +227,9 @@ def init_conns(): "varchar", "char(100)", ], + "boolean": [ + "boolean", + ], }, db.Databricks: { # https://docs.databricks.com/spark/latest/spark-sql/language-manual/data-types/int-type.html @@ -233,6 +253,9 @@ def init_conns(): "uuid": [ "STRING", ], + "boolean": [ + "boolean", + ], }, db.Trino: { "int": [ @@ -253,6 +276,9 @@ def init_conns(): "varchar", "char(100)", ], + "boolean": [ + "boolean", + ], }, db.Clickhouse: { "int": [ @@ -277,6 +303,9 @@ def init_conns(): "uuid": [ "String", ], + "boolean": [ + "boolean", + ], }, db.Vertica: { "int": ["int"], @@ -295,6 +324,9 @@ def init_conns(): "varchar(100)", "char(100)", ], + "boolean": [ + "boolean", + ], }, } @@ -368,6 +400,18 @@ def __iter__(self): def __len__(self): return self.max +class BooleanFaker: + MANUAL_FAKES = [False, True, True, False] + + def __init__(self, max): + self.max = max + + def __iter__(self): + return iter(self.MANUAL_FAKES[:self.max]) + + def __len__(self): + return min(self.max, len(self.MANUAL_FAKES)) + class FloatFaker: MANUAL_FAKES = [ @@ -417,6 +461,7 @@ def __iter__(self): "datetime": DateTimeFaker(N_SAMPLES), "float": FloatFaker(N_SAMPLES), "uuid": UUID_Faker(N_SAMPLES), + "boolean": BooleanFaker(N_SAMPLES), } @@ -529,6 +574,9 @@ def _insert_to_table(conn, table, values, type): if isinstance(sample, bytearray): value = f"'{sample.decode()}'" + elif type == 'boolean': + value = str(bool(sample)) + elif isinstance(conn, db.Clickhouse): if type.startswith("DateTime64"): value = f"'{sample.replace(tzinfo=None)}'" @@ -567,10 +615,19 @@ def _insert_to_table(conn, table, values, type): insertion_query += " UNION ALL ".join(selects) conn.query(insertion_query, None) selects = [] + insertion_query = default_insertion_query else: conn.query(insertion_query[0:-1], None) insertion_query = default_insertion_query + if insertion_query != default_insertion_query: + # Very bad, but this whole function needs to go + if isinstance(conn, db.Oracle): + insertion_query += " UNION ALL ".join(selects) + conn.query(insertion_query, None) + else: + conn.query(insertion_query[0:-1], None) + if not isinstance(conn, (db.BigQuery, db.Databricks, db.Clickhouse)): conn.query("COMMIT", None) @@ -686,16 +743,16 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego self.table2 = TableSegment(self.dst_conn, dst_table_path, ("id",), None, ("col",), case_sensitive=False) start = time.monotonic() - self.assertEqual(N_SAMPLES, self.table.count()) + self.assertEqual(len(sample_values), self.table.count()) count_source_duration = time.monotonic() - start start = time.monotonic() - self.assertEqual(N_SAMPLES, self.table2.count()) + self.assertEqual(len(sample_values), self.table2.count()) count_target_duration = time.monotonic() - start # When testing, we configure these to their lowest possible values for # the DEFAULT_N_SAMPLES. - # When benchmarking, we try to dynamically create some more optimal + # When benchmarking, we try to dynamically create some more optima # configuration with each segment being ~250k rows. ch_factor = min(max(int(N_SAMPLES / 250_000), 2), 128) if BENCHMARK else 2 ch_threshold = min(DEFAULT_BISECTION_THRESHOLD, int(N_SAMPLES / ch_factor)) if BENCHMARK else 3 @@ -710,7 +767,7 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego checksum_duration = time.monotonic() - start expected = [] self.assertEqual(expected, diff) - self.assertEqual(0, differ.stats.get("rows_downloaded", 0)) + self.assertEqual(0, differ.stats.get("rows_downloaded", 0)) # This may fail if the hash is different, but downloaded values are equal # This section downloads all rows to ensure that Python agrees with the # database, in terms of comparison.