diff --git a/data_diff/databases/oracle.py b/data_diff/databases/oracle.py index ebfe5e36..b917e9ce 100644 --- a/data_diff/databases/oracle.py +++ b/data_diff/databases/oracle.py @@ -1,5 +1,7 @@ import re +from ..utils import match_regexps + from .database_types import * from .base import ThreadedDatabase, import_helper, ConnectError, QueryError from .base import DEFAULT_DATETIME_PRECISION, TIMESTAMP_PRECISION_POS @@ -99,14 +101,10 @@ def _parse_type( r"TIMESTAMP\((\d)\) WITH TIME ZONE": TimestampTZ, r"TIMESTAMP\((\d)\)": Timestamp, } - for regexp, t_cls in regexps.items(): - m = re.match(regexp + "$", type_repr) - if m: - datetime_precision = int(m.group(1)) - return t_cls( - precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION, - rounds=self.ROUNDS_ON_PREC_LOSS, - ) + + for m, t_cls in match_regexps(regexps, type_repr): + precision = int(m.group(1)) + return t_cls(precision=precision, rounds=self.ROUNDS_ON_PREC_LOSS) return super()._parse_type( table_name, col_name, type_repr, datetime_precision, numeric_precision, numeric_scale diff --git a/data_diff/databases/presto.py b/data_diff/databases/presto.py index 25116339..abd0793e 100644 --- a/data_diff/databases/presto.py +++ b/data_diff/databases/presto.py @@ -1,5 +1,7 @@ import re +from ..utils import match_regexps + from .database_types import * from .base import Database, import_helper, _query_conn from .base import ( @@ -94,27 +96,18 @@ def _parse_type( r"timestamp\((\d)\)": Timestamp, r"timestamp\((\d)\) with time zone": TimestampTZ, } - for regexp, t_cls in timestamp_regexps.items(): - m = re.match(regexp + "$", type_repr) - if m: - datetime_precision = int(m.group(1)) - return t_cls( - precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION, - rounds=self.ROUNDS_ON_PREC_LOSS, - ) + for m, t_cls in match_regexps(timestamp_regexps, type_repr): + precision = int(m.group(1)) + return t_cls(precision=precision, rounds=self.ROUNDS_ON_PREC_LOSS) number_regexps = {r"decimal\((\d+),(\d+)\)": Decimal} - for regexp, n_cls in number_regexps.items(): - m = re.match(regexp + "$", type_repr) - if m: - prec, scale = map(int, m.groups()) - return n_cls(scale) + for m, n_cls in match_regexps(number_regexps, type_repr): + _prec, scale = map(int, m.groups()) + return n_cls(scale) string_regexps = {r"varchar\((\d+)\)": Text, r"char\((\d+)\)": Text} - for regexp, n_cls in string_regexps.items(): - m = re.match(regexp + "$", type_repr) - if m: - return n_cls() + for m, n_cls in match_regexps(string_regexps, type_repr): + return n_cls() return super()._parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision) diff --git a/data_diff/databases/trino.py b/data_diff/databases/trino.py index d2e16c84..c3e3e581 100644 --- a/data_diff/databases/trino.py +++ b/data_diff/databases/trino.py @@ -1,13 +1,7 @@ -import re - from .database_types import * -from .base import Database, import_helper -from .base import ( - MD5_HEXDIGITS, - CHECKSUM_HEXDIGITS, - TIMESTAMP_PRECISION_POS, - DEFAULT_DATETIME_PRECISION, -) +from .presto import Presto +from .base import import_helper +from .base import TIMESTAMP_PRECISION_POS @import_helper("trino") @@ -17,49 +11,12 @@ def import_trino(): return trino -class Trino(Database): - default_schema = "public" - TYPE_CLASSES = { - # Timestamps - "timestamp with time zone": TimestampTZ, - "timestamp without time zone": Timestamp, - "timestamp": Timestamp, - # Numbers - "integer": Integer, - "bigint": Integer, - "real": Float, - "double": Float, - # Text - "varchar": Text, - } - ROUNDS_ON_PREC_LOSS = True - +class Trino(Presto): def __init__(self, **kw): trino = import_trino() self._conn = trino.dbapi.connect(**kw) - def quote(self, s: str): - return f'"{s}"' - - def md5_to_int(self, s: str) -> str: - return f"cast(from_base(substr(to_hex(md5(to_utf8({s}))), {1 + MD5_HEXDIGITS - CHECKSUM_HEXDIGITS}), 16) as decimal(38, 0))" - - def to_string(self, s: str): - return f"cast({s} as varchar)" - - def _query(self, sql_code: str) -> list: - """Uses the standard SQL cursor interface""" - c = self._conn.cursor() - c.execute(sql_code) - if sql_code.lower().startswith("select"): - return c.fetchall() - if re.match(r"(insert|create|truncate|drop)", sql_code, re.IGNORECASE): - return c.fetchone() - - def close(self): - self._conn.close() - def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: if coltype.rounds: s = f"date_format(cast({value} as timestamp({coltype.precision})), '%Y-%m-%d %H:%i:%S.%f')" @@ -70,52 +27,5 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS + coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS + 6}, '0')" ) - def normalize_number(self, value: str, coltype: FractionalType) -> str: - return self.to_string(f"cast({value} as decimal(38,{coltype.precision}))") - - def select_table_schema(self, path: DbPath) -> str: - schema, table = self._normalize_table_path(path) - - return ( - f"SELECT column_name, data_type, 3 as datetime_precision, 3 as numeric_precision FROM INFORMATION_SCHEMA.COLUMNS " - f"WHERE table_name = '{table}' AND table_schema = '{schema}'" - ) - - def _parse_type( - self, - table_path: DbPath, - col_name: str, - type_repr: str, - datetime_precision: int = None, - numeric_precision: int = None, - ) -> ColType: - timestamp_regexps = { - r"timestamp\((\d)\)": Timestamp, - r"timestamp\((\d)\) with time zone": TimestampTZ, - } - for regexp, t_cls in timestamp_regexps.items(): - m = re.match(regexp + "$", type_repr) - if m: - datetime_precision = int(m.group(1)) - return t_cls( - precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION, - rounds=self.ROUNDS_ON_PREC_LOSS, - ) - - number_regexps = {r"decimal\((\d+),(\d+)\)": Decimal} - for regexp, n_cls in number_regexps.items(): - m = re.match(regexp + "$", type_repr) - if m: - prec, scale = map(int, m.groups()) - return n_cls(scale) - - string_regexps = {r"varchar\((\d+)\)": Text, r"char\((\d+)\)": Text} - for regexp, n_cls in string_regexps.items(): - m = re.match(regexp + "$", type_repr) - if m: - return n_cls() - - return super()._parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision) - def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: return f"TRIM({value})" diff --git a/data_diff/utils.py b/data_diff/utils.py index 27813850..c481fc57 100644 --- a/data_diff/utils.py +++ b/data_diff/utils.py @@ -1,6 +1,6 @@ import re import math -from typing import Iterable, Tuple, Union, Any, Sequence +from typing import Iterable, Tuple, Union, Any, Sequence, Dict from typing import TypeVar, Generic from abc import ABC, abstractmethod from urllib.parse import urlparse @@ -225,7 +225,7 @@ def match_like(pattern: str, strs: Sequence[str]) -> Iterable[str]: def accumulate(iterable, func=operator.add, *, initial=None): - 'Return running totals' + "Return running totals" # Taken from https://docs.python.org/3/library/itertools.html#itertools.accumulate, to backport 'initial' to 3.7 it = iter(iterable) total = initial @@ -238,3 +238,10 @@ def accumulate(iterable, func=operator.add, *, initial=None): for element in it: total = func(total, element) yield total + + +def match_regexps(regexps: Dict[str, Any], s: str) -> Sequence[tuple]: + for regexp, v in regexps.items(): + m = re.match(regexp + "$", s) + if m: + yield m, v