Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Support for UUID key column #119

Merged
merged 10 commits into from
Jun 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 77 additions & 23 deletions data_diff/databases/base.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,27 @@
from uuid import UUID
import math
import sys
import logging
from typing import Dict, Tuple, Optional, Sequence
from typing import Dict, Tuple, Optional, Sequence, Type, List
from functools import lru_cache, wraps
from concurrent.futures import ThreadPoolExecutor
import threading
from abc import abstractmethod

from .database_types import AbstractDatabase, ColType, Integer, Decimal, Float, UnknownColType
from data_diff.sql import DbPath, SqlOrStr, Compiler, Explain, Select
from data_diff.utils import is_uuid, safezip
from .database_types import (
ColType_UUID,
AbstractDatabase,
ColType,
Integer,
Decimal,
Float,
PrecisionType,
TemporalType,
UnknownColType,
Text,
)
from data_diff.sql import DbPath, SqlOrStr, Compiler, Explain, Select, TableName

logger = logging.getLogger("database")

Expand Down Expand Up @@ -62,7 +75,7 @@ class Database(AbstractDatabase):
Instanciated using :meth:`~data_diff.connect_to_uri`
"""

DATETIME_TYPES: Dict[str, type] = {}
TYPE_CLASSES: Dict[str, type] = {}
default_schema: str = None

@property
Expand Down Expand Up @@ -93,7 +106,7 @@ def query(self, sql_ast: SqlOrStr, res_type: type):
assert len(res) == 1, (sql_code, res)
return res[0]
elif getattr(res_type, "__origin__", None) is list and len(res_type.__args__) == 1:
if res_type.__args__ == (int,):
if res_type.__args__ == (int,) or res_type.__args__ == (str,):
return [_one(row) for row in res]
elif res_type.__args__ == (Tuple,):
return [tuple(row) for row in res]
Expand All @@ -109,8 +122,12 @@ def _convert_db_precision_to_digits(self, p: int) -> int:
# See: https://en.wikipedia.org/wiki/Single-precision_floating-point_format
return math.floor(math.log(2**p, 10))

def _parse_type_repr(self, type_repr: str) -> Optional[Type[ColType]]:
return self.TYPE_CLASSES.get(type_repr)

def _parse_type(
self,
table_path: DbPath,
col_name: str,
type_repr: str,
datetime_precision: int = None,
Expand All @@ -119,36 +136,38 @@ def _parse_type(
) -> ColType:
""" """

cls = self.DATETIME_TYPES.get(type_repr)
if cls:
cls = self._parse_type_repr(type_repr)
if not cls:
return UnknownColType(type_repr)

if issubclass(cls, TemporalType):
return cls(
precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION,
rounds=self.ROUNDS_ON_PREC_LOSS,
)

cls = self.NUMERIC_TYPES.get(type_repr)
if cls:
if issubclass(cls, Integer):
# Some DBs have a constant numeric_scale, so they don't report it.
# We fill in the constant, so we need to ignore it for integers.
return cls(precision=0)

elif issubclass(cls, Decimal):
if numeric_scale is None:
raise ValueError(
f"{self.name}: Unexpected numeric_scale is NULL, for column {col_name} of type {type_repr}."
)
return cls(precision=numeric_scale)
elif issubclass(cls, Integer):
return cls()

elif issubclass(cls, Decimal):
if numeric_scale is None:
raise ValueError(
f"{self.name}: Unexpected numeric_scale is NULL, for column {'.'.join(table_path)}.{col_name} of type {type_repr}."
)
return cls(precision=numeric_scale)

assert issubclass(cls, Float)
elif issubclass(cls, Float):
# assert numeric_scale is None
return cls(
precision=self._convert_db_precision_to_digits(
numeric_precision if numeric_precision is not None else DEFAULT_NUMERIC_PRECISION
)
)

return UnknownColType(type_repr)
elif issubclass(cls, Text):
return cls()

raise TypeError(f"Parsing {type_repr} returned an unknown type '{cls}'.")

def select_table_schema(self, path: DbPath) -> str:
schema, table = self._normalize_table_path(path)
Expand All @@ -167,8 +186,34 @@ def query_table_schema(self, path: DbPath, filter_columns: Optional[Sequence[str
accept = {i.lower() for i in filter_columns}
rows = [r for r in rows if r[0].lower() in accept]

col_dict: Dict[str, ColType] = {row[0]: self._parse_type(path, *row) for row in rows}

self._refine_coltypes(path, col_dict)

# Return a dict of form {name: type} after normalization
return {row[0]: self._parse_type(*row) for row in rows}
return col_dict

def _refine_coltypes(self, table_path: DbPath, col_dict: Dict[str, ColType]):
"Refine the types in the column dict, by querying the database for a sample of their values"

text_columns = [k for k, v in col_dict.items() if isinstance(v, Text)]
if not text_columns:
return

fields = [self.normalize_uuid(c, ColType_UUID()) for c in text_columns]
samples_by_row = self.query(Select(fields, TableName(table_path), limit=16), list)
samples_by_col = list(zip(*samples_by_row))
for col_name, samples in safezip(text_columns, samples_by_col):
uuid_samples = list(filter(is_uuid, samples))

if uuid_samples:
if len(uuid_samples) != len(samples):
logger.warning(
f"Mixed UUID/Non-UUID values detected in column {'.'.join(table_path)}.{col_name}, disabling UUID support."
)
else:
assert col_name in col_dict
col_dict[col_name] = ColType_UUID()

# @lru_cache()
# def get_table_schema(self, path: DbPath) -> Dict[str, ColType]:
Expand All @@ -186,6 +231,15 @@ def _normalize_table_path(self, path: DbPath) -> DbPath:
def parse_table_name(self, name: str) -> DbPath:
return parse_table_name(name)

def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None):
if offset:
raise NotImplementedError("No support for OFFSET in query")

return f"LIMIT {limit}"

def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
return f"TRIM({value})"


class ThreadedDatabase(Database):
"""Access the database through singleton threads.
Expand Down
8 changes: 5 additions & 3 deletions data_diff/databases/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,19 @@ def import_bigquery():


class BigQuery(Database):
DATETIME_TYPES = {
TYPE_CLASSES = {
# Dates
"TIMESTAMP": Timestamp,
"DATETIME": Datetime,
}
NUMERIC_TYPES = {
# Numbers
"INT64": Integer,
"INT32": Integer,
"NUMERIC": Decimal,
"BIGNUMERIC": Decimal,
"FLOAT64": Float,
"FLOAT32": Float,
# Text
"STRING": Text,
}
ROUNDS_ON_PREC_LOSS = False # Technically BigQuery doesn't allow implicit rounding or truncation

Expand Down
52 changes: 49 additions & 3 deletions data_diff/databases/database_types.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
import decimal
from abc import ABC, abstractmethod
from typing import Sequence, Optional, Tuple, Union, Dict
from typing import Sequence, Optional, Tuple, Union, Dict, Any
from datetime import datetime

from runtype import dataclass

from data_diff.utils import ArithUUID


DbPath = Tuple[str, ...]
DbKey = Union[int, str, bytes]
DbKey = Union[int, str, bytes, ArithUUID]
DbTime = datetime


class ColType:
supported = True
pass


Expand Down Expand Up @@ -50,11 +55,36 @@ class Float(FractionalType):


class Decimal(FractionalType):
@property
def python_type(self) -> type:
if self.precision == 0:
return int
return decimal.Decimal


class StringType(ColType):
pass


class IKey(ABC):
"Interface for ColType, for using a column as a key in data-diff"
python_type: type


class ColType_UUID(StringType, IKey):
python_type = ArithUUID


@dataclass
class Text(StringType):
supported = False


@dataclass
class Integer(NumericType):
class Integer(NumericType, IKey):
precision: int = 0
python_type: type = int

def __post_init__(self):
assert self.precision == 0

Expand All @@ -63,6 +93,8 @@ def __post_init__(self):
class UnknownColType(ColType):
text: str

supported = False


class AbstractDatabase(ABC):
@abstractmethod
Expand All @@ -80,6 +112,10 @@ def md5_to_int(self, s: str) -> str:
"Provide SQL for computing md5 and returning an int"
...

@abstractmethod
def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None):
...

@abstractmethod
def _query(self, sql_code: str) -> list:
"Send query to database and return result"
Expand Down Expand Up @@ -138,6 +174,14 @@ def normalize_number(self, value: str, coltype: FractionalType) -> str:
"""
...

@abstractmethod
def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
"""Creates an SQL expression, that converts 'value' to a normalized uuid.

i.e. just makes sure there is no trailing whitespace.
"""
...

def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
"""Creates an SQL expression, that converts 'value' to a normalized representation.

Expand All @@ -158,6 +202,8 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
return self.normalize_timestamp(value, coltype)
elif isinstance(coltype, FractionalType):
return self.normalize_number(value, coltype)
elif isinstance(coltype, ColType_UUID):
return self.normalize_uuid(value, coltype)
return self.to_string(value)

def _normalize_table_path(self, path: DbPath) -> DbPath:
Expand Down
9 changes: 6 additions & 3 deletions data_diff/databases/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,19 @@ def import_mysql():


class MySQL(ThreadedDatabase):
DATETIME_TYPES = {
TYPE_CLASSES = {
# Dates
"datetime": Datetime,
"timestamp": Timestamp,
}
NUMERIC_TYPES = {
# Numbers
"double": Float,
"float": Float,
"decimal": Decimal,
"int": Integer,
"bigint": Integer,
# Text
"varchar": Text,
"char": Text,
}
ROUNDS_ON_PREC_LOSS = True

Expand Down
37 changes: 20 additions & 17 deletions data_diff/databases/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,15 @@ def import_oracle():


class Oracle(ThreadedDatabase):
TYPE_CLASSES: Dict[str, type] = {
"NUMBER": Decimal,
"FLOAT": Float,
# Text
"CHAR": Text,
"NCHAR": Text,
"NVARCHAR2": Text,
"VARCHAR2": Text,
}
ROUNDS_ON_PREC_LOSS = True

def __init__(self, host, port, user, password, *, database, thread_count, **kw):
Expand Down Expand Up @@ -67,13 +76,13 @@ def normalize_number(self, value: str, coltype: FractionalType) -> str:

def _parse_type(
self,
table_name: DbPath,
col_name: str,
type_repr: str,
datetime_precision: int = None,
numeric_precision: int = None,
numeric_scale: int = None,
) -> ColType:
""" """
regexps = {
r"TIMESTAMP\((\d)\) WITH LOCAL TIME ZONE": Timestamp,
r"TIMESTAMP\((\d)\) WITH TIME ZONE": TimestampTZ,
Expand All @@ -87,20 +96,14 @@ def _parse_type(
rounds=self.ROUNDS_ON_PREC_LOSS,
)

n_cls = {
"NUMBER": Decimal,
"FLOAT": Float,
}.get(type_repr, None)
if n_cls:
if issubclass(n_cls, Decimal):
assert numeric_scale is not None, (type_repr, numeric_precision, numeric_scale)
return n_cls(precision=numeric_scale)

assert issubclass(n_cls, Float)
return n_cls(
precision=self._convert_db_precision_to_digits(
numeric_precision if numeric_precision is not None else DEFAULT_NUMERIC_PRECISION
)
)
return super()._parse_type(type_repr, col_name, type_repr, datetime_precision, numeric_precision, numeric_scale)

def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None):
if offset:
raise NotImplementedError("No support for OFFSET in query")

return f"FETCH NEXT {limit} ROWS ONLY"

return UnknownColType(type_repr)
def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
# Cast is necessary for correct MD5 (trimming not enough)
return f"CAST(TRIM({value}) AS VARCHAR(36))"
Loading