Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-38304: Fix crash in load-tap --dry-run command #19

Merged
merged 2 commits into from
Mar 9, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
43 changes: 28 additions & 15 deletions python/felis/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,6 @@ def load_tap(
if isinstance(normalized["@graph"], dict):
normalized["@graph"] = [normalized["@graph"]]

if not dry_run:
engine = create_engine(engine_url)
else:
_insert_dump = InsertDump()
engine = create_engine(engine_url, strategy="mock", executor=_insert_dump.dump, paramstyle="pyformat")
# After the engine is created, update the executor with the dialect
_insert_dump.dialect = engine.dialect
tap_tables = init_tables(
tap_schema_name,
tap_tables_postfix,
Expand All @@ -174,15 +167,35 @@ def load_tap(
tap_key_columns_table,
)

if engine_url == "sqlite://" and not dry_run:
# In Memory SQLite - Mostly used to test
Tap11Base.metadata.create_all(engine)
if not dry_run:
engine = create_engine(engine_url)

for schema in normalized["@graph"]:
tap_visitor = TapLoadingVisitor(
engine, catalog_name=catalog_name, schema_name=schema_name, mock=dry_run, tap_tables=tap_tables
)
tap_visitor.visit_schema(schema)
if engine_url == "sqlite://" and not dry_run:
# In Memory SQLite - Mostly used to test
Tap11Base.metadata.create_all(engine)

for schema in normalized["@graph"]:
tap_visitor = TapLoadingVisitor(
andy-slac marked this conversation as resolved.
Show resolved Hide resolved
engine,
catalog_name=catalog_name,
schema_name=schema_name,
tap_tables=tap_tables,
)
tap_visitor.visit_schema(schema)
else:
_insert_dump = InsertDump()
conn = create_mock_engine(make_url(engine_url), executor=_insert_dump.dump, paramstyle="pyformat")
# After the engine is created, update the executor with the dialect
_insert_dump.dialect = conn.dialect

for schema in normalized["@graph"]:
tap_visitor = TapLoadingVisitor.from_mock_connection(
conn,
catalog_name=catalog_name,
schema_name=schema_name,
tap_tables=tap_tables,
)
tap_visitor.visit_schema(schema)


@cli.command("modify-tap")
Expand Down
43 changes: 28 additions & 15 deletions python/felis/tap.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

from sqlalchemy import Column, Integer, String
from sqlalchemy.engine import Engine
from sqlalchemy.engine.mock import MockConnection
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.schema import MetaData
Expand Down Expand Up @@ -135,20 +136,31 @@ class Tap11KeyColumns(Tap11Base):
class TapLoadingVisitor(Visitor[None, tuple, Tap11Base, None, tuple, None]):
def __init__(
self,
engine: Engine,
engine: Engine | None,
catalog_name: Optional[str] = None,
schema_name: Optional[str] = None,
mock: bool = False,
tap_tables: Optional[MutableMapping[str, Any]] = None,
):
self.graph_index: MutableMapping[str, Any] = {}
self.catalog_name = catalog_name
self.schema_name = schema_name
self.engine = engine
self.mock = mock
self._mock_connection: MockConnection | None = None
self.tables = tap_tables or init_tables()
self.checker = FelisValidator()

@classmethod
def from_mock_connection(
cls,
mock_connection: MockConnection,
catalog_name: Optional[str] = None,
schema_name: Optional[str] = None,
tap_tables: Optional[MutableMapping[str, Any]] = None,
) -> TapLoadingVisitor:
visitor = cls(engine=None, catalog_name=catalog_name, schema_name=schema_name, tap_tables=tap_tables)
visitor._mock_connection = mock_connection
return visitor

def visit_schema(self, schema_obj: _Mapping) -> None:
self.checker.check_schema(schema_obj)
schema = self.tables["schemas"]()
Expand All @@ -160,7 +172,7 @@ def visit_schema(self, schema_obj: _Mapping) -> None:
schema.utype = schema_obj.get("votable:utype")
schema.schema_index = int(schema_obj.get("tap:schema_index", 0))

if not self.mock:
if self.engine is not None:
session: Session = sessionmaker(self.engine)()
session.add(schema)
for table_obj in schema_obj["tables"]:
Expand All @@ -172,17 +184,18 @@ def visit_schema(self, schema_obj: _Mapping) -> None:
session.commit()
else:
# Only if we are mocking (dry run)
with self.engine.begin() as conn:
conn.execute(_insert(self.tables["schemas"], schema))
for table_obj in schema_obj["tables"]:
table, columns, keys, key_columns = self.visit_table(table_obj, schema_obj)
conn.execute(_insert(self.tables["tables"], table))
for column in columns:
conn.execute(_insert(self.tables["columns"], column))
for key in keys:
conn.execute(_insert(self.tables["keys"], key))
for key_column in key_columns:
conn.execute(_insert(self.tables["key_columns"], key_column))
assert self._mock_connection is not None, "Mock connection must not be None"
conn = self._mock_connection
conn.execute(_insert(self.tables["schemas"], schema))
for table_obj in schema_obj["tables"]:
table, columns, keys, key_columns = self.visit_table(table_obj, schema_obj)
conn.execute(_insert(self.tables["tables"], table))
for column in columns:
conn.execute(_insert(self.tables["columns"], column))
for key in keys:
conn.execute(_insert(self.tables["keys"], key))
for key_column in key_columns:
conn.execute(_insert(self.tables["key_columns"], key_column))

def visit_table(self, table_obj: _Mapping, schema_obj: _Mapping) -> tuple:
self.checker.check_table(table_obj, schema_obj)
Expand Down
11 changes: 11 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,17 @@ def test_load_tap(self) -> None:
result = runner.invoke(cli, ["load-tap", f"--engine-url={url}", TEST_YAML], catch_exceptions=False)
self.assertEqual(result.exit_code, 0)

def test_load_tap_mock(self) -> None:
"""Test for load-tap --dry-run command"""

url = "postgresql+psycopg2://"

runner = CliRunner()
result = runner.invoke(
cli, ["load-tap", f"--engine-url={url}", "--dry-run", TEST_YAML], catch_exceptions=False
)
self.assertEqual(result.exit_code, 0)

def test_modify_tap(self) -> None:
"""Test for modify-tap command"""

Expand Down