Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Refactor - nicer regexp parsing; Trino now inherits from Presto #205

Merged
merged 1 commit into from
Aug 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions data_diff/databases/oracle.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import re

from ..utils import match_regexps

from .database_types import *
from .base import ThreadedDatabase, import_helper, ConnectError, QueryError
from .base import DEFAULT_DATETIME_PRECISION, TIMESTAMP_PRECISION_POS
Expand Down Expand Up @@ -99,14 +101,10 @@ def _parse_type(
r"TIMESTAMP\((\d)\) WITH TIME ZONE": TimestampTZ,
r"TIMESTAMP\((\d)\)": Timestamp,
}
for regexp, t_cls in regexps.items():
m = re.match(regexp + "$", type_repr)
if m:
datetime_precision = int(m.group(1))
return t_cls(
precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION,
rounds=self.ROUNDS_ON_PREC_LOSS,
)

for m, t_cls in match_regexps(regexps, type_repr):
precision = int(m.group(1))
return t_cls(precision=precision, rounds=self.ROUNDS_ON_PREC_LOSS)

return super()._parse_type(
table_name, col_name, type_repr, datetime_precision, numeric_precision, numeric_scale
Expand Down
27 changes: 10 additions & 17 deletions data_diff/databases/presto.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import re

from ..utils import match_regexps

from .database_types import *
from .base import Database, import_helper, _query_conn
from .base import (
Expand Down Expand Up @@ -94,27 +96,18 @@ def _parse_type(
r"timestamp\((\d)\)": Timestamp,
r"timestamp\((\d)\) with time zone": TimestampTZ,
}
for regexp, t_cls in timestamp_regexps.items():
m = re.match(regexp + "$", type_repr)
if m:
datetime_precision = int(m.group(1))
return t_cls(
precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION,
rounds=self.ROUNDS_ON_PREC_LOSS,
)
for m, t_cls in match_regexps(timestamp_regexps, type_repr):
precision = int(m.group(1))
return t_cls(precision=precision, rounds=self.ROUNDS_ON_PREC_LOSS)

number_regexps = {r"decimal\((\d+),(\d+)\)": Decimal}
for regexp, n_cls in number_regexps.items():
m = re.match(regexp + "$", type_repr)
if m:
prec, scale = map(int, m.groups())
return n_cls(scale)
for m, n_cls in match_regexps(number_regexps, type_repr):
_prec, scale = map(int, m.groups())
return n_cls(scale)

string_regexps = {r"varchar\((\d+)\)": Text, r"char\((\d+)\)": Text}
for regexp, n_cls in string_regexps.items():
m = re.match(regexp + "$", type_repr)
if m:
return n_cls()
for m, n_cls in match_regexps(string_regexps, type_repr):
return n_cls()

return super()._parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision)

Expand Down
98 changes: 4 additions & 94 deletions data_diff/databases/trino.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,7 @@
import re

from .database_types import *
from .base import Database, import_helper
from .base import (
MD5_HEXDIGITS,
CHECKSUM_HEXDIGITS,
TIMESTAMP_PRECISION_POS,
DEFAULT_DATETIME_PRECISION,
)
from .presto import Presto
from .base import import_helper
from .base import TIMESTAMP_PRECISION_POS


@import_helper("trino")
Expand All @@ -17,49 +11,12 @@ def import_trino():
return trino


class Trino(Database):
default_schema = "public"
TYPE_CLASSES = {
# Timestamps
"timestamp with time zone": TimestampTZ,
"timestamp without time zone": Timestamp,
"timestamp": Timestamp,
# Numbers
"integer": Integer,
"bigint": Integer,
"real": Float,
"double": Float,
# Text
"varchar": Text,
}
ROUNDS_ON_PREC_LOSS = True

class Trino(Presto):
def __init__(self, **kw):
trino = import_trino()

self._conn = trino.dbapi.connect(**kw)

def quote(self, s: str):
return f'"{s}"'

def md5_to_int(self, s: str) -> str:
return f"cast(from_base(substr(to_hex(md5(to_utf8({s}))), {1 + MD5_HEXDIGITS - CHECKSUM_HEXDIGITS}), 16) as decimal(38, 0))"

def to_string(self, s: str):
return f"cast({s} as varchar)"

def _query(self, sql_code: str) -> list:
"""Uses the standard SQL cursor interface"""
c = self._conn.cursor()
c.execute(sql_code)
if sql_code.lower().startswith("select"):
return c.fetchall()
if re.match(r"(insert|create|truncate|drop)", sql_code, re.IGNORECASE):
return c.fetchone()

def close(self):
self._conn.close()

def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
if coltype.rounds:
s = f"date_format(cast({value} as timestamp({coltype.precision})), '%Y-%m-%d %H:%i:%S.%f')"
Expand All @@ -70,52 +27,5 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS + coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS + 6}, '0')"
)

def normalize_number(self, value: str, coltype: FractionalType) -> str:
return self.to_string(f"cast({value} as decimal(38,{coltype.precision}))")

def select_table_schema(self, path: DbPath) -> str:
schema, table = self._normalize_table_path(path)

return (
f"SELECT column_name, data_type, 3 as datetime_precision, 3 as numeric_precision FROM INFORMATION_SCHEMA.COLUMNS "
f"WHERE table_name = '{table}' AND table_schema = '{schema}'"
)

def _parse_type(
self,
table_path: DbPath,
col_name: str,
type_repr: str,
datetime_precision: int = None,
numeric_precision: int = None,
) -> ColType:
timestamp_regexps = {
r"timestamp\((\d)\)": Timestamp,
r"timestamp\((\d)\) with time zone": TimestampTZ,
}
for regexp, t_cls in timestamp_regexps.items():
m = re.match(regexp + "$", type_repr)
if m:
datetime_precision = int(m.group(1))
return t_cls(
precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION,
rounds=self.ROUNDS_ON_PREC_LOSS,
)

number_regexps = {r"decimal\((\d+),(\d+)\)": Decimal}
for regexp, n_cls in number_regexps.items():
m = re.match(regexp + "$", type_repr)
if m:
prec, scale = map(int, m.groups())
return n_cls(scale)

string_regexps = {r"varchar\((\d+)\)": Text, r"char\((\d+)\)": Text}
for regexp, n_cls in string_regexps.items():
m = re.match(regexp + "$", type_repr)
if m:
return n_cls()

return super()._parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision)

def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
return f"TRIM({value})"
11 changes: 9 additions & 2 deletions data_diff/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import re
import math
from typing import Iterable, Tuple, Union, Any, Sequence
from typing import Iterable, Tuple, Union, Any, Sequence, Dict
from typing import TypeVar, Generic
from abc import ABC, abstractmethod
from urllib.parse import urlparse
Expand Down Expand Up @@ -225,7 +225,7 @@ def match_like(pattern: str, strs: Sequence[str]) -> Iterable[str]:


def accumulate(iterable, func=operator.add, *, initial=None):
'Return running totals'
"Return running totals"
# Taken from https://docs.python.org/3/library/itertools.html#itertools.accumulate, to backport 'initial' to 3.7
it = iter(iterable)
total = initial
Expand All @@ -238,3 +238,10 @@ def accumulate(iterable, func=operator.add, *, initial=None):
for element in it:
total = func(total, element)
yield total


def match_regexps(regexps: Dict[str, Any], s: str) -> Sequence[tuple]:
for regexp, v in regexps.items():
m = re.match(regexp + "$", s)
if m:
yield m, v