Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

tests: fix presto, add faker classes #106

Merged
merged 3 commits into from
Jun 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions data_diff/databases/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,13 @@ def to_string(self, s: str):

def _query(self, sql_code: str) -> list:
"Uses the standard SQL cursor interface"
return _query_conn(self._conn, sql_code)
c = self._conn.cursor()
c.execute(sql_code)
if sql_code.lower().startswith("select"):
return c.fetchall()
# Required for the query to actually run 🤯
if re.match(r"(insert|create|truncate|drop)", sql_code, re.IGNORECASE):
return c.fetchone()

def close(self):
self._conn.close()
Expand Down Expand Up @@ -88,7 +94,7 @@ def _parse_type(
datetime_precision = int(m.group(1))
return cls(
precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION,
rounds=False,
rounds=self.ROUNDS_ON_PREC_LOSS,
)

number_regexps = {r"decimal\((\d+),(\d+)\)": Decimal}
Expand Down
1 change: 1 addition & 0 deletions dev/presto-conf/standalone/catalog/postgresql.properties
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ connector.name=postgresql
connection-url=jdbc:postgresql://postgres:5432/postgres
connection-user=postgres
connection-password=Password1
allow-drop-table=true
13 changes: 11 additions & 2 deletions tests/common.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import hashlib
import os

from data_diff import databases as db
import logging

logging.basicConfig(level=logging.INFO)

TEST_MYSQL_CONN_STRING: str = "mysql://mysql:Password1@localhost/mysql"
TEST_POSTGRESQL_CONN_STRING: str = None
TEST_SNOWFLAKE_CONN_STRING: str = None
Expand All @@ -13,6 +12,16 @@
TEST_ORACLE_CONN_STRING: str = None
TEST_PRESTO_CONN_STRING: str = None

DEFAULT_N_SAMPLES = 50
N_SAMPLES = int(os.environ.get("N_SAMPLES", DEFAULT_N_SAMPLES))

level = logging.ERROR
if os.environ.get("LOG_LEVEL", False):
level = getattr(logging, os.environ["LOG_LEVEL"].upper())

logging.basicConfig(level=level)
logging.getLogger("diff_tables").setLevel(level)
logging.getLogger("database").setLevel(level)

try:
from .local_settings import *
Expand Down
220 changes: 181 additions & 39 deletions tests/test_database_types.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,120 @@
from contextlib import suppress
import unittest
import time
import logging
import re
import math
import datetime
from decimal import Decimal

from parameterized import parameterized

from data_diff import databases as db
from data_diff.diff_tables import TableDiffer, TableSegment
from .common import CONN_STRINGS

from .common import CONN_STRINGS, N_SAMPLES

logging.getLogger("diff_tables").setLevel(logging.ERROR)
logging.getLogger("database").setLevel(logging.WARN)

CONNS = {k: db.connect_to_uri(v) for k, v in CONN_STRINGS.items()}
CONNS = {k: db.connect_to_uri(v, 1) for k, v in CONN_STRINGS.items()}

CONNS[db.MySQL].query("SET @@session.time_zone='+00:00'", None)

TYPE_SAMPLES = {
"int": [127, -3, -9, 37, 15, 127],
"datetime_no_timezone": [
"2020-01-01 15:10:10",
"2020-02-01 9:9:9",
"2022-03-01 15:10:01.139",
"2022-04-01 15:10:02.020409",
"2022-05-01 15:10:03.003030",
"2022-06-01 15:10:05.009900",
],
"float": [

class PaginatedTable:
# We can't query all the rows at once for large tables. It'll occupy too
# much memory.
RECORDS_PER_BATCH = 1000000

def __init__(self, table, conn):
self.table = table
self.conn = conn

def __iter__(self):
iter = PaginatedTable(self.table, self.conn)
iter.last_id = 0
iter.values = []
iter.value_index = 0
return iter

def __next__(self) -> str:
if self.value_index == len(self.values): # end of current batch
query = f"SELECT id, col FROM {self.table} WHERE id > {self.last_id} ORDER BY id ASC LIMIT {self.RECORDS_PER_BATCH}"
if isinstance(self.conn, db.Oracle):
query = f"SELECT id, col FROM {self.table} WHERE id > {self.last_id} ORDER BY id ASC OFFSET 0 ROWS FETCH NEXT {self.RECORDS_PER_BATCH} ROWS ONLY"

self.values = self.conn.query(query, list)
if len(self.values) == 0: # we must be done!
raise StopIteration
self.last_id = self.values[-1][0]
self.value_index = 0

this_value = self.values[self.value_index]
self.value_index += 1
return this_value


class DateTimeFaker:
MANUAL_FAKES = [
datetime.datetime.fromisoformat("2020-01-01 15:10:10"),
datetime.datetime.fromisoformat("2020-02-01 09:09:09"),
datetime.datetime.fromisoformat("2022-03-01 15:10:01.139"),
datetime.datetime.fromisoformat("2022-04-01 15:10:02.020409"),
datetime.datetime.fromisoformat("2022-05-01 15:10:03.003030"),
datetime.datetime.fromisoformat("2022-06-01 15:10:05.009900"),
]

def __init__(self, max):
self.max = max

def __iter__(self):
iter = DateTimeFaker(self.max)
iter.prev = datetime.datetime(2000, 1, 1, 0, 0, 0, 0)
iter.i = 0
return iter

def __len__(self):
return self.max

def __next__(self) -> datetime.datetime:
if self.i < len(self.MANUAL_FAKES):
fake = self.MANUAL_FAKES[self.i]
self.i += 1
return fake
elif self.i < self.max:
self.prev = self.prev + datetime.timedelta(seconds=3, microseconds=571)
self.i += 1
return self.prev
else:
raise StopIteration


class IntFaker:
MANUAL_FAKES = [127, -3, -9, 37, 15, 127]

def __init__(self, max):
self.max = max

def __iter__(self):
iter = IntFaker(self.max)
iter.prev = -128
iter.i = 0
return iter

def __len__(self):
return self.max

def __next__(self) -> int:
if self.i < len(self.MANUAL_FAKES):
fake = self.MANUAL_FAKES[self.i]
self.i += 1
return fake
elif self.i < self.max:
self.prev += 1
self.i += 1
return self.prev
else:
raise StopIteration


class FloatFaker:
MANUAL_FAKES = [
0.0,
0.1,
0.00188,
Expand All @@ -45,15 +131,45 @@
1 / 1094893892389,
1 / 10948938923893289,
3.141592653589793,
],
]

def __init__(self, max):
self.max = max

def __iter__(self):
iter = FloatFaker(self.max)
iter.prev = -10.0001
iter.i = 0
return iter

def __len__(self):
return self.max

def __next__(self) -> float:
if self.i < len(self.MANUAL_FAKES):
fake = self.MANUAL_FAKES[self.i]
self.i += 1
return fake
elif self.i < self.max:
self.prev += 0.00571
self.i += 1
return self.prev
else:
raise StopIteration


TYPE_SAMPLES = {
"int": IntFaker(N_SAMPLES),
"datetime_no_timezone": DateTimeFaker(N_SAMPLES),
"float": FloatFaker(N_SAMPLES),
}

DATABASE_TYPES = {
db.PostgreSQL: {
# https://www.postgresql.org/docs/current/datatype-numeric.html#DATATYPE-INT
"int": [
# "smallint", # 2 bytes
# "int", # 4 bytes
"int", # 4 bytes
# "bigint", # 8 bytes
],
# https://www.postgresql.org/docs/current/datatype-datetime.html
Expand All @@ -76,7 +192,7 @@
# "tinyint", # 1 byte
# "smallint", # 2 bytes
# "mediumint", # 3 bytes
# "int", # 4 bytes
"int", # 4 bytes
# "bigint", # 8 bytes
],
# https://dev.mysql.com/doc/refman/8.0/en/datetime.html
Expand All @@ -96,6 +212,7 @@
],
},
db.BigQuery: {
"int": ["int"],
"datetime_no_timezone": [
"timestamp",
# "datetime",
Expand All @@ -110,7 +227,7 @@
# https://docs.snowflake.com/en/sql-reference/data-types-numeric.html#int-integer-bigint-smallint-tinyint-byteint
"int": [
# all 38 digits with 0 precision, don't need to test all
# "int",
"int",
# "integer",
# "bigint",
# "smallint",
Expand All @@ -132,7 +249,7 @@
},
db.Redshift: {
"int": [
# "int",
"int",
],
"datetime_no_timezone": [
"TIMESTAMP",
Expand All @@ -146,7 +263,7 @@
},
db.Oracle: {
"int": [
# "int",
"int",
],
"datetime_no_timezone": [
"timestamp with local time zone",
Expand All @@ -163,15 +280,12 @@
# "tinyint", # 1 byte
# "smallint", # 2 bytes
# "mediumint", # 3 bytes
# "int", # 4 bytes
"int", # 4 bytes
# "bigint", # 8 bytes
],
"datetime_no_timezone": [
"timestamp(6)",
"timestamp(3)",
"timestamp(0)",
"timestamp",
"datetime(6)",
"timestamp with time zone",
],
"float": [
"real",
Expand Down Expand Up @@ -203,18 +317,43 @@
)
)


def sanitize(name):
name = name.lower()
name = re.sub(r"[\(\)]", "", name) # timestamp(9) -> timestamp9
# Try to shorten long fields, due to length limitations in some DBs
name = name.replace(r"without time zone", "n_tz")
name = name.replace(r"with time zone", "y_tz")
name = name.replace(r"with local time zone", "y_tz")
name = name.replace(r"timestamp", "ts")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer shortening and never truncating names. That could lead to some nasty, nasty bugs down the line.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a nice idea!

return parameterized.to_safe_name(name)


def number_to_human(n):
millnames = ["", "k", "m", "b"]
n = float(n)
millidx = max(
0,
min(len(millnames) - 1, int(math.floor(0 if n == 0 else math.log10(abs(n)) / 3))),
)

return "{:.0f}{}".format(n / 10 ** (3 * millidx), millnames[millidx])


# Pass --verbose to test run to get a nice output.
def expand_params(testcase_func, param_num, param):
source_db, target_db, source_type, target_type, type_category = param.args
source_db_type = source_db.__name__
target_db_type = target_db.__name__
return "%s_%s_%s_to_%s_%s" % (
name = "%s_%s_%s_to_%s_%s_%s" % (
testcase_func.__name__,
source_db_type,
parameterized.to_safe_name(source_type),
target_db_type,
parameterized.to_safe_name(target_type),
sanitize(source_db_type),
sanitize(source_type),
sanitize(target_db_type),
sanitize(target_type),
number_to_human(N_SAMPLES),
)
return name


def _insert_to_table(conn, table, values):
Expand All @@ -232,8 +371,10 @@ def _insert_to_table(conn, table, values):
else:
insertion_query += " VALUES "
for j, sample in values:
if isinstance(sample, (float, Decimal)):
if isinstance(sample, (float, Decimal, int)):
value = str(sample)
elif isinstance(sample, datetime.datetime) and isinstance(conn, db.Presto):
value = f"timestamp '{sample}'"
else:
value = f"'{sample}'"
insertion_query += f"({j}, {value}),"
Expand All @@ -253,6 +394,7 @@ def _drop_table_if_exists(conn, table):
conn.query(f"DROP TABLE {table}", None)
else:
conn.query(f"DROP TABLE IF EXISTS {table}", None)
conn.query("COMMIT", None)


class TestDiffCrossDatabaseTables(unittest.TestCase):
Expand All @@ -266,9 +408,9 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego
self.connections = [self.src_conn, self.dst_conn]
sample_values = TYPE_SAMPLES[type_category]

# Limit in MySQL is 64
src_table_name = f"src_{self._testMethodName[:60]}"
dst_table_name = f"dst_{self._testMethodName[:60]}"
# Limit in MySQL is 64, Presto seems to be 63
src_table_name = f"src_{self._testMethodName[11:]}"
dst_table_name = f"dst_{self._testMethodName[11:]}"

src_table_path = src_conn.parse_table_name(src_table_name)
dst_table_path = dst_conn.parse_table_name(dst_table_name)
Expand All @@ -279,7 +421,7 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego
src_conn.query(f"CREATE TABLE {src_table}(id int, col {source_type})", None)
_insert_to_table(src_conn, src_table, enumerate(sample_values, 1))

values_in_source = src_conn.query(f"SELECT id, col FROM {src_table}", list)
values_in_source = PaginatedTable(src_table, src_conn)

_drop_table_if_exists(dst_conn, dst_table)
dst_conn.query(f"CREATE TABLE {dst_table}(id int, col {target_type})", None)
Expand Down