Skip to content

Commit

Permalink
Merge pull request #54 from lsst/tickets/DM-41247-revert
Browse files Browse the repository at this point in the history
DM-41247: Revert "Merge pull request #50 from lsst/tickets/DM-41247"
  • Loading branch information
timj committed Apr 12, 2024
2 parents a32f5f9 + 98f00bc commit a98f79b
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 284 deletions.
23 changes: 4 additions & 19 deletions python/felis/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,30 +373,15 @@ def merge(files: Iterable[io.TextIOBase]) -> None:
type=click.Choice(["RSP", "default"]),
default="default",
)
@click.option(
"-d", "--require-description", is_flag=True, help="Require description for all objects", default=False
)
@click.option(
"-t", "--check-redundant-datatypes", is_flag=True, help="Check for redundant datatypes", default=False
)
@click.option("-d", "--require-description", is_flag=True, help="Require description for all objects")
@click.argument("files", nargs=-1, type=click.File())
def validate(
schema_name: str,
require_description: bool,
check_redundant_datatypes: bool,
files: Iterable[io.TextIOBase],
) -> None:
def validate(schema_name: str, require_description: bool, files: Iterable[io.TextIOBase]) -> None:
"""Validate one or more felis YAML files."""
schema_class = get_schema(schema_name)
if schema_name != "default":
logger.info(f"Using schema '{schema_class.__name__}'")
logger.info(f"Using schema '{schema_class.__name__}'")

schema_class.Config.require_description = require_description
if require_description:
logger.info("Requiring descriptions for all objects")
schema_class.Config.check_redundant_datatypes = check_redundant_datatypes
if check_redundant_datatypes:
logger.info("Checking for redundant datatypes")
Schema.require_description(True)

rc = 0
for file in files:
Expand Down
152 changes: 25 additions & 127 deletions python/felis/datamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,13 @@
from __future__ import annotations

import logging
import re
from collections.abc import Mapping, Sequence
from enum import Enum
from typing import Annotated, Any, Literal, TypeAlias

from astropy import units as units # type: ignore
from astropy.io.votable import ucd # type: ignore
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from sqlalchemy import dialects
from sqlalchemy import types as sqa_types
from sqlalchemy.engine import create_mock_engine
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.types import TypeEngine

from .db.sqltypes import get_type_func
from .types import FelisType

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -102,7 +93,7 @@ class BaseObject(BaseModel):
@classmethod
def check_description(cls, values: dict[str, Any]) -> dict[str, Any]:
"""Check that the description is present if required."""
if Schema.Config.require_description:
if Schema.is_description_required():
if "description" not in values or not values["description"]:
raise ValueError("Description is required and must be non-empty")
if len(values["description"].strip()) < DESCR_MIN_LENGTH:
Expand All @@ -128,51 +119,6 @@ class DataType(Enum):
TIMESTAMP = "timestamp"


_DIALECTS = {
"mysql": create_mock_engine("mysql://", executor=None).dialect,
"postgresql": create_mock_engine("postgresql://", executor=None).dialect,
}
"""Dictionary of dialect names to SQLAlchemy dialects."""

_DIALECT_MODULES = {"mysql": getattr(dialects, "mysql"), "postgresql": getattr(dialects, "postgresql")}
"""Dictionary of dialect names to SQLAlchemy dialect modules."""

_DATATYPE_REGEXP = re.compile(r"(\w+)(\((.*)\))?")
"""Regular expression to match data types in the form "type(length)"""


def string_to_typeengine(
type_string: str, dialect: Dialect | None = None, length: int | None = None
) -> TypeEngine:
match = _DATATYPE_REGEXP.search(type_string)
if not match:
raise ValueError(f"Invalid type string: {type_string}")

type_name, _, params = match.groups()
if dialect is None:
type_class = getattr(sqa_types, type_name.upper(), None)
else:
try:
dialect_module = _DIALECT_MODULES[dialect.name]
except KeyError:
raise ValueError(f"Unsupported dialect: {dialect}")
type_class = getattr(dialect_module, type_name.upper(), None)

if not type_class:
raise ValueError(f"Unsupported type: {type_class}")

if params:
params = [int(param) if param.isdigit() else param for param in params.split(",")]
type_obj = type_class(*params)
else:
type_obj = type_class()

if hasattr(type_obj, "length") and getattr(type_obj, "length") is None and length is not None:
type_obj.length = length

return type_obj


class Column(BaseObject):
"""A column in a table."""

Expand Down Expand Up @@ -261,56 +207,6 @@ def check_units(cls, values: dict[str, Any]) -> dict[str, Any]:

return values

@model_validator(mode="before")
@classmethod
def check_redundant_datatypes(cls, values: dict[str, Any]) -> dict[str, Any]:
"""Check for redundant datatypes on columns."""
if not Schema.Config.check_redundant_datatypes:
return values
if all(f"{dialect}:datatype" not in values for dialect in _DIALECTS.keys()):
return values

datatype: str | None = values.get("datatype") or None
if datatype is None:
raise ValueError(f"Datatype must be provided for column '{values['@id']}'")
length: int | None = values.get("length") or None

datatype_func = get_type_func(datatype)
felis_type = FelisType.felis_type(datatype)
if felis_type.is_sized:
if length is not None:
datatype_obj = datatype_func(length)
else:
raise ValueError(
f"Length must be provided for sized type '{datatype}' in column '{values['@id']}'"
)
else:
datatype_obj = datatype_func()

for dialect_name, dialect in _DIALECTS.items():
if f"{dialect_name}:datatype" in values:
datatype_string = values[f"{dialect_name}:datatype"]
db_datatype_obj = string_to_typeengine(datatype_string, dialect, length)
if datatype_obj.compile(dialect) == db_datatype_obj.compile(dialect):
raise ValueError(
"'{}:datatype: {}' is the same as 'datatype: {}' in column '{}'".format(
dialect_name, datatype_string, values["datatype"], values["@id"]
)
)
else:
logger.debug(
"Valid type override of 'datatype: {}' with '{}:datatype: {}' in column '{}'".format(
values["datatype"], dialect_name, datatype_string, values["@id"]
)
)
logger.debug(
"Compiled datatype '{}' with {} compiled override '{}'".format(
datatype_obj.compile(dialect), dialect_name, db_datatype_obj.compile(dialect)
)
)

return values


class Constraint(BaseObject):
"""A database table constraint."""
Expand Down Expand Up @@ -508,22 +404,15 @@ def visit_constraint(self, constraint: Constraint) -> None:
class Schema(BaseObject):
"""The database schema containing the tables."""

class Config:
class ValidationConfig:
"""Validation configuration which is specific to Felis."""

require_description = False
_require_description = False
"""Flag to require a description for all objects.
This is set by the `require_description` class method.
"""

check_redundant_datatypes = False
"""Flag to enable checking for redundant datatypes on columns.
An example would be providing both ``mysql:datatype: DOUBLE`` and
``datatype: double`` as MySQL would have used that type by default.
"""

version: SchemaVersion | str | None = None
"""The version of the schema."""

Expand All @@ -541,29 +430,21 @@ def check_unique_table_names(cls, tables: list[Table]) -> list[Table]:
raise ValueError("Table names must be unique")
return tables

def _create_id_map(self: Schema) -> Schema:
"""Create a map of IDs to objects.
This method should not be called by users. It is called automatically
by the `model_post_init` method. If the ID map is already populated,
this method will return immediately.
"""
@model_validator(mode="after")
def create_id_map(self: Schema) -> Schema:
"""Create a map of IDs to objects."""
if len(self.id_map):
logger.debug("Ignoring call to create_id_map() - ID map was already populated")
logger.debug("ID map was already populated")
return self
visitor: SchemaIdVisitor = SchemaIdVisitor()
visitor.visit_schema(self)
logger.debug(f"Created schema ID map with {len(self.id_map.keys())} objects")
logger.debug(f"ID map contains {len(self.id_map.keys())} objects")
if len(visitor.duplicates):
raise ValueError(
"Duplicate IDs found in schema:\n " + "\n ".join(visitor.duplicates) + "\n"
)
return self

def model_post_init(self, ctx: Any) -> None:
"""Post-initialization hook for the model."""
self._create_id_map()

def __getitem__(self, id: str) -> BaseObject:
"""Get an object by its ID."""
if id not in self:
Expand All @@ -573,3 +454,20 @@ def __getitem__(self, id: str) -> BaseObject:
def __contains__(self, id: str) -> bool:
"""Check if an object with the given ID is in the schema."""
return id in self.id_map

@classmethod
def require_description(cls, rd: bool = True) -> None:
"""Set whether a description is required for all objects.
This includes the schema, tables, columns, and constraints.
Users should call this method to set the requirement for a description
when validating schemas, rather than change the flag value directly.
"""
logger.debug(f"Setting description requirement to '{rd}'")
cls.ValidationConfig._require_description = rd

@classmethod
def is_description_required(cls) -> bool:
"""Return whether a description is required for all objects."""
return cls.ValidationConfig._require_description
25 changes: 15 additions & 10 deletions python/felis/db/sqltypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@

import builtins
from collections.abc import Mapping
from typing import Any, Callable
from typing import Any

from sqlalchemy import SmallInteger, types
from sqlalchemy import Float, SmallInteger, types
from sqlalchemy.dialects import mysql, oracle, postgresql
from sqlalchemy.ext.compiler import compiles

Expand All @@ -39,12 +39,24 @@ class TINYINT(SmallInteger):
__visit_name__ = "TINYINT"


class DOUBLE(Float):
"""The non-standard DOUBLE type."""

__visit_name__ = "DOUBLE"


@compiles(TINYINT)
def compile_tinyint(type_: Any, compiler: Any, **kw: Any) -> str:
"""Return type name for TINYINT."""
return "TINYINT"


@compiles(DOUBLE)
def compile_double(type_: Any, compiler: Any, **kw: Any) -> str:
"""Return type name for double precision type."""
return "DOUBLE"


_TypeMap = Mapping[str, types.TypeEngine | type[types.TypeEngine]]

boolean_map: _TypeMap = {MYSQL: mysql.BIT(1), ORACLE: oracle.NUMBER(1), POSTGRES: postgresql.BOOLEAN()}
Expand Down Expand Up @@ -148,7 +160,7 @@ def float(**kwargs: Any) -> types.TypeEngine:

def double(**kwargs: Any) -> types.TypeEngine:
"""Return SQLAlchemy type for double precision float."""
return _vary(types.DOUBLE(), double_map, kwargs)
return _vary(DOUBLE(), double_map, kwargs)


def char(length: builtins.int, **kwargs: Any) -> types.TypeEngine:
Expand Down Expand Up @@ -181,13 +193,6 @@ def timestamp(**kwargs: Any) -> types.TypeEngine:
return types.TIMESTAMP()


def get_type_func(type_name: str) -> Callable:
"""Return the function for the type with the given name."""
if type_name not in globals():
raise ValueError(f"Unknown type: {type_name}")
return globals()[type_name]


def _vary(
type_: types.TypeEngine,
variant_map: _TypeMap,
Expand Down
Empty file removed tests/__init__.py
Empty file.
2 changes: 1 addition & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def test_validate_default_with_require_description(self) -> None:
raise e
finally:
# Turn the flag off so it does not effect subsequent tests.
Schema.Config.require_description = False
Schema.require_description(False)

self.assertEqual(result.exit_code, 0)

Expand Down
15 changes: 8 additions & 7 deletions tests/test_datamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,6 @@ def test_validation(self) -> None:
class ColumnTestCase(unittest.TestCase):
"""Test the `Column` class."""

def setUp(self) -> None:
"""Set up the test by turning off the description requirement in case
it was turned on by another test.
"""
Schema.Config.require_description = False

def test_validation(self) -> None:
"""Test validation of the `Column` class."""
# Default initialization should throw an exception.
Expand Down Expand Up @@ -130,7 +124,11 @@ def test_validation(self) -> None:
def test_require_description(self) -> None:
"""Test the require_description flag for the `Column` class."""
# Turn on description requirement for this test.
Schema.Config.require_description = True
Schema.require_description(True)

# Make sure that setting the flag for description requirement works
# correctly.
self.assertTrue(Schema.is_description_required(), "description should be required")

# Creating a column without a description when required should throw an
# error.
Expand All @@ -157,6 +155,9 @@ def test_require_description(self) -> None:
with self.assertRaises(ValidationError):
Column(**{"name": "testColumn", "@id": "#test_col_id", "datatype": "string", "description": "xy"})

# Turn off flag or it will affect subsequent tests.
Schema.require_description(False)


class ConstraintTestCase(unittest.TestCase):
"""Test the `UniqueConstraint`, `Index`, `CheckCosntraint`, and
Expand Down

0 comments on commit a98f79b

Please sign in to comment.