Skip to content

Commit

Permalink
Refactor tap module to use the Pydantic data model
Browse files Browse the repository at this point in the history
The tap module is updated to use the Pydantic data model, instead of
the raw YAML. The "load-tap" command in the command line interface
was also updated to use the Pydantic schema object. The PYLD
transformations were removed from the cli function, as they seemed to
be unnecessary. A few changes were made to the Pydantic data model to
support setting of reasonable defaults. A "votable_datatype" annotation
was added to Colum. The Column datatype was changed to return the enum
instead of a string, as this setting seemed to confuse mypy. Tests were
changed to conform to the new tap module and the minor changes to the
data model. Testing showed that the schema was loaded correctly into
a live PostgreSQL database.
  • Loading branch information
JeremyMcCormick committed Apr 19, 2024
1 parent 00bf779 commit 95c017f
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 133 deletions.
54 changes: 16 additions & 38 deletions python/felis/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,28 +203,8 @@ def load_tap(
This command loads the associated TAP metadata from a Felis FILE
to the TAP_SCHEMA tables.
"""
top_level_object = yaml.load(file, Loader=yaml.SafeLoader)
schema_obj: dict
if isinstance(top_level_object, dict):
schema_obj = top_level_object
if "@graph" not in schema_obj:
schema_obj["@type"] = "felis:Schema"
schema_obj["@context"] = DEFAULT_CONTEXT
elif isinstance(top_level_object, list):
schema_obj = {"@context": DEFAULT_CONTEXT, "@graph": top_level_object}
else:
logger.error("Schema object not of recognizable type")
raise click.exceptions.Exit(1)

normalized = _normalize(schema_obj, embed="@always")
if len(normalized["@graph"]) > 1 and (schema_name or catalog_name):
logger.error("--schema-name and --catalog-name incompatible with multiple schemas")
raise click.exceptions.Exit(1)

# Force normalized["@graph"] to a list, which is what happens when there's
# multiple schemas
if isinstance(normalized["@graph"], dict):
normalized["@graph"] = [normalized["@graph"]]
yaml_data = yaml.load(file, Loader=yaml.SafeLoader)
schema = Schema.model_validate(yaml_data)

tap_tables = init_tables(
tap_schema_name,
Expand All @@ -243,28 +223,26 @@ def load_tap(
# In Memory SQLite - Mostly used to test
Tap11Base.metadata.create_all(engine)

for schema in normalized["@graph"]:
tap_visitor = TapLoadingVisitor(
engine,
catalog_name=catalog_name,
schema_name=schema_name,
tap_tables=tap_tables,
)
tap_visitor.visit_schema(schema)
tap_visitor = TapLoadingVisitor(
engine,
catalog_name=catalog_name,
schema_name=schema_name,
tap_tables=tap_tables,
)
tap_visitor.visit_schema(schema)
else:
_insert_dump = InsertDump()
conn = create_mock_engine(make_url(engine_url), executor=_insert_dump.dump, paramstyle="pyformat")
# After the engine is created, update the executor with the dialect
_insert_dump.dialect = conn.dialect

for schema in normalized["@graph"]:
tap_visitor = TapLoadingVisitor.from_mock_connection(
conn,
catalog_name=catalog_name,
schema_name=schema_name,
tap_tables=tap_tables,
)
tap_visitor.visit_schema(schema)
tap_visitor = TapLoadingVisitor.from_mock_connection(
conn,
catalog_name=catalog_name,
schema_name=schema_name,
tap_tables=tap_tables,
)
tap_visitor.visit_schema(schema)


@cli.command("modify-tap")
Expand Down
13 changes: 8 additions & 5 deletions python/felis/datamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,10 @@ class BaseObject(BaseModel):
description: DescriptionStr | None = None
"""A description of the database object."""

@model_validator(mode="after") # type: ignore[arg-type]
votable_utype: str | None = Field(None, alias="votable:utype")
"""The VOTable utype (usage-specific or unique type) of the object."""

@model_validator(mode="after")
@classmethod
def check_description(cls, object: BaseObject, info: ValidationInfo) -> BaseObject:
"""Check that the description is present if required."""
Expand Down Expand Up @@ -222,12 +225,12 @@ class Column(BaseObject):
"""TAP_SCHEMA indication that this column is defined by an IVOA standard.
"""

votable_utype: str | None = Field(None, alias="votable:utype")
"""The VOTable utype (usage-specific or unique type) of the column."""

votable_xtype: str | None = Field(None, alias="votable:xtype")
"""The VOTable xtype (extended type) of the column."""

votable_datatype: str | None = Field(None, alias="votable:datatype")
"""The VOTable datatype of the column."""

@field_validator("ivoa_ucd")
@classmethod
def check_ivoa_ucd(cls, ivoa_ucd: str) -> str:
Expand Down Expand Up @@ -387,7 +390,7 @@ class Table(BaseObject):
primary_key: str | list[str] | None = Field(None, alias="primaryKey")
"""The primary key of the table."""

tap_table_index: int | None = Field(None, alias="tap:table_index")
tap_table_index: int = Field(0, alias="tap:table_index")
"""The IVOA TAP_SCHEMA table index of the table."""

mysql_engine: str | None = Field(None, alias="mysql:engine")
Expand Down
142 changes: 62 additions & 80 deletions python/felis/tap.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
__all__ = ["Tap11Base", "TapLoadingVisitor", "init_tables"]

import logging
from collections.abc import Iterable, Mapping, MutableMapping
from collections.abc import Iterable, MutableMapping
from typing import Any

from sqlalchemy import Column, Integer, String
Expand All @@ -34,14 +34,13 @@
from sqlalchemy.schema import MetaData
from sqlalchemy.sql.expression import Insert, insert

from .check import FelisValidator
from .types import FelisType
from .visitor import Visitor
from felis import datamodel

_Mapping = Mapping[str, Any]
from .datamodel import Constraint, Index, Schema, Table
from .types import FelisType

Tap11Base: Any = declarative_base() # Any to avoid mypy mess with SA 2
logger = logging.getLogger("felis")
logger = logging.getLogger(__name__)

IDENTIFIER_LENGTH = 128
SMALL_FIELD_LENGTH = 32
Expand Down Expand Up @@ -133,7 +132,7 @@ class Tap11KeyColumns(Tap11Base):
)


class TapLoadingVisitor(Visitor[None, tuple, Tap11Base, None, tuple, None, None]):
class TapLoadingVisitor:
"""Felis schema visitor for generating TAP schema.
Parameters
Expand Down Expand Up @@ -161,7 +160,6 @@ def __init__(
self.engine = engine
self._mock_connection: MockConnection | None = None
self.tables = tap_tables or init_tables()
self.checker = FelisValidator()

@classmethod
def from_mock_connection(
Expand All @@ -175,25 +173,22 @@ def from_mock_connection(
visitor._mock_connection = mock_connection
return visitor

def visit_schema(self, schema_obj: _Mapping) -> None:
self.checker.check_schema(schema_obj)
if (version_obj := schema_obj.get("version")) is not None:
self.visit_schema_version(version_obj, schema_obj)
def visit_schema(self, schema_obj: Schema) -> None:
schema = self.tables["schemas"]()
# Override with default
self.schema_name = self.schema_name or schema_obj["name"]
self.schema_name = self.schema_name or schema_obj.name

schema.schema_name = self._schema_name()
schema.description = schema_obj.get("description")
schema.utype = schema_obj.get("votable:utype")
schema.schema_index = int(schema_obj.get("tap:schema_index", 0))
schema.description = schema_obj.description
schema.utype = schema_obj.votable_utype
schema.schema_index = schema_obj.tap_schema_index

if self.engine is not None:
session: Session = sessionmaker(self.engine)()

session.add(schema)

for table_obj in schema_obj["tables"]:
for table_obj in schema_obj.tables:
table, columns = self.visit_table(table_obj, schema_obj)
session.add(table)
session.add_all(columns)
Expand All @@ -202,6 +197,8 @@ def visit_schema(self, schema_obj: _Mapping) -> None:
session.add_all(keys)
session.add_all(key_columns)

logger.debug("Committing TAP schema: %s", schema_obj.name)
logger.debug("TAP tables: %s", len(self.tables))
session.commit()
else:
logger.info("Dry run, not inserting into database")
Expand All @@ -211,7 +208,7 @@ def visit_schema(self, schema_obj: _Mapping) -> None:
conn = self._mock_connection
conn.execute(_insert(self.tables["schemas"], schema))

for table_obj in schema_obj["tables"]:
for table_obj in schema_obj.tables:
table, columns = self.visit_table(table_obj, schema_obj)
conn.execute(_insert(self.tables["tables"], table))
for column in columns:
Expand All @@ -223,56 +220,45 @@ def visit_schema(self, schema_obj: _Mapping) -> None:
for key_column in key_columns:
conn.execute(_insert(self.tables["key_columns"], key_column))

def visit_constraints(self, schema_obj: _Mapping) -> tuple:
def visit_constraints(self, schema_obj: Schema) -> tuple:
all_keys = []
all_key_columns = []
for table_obj in schema_obj["tables"]:
for c in table_obj.get("constraints", []):
key, key_columns = self.visit_constraint(c, table_obj)
for table_obj in schema_obj.tables:
for c in table_obj.constraints:
key, key_columns = self.visit_constraint(c)
if not key:
continue
all_keys.append(key)
all_key_columns += key_columns
return all_keys, all_key_columns

def visit_schema_version(
self, version_obj: str | Mapping[str, Any], schema_obj: Mapping[str, Any]
) -> None:
# Docstring is inherited.

# For now we ignore schema versioning completely, still do some checks.
self.checker.check_schema_version(version_obj, schema_obj)

def visit_table(self, table_obj: _Mapping, schema_obj: _Mapping) -> tuple:
self.checker.check_table(table_obj, schema_obj)
table_id = table_obj["@id"]
def visit_table(self, table_obj: Table, schema_obj: Schema) -> tuple:
table_id = table_obj.id
table = self.tables["tables"]()
table.schema_name = self._schema_name()
table.table_name = self._table_name(table_obj["name"])
table.table_name = self._table_name(table_obj.name)
table.table_type = "table"
table.utype = table_obj.get("votable:utype")
table.description = table_obj.get("description")
table.table_index = int(table_obj.get("tap:table_index", 0))
table.utype = table_obj.votable_utype
table.description = table_obj.description
table.table_index = table_obj.tap_table_index

columns = [self.visit_column(c, table_obj) for c in table_obj["columns"]]
self.visit_primary_key(table_obj.get("primaryKey", []), table_obj)
columns = [self.visit_column(c, table_obj) for c in table_obj.columns]
self.visit_primary_key(table_obj.primary_key, table_obj)

for i in table_obj.get("indexes", []):
for i in table_obj.indexes:
self.visit_index(i, table)

self.graph_index[table_id] = table
return table, columns

def check_column(self, column_obj: _Mapping, table_obj: _Mapping) -> None:
self.checker.check_column(column_obj, table_obj)
_id = column_obj["@id"]
# Guaranteed to exist at this point, for mypy use "" as default
datatype_name = column_obj.get("datatype", "")
felis_type = FelisType.felis_type(datatype_name)
def check_column(self, column_obj: datamodel.Column) -> None:
_id = column_obj.id
datatype_name = column_obj.datatype
felis_type = FelisType.felis_type(datatype_name.value)
if felis_type.is_sized:
# It is expected that both arraysize and length are fine for
# length types.
arraysize = column_obj.get("votable:arraysize", column_obj.get("length"))
arraysize = column_obj.votable_arraysize or column_obj.length
if arraysize is None:
logger.warning(
f"votable:arraysize and length for {_id} are None for type {datatype_name}. "
Expand All @@ -283,55 +269,53 @@ def check_column(self, column_obj: _Mapping, table_obj: _Mapping) -> None:
# datetime types really should have a votable:arraysize, because
# they are converted to strings and the `length` is loosely to the
# string size
if "votable:arraysize" not in column_obj:
if not column_obj.votable_arraysize:
logger.warning(
f"votable:arraysize for {_id} is None for type {datatype_name}. "
f'Using length "*". '
"Consider setting `votable:arraysize` to an appropriate size for "
"materialized datetime/timestamp strings."
)

def visit_column(self, column_obj: _Mapping, table_obj: _Mapping) -> Tap11Base:
self.check_column(column_obj, table_obj)
column_id = column_obj["@id"]
table_name = self._table_name(table_obj["name"])
def visit_column(self, column_obj: datamodel.Column, table_obj: Table) -> Tap11Base:
self.check_column(column_obj)
column_id = column_obj.id
table_name = self._table_name(table_obj.name)

column = self.tables["columns"]()
column.table_name = table_name
column.column_name = column_obj["name"]
column.column_name = column_obj.name

felis_datatype = column_obj["datatype"]
felis_type = FelisType.felis_type(felis_datatype)
column.datatype = column_obj.get("votable:datatype", felis_type.votable_name)
felis_datatype = column_obj.datatype
felis_type = FelisType.felis_type(felis_datatype.value)
column.datatype = column_obj.votable_datatype or felis_type.votable_name

arraysize = None
if felis_type.is_sized:
# prefer votable:arraysize to length, fall back to `*`
arraysize = column_obj.get("votable:arraysize", column_obj.get("length", "*"))
arraysize = column_obj.votable_arraysize or column_obj.length or "*"
if felis_type.is_timestamp:
arraysize = column_obj.get("votable:arraysize", "*")
arraysize = column_obj.votable_arraysize or "*"
column.arraysize = arraysize

column.xtype = column_obj.get("votable:xtype")
column.description = column_obj.get("description")
column.utype = column_obj.get("votable:utype")
column.xtype = column_obj.votable_xtype
column.description = column_obj.description
column.utype = column_obj.votable_utype

unit = column_obj.get("ivoa:unit") or column_obj.get("fits:tunit")
unit = column_obj.ivoa_unit or column_obj.fits_tunit
column.unit = unit
column.ucd = column_obj.get("ivoa:ucd")
column.ucd = column_obj.ivoa_ucd

# We modify this after we process columns
column.indexed = 0

column.principal = column_obj.get("tap:principal", 0)
column.std = column_obj.get("tap:std", 0)
column.column_index = column_obj.get("tap:column_index")
column.principal = column_obj.tap_principal
column.std = column_obj.tap_std
column.column_index = column_obj.tap_column_index

self.graph_index[column_id] = column
return column

def visit_primary_key(self, primary_key_obj: str | Iterable[str], table_obj: _Mapping) -> None:
self.checker.check_primary_key(primary_key_obj, table_obj)
def visit_primary_key(self, primary_key_obj: str | Iterable[str] | None, table_obj: Table) -> None:
if primary_key_obj:
if isinstance(primary_key_obj, str):
primary_key_obj = [primary_key_obj]
Expand All @@ -341,19 +325,18 @@ def visit_primary_key(self, primary_key_obj: str | Iterable[str], table_obj: _Ma
columns[0].indexed = 1
return None

def visit_constraint(self, constraint_obj: _Mapping, table_obj: _Mapping) -> tuple:
self.checker.check_constraint(constraint_obj, table_obj)
constraint_type = constraint_obj["@type"]
def visit_constraint(self, constraint_obj: Constraint) -> tuple:
constraint_type = constraint_obj.type
key = None
key_columns = []
if constraint_type == "ForeignKey":
constraint_name = constraint_obj["name"]
description = constraint_obj.get("description")
utype = constraint_obj.get("votable:utype")
constraint_name = constraint_obj.name
description = constraint_obj.description
utype = constraint_obj.votable_utype

columns = [self.graph_index[col["@id"]] for col in constraint_obj.get("columns", [])]
columns = [self.graph_index[col_id] for col_id in getattr(constraint_obj, "columns", [])]
refcolumns = [
self.graph_index[refcol["@id"]] for refcol in constraint_obj.get("referencedColumns", [])
self.graph_index[refcol_id] for refcol_id in getattr(constraint_obj, "referenced_columns", [])
]

table_name = None
Expand Down Expand Up @@ -386,9 +369,8 @@ def visit_constraint(self, constraint_obj: _Mapping, table_obj: _Mapping) -> tup
key_columns.append(key_column)
return key, key_columns

def visit_index(self, index_obj: _Mapping, table_obj: _Mapping) -> None:
self.checker.check_index(index_obj, table_obj)
columns = [self.graph_index[col["@id"]] for col in index_obj.get("columns", [])]
def visit_index(self, index_obj: Index, table_obj: Table) -> None:
columns = [self.graph_index[col_id] for col_id in getattr(index_obj, "columns", [])]
# if just one column and it's indexed, update the object
if len(columns) == 1:
columns[0].indexed = 1
Expand Down

0 comments on commit 95c017f

Please sign in to comment.