diff --git a/data_diff/__init__.py b/data_diff/__init__.py index 4bc73733..11b0b28e 100644 --- a/data_diff/__init__.py +++ b/data_diff/__init__.py @@ -1,6 +1,6 @@ from typing import Tuple, Iterator, Optional, Union -from .database import connect_to_uri +from .databases.connect import connect_to_uri from .diff_tables import ( TableSegment, TableDiffer, @@ -9,7 +9,6 @@ DbKey, DbTime, DbPath, - parse_table_name, ) @@ -18,10 +17,11 @@ def connect_to_table( ): """Connects to a URI and creates a TableSegment instance""" + db = connect_to_uri(db_uri, thread_count=thread_count) + if isinstance(table_name, str): - table_name = parse_table_name(table_name) + table_name = db.parse_table_name(table_name) - db = connect_to_uri(db_uri, thread_count=thread_count) return TableSegment(db, table_name, key_column, **kwargs) diff --git a/data_diff/__main__.py b/data_diff/__main__.py index 0514bd57..b0190eec 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -9,9 +9,8 @@ TableDiffer, DEFAULT_BISECTION_THRESHOLD, DEFAULT_BISECTION_FACTOR, - parse_table_name, ) -from .database import connect_to_uri, parse_table_name +from .databases.connect import connect_to_uri from .parse_time import parse_time_before_now, UNITS_STR, ParseError import rich @@ -51,7 +50,7 @@ @click.option("--max-age", default=None, help="Considers only rows younger than specified. See --min-age.") @click.option("-s", "--stats", is_flag=True, help="Print stats instead of a detailed diff") @click.option("-d", "--debug", is_flag=True, help="Print debug info") -@click.option("--json", 'json_output', is_flag=True, help="Print JSONL output for machine readability") +@click.option("--json", "json_output", is_flag=True, help="Print JSONL output for machine readability") @click.option("-v", "--verbose", is_flag=True, help="Print extra info") @click.option("-i", "--interactive", is_flag=True, help="Confirm queries, implies --debug") @click.option("--keep-column-case", is_flag=True, help="Don't use the schema to fix the case of given column names.") @@ -104,7 +103,7 @@ def main( try: threads = int(threads) except ValueError: - logger.error("Error: threads must be a number, 'auto', or 'serial'.") + logging.error("Error: threads must be a number, 'auto', or 'serial'.") return if threads < 1: logging.error("Error: threads must be >= 1") @@ -129,8 +128,8 @@ def main( logging.error("Error while parsing age expression: %s" % e) return - table1 = TableSegment(db1, parse_table_name(table1_name), key_column, update_column, columns, **options) - table2 = TableSegment(db2, parse_table_name(table2_name), key_column, update_column, columns, **options) + table1 = TableSegment(db1, db1.parse_table_name(table1_name), key_column, update_column, columns, **options) + table2 = TableSegment(db2, db2.parse_table_name(table2_name), key_column, update_column, columns, **options) differ = TableDiffer( bisection_factor=bisection_factor, diff --git a/data_diff/database.py b/data_diff/database.py deleted file mode 100644 index 01a0850a..00000000 --- a/data_diff/database.py +++ /dev/null @@ -1,939 +0,0 @@ -import math -from functools import lru_cache, wraps -from itertools import zip_longest -import re -from abc import ABC, abstractmethod -import logging -from typing import Sequence, Tuple, Optional, List, Type -from concurrent.futures import ThreadPoolExecutor -import threading -from typing import Dict -import dsnparse -import sys - -from runtype import dataclass - -from .sql import DbPath, SqlOrStr, Compiler, Explain, Select -from .database_types import * - - -logger = logging.getLogger("database") - - -def parse_table_name(t): - return tuple(t.split(".")) - - -def import_helper(package: str = None, text=""): - def dec(f): - @wraps(f) - def _inner(): - try: - return f() - except ModuleNotFoundError as e: - s = text - if package: - s += f"You can install it using 'pip install data-diff[{package}]'." - raise ModuleNotFoundError(f"{e}\n\n{s}\n") - - return _inner - - return dec - - -@import_helper("postgresql") -def import_postgresql(): - import psycopg2 - import psycopg2.extras - - psycopg2.extensions.set_wait_callback(psycopg2.extras.wait_select) - return psycopg2 - - -@import_helper("mysql") -def import_mysql(): - import mysql.connector - - return mysql.connector - - -@import_helper("snowflake") -def import_snowflake(): - import snowflake.connector - - return snowflake - - -def import_mssql(): - import pymssql - - return pymssql - - -def import_oracle(): - import cx_Oracle - - return cx_Oracle - - -@import_helper("presto") -def import_presto(): - import prestodb - - return prestodb - - -@import_helper(text="Please install BigQuery and configure your google-cloud access.") -def import_bigquery(): - from google.cloud import bigquery - - return bigquery - - -class ConnectError(Exception): - pass - - -class QueryError(Exception): - pass - - -def _one(seq): - (x,) = seq - return x - - -def _query_conn(conn, sql_code: str) -> list: - c = conn.cursor() - c.execute(sql_code) - if sql_code.lower().startswith("select"): - return c.fetchall() - - - -class Database(AbstractDatabase): - """Base abstract class for databases. - - Used for providing connection code and implementation specific SQL utilities. - - Instanciated using :meth:`~data_diff.connect_to_uri` - """ - - DATETIME_TYPES: Dict[str, type] = {} - default_schema: str = None - - @property - def name(self): - return type(self).__name__ - - def query(self, sql_ast: SqlOrStr, res_type: type): - "Query the given SQL code/AST, and attempt to convert the result to type 'res_type'" - - compiler = Compiler(self) - sql_code = compiler.compile(sql_ast) - logger.debug("Running SQL (%s): %s", type(self).__name__, sql_code) - if getattr(self, "_interactive", False) and isinstance(sql_ast, Select): - explained_sql = compiler.compile(Explain(sql_ast)) - logger.info(f"EXPLAIN for SQL SELECT") - logger.info(self._query(explained_sql)) - answer = input("Continue? [y/n] ") - if not answer.lower() in ["y", "yes"]: - sys.exit(1) - - res = self._query(sql_code) - if res_type is int: - res = _one(_one(res)) - if res is None: # May happen due to sum() of 0 items - return None - return int(res) - elif res_type is tuple: - assert len(res) == 1, (sql_code, res) - return res[0] - elif getattr(res_type, "__origin__", None) is list and len(res_type.__args__) == 1: - if res_type.__args__ == (int,): - return [_one(row) for row in res] - elif res_type.__args__ == (Tuple,): - return [tuple(row) for row in res] - else: - raise ValueError(res_type) - return res - - def enable_interactive(self): - self._interactive = True - - def _convert_db_precision_to_digits(self, p: int) -> int: - """Convert from binary precision, used by floats, to decimal precision.""" - # See: https://en.wikipedia.org/wiki/Single-precision_floating-point_format - return math.floor(math.log(2**p, 10)) - - def _parse_type( - self, - col_name: str, - type_repr: str, - datetime_precision: int = None, - numeric_precision: int = None, - numeric_scale: int = None, - ) -> ColType: - """ """ - - cls = self.DATETIME_TYPES.get(type_repr) - if cls: - return cls( - precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION, - rounds=self.ROUNDS_ON_PREC_LOSS, - ) - - cls = self.NUMERIC_TYPES.get(type_repr) - if cls: - if issubclass(cls, Integer): - # Some DBs have a constant numeric_scale, so they don't report it. - # We fill in the constant, so we need to ignore it for integers. - return cls(precision=0) - - elif issubclass(cls, Decimal): - if numeric_scale is None: - raise ValueError(f"{self.name}: Unexpected numeric_scale is NULL, for column {col_name} of type {type_repr}.") - return cls(precision=numeric_scale) - - assert issubclass(cls, Float) - # assert numeric_scale is None - return cls( - precision=self._convert_db_precision_to_digits( - numeric_precision if numeric_precision is not None else DEFAULT_NUMERIC_PRECISION - ) - ) - - return UnknownColType(type_repr) - - def select_table_schema(self, path: DbPath) -> str: - schema, table = self._normalize_table_path(path) - - return ( - "SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale FROM information_schema.columns " - f"WHERE table_name = '{table}' AND table_schema = '{schema}'" - ) - - def query_table_schema(self, path: DbPath, filter_columns: Optional[Sequence[str]] = None) -> Dict[str, ColType]: - rows = self.query(self.select_table_schema(path), list) - if not rows: - raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns") - - if filter_columns is not None: - accept = {i.lower() for i in filter_columns} - rows = [r for r in rows if r[0].lower() in accept] - - # Return a dict of form {name: type} after normalization - return {row[0]: self._parse_type(*row) for row in rows} - - # @lru_cache() - # def get_table_schema(self, path: DbPath) -> Dict[str, ColType]: - # return self.query_table_schema(path) - - def _normalize_table_path(self, path: DbPath) -> DbPath: - if len(path) == 1: - if self.default_schema: - return self.default_schema, path[0] - elif len(path) != 2: - raise ValueError(f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected form: schema.table") - - return path - - def parse_table_name(self, name: str) -> DbPath: - return parse_table_name(name) - - -class ThreadedDatabase(Database): - """Access the database through singleton threads. - - Used for database connectors that do not support sharing their connection between different threads. - """ - - def __init__(self, thread_count=1): - self._init_error = None - self._queue = ThreadPoolExecutor(thread_count, initializer=self.set_conn) - self.thread_local = threading.local() - - def set_conn(self): - assert not hasattr(self.thread_local, "conn") - try: - self.thread_local.conn = self.create_connection() - except ModuleNotFoundError as e: - self._init_error = e - - def _query(self, sql_code: str): - r = self._queue.submit(self._query_in_worker, sql_code) - return r.result() - - def _query_in_worker(self, sql_code: str): - "This method runs in a worker thread" - if self._init_error: - raise self._init_error - return _query_conn(self.thread_local.conn, sql_code) - - @abstractmethod - def create_connection(self): - ... - - def close(self): - self._queue.shutdown() - - -CHECKSUM_HEXDIGITS = 15 # Must be 15 or lower -MD5_HEXDIGITS = 32 - -_CHECKSUM_BITSIZE = CHECKSUM_HEXDIGITS << 2 -CHECKSUM_MASK = (2**_CHECKSUM_BITSIZE) - 1 - -DEFAULT_DATETIME_PRECISION = 6 -DEFAULT_NUMERIC_PRECISION = 24 - -TIMESTAMP_PRECISION_POS = 20 # len("2022-06-03 12:24:35.") == 20 - - -class PostgreSQL(ThreadedDatabase): - DATETIME_TYPES = { - "timestamp with time zone": TimestampTZ, - "timestamp without time zone": Timestamp, - "timestamp": Timestamp, - # "datetime": Datetime, - } - NUMERIC_TYPES = { - "double precision": Float, - "real": Float, - "decimal": Decimal, - "integer": Integer, - "numeric": Decimal, - "bigint": Integer, - } - ROUNDS_ON_PREC_LOSS = True - - default_schema = "public" - - def __init__(self, host, port, user, password, *, database, thread_count, **kw): - self.args = dict(host=host, port=port, database=database, user=user, password=password, **kw) - - super().__init__(thread_count=thread_count) - - def _convert_db_precision_to_digits(self, p: int) -> int: - # Subtracting 2 due to wierd precision issues in PostgreSQL - return super()._convert_db_precision_to_digits(p) - 2 - - def create_connection(self): - pg = import_postgresql() - try: - c = pg.connect(**self.args) - # c.cursor().execute("SET TIME ZONE 'UTC'") - return c - except pg.OperationalError as e: - raise ConnectError(*e.args) from e - - def quote(self, s: str): - return f'"{s}"' - - def md5_to_int(self, s: str) -> str: - return f"('x' || substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}))::bit({_CHECKSUM_BITSIZE})::bigint" - - def to_string(self, s: str): - return f"{s}::varchar" - - def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - if coltype.rounds: - return f"to_char({value}::timestamp({coltype.precision}), 'YYYY-mm-dd HH24:MI:SS.US')" - - timestamp6 = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')" - return ( - f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" - ) - - def normalize_number(self, value: str, coltype: NumericType) -> str: - return self.to_string(f"{value}::decimal(38, {coltype.precision})") - - -class Presto(Database): - default_schema = "public" - DATETIME_TYPES = { - "timestamp with time zone": TimestampTZ, - "timestamp without time zone": Timestamp, - "timestamp": Timestamp, - # "datetime": Datetime, - } - NUMERIC_TYPES = { - "integer": Integer, - "real": Float, - "double": Float, - } - ROUNDS_ON_PREC_LOSS = True - - def __init__(self, host, port, user, password, *, catalog, schema=None, **kw): - prestodb = import_presto() - self.args = dict(host=host, user=user, catalog=catalog, schema=schema, **kw) - - self._conn = prestodb.dbapi.connect(**self.args) - - def quote(self, s: str): - return f'"{s}"' - - def md5_to_int(self, s: str) -> str: - return f"cast(from_base(substr(to_hex(md5(to_utf8({s}))), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16) as decimal(38, 0))" - - def to_string(self, s: str): - return f"cast({s} as varchar)" - - def _query(self, sql_code: str) -> list: - "Uses the standard SQL cursor interface" - return _query_conn(self._conn, sql_code) - - def close(self): - self._conn.close() - - def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - # TODO - if coltype.rounds: - s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')" - else: - s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')" - - return f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')" - - def normalize_number(self, value: str, coltype: NumericType) -> str: - return self.to_string(f"cast({value} as decimal(38,{coltype.precision}))") - - def select_table_schema(self, path: DbPath) -> str: - schema, table = self._normalize_table_path(path) - - return ( - f"SELECT column_name, data_type, 3 as datetime_precision, 3 as numeric_precision FROM INFORMATION_SCHEMA.COLUMNS " - f"WHERE table_name = '{table}' AND table_schema = '{schema}'" - ) - - def _parse_type( - self, col_name: str, type_repr: str, datetime_precision: int = None, numeric_precision: int = None - ) -> ColType: - timestamp_regexps = { - r"timestamp\((\d)\)": Timestamp, - r"timestamp\((\d)\) with time zone": TimestampTZ, - } - for regexp, cls in timestamp_regexps.items(): - m = re.match(regexp + "$", type_repr) - if m: - datetime_precision = int(m.group(1)) - return cls( - precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION, - rounds=False, - ) - - number_regexps = {r"decimal\((\d+),(\d+)\)": Decimal} - for regexp, cls in number_regexps.items(): - m = re.match(regexp + "$", type_repr) - if m: - prec, scale = map(int, m.groups()) - return cls(scale) - - cls = self.NUMERIC_TYPES.get(type_repr) - if cls: - if issubclass(cls, Integer): - assert numeric_precision is not None - return cls(0) - - assert issubclass(cls, Float) - return cls( - precision=self._convert_db_precision_to_digits( - numeric_precision if numeric_precision is not None else DEFAULT_NUMERIC_PRECISION - ) - ) - - return UnknownColType(type_repr) - - -class MySQL(ThreadedDatabase): - DATETIME_TYPES = { - "datetime": Datetime, - "timestamp": Timestamp, - } - NUMERIC_TYPES = { - "double": Float, - "float": Float, - "decimal": Decimal, - "int": Integer, - } - ROUNDS_ON_PREC_LOSS = True - - def __init__(self, host, port, user, password, *, database, thread_count, **kw): - args = dict(host=host, port=port, database=database, user=user, password=password, **kw) - self._args = {k: v for k, v in args.items() if v is not None} - - super().__init__(thread_count=thread_count) - - self.default_schema = user - - def create_connection(self): - mysql = import_mysql() - try: - return mysql.connect(charset="utf8", use_unicode=True, **self._args) - except mysql.Error as e: - if e.errno == mysql.errorcode.ER_ACCESS_DENIED_ERROR: - raise ConnectError("Bad user name or password") from e - elif e.errno == mysql.errorcode.ER_BAD_DB_ERROR: - raise ConnectError("Database does not exist") from e - else: - raise ConnectError(*e.args) from e - - def quote(self, s: str): - return f"`{s}`" - - def md5_to_int(self, s: str) -> str: - return f"cast(conv(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16, 10) as unsigned)" - - def to_string(self, s: str): - return f"cast({s} as char)" - - def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - if coltype.rounds: - return self.to_string(f"cast( cast({value} as datetime({coltype.precision})) as datetime(6))") - - s = self.to_string(f"cast({value} as datetime(6))") - return f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')" - - def normalize_number(self, value: str, coltype: NumericType) -> str: - return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))") - - -class Oracle(ThreadedDatabase): - ROUNDS_ON_PREC_LOSS = True - - def __init__(self, host, port, user, password, *, database, thread_count, **kw): - assert not port - self.kwargs = dict(user=user, password=password, dsn="%s/%s" % (host, database), **kw) - super().__init__(thread_count=thread_count) - - def create_connection(self): - self._oracle = import_oracle() - try: - return self._oracle.connect(**self.kwargs) - except Exception as e: - raise ConnectError(*e.args) from e - - def _query(self, sql_code: str): - try: - return super()._query(sql_code) - except self._oracle.DatabaseError as e: - raise QueryError(e) - - def md5_to_int(self, s: str) -> str: - # standard_hash is faster than DBMS_CRYPTO.Hash - # TODO: Find a way to use UTL_RAW.CAST_TO_BINARY_INTEGER ? - return f"to_number(substr(standard_hash({s}, 'MD5'), 18), 'xxxxxxxxxxxxxxx')" - - def quote(self, s: str): - return f"{s}" - - def to_string(self, s: str): - return f"cast({s} as varchar(1024))" - - def select_table_schema(self, path: DbPath) -> str: - if len(path) > 1: - raise ValueError("Unexpected table path for oracle") - (table,) = path - - return ( - f"SELECT column_name, data_type, 6 as datetime_precision, data_precision as numeric_precision, data_scale as numeric_scale" - f" FROM USER_TAB_COLUMNS WHERE table_name = '{table.upper()}'" - ) - - def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - return f"to_char(cast({value} as timestamp({coltype.precision})), 'YYYY-MM-DD HH24:MI:SS.FF6')" - - def normalize_number(self, value: str, coltype: NumericType) -> str: - # FM999.9990 - format_str = "FM" + "9" * (38 - coltype.precision) - if coltype.precision: - format_str += "0." + "9" * (coltype.precision - 1) + "0" - return f"to_char({value}, '{format_str}')" - - def _parse_type( - self, - col_name: str, - type_repr: str, - datetime_precision: int = None, - numeric_precision: int = None, - numeric_scale: int = None, - ) -> ColType: - """ """ - regexps = { - r"TIMESTAMP\((\d)\) WITH LOCAL TIME ZONE": Timestamp, - r"TIMESTAMP\((\d)\) WITH TIME ZONE": TimestampTZ, - } - for regexp, cls in regexps.items(): - m = re.match(regexp + "$", type_repr) - if m: - datetime_precision = int(m.group(1)) - return cls( - precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION, - rounds=self.ROUNDS_ON_PREC_LOSS, - ) - - cls = { - "NUMBER": Decimal, - "FLOAT": Float, - }.get(type_repr, None) - if cls: - if issubclass(cls, Decimal): - assert numeric_scale is not None, (type_repr, numeric_precision, numeric_scale) - return cls(precision=numeric_scale) - - assert issubclass(cls, Float) - return cls( - precision=self._convert_db_precision_to_digits( - numeric_precision if numeric_precision is not None else DEFAULT_NUMERIC_PRECISION - ) - ) - - return UnknownColType(type_repr) - - -class Redshift(PostgreSQL): - NUMERIC_TYPES = { - **PostgreSQL.NUMERIC_TYPES, - "double": Float, - "real": Float, - } - - # def _convert_db_precision_to_digits(self, p: int) -> int: - # return super()._convert_db_precision_to_digits(p // 2) - - def md5_to_int(self, s: str) -> str: - return f"strtol(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16)::decimal(38)" - - def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - if coltype.rounds: - timestamp = f"{value}::timestamp(6)" - # Get seconds since epoch. Redshift doesn't support milli- or micro-seconds. - secs = f"timestamp 'epoch' + round(extract(epoch from {timestamp})::decimal(38)" - # Get the milliseconds from timestamp. - ms = f"extract(ms from {timestamp})" - # Get the microseconds from timestamp, without the milliseconds! - us = f"extract(us from {timestamp})" - # epoch = Total time since epoch in microseconds. - epoch = f"{secs}*1000000 + {ms}*1000 + {us}" - timestamp6 = ( - f"to_char({epoch}, -6+{coltype.precision}) * interval '0.000001 seconds', 'YYYY-mm-dd HH24:MI:SS.US')" - ) - else: - timestamp6 = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')" - return ( - f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" - ) - - def normalize_number(self, value: str, coltype: NumericType) -> str: - return self.to_string(f"{value}::decimal(38,{coltype.precision})") - - def select_table_schema(self, path: DbPath) -> str: - schema, table = self._normalize_table_path(path) - - return ( - "SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale FROM information_schema.columns " - f"WHERE table_name = '{table.lower()}' AND table_schema = '{schema.lower()}'" - ) - - -class MsSQL(ThreadedDatabase): - "AKA sql-server" - - def __init__(self, host, port, user, password, *, database, thread_count, **kw): - args = dict(server=host, port=port, database=database, user=user, password=password, **kw) - self._args = {k: v for k, v in args.items() if v is not None} - - super().__init__(thread_count=thread_count) - - def create_connection(self): - mssql = import_mssql() - try: - return mssql.connect(**self._args) - except mssql.Error as e: - raise ConnectError(*e.args) from e - - def quote(self, s: str): - return f"[{s}]" - - def md5_to_int(self, s: str) -> str: - return f"CONVERT(decimal(38,0), CONVERT(bigint, HashBytes('MD5', {s}), 2))" - # return f"CONVERT(bigint, (CHECKSUM({s})))" - - def to_string(self, s: str): - return f"CONVERT(varchar, {s})" - - -class BigQuery(Database): - DATETIME_TYPES = { - "TIMESTAMP": Timestamp, - "DATETIME": Datetime, - } - NUMERIC_TYPES = { - "INT64": Integer, - "INT32": Integer, - "NUMERIC": Decimal, - "BIGNUMERIC": Decimal, - "FLOAT64": Float, - "FLOAT32": Float, - } - ROUNDS_ON_PREC_LOSS = False # Technically BigQuery doesn't allow implicit rounding or truncation - - def __init__(self, project, *, dataset, **kw): - bigquery = import_bigquery() - - self._client = bigquery.Client(project, **kw) - self.project = project - self.dataset = dataset - - self.default_schema = dataset - - def quote(self, s: str): - return f"`{s}`" - - def md5_to_int(self, s: str) -> str: - return f"cast(cast( ('0x' || substr(TO_HEX(md5({s})), 18)) as int64) as numeric)" - - def _normalize_returned_value(self, value): - if isinstance(value, bytes): - return value.decode() - return value - - def _query(self, sql_code: str): - from google.cloud import bigquery - - try: - res = list(self._client.query(sql_code)) - except Exception as e: - msg = "Exception when trying to execute SQL code:\n %s\n\nGot error: %s" - raise ConnectError(msg % (sql_code, e)) - - if res and isinstance(res[0], bigquery.table.Row): - res = [tuple(self._normalize_returned_value(v) for v in row.values()) for row in res] - return res - - def to_string(self, s: str): - return f"cast({s} as string)" - - def close(self): - self._client.close() - - def select_table_schema(self, path: DbPath) -> str: - schema, table = self._normalize_table_path(path) - - return ( - f"SELECT column_name, data_type, 6 as datetime_precision, 38 as numeric_precision, 9 as numeric_scale FROM {schema}.INFORMATION_SCHEMA.COLUMNS " - f"WHERE table_name = '{table}' AND table_schema = '{schema}'" - ) - - def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - if coltype.rounds: - timestamp = f"timestamp_micros(cast(round(unix_micros(cast({value} as timestamp))/1000000, {coltype.precision})*1000000 as int))" - return f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {timestamp})" - - if coltype.precision == 0: - return f"FORMAT_TIMESTAMP('%F %H:%M:%S.000000, {value})" - elif coltype.precision == 6: - return f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {value})" - - timestamp6 = f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {value})" - return ( - f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" - ) - - def normalize_number(self, value: str, coltype: NumericType) -> str: - if isinstance(coltype, Integer): - return self.to_string(value) - return f"format('%.{coltype.precision}f', {value})" - - def parse_table_name(self, name: str) -> DbPath: - path = parse_table_name(name) - return self._normalize_table_path(path) - - -class Snowflake(Database): - DATETIME_TYPES = { - "TIMESTAMP_NTZ": Timestamp, - "TIMESTAMP_LTZ": Timestamp, - "TIMESTAMP_TZ": TimestampTZ, - } - NUMERIC_TYPES = { - "NUMBER": Decimal, - "FLOAT": Float, - } - ROUNDS_ON_PREC_LOSS = False - - def __init__( - self, - account: str, - _port: int, - user: str, - password: str, - *, - warehouse: str, - schema: str, - database: str, - role: str = None, - **kw, - ): - snowflake = import_snowflake() - logging.getLogger("snowflake.connector").setLevel(logging.WARNING) - - # Got an error: snowflake.connector.network.RetryRequest: could not find io module state (interpreter shutdown?) - # It's a known issue: https://github.com/snowflakedb/snowflake-connector-python/issues/145 - # Found a quick solution in comments - logging.getLogger("snowflake.connector.network").disabled = True - - assert '"' not in schema, "Schema name should not contain quotes!" - self._conn = snowflake.connector.connect( - user=user, - password=password, - account=account, - role=role, - database=database, - warehouse=warehouse, - schema=f'"{schema}"', - **kw, - ) - - self.default_schema = schema - - def close(self): - self._conn.close() - - def _query(self, sql_code: str) -> list: - "Uses the standard SQL cursor interface" - return _query_conn(self._conn, sql_code) - - def quote(self, s: str): - return f'"{s}"' - - def md5_to_int(self, s: str) -> str: - return f"BITAND(md5_number_lower64({s}), {CHECKSUM_MASK})" - - def to_string(self, s: str): - return f"cast({s} as string)" - - def select_table_schema(self, path: DbPath) -> str: - schema, table = self._normalize_table_path(path) - return super().select_table_schema((schema, table)) - - def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - if coltype.rounds: - timestamp = f"to_timestamp(round(date_part(epoch_nanosecond, {value}::timestamp(9))/1000000000, {coltype.precision}))" - else: - timestamp = f"cast({value} as timestamp({coltype.precision}))" - - return f"to_char({timestamp}, 'YYYY-MM-DD HH24:MI:SS.FF6')" - - def normalize_number(self, value: str, coltype: NumericType) -> str: - return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))") - - -@dataclass -class MatchUriPath: - database_cls: Type[Database] - params: List[str] - kwparams: List[str] = [] - help_str: str - - def match_path(self, dsn): - dsn_dict = dict(dsn.query) - matches = {} - for param, arg in zip_longest(self.params, dsn.paths): - if param is None: - raise ValueError(f"Too many parts to path. Expected format: {self.help_str}") - - optional = param.endswith("?") - param = param.rstrip("?") - - if arg is None: - try: - arg = dsn_dict.pop(param) - except KeyError: - if not optional: - raise ValueError(f"URI must specify '{param}'. Expected format: {self.help_str}") - - arg = None - - assert param and param not in matches - matches[param] = arg - - for param in self.kwparams: - try: - arg = dsn_dict.pop(param) - except KeyError: - raise ValueError(f"URI must specify '{param}'. Expected format: {self.help_str}") - - assert param and arg and param not in matches, (param, arg, matches.keys()) - matches[param] = arg - - for param, value in dsn_dict.items(): - if param in matches: - raise ValueError( - f"Parameter '{param}' already provided as positional argument. Expected format: {self.help_str}" - ) - - matches[param] = value - - return matches - - -MATCH_URI_PATH = { - "postgresql": MatchUriPath(PostgreSQL, ["database?"], help_str="postgresql://:@/"), - "mysql": MatchUriPath(MySQL, ["database?"], help_str="mysql://:@/"), - "oracle": MatchUriPath(Oracle, ["database?"], help_str="oracle://:@/"), - # "mssql": MatchUriPath(MsSQL, ["database?"], help_str="mssql://:@/"), - "redshift": MatchUriPath(Redshift, ["database?"], help_str="redshift://:@/"), - "snowflake": MatchUriPath( - Snowflake, - ["database", "schema"], - ["warehouse"], - help_str="snowflake://:@//?warehouse=", - ), - "presto": MatchUriPath(Presto, ["catalog", "schema"], help_str="presto://@//"), - "bigquery": MatchUriPath(BigQuery, ["dataset"], help_str="bigquery:///"), -} - - -def connect_to_uri(db_uri: str, thread_count: Optional[int] = 1) -> Database: - """Connect to the given database uri - - thread_count determines the max number of worker threads per database, - if relevant. None means no limit. - - Parameters: - db_uri (str): The URI for the database to connect - thread_count (int, optional): Size of the threadpool. Ignored by cloud databases. (default: 1) - - Note: For non-cloud databases, a low thread-pool size may be a performance bottleneck. - - Supported schemes: - - postgresql - - mysql - - oracle - - snowflake - - bigquery - - redshift - - presto - """ - - dsn = dsnparse.parse(db_uri) - if len(dsn.schemes) > 1: - raise NotImplementedError("No support for multiple schemes") - (scheme,) = dsn.schemes - - try: - matcher = MATCH_URI_PATH[scheme] - except KeyError: - raise NotImplementedError(f"Scheme {scheme} currently not supported") - - cls = matcher.database_cls - kw = matcher.match_path(dsn) - - if scheme == "bigquery": - return cls(dsn.host, **kw) - - if issubclass(cls, ThreadedDatabase): - return cls(dsn.host, dsn.port, dsn.user, dsn.password, thread_count=thread_count, **kw) - - return cls(dsn.host, dsn.port, dsn.user, dsn.password, **kw) diff --git a/data_diff/databases/__init__.py b/data_diff/databases/__init__.py new file mode 100644 index 00000000..be78f320 --- /dev/null +++ b/data_diff/databases/__init__.py @@ -0,0 +1,11 @@ +from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, QueryError, ConnectError + +from .postgresql import PostgreSQL +from .mysql import MySQL +from .oracle import Oracle +from .snowflake import Snowflake +from .bigquery import BigQuery +from .redshift import Redshift +from .presto import Presto + +from .connect import connect_to_uri diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py new file mode 100644 index 00000000..9602b583 --- /dev/null +++ b/data_diff/databases/base.py @@ -0,0 +1,235 @@ +import math +import sys +import logging +from typing import Dict, Tuple, Optional, Sequence +from functools import lru_cache, wraps +from concurrent.futures import ThreadPoolExecutor +import threading +from abc import abstractmethod + +from .database_types import AbstractDatabase, ColType, Integer, Decimal, Float, UnknownColType +from data_diff.sql import DbPath, SqlOrStr, Compiler, Explain, Select + +logger = logging.getLogger("database") + + +def parse_table_name(t): + return tuple(t.split(".")) + + +def import_helper(package: str = None, text=""): + def dec(f): + @wraps(f) + def _inner(): + try: + return f() + except ModuleNotFoundError as e: + s = text + if package: + s += f"You can install it using 'pip install data-diff[{package}]'." + raise ModuleNotFoundError(f"{e}\n\n{s}\n") + + return _inner + + return dec + + +class ConnectError(Exception): + pass + + +class QueryError(Exception): + pass + + +def _one(seq): + (x,) = seq + return x + + +def _query_conn(conn, sql_code: str) -> list: + c = conn.cursor() + c.execute(sql_code) + if sql_code.lower().startswith("select"): + return c.fetchall() + + +class Database(AbstractDatabase): + """Base abstract class for databases. + + Used for providing connection code and implementation specific SQL utilities. + + Instanciated using :meth:`~data_diff.connect_to_uri` + """ + + DATETIME_TYPES: Dict[str, type] = {} + default_schema: str = None + + @property + def name(self): + return type(self).__name__ + + def query(self, sql_ast: SqlOrStr, res_type: type): + "Query the given SQL code/AST, and attempt to convert the result to type 'res_type'" + + compiler = Compiler(self) + sql_code = compiler.compile(sql_ast) + logger.debug("Running SQL (%s): %s", type(self).__name__, sql_code) + if getattr(self, "_interactive", False) and isinstance(sql_ast, Select): + explained_sql = compiler.compile(Explain(sql_ast)) + logger.info(f"EXPLAIN for SQL SELECT") + logger.info(self._query(explained_sql)) + answer = input("Continue? [y/n] ") + if not answer.lower() in ["y", "yes"]: + sys.exit(1) + + res = self._query(sql_code) + if res_type is int: + res = _one(_one(res)) + if res is None: # May happen due to sum() of 0 items + return None + return int(res) + elif res_type is tuple: + assert len(res) == 1, (sql_code, res) + return res[0] + elif getattr(res_type, "__origin__", None) is list and len(res_type.__args__) == 1: + if res_type.__args__ == (int,): + return [_one(row) for row in res] + elif res_type.__args__ == (Tuple,): + return [tuple(row) for row in res] + else: + raise ValueError(res_type) + return res + + def enable_interactive(self): + self._interactive = True + + def _convert_db_precision_to_digits(self, p: int) -> int: + """Convert from binary precision, used by floats, to decimal precision.""" + # See: https://en.wikipedia.org/wiki/Single-precision_floating-point_format + return math.floor(math.log(2**p, 10)) + + def _parse_type( + self, + col_name: str, + type_repr: str, + datetime_precision: int = None, + numeric_precision: int = None, + numeric_scale: int = None, + ) -> ColType: + """ """ + + cls = self.DATETIME_TYPES.get(type_repr) + if cls: + return cls( + precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION, + rounds=self.ROUNDS_ON_PREC_LOSS, + ) + + cls = self.NUMERIC_TYPES.get(type_repr) + if cls: + if issubclass(cls, Integer): + # Some DBs have a constant numeric_scale, so they don't report it. + # We fill in the constant, so we need to ignore it for integers. + return cls(precision=0) + + elif issubclass(cls, Decimal): + if numeric_scale is None: + raise ValueError( + f"{self.name}: Unexpected numeric_scale is NULL, for column {col_name} of type {type_repr}." + ) + return cls(precision=numeric_scale) + + assert issubclass(cls, Float) + # assert numeric_scale is None + return cls( + precision=self._convert_db_precision_to_digits( + numeric_precision if numeric_precision is not None else DEFAULT_NUMERIC_PRECISION + ) + ) + + return UnknownColType(type_repr) + + def select_table_schema(self, path: DbPath) -> str: + schema, table = self._normalize_table_path(path) + + return ( + "SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale FROM information_schema.columns " + f"WHERE table_name = '{table}' AND table_schema = '{schema}'" + ) + + def query_table_schema(self, path: DbPath, filter_columns: Optional[Sequence[str]] = None) -> Dict[str, ColType]: + rows = self.query(self.select_table_schema(path), list) + if not rows: + raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns") + + if filter_columns is not None: + accept = {i.lower() for i in filter_columns} + rows = [r for r in rows if r[0].lower() in accept] + + # Return a dict of form {name: type} after normalization + return {row[0]: self._parse_type(*row) for row in rows} + + # @lru_cache() + # def get_table_schema(self, path: DbPath) -> Dict[str, ColType]: + # return self.query_table_schema(path) + + def _normalize_table_path(self, path: DbPath) -> DbPath: + if len(path) == 1: + if self.default_schema: + return self.default_schema, path[0] + elif len(path) != 2: + raise ValueError(f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected form: schema.table") + + return path + + def parse_table_name(self, name: str) -> DbPath: + return parse_table_name(name) + + +class ThreadedDatabase(Database): + """Access the database through singleton threads. + + Used for database connectors that do not support sharing their connection between different threads. + """ + + def __init__(self, thread_count=1): + self._init_error = None + self._queue = ThreadPoolExecutor(thread_count, initializer=self.set_conn) + self.thread_local = threading.local() + + def set_conn(self): + assert not hasattr(self.thread_local, "conn") + try: + self.thread_local.conn = self.create_connection() + except ModuleNotFoundError as e: + self._init_error = e + + def _query(self, sql_code: str): + r = self._queue.submit(self._query_in_worker, sql_code) + return r.result() + + def _query_in_worker(self, sql_code: str): + "This method runs in a worker thread" + if self._init_error: + raise self._init_error + return _query_conn(self.thread_local.conn, sql_code) + + @abstractmethod + def create_connection(self): + ... + + def close(self): + self._queue.shutdown() + + +CHECKSUM_HEXDIGITS = 15 # Must be 15 or lower +MD5_HEXDIGITS = 32 + +_CHECKSUM_BITSIZE = CHECKSUM_HEXDIGITS << 2 +CHECKSUM_MASK = (2**_CHECKSUM_BITSIZE) - 1 + +DEFAULT_DATETIME_PRECISION = 6 +DEFAULT_NUMERIC_PRECISION = 24 + +TIMESTAMP_PRECISION_POS = 20 # len("2022-06-03 12:24:35.") == 20 diff --git a/data_diff/databases/bigquery.py b/data_diff/databases/bigquery.py new file mode 100644 index 00000000..735cdc83 --- /dev/null +++ b/data_diff/databases/bigquery.py @@ -0,0 +1,97 @@ +from .database_types import * +from .base import Database, import_helper, parse_table_name, ConnectError +from .base import TIMESTAMP_PRECISION_POS + + +@import_helper(text="Please install BigQuery and configure your google-cloud access.") +def import_bigquery(): + from google.cloud import bigquery + + return bigquery + + +class BigQuery(Database): + DATETIME_TYPES = { + "TIMESTAMP": Timestamp, + "DATETIME": Datetime, + } + NUMERIC_TYPES = { + "INT64": Integer, + "INT32": Integer, + "NUMERIC": Decimal, + "BIGNUMERIC": Decimal, + "FLOAT64": Float, + "FLOAT32": Float, + } + ROUNDS_ON_PREC_LOSS = False # Technically BigQuery doesn't allow implicit rounding or truncation + + def __init__(self, project, *, dataset, **kw): + bigquery = import_bigquery() + + self._client = bigquery.Client(project, **kw) + self.project = project + self.dataset = dataset + + self.default_schema = dataset + + def quote(self, s: str): + return f"`{s}`" + + def md5_to_int(self, s: str) -> str: + return f"cast(cast( ('0x' || substr(TO_HEX(md5({s})), 18)) as int64) as numeric)" + + def _normalize_returned_value(self, value): + if isinstance(value, bytes): + return value.decode() + return value + + def _query(self, sql_code: str): + from google.cloud import bigquery + + try: + res = list(self._client.query(sql_code)) + except Exception as e: + msg = "Exception when trying to execute SQL code:\n %s\n\nGot error: %s" + raise ConnectError(msg % (sql_code, e)) + + if res and isinstance(res[0], bigquery.table.Row): + res = [tuple(self._normalize_returned_value(v) for v in row.values()) for row in res] + return res + + def to_string(self, s: str): + return f"cast({s} as string)" + + def close(self): + self._client.close() + + def select_table_schema(self, path: DbPath) -> str: + schema, table = self._normalize_table_path(path) + + return ( + f"SELECT column_name, data_type, 6 as datetime_precision, 38 as numeric_precision, 9 as numeric_scale FROM {schema}.INFORMATION_SCHEMA.COLUMNS " + f"WHERE table_name = '{table}' AND table_schema = '{schema}'" + ) + + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: + if coltype.rounds: + timestamp = f"timestamp_micros(cast(round(unix_micros(cast({value} as timestamp))/1000000, {coltype.precision})*1000000 as int))" + return f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {timestamp})" + + if coltype.precision == 0: + return f"FORMAT_TIMESTAMP('%F %H:%M:%S.000000, {value})" + elif coltype.precision == 6: + return f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {value})" + + timestamp6 = f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {value})" + return ( + f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" + ) + + def normalize_number(self, value: str, coltype: NumericType) -> str: + if isinstance(coltype, Integer): + return self.to_string(value) + return f"format('%.{coltype.precision}f', {value})" + + def parse_table_name(self, name: str) -> DbPath: + path = parse_table_name(name) + return self._normalize_table_path(path) diff --git a/data_diff/databases/connect.py b/data_diff/databases/connect.py new file mode 100644 index 00000000..8a705046 --- /dev/null +++ b/data_diff/databases/connect.py @@ -0,0 +1,124 @@ +from typing import Type, List, Optional +from itertools import zip_longest +import dsnparse + +from runtype import dataclass + +from .base import Database, ThreadedDatabase +from .postgresql import PostgreSQL +from .mysql import MySQL +from .oracle import Oracle +from .snowflake import Snowflake +from .bigquery import BigQuery +from .redshift import Redshift +from .presto import Presto + + +@dataclass +class MatchUriPath: + database_cls: Type[Database] + params: List[str] + kwparams: List[str] = [] + help_str: str + + def match_path(self, dsn): + dsn_dict = dict(dsn.query) + matches = {} + for param, arg in zip_longest(self.params, dsn.paths): + if param is None: + raise ValueError(f"Too many parts to path. Expected format: {self.help_str}") + + optional = param.endswith("?") + param = param.rstrip("?") + + if arg is None: + try: + arg = dsn_dict.pop(param) + except KeyError: + if not optional: + raise ValueError(f"URI must specify '{param}'. Expected format: {self.help_str}") + + arg = None + + assert param and param not in matches + matches[param] = arg + + for param in self.kwparams: + try: + arg = dsn_dict.pop(param) + except KeyError: + raise ValueError(f"URI must specify '{param}'. Expected format: {self.help_str}") + + assert param and arg and param not in matches, (param, arg, matches.keys()) + matches[param] = arg + + for param, value in dsn_dict.items(): + if param in matches: + raise ValueError( + f"Parameter '{param}' already provided as positional argument. Expected format: {self.help_str}" + ) + + matches[param] = value + + return matches + + +MATCH_URI_PATH = { + "postgresql": MatchUriPath(PostgreSQL, ["database?"], help_str="postgresql://:@/"), + "mysql": MatchUriPath(MySQL, ["database?"], help_str="mysql://:@/"), + "oracle": MatchUriPath(Oracle, ["database?"], help_str="oracle://:@/"), + # "mssql": MatchUriPath(MsSQL, ["database?"], help_str="mssql://:@/"), + "redshift": MatchUriPath(Redshift, ["database?"], help_str="redshift://:@/"), + "snowflake": MatchUriPath( + Snowflake, + ["database", "schema"], + ["warehouse"], + help_str="snowflake://:@//?warehouse=", + ), + "presto": MatchUriPath(Presto, ["catalog", "schema"], help_str="presto://@//"), + "bigquery": MatchUriPath(BigQuery, ["dataset"], help_str="bigquery:///"), +} + + +def connect_to_uri(db_uri: str, thread_count: Optional[int] = 1) -> Database: + """Connect to the given database uri + + thread_count determines the max number of worker threads per database, + if relevant. None means no limit. + + Parameters: + db_uri (str): The URI for the database to connect + thread_count (int, optional): Size of the threadpool. Ignored by cloud databases. (default: 1) + + Note: For non-cloud databases, a low thread-pool size may be a performance bottleneck. + + Supported schemes: + - postgresql + - mysql + - oracle + - snowflake + - bigquery + - redshift + - presto + """ + + dsn = dsnparse.parse(db_uri) + if len(dsn.schemes) > 1: + raise NotImplementedError("No support for multiple schemes") + (scheme,) = dsn.schemes + + try: + matcher = MATCH_URI_PATH[scheme] + except KeyError: + raise NotImplementedError(f"Scheme {scheme} currently not supported") + + cls = matcher.database_cls + kw = matcher.match_path(dsn) + + if scheme == "bigquery": + return cls(dsn.host, **kw) + + if issubclass(cls, ThreadedDatabase): + return cls(dsn.host, dsn.port, dsn.user, dsn.password, thread_count=thread_count, **kw) + + return cls(dsn.host, dsn.port, dsn.user, dsn.password, **kw) diff --git a/data_diff/database_types.py b/data_diff/databases/database_types.py similarity index 99% rename from data_diff/database_types.py rename to data_diff/databases/database_types.py index c441c4cb..0150010a 100644 --- a/data_diff/database_types.py +++ b/data_diff/databases/database_types.py @@ -152,7 +152,6 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str: return self.normalize_number(value, coltype) return self.to_string(f"{value}") - def _normalize_table_path(self, path: DbPath) -> DbPath: ... diff --git a/data_diff/databases/mssql.py b/data_diff/databases/mssql.py new file mode 100644 index 00000000..a4d8d22a --- /dev/null +++ b/data_diff/databases/mssql.py @@ -0,0 +1,25 @@ +class MsSQL(ThreadedDatabase): + "AKA sql-server" + + def __init__(self, host, port, user, password, *, database, thread_count, **kw): + args = dict(server=host, port=port, database=database, user=user, password=password, **kw) + self._args = {k: v for k, v in args.items() if v is not None} + + super().__init__(thread_count=thread_count) + + def create_connection(self): + mssql = import_mssql() + try: + return mssql.connect(**self._args) + except mssql.Error as e: + raise ConnectError(*e.args) from e + + def quote(self, s: str): + return f"[{s}]" + + def md5_to_int(self, s: str) -> str: + return f"CONVERT(decimal(38,0), CONVERT(bigint, HashBytes('MD5', {s}), 2))" + # return f"CONVERT(bigint, (CHECKSUM({s})))" + + def to_string(self, s: str): + return f"CONVERT(varchar, {s})" diff --git a/data_diff/databases/mysql.py b/data_diff/databases/mysql.py new file mode 100644 index 00000000..bee1a2a8 --- /dev/null +++ b/data_diff/databases/mysql.py @@ -0,0 +1,63 @@ +from .database_types import * +from .base import ThreadedDatabase, import_helper, ConnectError +from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, TIMESTAMP_PRECISION_POS + + +@import_helper("mysql") +def import_mysql(): + import mysql.connector + + return mysql.connector + + +class MySQL(ThreadedDatabase): + DATETIME_TYPES = { + "datetime": Datetime, + "timestamp": Timestamp, + } + NUMERIC_TYPES = { + "double": Float, + "float": Float, + "decimal": Decimal, + "int": Integer, + } + ROUNDS_ON_PREC_LOSS = True + + def __init__(self, host, port, user, password, *, database, thread_count, **kw): + args = dict(host=host, port=port, database=database, user=user, password=password, **kw) + self._args = {k: v for k, v in args.items() if v is not None} + + super().__init__(thread_count=thread_count) + + self.default_schema = user + + def create_connection(self): + mysql = import_mysql() + try: + return mysql.connect(charset="utf8", use_unicode=True, **self._args) + except mysql.Error as e: + if e.errno == mysql.errorcode.ER_ACCESS_DENIED_ERROR: + raise ConnectError("Bad user name or password") from e + elif e.errno == mysql.errorcode.ER_BAD_DB_ERROR: + raise ConnectError("Database does not exist") from e + else: + raise ConnectError(*e.args) from e + + def quote(self, s: str): + return f"`{s}`" + + def md5_to_int(self, s: str) -> str: + return f"cast(conv(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16, 10) as unsigned)" + + def to_string(self, s: str): + return f"cast({s} as char)" + + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: + if coltype.rounds: + return self.to_string(f"cast( cast({value} as datetime({coltype.precision})) as datetime(6))") + + s = self.to_string(f"cast({value} as datetime(6))") + return f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')" + + def normalize_number(self, value: str, coltype: NumericType) -> str: + return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))") diff --git a/data_diff/databases/oracle.py b/data_diff/databases/oracle.py new file mode 100644 index 00000000..6a00fd36 --- /dev/null +++ b/data_diff/databases/oracle.py @@ -0,0 +1,105 @@ +import re + +from .database_types import * +from .base import ThreadedDatabase, import_helper, ConnectError, QueryError +from .base import DEFAULT_DATETIME_PRECISION, DEFAULT_NUMERIC_PRECISION + + +@import_helper("oracle") +def import_oracle(): + import cx_Oracle + + return cx_Oracle + + +class Oracle(ThreadedDatabase): + ROUNDS_ON_PREC_LOSS = True + + def __init__(self, host, port, user, password, *, database, thread_count, **kw): + assert not port + self.kwargs = dict(user=user, password=password, dsn="%s/%s" % (host, database), **kw) + super().__init__(thread_count=thread_count) + + def create_connection(self): + self._oracle = import_oracle() + try: + return self._oracle.connect(**self.kwargs) + except Exception as e: + raise ConnectError(*e.args) from e + + def _query(self, sql_code: str): + try: + return super()._query(sql_code) + except self._oracle.DatabaseError as e: + raise QueryError(e) + + def md5_to_int(self, s: str) -> str: + # standard_hash is faster than DBMS_CRYPTO.Hash + # TODO: Find a way to use UTL_RAW.CAST_TO_BINARY_INTEGER ? + return f"to_number(substr(standard_hash({s}, 'MD5'), 18), 'xxxxxxxxxxxxxxx')" + + def quote(self, s: str): + return f"{s}" + + def to_string(self, s: str): + return f"cast({s} as varchar(1024))" + + def select_table_schema(self, path: DbPath) -> str: + if len(path) > 1: + raise ValueError("Unexpected table path for oracle") + (table,) = path + + return ( + f"SELECT column_name, data_type, 6 as datetime_precision, data_precision as numeric_precision, data_scale as numeric_scale" + f" FROM USER_TAB_COLUMNS WHERE table_name = '{table.upper()}'" + ) + + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: + return f"to_char(cast({value} as timestamp({coltype.precision})), 'YYYY-MM-DD HH24:MI:SS.FF6')" + + def normalize_number(self, value: str, coltype: NumericType) -> str: + # FM999.9990 + format_str = "FM" + "9" * (38 - coltype.precision) + if coltype.precision: + format_str += "0." + "9" * (coltype.precision - 1) + "0" + return f"to_char({value}, '{format_str}')" + + def _parse_type( + self, + col_name: str, + type_repr: str, + datetime_precision: int = None, + numeric_precision: int = None, + numeric_scale: int = None, + ) -> ColType: + """ """ + regexps = { + r"TIMESTAMP\((\d)\) WITH LOCAL TIME ZONE": Timestamp, + r"TIMESTAMP\((\d)\) WITH TIME ZONE": TimestampTZ, + } + for regexp, cls in regexps.items(): + m = re.match(regexp + "$", type_repr) + if m: + datetime_precision = int(m.group(1)) + return cls( + precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION, + rounds=self.ROUNDS_ON_PREC_LOSS, + ) + + cls = { + "NUMBER": Decimal, + "FLOAT": Float, + }.get(type_repr, None) + if cls: + if issubclass(cls, Decimal): + assert numeric_scale is not None, (type_repr, numeric_precision, numeric_scale) + return cls(precision=numeric_scale) + + assert issubclass(cls, Float) + return cls( + precision=self._convert_db_precision_to_digits( + numeric_precision if numeric_precision is not None else DEFAULT_NUMERIC_PRECISION + ) + ) + + return UnknownColType(type_repr) diff --git a/data_diff/databases/postgresql.py b/data_diff/databases/postgresql.py new file mode 100644 index 00000000..88ac18be --- /dev/null +++ b/data_diff/databases/postgresql.py @@ -0,0 +1,71 @@ +from .database_types import * +from .base import ThreadedDatabase, import_helper, ConnectError +from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, _CHECKSUM_BITSIZE, TIMESTAMP_PRECISION_POS + + +@import_helper("postgresql") +def import_postgresql(): + import psycopg2 + import psycopg2.extras + + psycopg2.extensions.set_wait_callback(psycopg2.extras.wait_select) + return psycopg2 + + +class PostgreSQL(ThreadedDatabase): + DATETIME_TYPES = { + "timestamp with time zone": TimestampTZ, + "timestamp without time zone": Timestamp, + "timestamp": Timestamp, + # "datetime": Datetime, + } + NUMERIC_TYPES = { + "double precision": Float, + "real": Float, + "decimal": Decimal, + "integer": Integer, + "numeric": Decimal, + "bigint": Integer, + } + ROUNDS_ON_PREC_LOSS = True + + default_schema = "public" + + def __init__(self, host, port, user, password, *, database, thread_count, **kw): + self.args = dict(host=host, port=port, database=database, user=user, password=password, **kw) + + super().__init__(thread_count=thread_count) + + def _convert_db_precision_to_digits(self, p: int) -> int: + # Subtracting 2 due to wierd precision issues in PostgreSQL + return super()._convert_db_precision_to_digits(p) - 2 + + def create_connection(self): + pg = import_postgresql() + try: + c = pg.connect(**self.args) + # c.cursor().execute("SET TIME ZONE 'UTC'") + return c + except pg.OperationalError as e: + raise ConnectError(*e.args) from e + + def quote(self, s: str): + return f'"{s}"' + + def md5_to_int(self, s: str) -> str: + return f"('x' || substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}))::bit({_CHECKSUM_BITSIZE})::bigint" + + def to_string(self, s: str): + return f"{s}::varchar" + + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: + if coltype.rounds: + return f"to_char({value}::timestamp({coltype.precision}), 'YYYY-mm-dd HH24:MI:SS.US')" + + timestamp6 = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')" + return ( + f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" + ) + + def normalize_number(self, value: str, coltype: NumericType) -> str: + return self.to_string(f"{value}::decimal(38, {coltype.precision})") diff --git a/data_diff/databases/presto.py b/data_diff/databases/presto.py new file mode 100644 index 00000000..a0f75010 --- /dev/null +++ b/data_diff/databases/presto.py @@ -0,0 +1,114 @@ +import re + +from .database_types import * +from .base import Database, import_helper, _query_conn +from .base import ( + MD5_HEXDIGITS, + CHECKSUM_HEXDIGITS, + TIMESTAMP_PRECISION_POS, + DEFAULT_DATETIME_PRECISION, + DEFAULT_NUMERIC_PRECISION, +) + + +@import_helper("presto") +def import_presto(): + import prestodb + + return prestodb + + +class Presto(Database): + default_schema = "public" + DATETIME_TYPES = { + "timestamp with time zone": TimestampTZ, + "timestamp without time zone": Timestamp, + "timestamp": Timestamp, + # "datetime": Datetime, + } + NUMERIC_TYPES = { + "integer": Integer, + "real": Float, + "double": Float, + } + ROUNDS_ON_PREC_LOSS = True + + def __init__(self, host, port, user, password, *, catalog, schema=None, **kw): + prestodb = import_presto() + self.args = dict(host=host, user=user, catalog=catalog, schema=schema, **kw) + + self._conn = prestodb.dbapi.connect(**self.args) + + def quote(self, s: str): + return f'"{s}"' + + def md5_to_int(self, s: str) -> str: + return f"cast(from_base(substr(to_hex(md5(to_utf8({s}))), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16) as decimal(38, 0))" + + def to_string(self, s: str): + return f"cast({s} as varchar)" + + def _query(self, sql_code: str) -> list: + "Uses the standard SQL cursor interface" + return _query_conn(self._conn, sql_code) + + def close(self): + self._conn.close() + + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: + # TODO + if coltype.rounds: + s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')" + else: + s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')" + + return f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')" + + def normalize_number(self, value: str, coltype: NumericType) -> str: + return self.to_string(f"cast({value} as decimal(38,{coltype.precision}))") + + def select_table_schema(self, path: DbPath) -> str: + schema, table = self._normalize_table_path(path) + + return ( + f"SELECT column_name, data_type, 3 as datetime_precision, 3 as numeric_precision FROM INFORMATION_SCHEMA.COLUMNS " + f"WHERE table_name = '{table}' AND table_schema = '{schema}'" + ) + + def _parse_type( + self, col_name: str, type_repr: str, datetime_precision: int = None, numeric_precision: int = None + ) -> ColType: + timestamp_regexps = { + r"timestamp\((\d)\)": Timestamp, + r"timestamp\((\d)\) with time zone": TimestampTZ, + } + for regexp, cls in timestamp_regexps.items(): + m = re.match(regexp + "$", type_repr) + if m: + datetime_precision = int(m.group(1)) + return cls( + precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION, + rounds=False, + ) + + number_regexps = {r"decimal\((\d+),(\d+)\)": Decimal} + for regexp, cls in number_regexps.items(): + m = re.match(regexp + "$", type_repr) + if m: + prec, scale = map(int, m.groups()) + return cls(scale) + + cls = self.NUMERIC_TYPES.get(type_repr) + if cls: + if issubclass(cls, Integer): + assert numeric_precision is not None + return cls(0) + + assert issubclass(cls, Float) + return cls( + precision=self._convert_db_precision_to_digits( + numeric_precision if numeric_precision is not None else DEFAULT_NUMERIC_PRECISION + ) + ) + + return UnknownColType(type_repr) diff --git a/data_diff/databases/redshift.py b/data_diff/databases/redshift.py new file mode 100644 index 00000000..bb8cbddd --- /dev/null +++ b/data_diff/databases/redshift.py @@ -0,0 +1,44 @@ +from .database_types import * +from .postgresql import PostgreSQL, MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, TIMESTAMP_PRECISION_POS + + +class Redshift(PostgreSQL): + NUMERIC_TYPES = { + **PostgreSQL.NUMERIC_TYPES, + "double": Float, + "real": Float, + } + + def md5_to_int(self, s: str) -> str: + return f"strtol(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16)::decimal(38)" + + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: + if coltype.rounds: + timestamp = f"{value}::timestamp(6)" + # Get seconds since epoch. Redshift doesn't support milli- or micro-seconds. + secs = f"timestamp 'epoch' + round(extract(epoch from {timestamp})::decimal(38)" + # Get the milliseconds from timestamp. + ms = f"extract(ms from {timestamp})" + # Get the microseconds from timestamp, without the milliseconds! + us = f"extract(us from {timestamp})" + # epoch = Total time since epoch in microseconds. + epoch = f"{secs}*1000000 + {ms}*1000 + {us}" + timestamp6 = ( + f"to_char({epoch}, -6+{coltype.precision}) * interval '0.000001 seconds', 'YYYY-mm-dd HH24:MI:SS.US')" + ) + else: + timestamp6 = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')" + return ( + f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" + ) + + def normalize_number(self, value: str, coltype: NumericType) -> str: + return self.to_string(f"{value}::decimal(38,{coltype.precision})") + + def select_table_schema(self, path: DbPath) -> str: + schema, table = self._normalize_table_path(path) + + return ( + "SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale FROM information_schema.columns " + f"WHERE table_name = '{table.lower()}' AND table_schema = '{schema.lower()}'" + ) diff --git a/data_diff/databases/snowflake.py b/data_diff/databases/snowflake.py new file mode 100644 index 00000000..043eba40 --- /dev/null +++ b/data_diff/databases/snowflake.py @@ -0,0 +1,90 @@ +import logging + +from .database_types import * +from .base import Database, import_helper, _query_conn, CHECKSUM_MASK + + +@import_helper("snowflake") +def import_snowflake(): + import snowflake.connector + + return snowflake + + +class Snowflake(Database): + DATETIME_TYPES = { + "TIMESTAMP_NTZ": Timestamp, + "TIMESTAMP_LTZ": Timestamp, + "TIMESTAMP_TZ": TimestampTZ, + } + NUMERIC_TYPES = { + "NUMBER": Decimal, + "FLOAT": Float, + } + ROUNDS_ON_PREC_LOSS = False + + def __init__( + self, + account: str, + _port: int, + user: str, + password: str, + *, + warehouse: str, + schema: str, + database: str, + role: str = None, + **kw, + ): + snowflake = import_snowflake() + logging.getLogger("snowflake.connector").setLevel(logging.WARNING) + + # Got an error: snowflake.connector.network.RetryRequest: could not find io module state (interpreter shutdown?) + # It's a known issue: https://github.com/snowflakedb/snowflake-connector-python/issues/145 + # Found a quick solution in comments + logging.getLogger("snowflake.connector.network").disabled = True + + assert '"' not in schema, "Schema name should not contain quotes!" + self._conn = snowflake.connector.connect( + user=user, + password=password, + account=account, + role=role, + database=database, + warehouse=warehouse, + schema=f'"{schema}"', + **kw, + ) + + self.default_schema = schema + + def close(self): + self._conn.close() + + def _query(self, sql_code: str) -> list: + "Uses the standard SQL cursor interface" + return _query_conn(self._conn, sql_code) + + def quote(self, s: str): + return f'"{s}"' + + def md5_to_int(self, s: str) -> str: + return f"BITAND(md5_number_lower64({s}), {CHECKSUM_MASK})" + + def to_string(self, s: str): + return f"cast({s} as string)" + + def select_table_schema(self, path: DbPath) -> str: + schema, table = self._normalize_table_path(path) + return super().select_table_schema((schema, table)) + + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: + if coltype.rounds: + timestamp = f"to_timestamp(round(date_part(epoch_nanosecond, {value}::timestamp(9))/1000000000, {coltype.precision}))" + else: + timestamp = f"cast({value} as timestamp({coltype.precision}))" + + return f"to_char({timestamp}, 'YYYY-MM-DD HH24:MI:SS.FF6')" + + def normalize_number(self, value: str, coltype: NumericType) -> str: + return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))") diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index b5544927..52f596f9 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -12,8 +12,15 @@ from runtype import dataclass from .sql import Select, Checksum, Compare, DbPath, DbKey, DbTime, Count, TableName, Time, Min, Max -from .database import Database -from .database_types import NumericType, PrecisionType, UnknownColType, Schema, Schema_CaseInsensitive, Schema_CaseSensitive +from .databases.base import Database +from .databases.database_types import ( + NumericType, + PrecisionType, + UnknownColType, + Schema, + Schema_CaseInsensitive, + Schema_CaseSensitive, +) logger = logging.getLogger("diff_tables") @@ -34,11 +41,6 @@ def split_space(start, end, count): return list(range(start, end, (size + 1) // (count + 1)))[1 : count + 1] -def parse_table_name(t): - return tuple(t.split(".")) - - - @dataclass(frozen=False) class TableSegment: """Signifies a segment of rows (and selected columns) within a table diff --git a/data_diff/sql.py b/data_diff/sql.py index 455aec85..6946c26f 100644 --- a/data_diff/sql.py +++ b/data_diff/sql.py @@ -6,8 +6,7 @@ from runtype import dataclass -from .database_types import AbstractDatabase, DbPath, DbKey, DbTime - +from .databases.database_types import AbstractDatabase, DbPath, DbKey, DbTime class Sql: diff --git a/tests/common.py b/tests/common.py index 1fd610a0..e0b35ac1 100644 --- a/tests/common.py +++ b/tests/common.py @@ -1,6 +1,6 @@ import hashlib -from data_diff import database as db +from data_diff import databases as db import logging logging.basicConfig(level=logging.INFO) diff --git a/tests/test_database.py b/tests/test_database.py index 924925c2..717468ec 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -1,7 +1,7 @@ import unittest from .common import str_to_checksum, TEST_MYSQL_CONN_STRING -from data_diff.database import connect_to_uri +from data_diff.databases import connect_to_uri class TestDatabase(unittest.TestCase): diff --git a/tests/test_database_types.py b/tests/test_database_types.py index 2f665618..4515de29 100644 --- a/tests/test_database_types.py +++ b/tests/test_database_types.py @@ -4,14 +4,11 @@ import logging from decimal import Decimal -from parameterized import parameterized, parameterized_class -import preql +from parameterized import parameterized -from data_diff import database as db +from data_diff import databases as db from data_diff.diff_tables import TableDiffer, TableSegment -from parameterized import parameterized, parameterized_class from .common import CONN_STRINGS -import logging logging.getLogger("diff_tables").setLevel(logging.ERROR) @@ -187,11 +184,6 @@ type_pairs = [] -# => -# { source: (preql, connection) -# target: (preql, connection) -# source_type: (int, tinyint), -# target_type: (int, bigint) } for source_db, source_type_categories in DATABASE_TYPES.items(): for target_db, target_type_categories in DATABASE_TYPES.items(): for ( diff --git a/tests/test_diff_tables.py b/tests/test_diff_tables.py index 3cd97212..a5eefcd9 100644 --- a/tests/test_diff_tables.py +++ b/tests/test_diff_tables.py @@ -4,7 +4,7 @@ import preql import arrow # comes with preql -from data_diff.database import connect_to_uri +from data_diff.databases import connect_to_uri from data_diff.diff_tables import TableDiffer, TableSegment, split_space from .common import TEST_MYSQL_CONN_STRING, str_to_checksum diff --git a/tests/test_normalize_fields.py b/tests/test_normalize_fields.py index 7893022f..dc1a8c4b 100644 --- a/tests/test_normalize_fields.py +++ b/tests/test_normalize_fields.py @@ -5,9 +5,8 @@ import preql -from data_diff.database import BigQuery, MySQL, Snowflake, connect_to_uri, Oracle from data_diff.sql import Select -from data_diff import database as db +from data_diff import databases as db from .common import CONN_STRINGS @@ -41,7 +40,7 @@ def _test_dates_for_db(self, item, precision=3): sample_date3 = datetime(2021, 5, 2, 11, 23, 34, 000000, tzinfo=timezone.utc) dates = [sample_date1, sample_date2, sample_date3] - if db_id in (BigQuery, Oracle): + if db_id in (db.BigQuery, db.Oracle): # TODO BigQuery doesn't seem to support timezone for datetime dates = [d.replace(tzinfo=None) for d in dates] @@ -49,9 +48,9 @@ def _test_dates_for_db(self, item, precision=3): date_types = [t.format(p=precision) for t in DATE_TYPES[db_id]] date_type_tables = {dt: self._new_table(dt) for dt in date_types} - if db_id is BigQuery: + if db_id is db.BigQuery: date_type_tables = {dt: f"data_diff.{name}" for dt, name in date_type_tables.items()} - elif db_id is MySQL: + elif db_id is db.MySQL: pql.run_statement("SET @@session.time_zone='+00:00'") used_tables = list(date_type_tables.values()) @@ -60,7 +59,7 @@ def _test_dates_for_db(self, item, precision=3): try: for date_type, table in date_type_tables.items(): - if db_id is not Oracle: + if db_id is not db.Oracle: pql.run_statement(f"DROP TABLE IF EXISTS {table}") pql.run_statement(f"CREATE TABLE {table}(id int, v {date_type})") pql.commit() @@ -69,7 +68,7 @@ def _test_dates_for_db(self, item, precision=3): for index, date in enumerate(dates, 1): # print(f"insert into {table}(v) values ('{date}')") - if db_id is BigQuery: + if db_id is db.BigQuery: pql.run_statement( f"insert into {table}(id, v) values ({index}, cast(timestamp '{date}' as {date_type}))" ) @@ -77,14 +76,14 @@ def _test_dates_for_db(self, item, precision=3): pql.run_statement(f"insert into {table}(id, v) values ({index}, timestamp '{date}')") pql.commit() - conn = connect_to_uri(conn_string) + conn = db.connect_to_uri(conn_string) assert type(conn) is db_id # Might change in the future - if db_id is MySQL: + if db_id is db.MySQL: conn.query("SET @@session.time_zone='+00:00'", None) for date_type, table in date_type_tables.items(): - if db_id is Snowflake: + if db_id is db.Snowflake: table = table.upper() schema = conn.query_table_schema(table.split(".")) schema = {k.lower(): v for k, v in schema.items()} diff --git a/tests/test_sql.py b/tests/test_sql.py index 032d6a7d..bc4828c0 100644 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -1,6 +1,6 @@ import unittest -from data_diff.database import connect_to_uri +from data_diff.databases import connect_to_uri from data_diff.sql import Checksum, Compare, Compiler, Count, Enum, Explain, In, Select, TableName from .common import TEST_MYSQL_CONN_STRING