From 49e8a4f75e3580dfc5886b2ee329dd1966fb937e Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Tue, 22 Nov 2022 15:58:31 -0300 Subject: [PATCH 1/2] List tables from schema, mid-work (WIP) --- data_diff/sqeleton/abcs/mixins.py | 8 +++-- data_diff/sqeleton/databases/base.py | 24 +++++++++++-- data_diff/sqeleton/databases/bigquery.py | 19 ++++++++-- data_diff/sqeleton/databases/duckdb.py | 4 +-- data_diff/sqeleton/databases/mysql.py | 4 +-- data_diff/sqeleton/databases/postgresql.py | 9 ++--- data_diff/sqeleton/databases/presto.py | 4 +-- data_diff/sqeleton/databases/snowflake.py | 25 ++++++++++++-- data_diff/sqeleton/databases/vertica.py | 21 ++++++++++-- tests/sqeleton/test_database.py | 40 ++++++++++++++++++++++ 10 files changed, 135 insertions(+), 23 deletions(-) diff --git a/data_diff/sqeleton/abcs/mixins.py b/data_diff/sqeleton/abcs/mixins.py index 774dfa30..764c86e6 100644 --- a/data_diff/sqeleton/abcs/mixins.py +++ b/data_diff/sqeleton/abcs/mixins.py @@ -97,9 +97,13 @@ class AbstractMixin_Schema(ABC): TODO: Move AbstractDatabase.query_table_schema() and friends over here """ + def table_information(self) -> Compilable: + "Query to return a table of schema information about existing tables" + raise NotImplementedError() + @abstractmethod - def list_tables(self, like: Compilable = None) -> Compilable: - """Query to select the list of tables in the schema. + def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: + """Query to select the list of tables in the schema. (query return type: table[str]) If 'like' is specified, the value is applied to the table name, using the 'like' operator. """ diff --git a/data_diff/sqeleton/databases/base.py b/data_diff/sqeleton/databases/base.py index de1bcb13..fa892e69 100644 --- a/data_diff/sqeleton/databases/base.py +++ b/data_diff/sqeleton/databases/base.py @@ -11,7 +11,7 @@ import decimal from ..utils import is_uuid, safezip -from ..queries import Expr, Compiler, table, Select, SKIP, Explain, Code +from ..queries import Expr, Compiler, table, Select, SKIP, Explain, Code, this from ..abcs.database_types import ( AbstractDatabase, AbstractDialect, @@ -30,6 +30,8 @@ DbPath, Boolean, ) +from ..abcs.mixins import Compilable +from ..abcs.mixins import AbstractMixin_Schema logger = logging.getLogger("database") @@ -101,6 +103,22 @@ def apply_query(callback: Callable[[str], Any], sql_code: Union[str, ThreadLocal return callback(sql_code) +class Mixin_Schema(AbstractMixin_Schema): + def table_information(self) -> Compilable: + return table("information_schema", "tables") + + def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: + return ( + self.table_information() + .where( + this.table_schema == table_schema, + this.table_name.like(like) if like is not None else SKIP, + this.table_type == "BASE TABLE", + ) + .select(this.table_name) + ) + + class BaseDialect(AbstractDialect): SUPPORTS_PRIMARY_KEY = False TYPE_CLASSES: Dict[str, type] = {} @@ -354,7 +372,9 @@ def _refine_coltypes(self, table_path: DbPath, col_dict: Dict[str, ColType], whe return fields = [Code(self.dialect.normalize_uuid(self.dialect.quote(c), String_UUID())) for c in text_columns] - samples_by_row = self.query(table(*table_path).select(*fields).where(Code(where) if where else SKIP).limit(sample_size), list) + samples_by_row = self.query( + table(*table_path).select(*fields).where(Code(where) if where else SKIP).limit(sample_size), list + ) if not samples_by_row: raise ValueError(f"Table {table_path} is empty.") diff --git a/data_diff/sqeleton/databases/bigquery.py b/data_diff/sqeleton/databases/bigquery.py index 07520bf2..a8e715cb 100644 --- a/data_diff/sqeleton/databases/bigquery.py +++ b/data_diff/sqeleton/databases/bigquery.py @@ -11,7 +11,9 @@ TemporalType, Boolean, ) -from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue +from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue, AbstractMixin_Schema +from ..abcs import Compilable +from ..queries import this, table, SKIP from .base import BaseDialect, Database, import_helper, parse_table_name, ConnectError, apply_query from .base import TIMESTAMP_PRECISION_POS, ThreadLocalInterpreter @@ -51,7 +53,20 @@ def normalize_boolean(self, value: str, coltype: Boolean) -> str: return self.to_string(f"cast({value} as int)") -class Dialect(BaseDialect): +class Mixin_Schema(AbstractMixin_Schema): + def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: + return ( + table(table_schema, "INFORMATION_SCHEMA", "TABLES") + .where( + this.table_schema == table_schema, + this.table_name.like(like) if like is not None else SKIP, + this.table_type == "BASE TABLE", + ) + .select(this.table_name) + ) + + +class Dialect(BaseDialect, Mixin_Schema): name = "BigQuery" ROUNDS_ON_PREC_LOSS = False # Technically BigQuery doesn't allow implicit rounding or truncation TYPE_CLASSES = { diff --git a/data_diff/sqeleton/databases/duckdb.py b/data_diff/sqeleton/databases/duckdb.py index 1fefcf1d..ff9f5c0d 100644 --- a/data_diff/sqeleton/databases/duckdb.py +++ b/data_diff/sqeleton/databases/duckdb.py @@ -24,7 +24,7 @@ ThreadLocalInterpreter, TIMESTAMP_PRECISION_POS, ) -from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS +from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, Mixin_Schema @import_helper("duckdb") @@ -54,7 +54,7 @@ def normalize_boolean(self, value: str, coltype: Boolean) -> str: return self.to_string(f"{value}::INTEGER") -class Dialect(BaseDialect): +class Dialect(BaseDialect, Mixin_Schema): name = "DuckDB" ROUNDS_ON_PREC_LOSS = False SUPPORTS_PRIMARY_KEY = True diff --git a/data_diff/sqeleton/databases/mysql.py b/data_diff/sqeleton/databases/mysql.py index 5fe13072..bbd21241 100644 --- a/data_diff/sqeleton/databases/mysql.py +++ b/data_diff/sqeleton/databases/mysql.py @@ -17,7 +17,7 @@ ConnectError, BaseDialect, ) -from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, TIMESTAMP_PRECISION_POS +from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, TIMESTAMP_PRECISION_POS, Mixin_Schema @import_helper("mysql") @@ -47,7 +47,7 @@ def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: return f"TRIM(CAST({value} AS char))" -class Dialect(BaseDialect): +class Dialect(BaseDialect, Mixin_Schema): name = "MySQL" ROUNDS_ON_PREC_LOSS = True SUPPORTS_PRIMARY_KEY = True diff --git a/data_diff/sqeleton/databases/postgresql.py b/data_diff/sqeleton/databases/postgresql.py index d313fe73..18412804 100644 --- a/data_diff/sqeleton/databases/postgresql.py +++ b/data_diff/sqeleton/databases/postgresql.py @@ -11,12 +11,7 @@ Boolean, ) from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue -from .base import ( - BaseDialect, - ThreadedDatabase, - import_helper, - ConnectError, -) +from .base import BaseDialect, ThreadedDatabase, import_helper, ConnectError, Mixin_Schema from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, _CHECKSUM_BITSIZE, TIMESTAMP_PRECISION_POS SESSION_TIME_ZONE = None # Changed by the tests @@ -53,7 +48,7 @@ def normalize_boolean(self, value: str, coltype: Boolean) -> str: return self.to_string(f"{value}::int") -class PostgresqlDialect(BaseDialect): +class PostgresqlDialect(BaseDialect, Mixin_Schema): name = "PostgreSQL" ROUNDS_ON_PREC_LOSS = True SUPPORTS_PRIMARY_KEY = True diff --git a/data_diff/sqeleton/databases/presto.py b/data_diff/sqeleton/databases/presto.py index da3d0404..76af62cf 100644 --- a/data_diff/sqeleton/databases/presto.py +++ b/data_diff/sqeleton/databases/presto.py @@ -19,7 +19,7 @@ Boolean, ) from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue -from .base import BaseDialect, Database, import_helper, ThreadLocalInterpreter +from .base import BaseDialect, Database, import_helper, ThreadLocalInterpreter, Mixin_Schema from .base import ( MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, @@ -69,7 +69,7 @@ def normalize_boolean(self, value: str, coltype: Boolean) -> str: return self.to_string(f"cast ({value} as int)") -class Dialect(BaseDialect): +class Dialect(BaseDialect, Mixin_Schema): name = "Presto" ROUNDS_ON_PREC_LOSS = True TYPE_CLASSES = { diff --git a/data_diff/sqeleton/databases/snowflake.py b/data_diff/sqeleton/databases/snowflake.py index a9b88cd2..4ea1fb92 100644 --- a/data_diff/sqeleton/databases/snowflake.py +++ b/data_diff/sqeleton/databases/snowflake.py @@ -12,7 +12,9 @@ DbPath, Boolean, ) -from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue +from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue, AbstractMixin_Schema +from ..abcs import Compilable +from data_diff.sqeleton.queries import table, this, SKIP from .base import BaseDialect, ConnectError, Database, import_helper, CHECKSUM_MASK, ThreadLocalInterpreter @@ -46,7 +48,23 @@ def normalize_boolean(self, value: str, coltype: Boolean) -> str: return self.to_string(f"{value}::int") -class Dialect(BaseDialect): +class Mixin_Schema(AbstractMixin_Schema): + def table_information(self) -> Compilable: + return table("INFORMATION_SCHEMA", "TABLES") + + def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: + return ( + self.table_information() + .where( + this.TABLE_SCHEMA == table_schema, + this.TABLE_NAME.like(like) if like is not None else SKIP, + this.TABLE_TYPE == "BASE TABLE", + ) + .select(table_name=this.TABLE_NAME) + ) + + +class Dialect(BaseDialect, Mixin_Schema): name = "Snowflake" ROUNDS_ON_PREC_LOSS = False TYPE_CLASSES = { @@ -72,6 +90,9 @@ def quote(self, s: str): def to_string(self, s: str): return f"cast({s} as string)" + def table_information(self) -> Compilable: + return table("INFORMATION_SCHEMA", "TABLES") + class Snowflake(Database): dialect = Dialect() diff --git a/data_diff/sqeleton/databases/vertica.py b/data_diff/sqeleton/databases/vertica.py index 86da1b78..c01e9544 100644 --- a/data_diff/sqeleton/databases/vertica.py +++ b/data_diff/sqeleton/databases/vertica.py @@ -24,7 +24,9 @@ Boolean, ColType_UUID, ) -from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue +from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue, AbstractMixin_Schema +from ..abcs import Compilable +from ..queries import table, this, SKIP @import_helper("vertica") @@ -60,7 +62,22 @@ def normalize_boolean(self, value: str, coltype: Boolean) -> str: return self.to_string(f"cast ({value} as int)") -class Dialect(BaseDialect): +class Mixin_Schema(AbstractMixin_Schema): + def table_information(self) -> Compilable: + return table("v_catalog", "tables") + + def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: + return ( + self.table_information() + .where( + this.table_schema == table_schema, + this.table_name.like(like) if like is not None else SKIP, + ) + .select(this.table_name) + ) + + +class Dialect(BaseDialect, Mixin_Schema): name = "Vertica" ROUNDS_ON_PREC_LOSS = True diff --git a/tests/sqeleton/test_database.py b/tests/sqeleton/test_database.py index 1577f76f..33c3e128 100644 --- a/tests/sqeleton/test_database.py +++ b/tests/sqeleton/test_database.py @@ -1,9 +1,33 @@ +from typing import Callable, List import unittest from ..common import str_to_checksum, TEST_MYSQL_CONN_STRING +from ..common import str_to_checksum, test_each_database_in_list, TestPerDatabase, get_conn, random_table_suffix +# from data_diff.sqeleton import databases as db +# from data_diff.sqeleton import connect + +from data_diff.sqeleton.queries import table + +from data_diff import databases as dbs from data_diff.databases import connect +TEST_DATABASES = { + dbs.MySQL, + dbs.PostgreSQL, + # dbs.Oracle, + # dbs.Redshift, + dbs.Snowflake, + dbs.DuckDB, + dbs.BigQuery, + dbs.Presto, + dbs.Trino, + dbs.Vertica, +} + +test_each_database: Callable = test_each_database_in_list(TEST_DATABASES) + + class TestDatabase(unittest.TestCase): def setUp(self): self.mysql = connect(TEST_MYSQL_CONN_STRING) @@ -25,3 +49,19 @@ def test_bad_uris(self): self.assertRaises(ValueError, connect, "postgresql:///bla/foo") self.assertRaises(ValueError, connect, "snowflake://user:pass@bya42734/xdiffdev/TEST1") self.assertRaises(ValueError, connect, "snowflake://user:pass@bya42734/xdiffdev/TEST1?warehouse=ha&schema=dup") + + +@test_each_database +class TestSchema(TestPerDatabase): + def test_table_list(self): + name = self.table_src_name + db = self.connection + tbl = table(db.parse_table_name(name), schema={'id': int}) + q = db.dialect.list_tables(db.default_schema, name) + assert not db.query(q) + + db.query(tbl.create()) + assert db.query(q, List[str] ) == [name] + + db.query( tbl.drop() ) + assert not db.query(q) From 3d6632526323136fde1040af672d5482e96bd256 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Wed, 23 Nov 2022 14:43:16 -0300 Subject: [PATCH 2/2] Add Oracle, Redshift --- data_diff/sqeleton/databases/oracle.py | 25 +++++++++++++++++++------ tests/sqeleton/test_database.py | 6 +++--- tests/test_database_types.py | 5 +++-- 3 files changed, 25 insertions(+), 11 deletions(-) diff --git a/data_diff/sqeleton/databases/oracle.py b/data_diff/sqeleton/databases/oracle.py index c7003f05..e9f77930 100644 --- a/data_diff/sqeleton/databases/oracle.py +++ b/data_diff/sqeleton/databases/oracle.py @@ -14,8 +14,10 @@ TimestampTZ, FractionalType, ) -from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue -from .base import BaseDialect, ThreadedDatabase, import_helper, ConnectError, QueryError +from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue, AbstractMixin_Schema +from ..abcs import Compilable +from ..queries import this, table, SKIP +from .base import BaseDialect, ThreadedDatabase, import_helper, ConnectError, QueryError, Mixin_Schema from .base import TIMESTAMP_PRECISION_POS SESSION_TIME_ZONE = None # Changed by the tests @@ -57,8 +59,19 @@ def normalize_number(self, value: str, coltype: FractionalType) -> str: format_str += "0." + "9" * (coltype.precision - 1) + "0" return f"to_char({value}, '{format_str}')" +class Mixin_Schema(AbstractMixin_Schema): + def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: + return ( + table('ALL_TABLES') + .where( + this.OWNER == table_schema, + this.TABLE_NAME.like(like) if like is not None else SKIP, + ) + .select(table_name = this.TABLE_NAME) + ) + -class Dialect(BaseDialect): +class Dialect(BaseDialect, Mixin_Schema): name = "Oracle" SUPPORTS_PRIMARY_KEY = True TYPE_CLASSES: Dict[str, type] = { @@ -73,7 +86,7 @@ class Dialect(BaseDialect): ROUNDS_ON_PREC_LOSS = True def quote(self, s: str): - return f"{s}" + return f'"{s}"' def to_string(self, s: str): return f"cast({s} as varchar(1024))" @@ -143,7 +156,7 @@ class Oracle(ThreadedDatabase): def __init__(self, *, host, database, thread_count, **kw): self.kwargs = dict(dsn=f"{host}/{database}" if database else host, **kw) - self.default_schema = kw.get("user") + self.default_schema = kw.get("user").upper() super().__init__(thread_count=thread_count) @@ -168,5 +181,5 @@ def select_table_schema(self, path: DbPath) -> str: return ( f"SELECT column_name, data_type, 6 as datetime_precision, data_precision as numeric_precision, data_scale as numeric_scale" - f" FROM ALL_TAB_COLUMNS WHERE table_name = '{table.upper()}' AND owner = '{schema.upper()}'" + f" FROM ALL_TAB_COLUMNS WHERE table_name = '{table}' AND owner = '{schema}'" ) diff --git a/tests/sqeleton/test_database.py b/tests/sqeleton/test_database.py index 33c3e128..38c6cb6f 100644 --- a/tests/sqeleton/test_database.py +++ b/tests/sqeleton/test_database.py @@ -15,8 +15,8 @@ TEST_DATABASES = { dbs.MySQL, dbs.PostgreSQL, - # dbs.Oracle, - # dbs.Redshift, + dbs.Oracle, + dbs.Redshift, dbs.Snowflake, dbs.DuckDB, dbs.BigQuery, @@ -61,7 +61,7 @@ def test_table_list(self): assert not db.query(q) db.query(tbl.create()) - assert db.query(q, List[str] ) == [name] + self.assertEqual( db.query(q, List[str] ), [name]) db.query( tbl.drop() ) assert not db.query(q) diff --git a/tests/test_database_types.py b/tests/test_database_types.py index 2333ee94..991c41cc 100644 --- a/tests/test_database_types.py +++ b/tests/test_database_types.py @@ -602,12 +602,13 @@ def _create_indexes(conn, table): try: if_not_exists = "IF NOT EXISTS" if not isinstance(conn, (db.MySQL, db.Oracle)) else "" + quote = conn.dialect.quote conn.query( - f"CREATE INDEX {if_not_exists} xa_{table[1:-1]} ON {table} (id, col)", + f"CREATE INDEX {if_not_exists} xa_{table[1:-1]} ON {table} ({quote('id')}, {quote('col')})", None, ) conn.query( - f"CREATE INDEX {if_not_exists} xb_{table[1:-1]} ON {table} (id)", + f"CREATE INDEX {if_not_exists} xb_{table[1:-1]} ON {table} ({quote('id')})", None, ) except Exception as err: