Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-37744: Make felis compatible with sqlalchemy 2 #18

Merged
merged 4 commits into from
Feb 11, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 4 additions & 2 deletions python/felis/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
import click
import yaml
from pyld import jsonld
from sqlalchemy import create_engine
from sqlalchemy.engine import Engine, create_engine, create_mock_engine, make_url
from sqlalchemy.engine.mock import MockConnection

from . import DEFAULT_CONTEXT, DEFAULT_FRAME, __version__
from .check import CheckingVisitor
Expand Down Expand Up @@ -61,11 +62,12 @@ def create_all(engine_url: str, schema_name: str, dry_run: bool, file: io.TextIO

metadata = schema.metadata

engine: Engine | MockConnection
if not dry_run:
engine = create_engine(engine_url)
else:
_insert_dump = InsertDump()
engine = create_engine(engine_url, strategy="mock", executor=_insert_dump.dump)
engine = create_mock_engine(make_url(engine_url), executor=_insert_dump.dump)
_insert_dump.dialect = engine.dialect
metadata.create_all(engine)

Expand Down
64 changes: 33 additions & 31 deletions python/felis/db/sqltypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.

import builtins
from collections.abc import Mapping, MutableMapping
from typing import Any
from collections.abc import Mapping
from typing import Any, Union

from sqlalchemy import Float, SmallInteger, types
from sqlalchemy.dialects import mysql, oracle, postgresql
Expand Down Expand Up @@ -55,121 +55,123 @@ def compile_double(type_: Any, compiler: Any, **kw: Any) -> str:
return "DOUBLE"


boolean_map = {MYSQL: mysql.BIT(1), ORACLE: oracle.NUMBER(1), POSTGRES: postgresql.BOOLEAN()}
_TypeMap = Mapping[str, Union[types.TypeEngine, type[types.TypeEngine]]]

byte_map = {
boolean_map: _TypeMap = {MYSQL: mysql.BIT(1), ORACLE: oracle.NUMBER(1), POSTGRES: postgresql.BOOLEAN()}

byte_map: _TypeMap = {
MYSQL: mysql.TINYINT(),
ORACLE: oracle.NUMBER(3),
POSTGRES: postgresql.SMALLINT(),
}

short_map = {
short_map: _TypeMap = {
MYSQL: mysql.SMALLINT(),
ORACLE: oracle.NUMBER(5),
POSTGRES: postgresql.SMALLINT(),
}

# Skip Oracle
int_map = {
int_map: _TypeMap = {
MYSQL: mysql.INTEGER(),
POSTGRES: postgresql.INTEGER(),
}

long_map = {
long_map: _TypeMap = {
MYSQL: mysql.BIGINT(),
ORACLE: oracle.NUMBER(38, 0),
POSTGRES: postgresql.BIGINT(),
}

float_map = {
float_map: _TypeMap = {
MYSQL: mysql.FLOAT(),
ORACLE: oracle.BINARY_FLOAT(),
POSTGRES: postgresql.FLOAT(),
}

double_map = {
double_map: _TypeMap = {
MYSQL: mysql.DOUBLE(),
ORACLE: oracle.BINARY_DOUBLE(),
POSTGRES: postgresql.DOUBLE_PRECISION(),
}

char_map = {
char_map: _TypeMap = {
MYSQL: mysql.CHAR,
ORACLE: oracle.CHAR,
POSTGRES: postgresql.CHAR,
}

string_map = {
string_map: _TypeMap = {
MYSQL: mysql.VARCHAR,
ORACLE: oracle.VARCHAR2,
POSTGRES: postgresql.VARCHAR,
}

unicode_map = {
unicode_map: _TypeMap = {
MYSQL: mysql.NVARCHAR,
ORACLE: oracle.NVARCHAR2,
POSTGRES: postgresql.VARCHAR,
}

text_map = {
text_map: _TypeMap = {
MYSQL: mysql.LONGTEXT,
ORACLE: oracle.CLOB,
POSTGRES: postgresql.TEXT,
}

binary_map = {
binary_map: _TypeMap = {
MYSQL: mysql.LONGBLOB,
ORACLE: oracle.BLOB,
POSTGRES: postgresql.BYTEA,
}


def boolean(**kwargs: Any) -> types.TypeEngine:
return _vary(types.BOOLEAN(), boolean_map.copy(), kwargs)
return _vary(types.BOOLEAN(), boolean_map, kwargs)


def byte(**kwargs: Any) -> types.TypeEngine:
return _vary(TINYINT(), byte_map.copy(), kwargs)
return _vary(TINYINT(), byte_map, kwargs)


def short(**kwargs: Any) -> types.TypeEngine:
return _vary(types.SMALLINT(), short_map.copy(), kwargs)
return _vary(types.SMALLINT(), short_map, kwargs)


def int(**kwargs: Any) -> types.TypeEngine:
return _vary(types.INTEGER(), int_map.copy(), kwargs)
return _vary(types.INTEGER(), int_map, kwargs)


def long(**kwargs: Any) -> types.TypeEngine:
return _vary(types.BIGINT(), long_map.copy(), kwargs)
return _vary(types.BIGINT(), long_map, kwargs)


def float(**kwargs: Any) -> types.TypeEngine:
return _vary(types.FLOAT(), float_map.copy(), kwargs)
return _vary(types.FLOAT(), float_map, kwargs)


def double(**kwargs: Any) -> types.TypeEngine:
return _vary(DOUBLE(), double_map.copy(), kwargs)
return _vary(DOUBLE(), double_map, kwargs)


def char(length: builtins.int, **kwargs: Any) -> types.TypeEngine:
return _vary(types.CHAR(length), char_map.copy(), kwargs, length)
return _vary(types.CHAR(length), char_map, kwargs, length)


def string(length: builtins.int, **kwargs: Any) -> types.TypeEngine:
return _vary(types.VARCHAR(length), string_map.copy(), kwargs, length)
return _vary(types.VARCHAR(length), string_map, kwargs, length)


def unicode(length: builtins.int, **kwargs: Any) -> types.TypeEngine:
return _vary(types.NVARCHAR(length), unicode_map.copy(), kwargs, length)
return _vary(types.NVARCHAR(length), unicode_map, kwargs, length)


def text(length: builtins.int, **kwargs: Any) -> types.TypeEngine:
return _vary(types.CLOB(length), text_map.copy(), kwargs, length)
return _vary(types.CLOB(length), text_map, kwargs, length)


def binary(length: builtins.int, **kwargs: Any) -> types.TypeEngine:
return _vary(types.BLOB(length), binary_map.copy(), kwargs, length)
return _vary(types.BLOB(length), binary_map, kwargs, length)


def timestamp(**kwargs: Any) -> types.TypeEngine:
Expand All @@ -178,13 +180,13 @@ def timestamp(**kwargs: Any) -> types.TypeEngine:

def _vary(
type_: types.TypeEngine,
variant_map: MutableMapping[str, types.TypeEngine],
overrides: Mapping[str, types.TypeEngine],
variant_map: _TypeMap,
overrides: _TypeMap,
*args: Any,
) -> types.TypeEngine:
for dialect, variant in overrides.items():
variant_map[dialect] = variant
for dialect, variant in variant_map.items():
variants: dict[str, Union[types.TypeEngine, type[types.TypeEngine]]] = dict(variant_map)
variants.update(overrides)
for dialect, variant in variants.items():
# If this is a class and not an instance, instantiate
if isinstance(variant, type):
variant = variant(*args)
Expand Down
7 changes: 3 additions & 4 deletions python/felis/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@


class Schema(NamedTuple):

name: Optional[str]
tables: list[Table]
metadata: MetaData
Expand Down Expand Up @@ -126,11 +125,11 @@ def visit_table(self, table_obj: _Mapping, schema_obj: _Mapping) -> Table:
if primary_key:
table.append_constraint(primary_key)

constraints = [self.visit_constraint(c, table) for c in table_obj.get("constraints", [])]
constraints = [self.visit_constraint(c, table_obj) for c in table_obj.get("constraints", [])]
for constraint in constraints:
table.append_constraint(constraint)

indexes = [self.visit_index(i, table) for i in table_obj.get("indexes", [])]
indexes = [self.visit_index(i, table_obj) for i in table_obj.get("indexes", [])]
for index in indexes:
# FIXME: Hack because there's no table.add_index
index._set_parent(table)
Expand Down Expand Up @@ -170,7 +169,7 @@ def visit_column(self, column_obj: _Mapping, table_obj: _Mapping) -> Column:
column_nullable = column_obj.get("nullable", nullable_default)
column_autoincrement = column_obj.get("autoincrement", "auto")

column = Column(
column: Column = Column(
column_name,
datatype,
comment=column_description,
Expand Down
28 changes: 14 additions & 14 deletions python/felis/tap.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from sqlalchemy import Column, Integer, String
from sqlalchemy.engine import Engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import DeclarativeMeta, Session, sessionmaker
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.schema import MetaData
from sqlalchemy.sql.expression import Insert, insert

Expand All @@ -40,7 +40,7 @@

_Mapping = Mapping[str, Any]

Tap11Base: DeclarativeMeta = declarative_base()
Tap11Base: Any = declarative_base() # Any to avoid mypy mess with SA 2
logger = logging.getLogger("felis")

IDENTIFIER_LENGTH = 128
Expand Down Expand Up @@ -172,17 +172,17 @@ def visit_schema(self, schema_obj: _Mapping) -> None:
session.commit()
else:
# Only if we are mocking (dry run)
conn = self.engine
conn.execute(_insert(self.tables["schemas"], schema))
for table_obj in schema_obj["tables"]:
table, columns, keys, key_columns = self.visit_table(table_obj, schema_obj)
conn.execute(_insert(self.tables["tables"], table))
for column in columns:
conn.execute(_insert(self.tables["columns"], column))
for key in keys:
conn.execute(_insert(self.tables["keys"], key))
for key_column in key_columns:
conn.execute(_insert(self.tables["key_columns"], key_column))
with self.engine.begin() as conn:
conn.execute(_insert(self.tables["schemas"], schema))
for table_obj in schema_obj["tables"]:
table, columns, keys, key_columns = self.visit_table(table_obj, schema_obj)
conn.execute(_insert(self.tables["tables"], table))
for column in columns:
conn.execute(_insert(self.tables["columns"], column))
for key in keys:
conn.execute(_insert(self.tables["keys"], key))
for key_column in key_columns:
conn.execute(_insert(self.tables["key_columns"], key_column))

def visit_table(self, table_obj: _Mapping, schema_obj: _Mapping) -> tuple:
self.checker.check_table(table_obj, schema_obj)
Expand Down Expand Up @@ -370,4 +370,4 @@ def _insert(table: Tap11Base, value: Any) -> Insert:
if type(column_value) == str:
column_value = column_value.replace("'", "''")
values_dict[name] = column_value
return insert(table, values=values_dict)
return insert(table).values(values_dict)
1 change: 0 additions & 1 deletion python/felis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def visit_schema(self, schema_obj: _MutableMapping) -> _Mapping:
return _new_order(schema_obj, ["@context", "name", "@id", "@type", "description", "tables"])

def visit_table(self, table_obj: _MutableMapping, schema_obj: _Mapping) -> _Mapping:

columns = [self.visit_column(c, table_obj) for c in table_obj["columns"]]
primary_key = self.visit_primary_key(table_obj.get("primaryKey", []), table_obj)
constraints = [self.visit_constraint(c, table_obj) for c in table_obj.get("constraints", [])]
Expand Down
40 changes: 32 additions & 8 deletions tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@
import os
import unittest
from collections.abc import Mapping, MutableMapping
from typing import Any, Optional
from typing import Any, Optional, cast

import sqlalchemy
import yaml

from felis import DEFAULT_FRAME
from felis.db import sqltypes
from felis.sql import SQLVisitor

TESTDIR = os.path.abspath(os.path.dirname(__file__))
Expand All @@ -53,7 +54,7 @@ def _get_unique_constraint(table: sqlalchemy.schema.Table) -> Optional[sqlalchem

def _get_indices(table: sqlalchemy.schema.Table) -> Mapping[str, sqlalchemy.schema.Index]:
"""Return mapping of table indices indexed by index name."""
return {index.name: index for index in table.indexes}
return {cast(str, index.name): index for index in table.indexes}


class VisitorTestCase(unittest.TestCase):
Expand Down Expand Up @@ -102,8 +103,11 @@ def test_make_metadata(self) -> None:
self.assertCountEqual(table.columns.keys(), ["sdqa_imageStatusId", "statusName", "definition"])
self.assertTrue(table.columns["sdqa_imageStatusId"].primary_key)
self.assertFalse(table.indexes)
for column in table.columns.values():
self.assertIsInstance(column.type, sqlalchemy.types.Variant)
for column, ctype in zip(
table.columns.values(),
(sqlalchemy.types.SMALLINT, sqlalchemy.types.VARCHAR, sqlalchemy.types.VARCHAR),
):
self.assertIsInstance(column.type, (ctype, sqlalchemy.types.Variant))

# Details of sdqa_Metric table.
table = tables["sdqa.sdqa_Metric"]
Expand All @@ -112,8 +116,17 @@ def test_make_metadata(self) -> None:
)
self.assertTrue(table.columns["sdqa_metricId"].primary_key)
self.assertFalse(table.indexes)
for column in table.columns.values():
self.assertIsInstance(column.type, sqlalchemy.types.Variant)
for column, ctype in zip(
table.columns.values(),
(
sqlalchemy.types.SMALLINT,
sqlalchemy.types.VARCHAR,
sqlalchemy.types.VARCHAR,
sqlalchemy.types.CHAR,
sqlalchemy.types.VARCHAR,
),
):
self.assertIsInstance(column.type, (ctype, sqlalchemy.types.Variant))
# It defines a unique constraint.
unique = _get_unique_constraint(table)
assert unique is not None, "Constraint must be defined"
Expand All @@ -134,10 +147,21 @@ def test_make_metadata(self) -> None:
],
)
self.assertTrue(table.columns["sdqa_ratingId"].primary_key)
for column in table.columns.values():
self.assertIsInstance(column.type, sqlalchemy.types.Variant)
for column, ctype in zip(
table.columns.values(),
(
sqlalchemy.types.BIGINT,
sqlalchemy.types.SMALLINT,
sqlalchemy.types.SMALLINT,
sqlalchemy.types.BIGINT,
sqltypes.DOUBLE,
sqltypes.DOUBLE,
),
):
self.assertIsInstance(column.type, (ctype, sqlalchemy.types.Variant))
unique = _get_unique_constraint(table)
self.assertIsNotNone(unique)
assert unique is not None, "Constraint must be defined"
self.assertEqual(unique.name, "UQ_sdqaRatingForAmpVisit_metricId_ampVisitId")
self.assertCountEqual(unique.columns, [table.columns["sdqa_metricId"], table.columns["ampVisitId"]])
# It has a bunch of indices.
Expand Down