Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-43998: Properly validate default values and fix their handling in SQL metadata #61

Merged
merged 3 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
54 changes: 37 additions & 17 deletions python/felis/datamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from sqlalchemy.types import TypeEngine

from .db.sqltypes import get_type_func
from .types import FelisType
from .types import Boolean, Byte, Char, Double, FelisType, Float, Int, Long, Short, String, Text, Unicode

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -184,7 +184,7 @@ class Column(BaseObject):
nullable: bool = True
"""Whether the column can be ``NULL``."""

value: Any = None
value: str | int | float | bool | None = None
"""The default value of the column."""

autoincrement: bool | None = None
Expand Down Expand Up @@ -225,6 +225,27 @@ class Column(BaseObject):
votable_datatype: str | None = Field(None, alias="votable:datatype")
"""The VOTable datatype of the column."""

@model_validator(mode="after")
def check_value(self) -> Column:
"""Check that the default value is valid."""
if (value := self.value) is not None:
if value is not None and self.autoincrement is True:
raise ValueError("Column cannot have both a default value and be autoincremented")
felis_type = FelisType.felis_type(self.datatype)
if felis_type.is_numeric:
if felis_type in (Byte, Short, Int, Long) and not isinstance(value, int):
raise ValueError("Default value must be an int for integer type columns")
elif felis_type in (Float, Double) and not isinstance(value, float):
raise ValueError("Default value must be a decimal number for float and double columns")
elif felis_type in (String, Char, Unicode, Text):
if not isinstance(value, str):
raise ValueError("Default value must be a string for string columns")
if not len(value):
raise ValueError("Default value must be a non-empty string for string columns")
elif felis_type is Boolean and not isinstance(value, bool):
raise ValueError("Default value must be a boolean for boolean columns")
return self

@field_validator("ivoa_ucd")
@classmethod
def check_ivoa_ucd(cls, ivoa_ucd: str) -> str:
Expand Down Expand Up @@ -255,51 +276,50 @@ def check_units(cls, values: dict[str, Any]) -> dict[str, Any]:

return values

@model_validator(mode="after") # type: ignore[arg-type]
@classmethod
def validate_datatypes(cls, col: Column, info: ValidationInfo) -> Column:
@model_validator(mode="after")
JeremyMcCormick marked this conversation as resolved.
Show resolved Hide resolved
def check_datatypes(self, info: ValidationInfo) -> Column:
"""Check for redundant datatypes on columns."""
context = info.context
if not context or not context.get("check_redundant_datatypes", False):
return col
if all(getattr(col, f"{dialect}:datatype", None) is not None for dialect in _DIALECTS.keys()):
return col
return self
if all(getattr(self, f"{dialect}:datatype", None) is not None for dialect in _DIALECTS.keys()):
return self

datatype = col.datatype
length: int | None = col.length or None
datatype = self.datatype
length: int | None = self.length or None

datatype_func = get_type_func(datatype)
felis_type = FelisType.felis_type(datatype)
if felis_type.is_sized:
if length is not None:
datatype_obj = datatype_func(length)
else:
raise ValueError(f"Length must be provided for sized type '{datatype}' in column '{col.id}'")
raise ValueError(f"Length must be provided for sized type '{datatype}' in column '{self.id}'")
else:
datatype_obj = datatype_func()

for dialect_name, dialect in _DIALECTS.items():
db_annotation = f"{dialect_name}_datatype"
if datatype_string := col.model_dump().get(db_annotation):
if datatype_string := self.model_dump().get(db_annotation):
db_datatype_obj = string_to_typeengine(datatype_string, dialect, length)
if datatype_obj.compile(dialect) == db_datatype_obj.compile(dialect):
raise ValueError(
"'{}: {}' is a redundant override of 'datatype: {}' in column '{}'{}".format(
db_annotation,
datatype_string,
col.datatype,
col.id,
self.datatype,
self.id,
"" if length is None else f" with length {length}",
)
)
else:
logger.debug(
f"Type override of 'datatype: {col.datatype}' "
f"with '{db_annotation}: {datatype_string}' in column '{col.id}' "
f"Type override of 'datatype: {self.datatype}' "
f"with '{db_annotation}: {datatype_string}' in column '{self.id}' "
f"compiled to '{datatype_obj.compile(dialect)}' and "
f"'{db_datatype_obj.compile(dialect)}'"
)
return col
return self


class Constraint(BaseObject):
Expand Down
21 changes: 18 additions & 3 deletions python/felis/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
PrimaryKeyConstraint,
ResultProxy,
Table,
TextClause,
UniqueConstraint,
create_mock_engine,
make_url,
Expand Down Expand Up @@ -134,6 +135,9 @@ def get_datatype_with_variants(column_obj: datamodel.Column) -> TypeEngine:
return datatype


_VALID_SERVER_DEFAULTS = ("CURRENT_TIMESTAMP", "NOW()", "LOCALTIMESTAMP", "NULL")


class MetaDataBuilder:
"""A class for building a `MetaData` object from a Felis `Schema`."""

Expand Down Expand Up @@ -263,7 +267,7 @@ def build_column(self, column_obj: datamodel.Column) -> Column:
name = column_obj.name
id = column_obj.id
description = column_obj.description
default = column_obj.value
value = column_obj.value
nullable = column_obj.nullable

# Get datatype, handling variant overrides such as "mysql:datatype".
Expand All @@ -274,13 +278,24 @@ def build_column(self, column_obj: datamodel.Column) -> Column:
column_obj.autoincrement if column_obj.autoincrement is not None else "auto"
)

server_default: str | TextClause | None = None
if value is not None:
server_default = str(value)
if server_default in _VALID_SERVER_DEFAULTS or not isinstance(value, str):
# If the server default is a valid keyword or not a string,
# use it as is.
server_default = text(server_default)

if server_default is not None:
logger.debug(f"Column '{id}' has default value: {server_default}")

column: Column = Column(
name,
datatype,
comment=description,
autoincrement=autoincrement,
nullable=nullable,
server_default=default,
server_default=server_default,
)

self._objects[id] = column
Expand Down Expand Up @@ -469,7 +484,7 @@ def drop_if_exists(self) -> None:
self.connection.execute(text(f"DROP DATABASE IF EXISTS {schema_name}"))
elif db_type == "postgresql":
logger.info(f"Dropping PostgreSQL schema if exists: {schema_name}")
self.connection.execute(sqa_schema.DropSchema(schema_name, if_exists=True))
self.connection.execute(sqa_schema.DropSchema(schema_name, if_exists=True, cascade=True))
else:
raise ValueError(f"Unsupported database type: {db_type}")
except SQLAlchemyError as e:
Expand Down
13 changes: 13 additions & 0 deletions tests/data/sales.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ tables:
"@id": "#orders.order_date"
datatype: timestamp
description: Order date
value: CURRENT_TIMESTAMP
- name: note
"@id": "#orders.note"
description: Order note
datatype: string
length: 256
constraints:
- name: fk_customer_id
"@id": "#orders_fk_customer_id"
Expand Down Expand Up @@ -82,6 +88,13 @@ tables:
"@id": "#items.quantity"
datatype: int
description: Quantity ordered
value: 1
- name: note
"@id": "#items.note"
description: Item note
datatype: string
length: 256
value: "No note"
constraints:
- name: non_negative_quantity
"@id": "#items_non_negative_quantity"
Expand Down
2 changes: 1 addition & 1 deletion tests/data/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ tables:
"@id": "#sdqa_Threshold.createdDate"
datatype: timestamp
description: Database timestamp when the record is inserted.
# value: CURRENT_TIMESTAMP ## DM-43312: This causes an error due to quoting.
value: CURRENT_TIMESTAMP
mysql:datatype: TIMESTAMP
primaryKey: "#sdqa_Threshold.sdqa_thresholdId"
indexes:
Expand Down
80 changes: 78 additions & 2 deletions tests/test_datamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import os
import unittest
from collections import defaultdict

import yaml
from pydantic import ValidationError
Expand Down Expand Up @@ -168,7 +169,6 @@ def _check_description(col: Column):
"name": "testColumn",
"@id": "#test_col_id",
"datatype": "string",
"require_description": True,
"description": "",
}
)
Expand All @@ -182,12 +182,88 @@ def _check_description(col: Column):
"name": "testColumn",
"@id": "#test_col_id",
"datatype": "string",
"require_description": True,
"description": "xy",
}
)
)

def test_values(self):
"""Test the `value` field of the `Column` class."""

# Define a function to return the default column data
def default_coldata():
return defaultdict(str, {"name": "testColumn", "@id": "#test_col_id"})

# Setting both value and autoincrement should throw.
autoincr_coldata = default_coldata()
autoincr_coldata["datatype"] = "int"
autoincr_coldata["autoincrement"] = True
autoincr_coldata["value"] = 1
with self.assertRaises(ValueError):
Column(**autoincr_coldata)

# Setting an invalid default on a column with an integer type should
# throw.
bad_numeric_coldata = default_coldata()
for datatype in ["int", "long", "short", "byte"]:
for value in ["bad", "1.0", "1", 1.1]:
bad_numeric_coldata["datatype"] = datatype
bad_numeric_coldata["value"] = value
with self.assertRaises(ValueError):
Column(**bad_numeric_coldata)

# Setting an invalid default on a column with a decimal type should
# throw.
bad_numeric_coldata = default_coldata()
for datatype in ["double", "float"]:
for value in ["bad", "1.0", "1", 1]:
bad_numeric_coldata["datatype"] = datatype
bad_numeric_coldata["value"] = value
with self.assertRaises(ValueError):
Column(**bad_numeric_coldata)

# Setting a bad default on a string column should throw.
bad_str_coldata = default_coldata()
bad_str_coldata["value"] = 1
bad_str_coldata["length"] = 256
for datatype in ["string", "char", "unicode", "text"]:
for value in [1, 1.1, True, "", " ", " ", "\n", "\t"]:
bad_str_coldata["datatype"] = datatype
bad_str_coldata["value"] = value
with self.assertRaises(ValueError):
Column(**bad_str_coldata)

# Setting a non-boolean value on a boolean column should throw.
bool_coldata = default_coldata()
bool_coldata["datatype"] = "boolean"
bool_coldata["value"] = "bad"
with self.assertRaises(ValueError):
for value in ["bad", 1, 1.1]:
bool_coldata["value"] = value
Column(**bool_coldata)

# Setting a valid value on a string column should be okay.
str_coldata = default_coldata()
str_coldata["value"] = 1
str_coldata["length"] = 256
str_coldata["value"] = "okay"
for datatype in ["string", "char", "unicode", "text"]:
str_coldata["datatype"] = datatype
Column(**str_coldata)

# Setting an integer value on a column with an int type should be okay.
int_coldata = default_coldata()
int_coldata["value"] = 1
for datatype in ["int", "long", "short", "byte"]:
int_coldata["datatype"] = datatype
Column(**int_coldata)

# Setting a decimal value on a column with a float type should be okay.
bool_coldata = default_coldata()
bool_coldata["datatype"] = "boolean"
bool_coldata["value"] = True
Column(**bool_coldata)


class ConstraintTestCase(unittest.TestCase):
"""Test the `UniqueConstraint`, `Index`, `CheckConstraint`, and
Expand Down