Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

List tables from schema #311

Merged
merged 2 commits into from
Nov 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions data_diff/sqeleton/abcs/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,13 @@ class AbstractMixin_Schema(ABC):
TODO: Move AbstractDatabase.query_table_schema() and friends over here
"""

def table_information(self) -> Compilable:
"Query to return a table of schema information about existing tables"
raise NotImplementedError()

@abstractmethod
def list_tables(self, like: Compilable = None) -> Compilable:
"""Query to select the list of tables in the schema.
def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable:
"""Query to select the list of tables in the schema. (query return type: table[str])

If 'like' is specified, the value is applied to the table name, using the 'like' operator.
"""
24 changes: 22 additions & 2 deletions data_diff/sqeleton/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import decimal

from ..utils import is_uuid, safezip
from ..queries import Expr, Compiler, table, Select, SKIP, Explain, Code
from ..queries import Expr, Compiler, table, Select, SKIP, Explain, Code, this
from ..abcs.database_types import (
AbstractDatabase,
AbstractDialect,
Expand All @@ -30,6 +30,8 @@
DbPath,
Boolean,
)
from ..abcs.mixins import Compilable
from ..abcs.mixins import AbstractMixin_Schema

logger = logging.getLogger("database")

Expand Down Expand Up @@ -101,6 +103,22 @@ def apply_query(callback: Callable[[str], Any], sql_code: Union[str, ThreadLocal
return callback(sql_code)


class Mixin_Schema(AbstractMixin_Schema):
def table_information(self) -> Compilable:
return table("information_schema", "tables")

def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable:
return (
self.table_information()
.where(
this.table_schema == table_schema,
this.table_name.like(like) if like is not None else SKIP,
this.table_type == "BASE TABLE",
)
.select(this.table_name)
)


class BaseDialect(AbstractDialect):
SUPPORTS_PRIMARY_KEY = False
TYPE_CLASSES: Dict[str, type] = {}
Expand Down Expand Up @@ -354,7 +372,9 @@ def _refine_coltypes(self, table_path: DbPath, col_dict: Dict[str, ColType], whe
return

fields = [Code(self.dialect.normalize_uuid(self.dialect.quote(c), String_UUID())) for c in text_columns]
samples_by_row = self.query(table(*table_path).select(*fields).where(Code(where) if where else SKIP).limit(sample_size), list)
samples_by_row = self.query(
table(*table_path).select(*fields).where(Code(where) if where else SKIP).limit(sample_size), list
)
if not samples_by_row:
raise ValueError(f"Table {table_path} is empty.")

Expand Down
19 changes: 17 additions & 2 deletions data_diff/sqeleton/databases/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
TemporalType,
Boolean,
)
from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue
from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue, AbstractMixin_Schema
from ..abcs import Compilable
from ..queries import this, table, SKIP
from .base import BaseDialect, Database, import_helper, parse_table_name, ConnectError, apply_query
from .base import TIMESTAMP_PRECISION_POS, ThreadLocalInterpreter

Expand Down Expand Up @@ -51,7 +53,20 @@ def normalize_boolean(self, value: str, coltype: Boolean) -> str:
return self.to_string(f"cast({value} as int)")


class Dialect(BaseDialect):
class Mixin_Schema(AbstractMixin_Schema):
def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable:
return (
table(table_schema, "INFORMATION_SCHEMA", "TABLES")
.where(
this.table_schema == table_schema,
this.table_name.like(like) if like is not None else SKIP,
this.table_type == "BASE TABLE",
)
.select(this.table_name)
)


class Dialect(BaseDialect, Mixin_Schema):
name = "BigQuery"
ROUNDS_ON_PREC_LOSS = False # Technically BigQuery doesn't allow implicit rounding or truncation
TYPE_CLASSES = {
Expand Down
4 changes: 2 additions & 2 deletions data_diff/sqeleton/databases/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
ThreadLocalInterpreter,
TIMESTAMP_PRECISION_POS,
)
from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS
from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, Mixin_Schema


@import_helper("duckdb")
Expand Down Expand Up @@ -54,7 +54,7 @@ def normalize_boolean(self, value: str, coltype: Boolean) -> str:
return self.to_string(f"{value}::INTEGER")


class Dialect(BaseDialect):
class Dialect(BaseDialect, Mixin_Schema):
name = "DuckDB"
ROUNDS_ON_PREC_LOSS = False
SUPPORTS_PRIMARY_KEY = True
Expand Down
4 changes: 2 additions & 2 deletions data_diff/sqeleton/databases/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
ConnectError,
BaseDialect,
)
from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, TIMESTAMP_PRECISION_POS
from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, TIMESTAMP_PRECISION_POS, Mixin_Schema


@import_helper("mysql")
Expand Down Expand Up @@ -47,7 +47,7 @@ def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
return f"TRIM(CAST({value} AS char))"


class Dialect(BaseDialect):
class Dialect(BaseDialect, Mixin_Schema):
name = "MySQL"
ROUNDS_ON_PREC_LOSS = True
SUPPORTS_PRIMARY_KEY = True
Expand Down
25 changes: 19 additions & 6 deletions data_diff/sqeleton/databases/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
TimestampTZ,
FractionalType,
)
from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue
from .base import BaseDialect, ThreadedDatabase, import_helper, ConnectError, QueryError
from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue, AbstractMixin_Schema
from ..abcs import Compilable
from ..queries import this, table, SKIP
from .base import BaseDialect, ThreadedDatabase, import_helper, ConnectError, QueryError, Mixin_Schema
from .base import TIMESTAMP_PRECISION_POS

SESSION_TIME_ZONE = None # Changed by the tests
Expand Down Expand Up @@ -57,8 +59,19 @@ def normalize_number(self, value: str, coltype: FractionalType) -> str:
format_str += "0." + "9" * (coltype.precision - 1) + "0"
return f"to_char({value}, '{format_str}')"

class Mixin_Schema(AbstractMixin_Schema):
def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable:
return (
table('ALL_TABLES')
.where(
this.OWNER == table_schema,
this.TABLE_NAME.like(like) if like is not None else SKIP,
)
.select(table_name = this.TABLE_NAME)
)


class Dialect(BaseDialect):
class Dialect(BaseDialect, Mixin_Schema):
name = "Oracle"
SUPPORTS_PRIMARY_KEY = True
TYPE_CLASSES: Dict[str, type] = {
Expand All @@ -73,7 +86,7 @@ class Dialect(BaseDialect):
ROUNDS_ON_PREC_LOSS = True

def quote(self, s: str):
return f"{s}"
return f'"{s}"'

def to_string(self, s: str):
return f"cast({s} as varchar(1024))"
Expand Down Expand Up @@ -143,7 +156,7 @@ class Oracle(ThreadedDatabase):
def __init__(self, *, host, database, thread_count, **kw):
self.kwargs = dict(dsn=f"{host}/{database}" if database else host, **kw)

self.default_schema = kw.get("user")
self.default_schema = kw.get("user").upper()

super().__init__(thread_count=thread_count)

Expand All @@ -168,5 +181,5 @@ def select_table_schema(self, path: DbPath) -> str:

return (
f"SELECT column_name, data_type, 6 as datetime_precision, data_precision as numeric_precision, data_scale as numeric_scale"
f" FROM ALL_TAB_COLUMNS WHERE table_name = '{table.upper()}' AND owner = '{schema.upper()}'"
f" FROM ALL_TAB_COLUMNS WHERE table_name = '{table}' AND owner = '{schema}'"
)
9 changes: 2 additions & 7 deletions data_diff/sqeleton/databases/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,7 @@
Boolean,
)
from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue
from .base import (
BaseDialect,
ThreadedDatabase,
import_helper,
ConnectError,
)
from .base import BaseDialect, ThreadedDatabase, import_helper, ConnectError, Mixin_Schema
from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, _CHECKSUM_BITSIZE, TIMESTAMP_PRECISION_POS

SESSION_TIME_ZONE = None # Changed by the tests
Expand Down Expand Up @@ -53,7 +48,7 @@ def normalize_boolean(self, value: str, coltype: Boolean) -> str:
return self.to_string(f"{value}::int")


class PostgresqlDialect(BaseDialect):
class PostgresqlDialect(BaseDialect, Mixin_Schema):
name = "PostgreSQL"
ROUNDS_ON_PREC_LOSS = True
SUPPORTS_PRIMARY_KEY = True
Expand Down
4 changes: 2 additions & 2 deletions data_diff/sqeleton/databases/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
Boolean,
)
from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue
from .base import BaseDialect, Database, import_helper, ThreadLocalInterpreter
from .base import BaseDialect, Database, import_helper, ThreadLocalInterpreter, Mixin_Schema
from .base import (
MD5_HEXDIGITS,
CHECKSUM_HEXDIGITS,
Expand Down Expand Up @@ -69,7 +69,7 @@ def normalize_boolean(self, value: str, coltype: Boolean) -> str:
return self.to_string(f"cast ({value} as int)")


class Dialect(BaseDialect):
class Dialect(BaseDialect, Mixin_Schema):
name = "Presto"
ROUNDS_ON_PREC_LOSS = True
TYPE_CLASSES = {
Expand Down
25 changes: 23 additions & 2 deletions data_diff/sqeleton/databases/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
DbPath,
Boolean,
)
from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue
from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue, AbstractMixin_Schema
from ..abcs import Compilable
from data_diff.sqeleton.queries import table, this, SKIP
from .base import BaseDialect, ConnectError, Database, import_helper, CHECKSUM_MASK, ThreadLocalInterpreter


Expand Down Expand Up @@ -46,7 +48,23 @@ def normalize_boolean(self, value: str, coltype: Boolean) -> str:
return self.to_string(f"{value}::int")


class Dialect(BaseDialect):
class Mixin_Schema(AbstractMixin_Schema):
def table_information(self) -> Compilable:
return table("INFORMATION_SCHEMA", "TABLES")

def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable:
return (
self.table_information()
.where(
this.TABLE_SCHEMA == table_schema,
this.TABLE_NAME.like(like) if like is not None else SKIP,
this.TABLE_TYPE == "BASE TABLE",
)
.select(table_name=this.TABLE_NAME)
)


class Dialect(BaseDialect, Mixin_Schema):
name = "Snowflake"
ROUNDS_ON_PREC_LOSS = False
TYPE_CLASSES = {
Expand All @@ -72,6 +90,9 @@ def quote(self, s: str):
def to_string(self, s: str):
return f"cast({s} as string)"

def table_information(self) -> Compilable:
return table("INFORMATION_SCHEMA", "TABLES")


class Snowflake(Database):
dialect = Dialect()
Expand Down
21 changes: 19 additions & 2 deletions data_diff/sqeleton/databases/vertica.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
Boolean,
ColType_UUID,
)
from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue
from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue, AbstractMixin_Schema
from ..abcs import Compilable
from ..queries import table, this, SKIP


@import_helper("vertica")
Expand Down Expand Up @@ -60,7 +62,22 @@ def normalize_boolean(self, value: str, coltype: Boolean) -> str:
return self.to_string(f"cast ({value} as int)")


class Dialect(BaseDialect):
class Mixin_Schema(AbstractMixin_Schema):
def table_information(self) -> Compilable:
return table("v_catalog", "tables")

def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable:
return (
self.table_information()
.where(
this.table_schema == table_schema,
this.table_name.like(like) if like is not None else SKIP,
)
.select(this.table_name)
)


class Dialect(BaseDialect, Mixin_Schema):
name = "Vertica"
ROUNDS_ON_PREC_LOSS = True

Expand Down
40 changes: 40 additions & 0 deletions tests/sqeleton/test_database.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,33 @@
from typing import Callable, List
import unittest

from ..common import str_to_checksum, TEST_MYSQL_CONN_STRING
from ..common import str_to_checksum, test_each_database_in_list, TestPerDatabase, get_conn, random_table_suffix
# from data_diff.sqeleton import databases as db
# from data_diff.sqeleton import connect

from data_diff.sqeleton.queries import table

from data_diff import databases as dbs
from data_diff.databases import connect


TEST_DATABASES = {
dbs.MySQL,
dbs.PostgreSQL,
dbs.Oracle,
dbs.Redshift,
dbs.Snowflake,
dbs.DuckDB,
dbs.BigQuery,
dbs.Presto,
dbs.Trino,
dbs.Vertica,
}

test_each_database: Callable = test_each_database_in_list(TEST_DATABASES)


class TestDatabase(unittest.TestCase):
def setUp(self):
self.mysql = connect(TEST_MYSQL_CONN_STRING)
Expand All @@ -25,3 +49,19 @@ def test_bad_uris(self):
self.assertRaises(ValueError, connect, "postgresql:///bla/foo")
self.assertRaises(ValueError, connect, "snowflake://user:pass@bya42734/xdiffdev/TEST1")
self.assertRaises(ValueError, connect, "snowflake://user:pass@bya42734/xdiffdev/TEST1?warehouse=ha&schema=dup")


@test_each_database
class TestSchema(TestPerDatabase):
def test_table_list(self):
name = self.table_src_name
db = self.connection
tbl = table(db.parse_table_name(name), schema={'id': int})
q = db.dialect.list_tables(db.default_schema, name)
assert not db.query(q)

db.query(tbl.create())
self.assertEqual( db.query(q, List[str] ), [name])

db.query( tbl.drop() )
assert not db.query(q)
5 changes: 3 additions & 2 deletions tests/test_database_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,12 +602,13 @@ def _create_indexes(conn, table):

try:
if_not_exists = "IF NOT EXISTS" if not isinstance(conn, (db.MySQL, db.Oracle)) else ""
quote = conn.dialect.quote
conn.query(
f"CREATE INDEX {if_not_exists} xa_{table[1:-1]} ON {table} (id, col)",
f"CREATE INDEX {if_not_exists} xa_{table[1:-1]} ON {table} ({quote('id')}, {quote('col')})",
None,
)
conn.query(
f"CREATE INDEX {if_not_exists} xb_{table[1:-1]} ON {table} (id)",
f"CREATE INDEX {if_not_exists} xb_{table[1:-1]} ON {table} ({quote('id')})",
None,
)
except Exception as err:
Expand Down