diff --git a/data_diff/database.py b/data_diff/database.py index fa582e05..302b1d5d 100644 --- a/data_diff/database.py +++ b/data_diff/database.py @@ -63,6 +63,9 @@ def import_presto(): class ConnectError(Exception): pass +class QueryError(Exception): + pass + def _one(seq): (x,) = seq @@ -156,9 +159,8 @@ def normalize_value_by_type(value: str, coltype: ColType) -> str: - Dates are expected in the format: "YYYY-MM-DD HH:mm:SS.FFFFFF" - (number of F depends on coltype.precision) - Or if precision=0 then - "YYYY-MM-DD HH:mm:SS" (without the dot) + + Rounded up/down according to coltype.rounds """ ... @@ -474,18 +476,26 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str: class Oracle(ThreadedDatabase): + ROUNDS_ON_PREC_LOSS = True + def __init__(self, host, port, user, password, *, database, thread_count, **kw): assert not port self.kwargs = dict(user=user, password=password, dsn="%s/%s" % (host, database), **kw) super().__init__(thread_count=thread_count) def create_connection(self): - oracle = import_oracle() + self._oracle = import_oracle() try: - return oracle.connect(**self.kwargs) + return self._oracle.connect(**self.kwargs) except Exception as e: raise ConnectError(*e.args) from e + def _query(self, sql_code: str): + try: + return super()._query(sql_code) + except self._oracle.DatabaseError as e: + raise QueryError(e) + def md5_to_int(self, s: str) -> str: # standard_hash is faster than DBMS_CRYPTO.Hash # TODO: Find a way to use UTL_RAW.CAST_TO_BINARY_INTEGER ? @@ -509,9 +519,7 @@ def select_table_schema(self, path: DbPath) -> str: def normalize_value_by_type(self, value: str, coltype: ColType) -> str: if isinstance(coltype, PrecisionType): - if coltype.precision == 0: - return f"to_char(cast({value} as timestamp({coltype.precision})), 'YYYY-MM-DD HH24:MI:SS')" - return f"to_char(cast({value} as timestamp({coltype.precision})), 'YYYY-MM-DD HH24:MI:SS.FF{coltype.precision or ''}')" + return f"to_char(cast({value} as timestamp({coltype.precision})), 'YYYY-MM-DD HH24:MI:SS.FF6')" return self.to_string(f"{value}") def _parse_type(self, type_repr: str, datetime_precision: int = None, numeric_precision: int = None) -> ColType: @@ -524,7 +532,9 @@ def _parse_type(self, type_repr: str, datetime_precision: int = None, numeric_pr m = re.match(regexp + "$", type_repr) if m: datetime_precision = int(m.group(1)) - return cls(precision=datetime_precision if datetime_precision is not None else DEFAULT_PRECISION) + return cls(precision=datetime_precision if datetime_precision is not None else DEFAULT_PRECISION, + rounds=self.ROUNDS_ON_PREC_LOSS + ) return UnknownColType(type_repr) @@ -533,6 +543,25 @@ class Redshift(Postgres): def md5_to_int(self, s: str) -> str: return f"strtol(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16)::decimal(38)" + def normalize_value_by_type(self, value: str, coltype: ColType) -> str: + if isinstance(coltype, TemporalType): + if coltype.rounds: + timestamp = f"{value}::timestamp(6)" + # Get seconds since epoch. Redshift doesn't support milli- or micro-seconds. + secs = f"timestamp 'epoch' + round(extract(epoch from {timestamp})::decimal(38)" + # Get the milliseconds from timestamp. + ms = f"extract(ms from {timestamp})" + # Get the microseconds from timestamp, without the milliseconds! + us = f"extract(us from {timestamp})" + # epoch = Total time since epoch in microseconds. + epoch = f"{secs}*1000000 + {ms}*1000 + {us}" + timestamp6 = f"to_char({epoch}, -6+{coltype.precision}) * interval '0.000001 seconds', 'YYYY-mm-dd HH24:MI:SS.US')" + else: + timestamp6 = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')" + return f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" + + return self.to_string(f"{value}") + class MsSQL(ThreadedDatabase): "AKA sql-server" diff --git a/tests/test_database_types.py b/tests/test_database_types.py index f9ce1648..c9c9d895 100644 --- a/tests/test_database_types.py +++ b/tests/test_database_types.py @@ -1,10 +1,10 @@ +from contextlib import suppress import unittest -import preql import time from data_diff import database as db -from data_diff.diff_tables import TableDiffer, TableSegment, split_space +from data_diff.diff_tables import TableDiffer, TableSegment from parameterized import parameterized, parameterized_class -from .common import CONN_STRINGS, str_to_checksum +from .common import CONN_STRINGS import logging logging.getLogger("diff_tables").setLevel(logging.WARN) @@ -18,11 +18,11 @@ "int": [127, -3, -9, 37, 15, 127], "datetime_no_timezone": [ "2020-01-01 15:10:10", - "2020-01-01 9:9:9", - "2022-01-01 15:10:01.139", - "2022-01-01 15:10:02.020409", - "2022-01-01 15:10:03.003030", - "2022-01-01 15:10:05.009900", + "2020-02-01 9:9:9", + "2022-03-01 15:10:01.139", + "2022-04-01 15:10:02.020409", + "2022-05-01 15:10:03.003030", + "2022-06-01 15:10:05.009900", ], "float": [0.0, 0.1, 0.10, 10.0, 100.98], } @@ -101,7 +101,7 @@ # "int", ], "datetime_no_timezone": [ - # "TIMESTAMP", + "TIMESTAMP", ], # https://docs.aws.amazon.com/redshift/latest/dg/r_Numeric_types201.html#r_Numeric_types201-floating-point-types "float": [ @@ -115,9 +115,9 @@ # "int", ], "datetime_no_timezone": [ - # "timestamp", - # "timestamp(6)", - # "timestamp(9)", + "timestamp with local time zone", + "timestamp(6) with local time zone", + "timestamp(9) with local time zone", ], "float": [ # "float", @@ -179,14 +179,29 @@ def expand_params(testcase_func, param_num, param): def _insert_to_table(conn, table, values): - insertion_query = f"INSERT INTO {table} (id, col) VALUES " - for j, sample in values: - insertion_query += f"({j}, '{sample}')," - - conn.query(insertion_query[0:-1], None) + insertion_query = f"INSERT INTO {table} (id, col) " + + if isinstance(conn, db.Oracle): + selects = [] + for j, sample in values: + selects.append( f"SELECT {j}, timestamp '{sample}' FROM dual" ) + insertion_query += ' UNION ALL '.join(selects) + else: + insertion_query += ' VALUES ' + for j, sample in values: + insertion_query += f"({j}, '{sample}')," + insertion_query = insertion_query[0:-1] + + conn.query(insertion_query, None) if not isinstance(conn, db.BigQuery): conn.query("COMMIT", None) +def _drop_table_if_exists(conn, table): + with suppress(db.QueryError): + if isinstance(conn, db.Oracle): + conn.query(f"DROP TABLE {table}", None) + else: + conn.query(f"DROP TABLE IF EXISTS {table}", None) class TestDiffCrossDatabaseTables(unittest.TestCase): @parameterized.expand(type_pairs, name_func=expand_params) @@ -204,14 +219,14 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego src_table = src_conn.quote(".".join(src_table_path)) dst_table = dst_conn.quote(".".join(dst_table_path)) - src_conn.query(f"DROP TABLE IF EXISTS {src_table}", None) - src_conn.query(f"CREATE TABLE {src_table}(id int, col {source_type});", None) + _drop_table_if_exists(src_conn, src_table) + src_conn.query(f"CREATE TABLE {src_table}(id int, col {source_type})", None) _insert_to_table(src_conn, src_table, enumerate(sample_values, 1)) values_in_source = src_conn.query(f"SELECT id, col FROM {src_table}", list) - dst_conn.query(f"DROP TABLE IF EXISTS {dst_table}", None) - dst_conn.query(f"CREATE TABLE {dst_table}(id int, col {target_type});", None) + _drop_table_if_exists(dst_conn, dst_table) + dst_conn.query(f"CREATE TABLE {dst_table}(id int, col {target_type})", None) _insert_to_table(dst_conn, dst_table, values_in_source) self.table = TableSegment(self.src_conn, src_table_path, "id", None, ("col",), quote_columns=False) @@ -235,3 +250,4 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego duration = time.time() - start # print(f"source_db={source_db.__name__} target_db={target_db.__name__} source_type={source_type} target_type={target_type} duration={round(duration * 1000, 2)}ms") +