Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions data_diff/databases/database_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,11 @@ class NumericType(ColType):
# 'precision' signifies how many fractional digits (after the dot) we want to compare
precision: int


class FractionalType(NumericType):
pass


class Float(FractionalType):
pass

Expand Down
2 changes: 1 addition & 1 deletion data_diff/diff_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def _update_column(self):

def _quote_column(self, c):
if self._schema:
c = self._schema.get_key(c) # Get the actual name. Might be case-insensitive.
c = self._schema.get_key(c) # Get the actual name. Might be case-insensitive.
return self.database.quote(c)

def with_schema(self) -> "TableSegment":
Expand Down
8 changes: 1 addition & 7 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,8 @@


class TestApi(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Avoid leaking connections that require waiting for the GC, which can
# cause deadlocks for table-level modifications.
cls.preql = preql.Preql(TEST_MYSQL_CONN_STRING)

def setUp(self) -> None:
# self.preql = preql.Preql(TEST_MYSQL_CONN_STRING)
self.preql = preql.Preql(TEST_MYSQL_CONN_STRING)
self.preql(
r"""
table test_api {
Expand Down
27 changes: 16 additions & 11 deletions tests/test_database_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import time
import re
import math
import datetime
from datetime import datetime, timedelta
from decimal import Decimal
from parameterized import parameterized

Expand Down Expand Up @@ -52,33 +52,33 @@ def __next__(self) -> str:

class DateTimeFaker:
MANUAL_FAKES = [
datetime.datetime.fromisoformat("2020-01-01 15:10:10"),
datetime.datetime.fromisoformat("2020-02-01 09:09:09"),
datetime.datetime.fromisoformat("2022-03-01 15:10:01.139"),
datetime.datetime.fromisoformat("2022-04-01 15:10:02.020409"),
datetime.datetime.fromisoformat("2022-05-01 15:10:03.003030"),
datetime.datetime.fromisoformat("2022-06-01 15:10:05.009900"),
datetime.fromisoformat("2020-01-01 15:10:10"),
datetime.fromisoformat("2020-02-01 09:09:09"),
datetime.fromisoformat("2022-03-01 15:10:01.139"),
datetime.fromisoformat("2022-04-01 15:10:02.020409"),
datetime.fromisoformat("2022-05-01 15:10:03.003030"),
datetime.fromisoformat("2022-06-01 15:10:05.009900"),
]

def __init__(self, max):
self.max = max

def __iter__(self):
iter = DateTimeFaker(self.max)
iter.prev = datetime.datetime(2000, 1, 1, 0, 0, 0, 0)
iter.prev = datetime(2000, 1, 1, 0, 0, 0, 0)
iter.i = 0
return iter

def __len__(self):
return self.max

def __next__(self) -> datetime.datetime:
def __next__(self) -> datetime:
if self.i < len(self.MANUAL_FAKES):
fake = self.MANUAL_FAKES[self.i]
self.i += 1
return fake
elif self.i < self.max:
self.prev = self.prev + datetime.timedelta(seconds=3, microseconds=571)
self.prev = self.prev + timedelta(seconds=3, microseconds=571)
self.i += 1
return self.prev
else:
Expand Down Expand Up @@ -373,7 +373,7 @@ def _insert_to_table(conn, table, values):
for j, sample in values:
if isinstance(sample, (float, Decimal, int)):
value = str(sample)
elif isinstance(sample, datetime.datetime) and isinstance(conn, db.Presto):
elif isinstance(sample, datetime) and isinstance(conn, db.Presto):
value = f"timestamp '{sample}'"
else:
value = f"'{sample}'"
Expand Down Expand Up @@ -422,6 +422,11 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego
_insert_to_table(src_conn, src_table, enumerate(sample_values, 1))

values_in_source = PaginatedTable(src_table, src_conn)
if source_db is db.Presto:
if source_type.startswith("decimal"):
values_in_source = [(a, Decimal(b)) for a, b in values_in_source]
elif source_type.startswith("timestamp"):
values_in_source = [(a, datetime.fromisoformat(b.rstrip(" UTC"))) for a, b in values_in_source]

_drop_table_if_exists(dst_conn, dst_table)
dst_conn.query(f"CREATE TABLE {dst_table}(id int, col {target_type})", None)
Expand Down