Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Added support for Boolean types #282

Merged
merged 1 commit into from
Nov 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions data_diff/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
Text,
DbTime,
DbPath,
Boolean,
)

logger = logging.getLogger("database")
Expand Down Expand Up @@ -188,6 +189,9 @@ def parse_type(
elif issubclass(cls, Integer):
return cls()

elif issubclass(cls, Boolean):
return cls()

elif issubclass(cls, Decimal):
if numeric_scale is None:
numeric_scale = 0 # Needed for Oracle.
Expand Down
10 changes: 10 additions & 0 deletions data_diff/databases/database_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ class PrecisionType(ColType):
rounds: bool


class Boolean(ColType):
supported = True

class TemporalType(PrecisionType):
pass

Expand Down Expand Up @@ -250,6 +253,11 @@ def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
"""
...

def normalize_boolean(self, value: str, coltype: Boolean) -> str:
"""Creates an SQL expression, that converts 'value' to either '0' or '1'.
"""
return self.to_string(value)

def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
"""Creates an SQL expression, that converts 'value' to a normalized representation.

Expand All @@ -272,6 +280,8 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
return self.normalize_number(value, coltype)
elif isinstance(coltype, ColType_UUID):
return self.normalize_uuid(value, coltype)
elif isinstance(coltype, Boolean):
return self.normalize_boolean(value, coltype)
return self.to_string(value)


Expand Down
3 changes: 3 additions & 0 deletions data_diff/databases/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
TemporalType,
FractionalType,
ColType_UUID,
Boolean,
)
from .base import ThreadedDatabase, import_helper, ConnectError, BaseDialect
from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, TIMESTAMP_PRECISION_POS
Expand Down Expand Up @@ -39,6 +40,8 @@ class Dialect(BaseDialect):
"char": Text,
"varbinary": Text,
"binary": Text,
# Boolean
"boolean": Boolean,
}

def quote(self, s: str):
Expand Down
6 changes: 6 additions & 0 deletions data_diff/databases/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Native_UUID,
Text,
FractionalType,
Boolean,
)
from .base import BaseDialect, ThreadedDatabase, import_helper, ConnectError
from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, _CHECKSUM_BITSIZE, TIMESTAMP_PRECISION_POS
Expand Down Expand Up @@ -48,6 +49,8 @@ class PostgresqlDialect(BaseDialect):
"text": Text,
# UUID
"uuid": Native_UUID,
# Boolean
"boolean": Boolean,
}

def quote(self, s: str):
Expand All @@ -71,6 +74,9 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
def normalize_number(self, value: str, coltype: FractionalType) -> str:
return self.to_string(f"{value}::decimal(38, {coltype.precision})")

def normalize_boolean(self, value: str, coltype: Boolean) -> str:
return self.to_string(f"{value}::int")

def _convert_db_precision_to_digits(self, p: int) -> int:
# Subtracting 2 due to wierd precision issues in PostgreSQL
return super()._convert_db_precision_to_digits(p) - 2
Expand Down
6 changes: 6 additions & 0 deletions data_diff/databases/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
ColType,
ColType_UUID,
TemporalType,
Boolean,
)
from .base import BaseDialect, Database, import_helper, ThreadLocalInterpreter
from .base import (
Expand Down Expand Up @@ -56,6 +57,8 @@ class Dialect(BaseDialect):
"double": Float,
# Text
"varchar": Text,
# Boolean
"boolean": Boolean,
}

def explain_as_text(self, query: str) -> str:
Expand Down Expand Up @@ -95,6 +98,9 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
def normalize_number(self, value: str, coltype: FractionalType) -> str:
return self.to_string(f"cast({value} as decimal(38,{coltype.precision}))")

def normalize_boolean(self, value: str, coltype: Boolean) -> str:
return self.to_string(f"cast ({value} as int)")

def parse_type(
self,
table_path: DbPath,
Expand Down
7 changes: 6 additions & 1 deletion data_diff/databases/snowflake.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Union, List
import logging

from .database_types import Timestamp, TimestampTZ, Decimal, Float, Text, FractionalType, TemporalType, DbPath
from .database_types import Timestamp, TimestampTZ, Decimal, Float, Text, FractionalType, TemporalType, DbPath, Boolean
from .base import BaseDialect, ConnectError, Database, import_helper, CHECKSUM_MASK, ThreadLocalInterpreter


Expand All @@ -27,6 +27,8 @@ class Dialect(BaseDialect):
"FLOAT": Float,
# Text
"TEXT": Text,
# Boolean
"BOOLEAN": Boolean,
}

def explain_as_text(self, query: str) -> str:
Expand All @@ -43,6 +45,9 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
def normalize_number(self, value: str, coltype: FractionalType) -> str:
return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))")

def normalize_boolean(self, value: str, coltype: Boolean) -> str:
return self.to_string(f"{value}::int")

def quote(self, s: str):
return f'"{s}"'

Expand Down
6 changes: 6 additions & 0 deletions data_diff/databases/vertica.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
Text,
Timestamp,
TimestampTZ,
Boolean,
)


Expand All @@ -47,6 +48,8 @@ class Dialect(BaseDialect):
# Text
"char": Text,
"varchar": Text,
# Boolean
"boolean": Boolean,
}

def quote(self, s: str):
Expand Down Expand Up @@ -77,6 +80,9 @@ def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
# Trim doesn't work on CHAR type
return f"TRIM(CAST({value} AS VARCHAR))"

def normalize_boolean(self, value: str, coltype: Boolean) -> str:
return self.to_string(f"cast ({value} as int)")

def is_distinct_from(self, a: str, b: str) -> str:
return f"not ({a} <=> {b})"

Expand Down
65 changes: 61 additions & 4 deletions tests/test_database_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ def init_conns():
"varchar(100)",
"char(100)",
],
"boolean": [
"boolean",
],
},
db.MySQL: {
# https://dev.mysql.com/doc/refman/8.0/en/integer-types.html
Expand Down Expand Up @@ -100,6 +103,9 @@ def init_conns():
"char(100)",
"varbinary(100)",
],
"boolean": [
"boolean",
],
},
db.BigQuery: {
"int": ["int"],
Expand All @@ -115,6 +121,9 @@ def init_conns():
"uuid": [
"STRING",
],
"boolean": [
"boolean",
],
},
db.Snowflake: {
# https://docs.snowflake.com/en/sql-reference/data-types-numeric.html#int-integer-bigint-smallint-tinyint-byteint
Expand Down Expand Up @@ -144,6 +153,9 @@ def init_conns():
"varchar",
"varchar(100)",
],
"boolean": [
"boolean",
],
},
db.Redshift: {
"int": [
Expand All @@ -164,6 +176,9 @@ def init_conns():
"varchar(100)",
"char(100)",
],
"boolean": [
"boolean",
],
},
db.Oracle: {
"int": [
Expand All @@ -187,6 +202,8 @@ def init_conns():
"NCHAR(100)",
"NVARCHAR2(100)",
],
"boolean": [ # Oracle has no boolean type
],
},
db.Presto: {
"int": [
Expand All @@ -210,6 +227,9 @@ def init_conns():
"varchar",
"char(100)",
],
"boolean": [
"boolean",
],
},
db.Databricks: {
# https://docs.databricks.com/spark/latest/spark-sql/language-manual/data-types/int-type.html
Expand All @@ -233,6 +253,9 @@ def init_conns():
"uuid": [
"STRING",
],
"boolean": [
"boolean",
],
},
db.Trino: {
"int": [
Expand All @@ -253,6 +276,9 @@ def init_conns():
"varchar",
"char(100)",
],
"boolean": [
"boolean",
],
},
db.Clickhouse: {
"int": [
Expand All @@ -277,6 +303,9 @@ def init_conns():
"uuid": [
"String",
],
"boolean": [
"boolean",
],
},
db.Vertica: {
"int": ["int"],
Expand All @@ -295,6 +324,9 @@ def init_conns():
"varchar(100)",
"char(100)",
],
"boolean": [
"boolean",
],
},
}

Expand Down Expand Up @@ -368,6 +400,18 @@ def __iter__(self):
def __len__(self):
return self.max

class BooleanFaker:
MANUAL_FAKES = [False, True, True, False]

def __init__(self, max):
self.max = max

def __iter__(self):
return iter(self.MANUAL_FAKES[:self.max])

def __len__(self):
return min(self.max, len(self.MANUAL_FAKES))


class FloatFaker:
MANUAL_FAKES = [
Expand Down Expand Up @@ -417,6 +461,7 @@ def __iter__(self):
"datetime": DateTimeFaker(N_SAMPLES),
"float": FloatFaker(N_SAMPLES),
"uuid": UUID_Faker(N_SAMPLES),
"boolean": BooleanFaker(N_SAMPLES),
}


Expand Down Expand Up @@ -529,6 +574,9 @@ def _insert_to_table(conn, table, values, type):
if isinstance(sample, bytearray):
value = f"'{sample.decode()}'"

elif type == 'boolean':
value = str(bool(sample))

elif isinstance(conn, db.Clickhouse):
if type.startswith("DateTime64"):
value = f"'{sample.replace(tzinfo=None)}'"
Expand Down Expand Up @@ -567,10 +615,19 @@ def _insert_to_table(conn, table, values, type):
insertion_query += " UNION ALL ".join(selects)
conn.query(insertion_query, None)
selects = []
insertion_query = default_insertion_query
else:
conn.query(insertion_query[0:-1], None)
insertion_query = default_insertion_query

if insertion_query != default_insertion_query:
# Very bad, but this whole function needs to go
if isinstance(conn, db.Oracle):
insertion_query += " UNION ALL ".join(selects)
conn.query(insertion_query, None)
else:
conn.query(insertion_query[0:-1], None)

if not isinstance(conn, (db.BigQuery, db.Databricks, db.Clickhouse)):
conn.query("COMMIT", None)

Expand Down Expand Up @@ -686,16 +743,16 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego
self.table2 = TableSegment(self.dst_conn, dst_table_path, ("id",), None, ("col",), case_sensitive=False)

start = time.monotonic()
self.assertEqual(N_SAMPLES, self.table.count())
self.assertEqual(len(sample_values), self.table.count())
count_source_duration = time.monotonic() - start

start = time.monotonic()
self.assertEqual(N_SAMPLES, self.table2.count())
self.assertEqual(len(sample_values), self.table2.count())
count_target_duration = time.monotonic() - start

# When testing, we configure these to their lowest possible values for
# the DEFAULT_N_SAMPLES.
# When benchmarking, we try to dynamically create some more optimal
# When benchmarking, we try to dynamically create some more optima
# configuration with each segment being ~250k rows.
ch_factor = min(max(int(N_SAMPLES / 250_000), 2), 128) if BENCHMARK else 2
ch_threshold = min(DEFAULT_BISECTION_THRESHOLD, int(N_SAMPLES / ch_factor)) if BENCHMARK else 3
Expand All @@ -710,7 +767,7 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego
checksum_duration = time.monotonic() - start
expected = []
self.assertEqual(expected, diff)
self.assertEqual(0, differ.stats.get("rows_downloaded", 0))
self.assertEqual(0, differ.stats.get("rows_downloaded", 0)) # This may fail if the hash is different, but downloaded values are equal

# This section downloads all rows to ensure that Python agrees with the
# database, in terms of comparison.
Expand Down