Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-41247: Add tool that checks for redundant datatype definitions #50

Merged
merged 19 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
674b2f7
Use global class variable for setting validation flags
JeremyMcCormick Apr 4, 2024
6a89b84
Remove the unnecessary implementation of custom double type
JeremyMcCormick Apr 4, 2024
2ef40f0
Add Python init so unittest functions properly
JeremyMcCormick Apr 4, 2024
98588ce
Implement redundant type checking for MySQL datatype annotations
JeremyMcCormick Apr 4, 2024
892116e
Add test module for datatypes
JeremyMcCormick Apr 4, 2024
7e30331
Compile types with correct dialect so they are comparable
JeremyMcCormick Apr 4, 2024
7a3e74e
Refactor the redundant datatype validation to use a separate function
JeremyMcCormick Apr 4, 2024
cdcd842
Add additional redundant datatype tests
JeremyMcCormick Apr 5, 2024
f43d5f6
Remove debug print outs
JeremyMcCormick Apr 8, 2024
50a820f
Turn off check of redundant datatypes by default
JeremyMcCormick Apr 8, 2024
8a946ea
Populate the ID map for the schema in the post init hook
JeremyMcCormick Apr 8, 2024
4f20fac
Change class name for schema config to Config
JeremyMcCormick Apr 8, 2024
cce9ebe
Add validation option to CLI for checking redundant datatypes
JeremyMcCormick Apr 8, 2024
b63285e
Print a debug message when a valid type override is processed
JeremyMcCormick Apr 8, 2024
56f0bf6
Only print name of schema if not default
JeremyMcCormick Apr 8, 2024
7a42128
Throw an error if length is missing on a column with a sized Felis type
JeremyMcCormick Apr 8, 2024
fb3be4f
Fix mypy errors
JeremyMcCormick Apr 8, 2024
e9623eb
Add debug print of compiled type
JeremyMcCormick Apr 8, 2024
941933f
Decapitalize Felis datatype in variable comment
JeremyMcCormick Apr 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
23 changes: 19 additions & 4 deletions python/felis/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,15 +373,30 @@ def merge(files: Iterable[io.TextIOBase]) -> None:
type=click.Choice(["RSP", "default"]),
default="default",
)
@click.option("-d", "--require-description", is_flag=True, help="Require description for all objects")
@click.option(
"-d", "--require-description", is_flag=True, help="Require description for all objects", default=False
)
@click.option(
"-t", "--check-redundant-datatypes", is_flag=True, help="Check for redundant datatypes", default=False
)
@click.argument("files", nargs=-1, type=click.File())
def validate(schema_name: str, require_description: bool, files: Iterable[io.TextIOBase]) -> None:
def validate(
schema_name: str,
require_description: bool,
check_redundant_datatypes: bool,
files: Iterable[io.TextIOBase],
) -> None:
"""Validate one or more felis YAML files."""
schema_class = get_schema(schema_name)
logger.info(f"Using schema '{schema_class.__name__}'")
if schema_name != "default":
logger.info(f"Using schema '{schema_class.__name__}'")

schema_class.Config.require_description = require_description
if require_description:
Schema.require_description(True)
logger.info("Requiring descriptions for all objects")
schema_class.Config.check_redundant_datatypes = check_redundant_datatypes
if check_redundant_datatypes:
logger.info("Checking for redundant datatypes")

rc = 0
for file in files:
Expand Down
152 changes: 127 additions & 25 deletions python/felis/datamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,22 @@
from __future__ import annotations

import logging
import re
from collections.abc import Mapping, Sequence
from enum import Enum
from typing import Annotated, Any, Literal, TypeAlias

from astropy import units as units # type: ignore
from astropy.io.votable import ucd # type: ignore
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from sqlalchemy import dialects
from sqlalchemy import types as sqa_types
from sqlalchemy.engine import create_mock_engine
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.types import TypeEngine

from .db.sqltypes import get_type_func
from .types import FelisType

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -93,7 +102,7 @@ class BaseObject(BaseModel):
@classmethod
def check_description(cls, values: dict[str, Any]) -> dict[str, Any]:
"""Check that the description is present if required."""
if Schema.is_description_required():
if Schema.Config.require_description:
if "description" not in values or not values["description"]:
raise ValueError("Description is required and must be non-empty")
if len(values["description"].strip()) < DESCR_MIN_LENGTH:
Expand All @@ -119,6 +128,51 @@ class DataType(Enum):
TIMESTAMP = "timestamp"


_DIALECTS = {
"mysql": create_mock_engine("mysql://", executor=None).dialect,
"postgresql": create_mock_engine("postgresql://", executor=None).dialect,
}
"""Dictionary of dialect names to SQLAlchemy dialects."""

_DIALECT_MODULES = {"mysql": getattr(dialects, "mysql"), "postgresql": getattr(dialects, "postgresql")}
"""Dictionary of dialect names to SQLAlchemy dialect modules."""

_DATATYPE_REGEXP = re.compile(r"(\w+)(\((.*)\))?")
"""Regular expression to match data types in the form "type(length)"""


def string_to_typeengine(
type_string: str, dialect: Dialect | None = None, length: int | None = None
) -> TypeEngine:
match = _DATATYPE_REGEXP.search(type_string)
if not match:
raise ValueError(f"Invalid type string: {type_string}")

type_name, _, params = match.groups()
if dialect is None:
type_class = getattr(sqa_types, type_name.upper(), None)
else:
try:
dialect_module = _DIALECT_MODULES[dialect.name]
except KeyError:
raise ValueError(f"Unsupported dialect: {dialect}")
type_class = getattr(dialect_module, type_name.upper(), None)

if not type_class:
raise ValueError(f"Unsupported type: {type_class}")

if params:
params = [int(param) if param.isdigit() else param for param in params.split(",")]
type_obj = type_class(*params)
else:
type_obj = type_class()

if hasattr(type_obj, "length") and getattr(type_obj, "length") is None and length is not None:
type_obj.length = length

return type_obj


class Column(BaseObject):
"""A column in a table."""

Expand Down Expand Up @@ -207,6 +261,56 @@ def check_units(cls, values: dict[str, Any]) -> dict[str, Any]:

return values

@model_validator(mode="before")
@classmethod
def check_redundant_datatypes(cls, values: dict[str, Any]) -> dict[str, Any]:
"""Check for redundant datatypes on columns."""
if not Schema.Config.check_redundant_datatypes:
return values
if all(f"{dialect}:datatype" not in values for dialect in _DIALECTS.keys()):
return values

datatype: str | None = values.get("datatype") or None
if datatype is None:
raise ValueError(f"Datatype must be provided for column '{values['@id']}'")
length: int | None = values.get("length") or None

datatype_func = get_type_func(datatype)
felis_type = FelisType.felis_type(datatype)
if felis_type.is_sized:
if length is not None:
datatype_obj = datatype_func(length)
else:
raise ValueError(
f"Length must be provided for sized type '{datatype}' in column '{values['@id']}'"
)
else:
datatype_obj = datatype_func()

for dialect_name, dialect in _DIALECTS.items():
if f"{dialect_name}:datatype" in values:
datatype_string = values[f"{dialect_name}:datatype"]
db_datatype_obj = string_to_typeengine(datatype_string, dialect, length)
if datatype_obj.compile(dialect) == db_datatype_obj.compile(dialect):
raise ValueError(
"'{}:datatype: {}' is the same as 'datatype: {}' in column '{}'".format(
dialect_name, datatype_string, values["datatype"], values["@id"]
)
)
else:
logger.debug(
"Valid type override of 'datatype: {}' with '{}:datatype: {}' in column '{}'".format(
values["datatype"], dialect_name, datatype_string, values["@id"]
)
)
logger.debug(
"Compiled datatype '{}' with {} compiled override '{}'".format(
datatype_obj.compile(dialect), dialect_name, db_datatype_obj.compile(dialect)
)
)

return values


class Constraint(BaseObject):
"""A database table constraint."""
Expand Down Expand Up @@ -404,15 +508,22 @@ def visit_constraint(self, constraint: Constraint) -> None:
class Schema(BaseObject):
"""The database schema containing the tables."""

class ValidationConfig:
class Config:
"""Validation configuration which is specific to Felis."""

_require_description = False
require_description = False
"""Flag to require a description for all objects.

This is set by the `require_description` class method.
"""

check_redundant_datatypes = False
"""Flag to enable checking for redundant datatypes on columns.

An example would be providing both ``mysql:datatype: DOUBLE`` and
``datatype: double`` as MySQL would have used that type by default.
"""

version: SchemaVersion | str | None = None
"""The version of the schema."""

Expand All @@ -430,21 +541,29 @@ def check_unique_table_names(cls, tables: list[Table]) -> list[Table]:
raise ValueError("Table names must be unique")
return tables

@model_validator(mode="after")
def create_id_map(self: Schema) -> Schema:
"""Create a map of IDs to objects."""
def _create_id_map(self: Schema) -> Schema:
"""Create a map of IDs to objects.

This method should not be called by users. It is called automatically
by the `model_post_init` method. If the ID map is already populated,
this method will return immediately.
"""
if len(self.id_map):
logger.debug("ID map was already populated")
logger.debug("Ignoring call to create_id_map() - ID map was already populated")
return self
visitor: SchemaIdVisitor = SchemaIdVisitor()
visitor.visit_schema(self)
logger.debug(f"ID map contains {len(self.id_map.keys())} objects")
logger.debug(f"Created schema ID map with {len(self.id_map.keys())} objects")
if len(visitor.duplicates):
raise ValueError(
"Duplicate IDs found in schema:\n " + "\n ".join(visitor.duplicates) + "\n"
)
return self

def model_post_init(self, ctx: Any) -> None:
"""Post-initialization hook for the model."""
self._create_id_map()

def __getitem__(self, id: str) -> BaseObject:
"""Get an object by its ID."""
if id not in self:
Expand All @@ -454,20 +573,3 @@ def __getitem__(self, id: str) -> BaseObject:
def __contains__(self, id: str) -> bool:
"""Check if an object with the given ID is in the schema."""
return id in self.id_map

@classmethod
def require_description(cls, rd: bool = True) -> None:
"""Set whether a description is required for all objects.

This includes the schema, tables, columns, and constraints.

Users should call this method to set the requirement for a description
when validating schemas, rather than change the flag value directly.
"""
logger.debug(f"Setting description requirement to '{rd}'")
cls.ValidationConfig._require_description = rd

@classmethod
def is_description_required(cls) -> bool:
"""Return whether a description is required for all objects."""
return cls.ValidationConfig._require_description
25 changes: 10 additions & 15 deletions python/felis/db/sqltypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@

import builtins
from collections.abc import Mapping
from typing import Any
from typing import Any, Callable

from sqlalchemy import Float, SmallInteger, types
from sqlalchemy import SmallInteger, types
from sqlalchemy.dialects import mysql, oracle, postgresql
from sqlalchemy.ext.compiler import compiles

Expand All @@ -39,24 +39,12 @@ class TINYINT(SmallInteger):
__visit_name__ = "TINYINT"


class DOUBLE(Float):
"""The non-standard DOUBLE type."""

__visit_name__ = "DOUBLE"


@compiles(TINYINT)
def compile_tinyint(type_: Any, compiler: Any, **kw: Any) -> str:
"""Return type name for TINYINT."""
return "TINYINT"


@compiles(DOUBLE)
def compile_double(type_: Any, compiler: Any, **kw: Any) -> str:
"""Return type name for double precision type."""
return "DOUBLE"


_TypeMap = Mapping[str, types.TypeEngine | type[types.TypeEngine]]

boolean_map: _TypeMap = {MYSQL: mysql.BIT(1), ORACLE: oracle.NUMBER(1), POSTGRES: postgresql.BOOLEAN()}
Expand Down Expand Up @@ -160,7 +148,7 @@ def float(**kwargs: Any) -> types.TypeEngine:

def double(**kwargs: Any) -> types.TypeEngine:
"""Return SQLAlchemy type for double precision float."""
return _vary(DOUBLE(), double_map, kwargs)
return _vary(types.DOUBLE(), double_map, kwargs)


def char(length: builtins.int, **kwargs: Any) -> types.TypeEngine:
Expand Down Expand Up @@ -193,6 +181,13 @@ def timestamp(**kwargs: Any) -> types.TypeEngine:
return types.TIMESTAMP()


def get_type_func(type_name: str) -> Callable:
"""Return the function for the type with the given name."""
if type_name not in globals():
raise ValueError(f"Unknown type: {type_name}")
return globals()[type_name]


def _vary(
type_: types.TypeEngine,
variant_map: _TypeMap,
Expand Down
Empty file added tests/__init__.py
Empty file.
2 changes: 1 addition & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def test_validate_default_with_require_description(self) -> None:
raise e
finally:
# Turn the flag off so it does not effect subsequent tests.
Schema.require_description(False)
Schema.Config.require_description = False

self.assertEqual(result.exit_code, 0)

Expand Down
15 changes: 7 additions & 8 deletions tests/test_datamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ def test_validation(self) -> None:
class ColumnTestCase(unittest.TestCase):
"""Test the `Column` class."""

def setUp(self) -> None:
"""Set up the test by turning off the description requirement in case
it was turned on by another test.
"""
Schema.Config.require_description = False

def test_validation(self) -> None:
"""Test validation of the `Column` class."""
# Default initialization should throw an exception.
Expand Down Expand Up @@ -124,11 +130,7 @@ def test_validation(self) -> None:
def test_require_description(self) -> None:
"""Test the require_description flag for the `Column` class."""
# Turn on description requirement for this test.
Schema.require_description(True)

# Make sure that setting the flag for description requirement works
# correctly.
self.assertTrue(Schema.is_description_required(), "description should be required")
Schema.Config.require_description = True

# Creating a column without a description when required should throw an
# error.
Expand All @@ -155,9 +157,6 @@ def test_require_description(self) -> None:
with self.assertRaises(ValidationError):
Column(**{"name": "testColumn", "@id": "#test_col_id", "datatype": "string", "description": "xy"})

# Turn off flag or it will affect subsequent tests.
Schema.require_description(False)


class ConstraintTestCase(unittest.TestCase):
"""Test the `UniqueConstraint`, `Index`, `CheckCosntraint`, and
Expand Down