diff --git a/data_diff/databases/database_types.py b/data_diff/databases/database_types.py index 1570fb3c..22691d40 100644 --- a/data_diff/databases/database_types.py +++ b/data_diff/databases/database_types.py @@ -40,9 +40,11 @@ class NumericType(ColType): # 'precision' signifies how many fractional digits (after the dot) we want to compare precision: int + class FractionalType(NumericType): pass + class Float(FractionalType): pass diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index 49c03fd9..b9526e8d 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -96,7 +96,7 @@ def _update_column(self): def _quote_column(self, c): if self._schema: - c = self._schema.get_key(c) # Get the actual name. Might be case-insensitive. + c = self._schema.get_key(c) # Get the actual name. Might be case-insensitive. return self.database.quote(c) def with_schema(self) -> "TableSegment": diff --git a/tests/test_api.py b/tests/test_api.py index e9e56f67..46658479 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -8,14 +8,8 @@ class TestApi(unittest.TestCase): - @classmethod - def setUpClass(cls): - # Avoid leaking connections that require waiting for the GC, which can - # cause deadlocks for table-level modifications. - cls.preql = preql.Preql(TEST_MYSQL_CONN_STRING) - def setUp(self) -> None: - # self.preql = preql.Preql(TEST_MYSQL_CONN_STRING) + self.preql = preql.Preql(TEST_MYSQL_CONN_STRING) self.preql( r""" table test_api { diff --git a/tests/test_database_types.py b/tests/test_database_types.py index 89832ed6..09ea2484 100644 --- a/tests/test_database_types.py +++ b/tests/test_database_types.py @@ -3,7 +3,7 @@ import time import re import math -import datetime +from datetime import datetime, timedelta from decimal import Decimal from parameterized import parameterized @@ -52,12 +52,12 @@ def __next__(self) -> str: class DateTimeFaker: MANUAL_FAKES = [ - datetime.datetime.fromisoformat("2020-01-01 15:10:10"), - datetime.datetime.fromisoformat("2020-02-01 09:09:09"), - datetime.datetime.fromisoformat("2022-03-01 15:10:01.139"), - datetime.datetime.fromisoformat("2022-04-01 15:10:02.020409"), - datetime.datetime.fromisoformat("2022-05-01 15:10:03.003030"), - datetime.datetime.fromisoformat("2022-06-01 15:10:05.009900"), + datetime.fromisoformat("2020-01-01 15:10:10"), + datetime.fromisoformat("2020-02-01 09:09:09"), + datetime.fromisoformat("2022-03-01 15:10:01.139"), + datetime.fromisoformat("2022-04-01 15:10:02.020409"), + datetime.fromisoformat("2022-05-01 15:10:03.003030"), + datetime.fromisoformat("2022-06-01 15:10:05.009900"), ] def __init__(self, max): @@ -65,20 +65,20 @@ def __init__(self, max): def __iter__(self): iter = DateTimeFaker(self.max) - iter.prev = datetime.datetime(2000, 1, 1, 0, 0, 0, 0) + iter.prev = datetime(2000, 1, 1, 0, 0, 0, 0) iter.i = 0 return iter def __len__(self): return self.max - def __next__(self) -> datetime.datetime: + def __next__(self) -> datetime: if self.i < len(self.MANUAL_FAKES): fake = self.MANUAL_FAKES[self.i] self.i += 1 return fake elif self.i < self.max: - self.prev = self.prev + datetime.timedelta(seconds=3, microseconds=571) + self.prev = self.prev + timedelta(seconds=3, microseconds=571) self.i += 1 return self.prev else: @@ -373,7 +373,7 @@ def _insert_to_table(conn, table, values): for j, sample in values: if isinstance(sample, (float, Decimal, int)): value = str(sample) - elif isinstance(sample, datetime.datetime) and isinstance(conn, db.Presto): + elif isinstance(sample, datetime) and isinstance(conn, db.Presto): value = f"timestamp '{sample}'" else: value = f"'{sample}'" @@ -422,6 +422,11 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego _insert_to_table(src_conn, src_table, enumerate(sample_values, 1)) values_in_source = PaginatedTable(src_table, src_conn) + if source_db is db.Presto: + if source_type.startswith("decimal"): + values_in_source = [(a, Decimal(b)) for a, b in values_in_source] + elif source_type.startswith("timestamp"): + values_in_source = [(a, datetime.fromisoformat(b.rstrip(" UTC"))) for a, b in values_in_source] _drop_table_if_exists(dst_conn, dst_table) dst_conn.query(f"CREATE TABLE {dst_table}(id int, col {target_type})", None)