Skip to content

Commit

Permalink
Validate literals in configspec
Browse files Browse the repository at this point in the history
  • Loading branch information
steinitzu committed Apr 19, 2024
1 parent ac3baa5 commit 7120835
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 13 deletions.
2 changes: 1 addition & 1 deletion dlt/common/configuration/resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def _resolve_config_field(
embedded_sections: Tuple[str, ...],
accept_partial: bool,
) -> Tuple[Any, List[LookupTrace]]:
inner_hint = extract_inner_hint(hint)
inner_hint = extract_inner_hint(hint, preserve_literal=True)

if explicit_value is not None:
value = explicit_value
Expand Down
12 changes: 9 additions & 3 deletions dlt/common/configuration/specs/base_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
overload,
ClassVar,
TypeVar,
Literal,
)
from typing_extensions import get_args, get_origin, dataclass_transform
from functools import wraps
Expand Down Expand Up @@ -84,13 +85,18 @@ def is_valid_hint(hint: Type[Any]) -> bool:
return False


def extract_inner_hint(hint: Type[Any], preserve_new_types: bool = False) -> Type[Any]:
def extract_inner_hint(
hint: Type[Any], preserve_new_types: bool = False, preserve_literal: bool = False
) -> Type[Any]:
# extract hint from Optional / Literal / NewType hints
inner_hint = extract_inner_type(hint, preserve_new_types)
inner_hint = extract_inner_type(hint, preserve_new_types, preserve_literal)
# get base configuration from union type
inner_hint = get_config_if_union_hint(inner_hint) or inner_hint
# extract origin from generic types (ie List[str] -> List)
return get_origin(inner_hint) or inner_hint
origin = get_origin(inner_hint) or inner_hint
if preserve_literal and origin is Literal:
return inner_hint
return origin or inner_hint


def is_secret_hint(hint: Type[Any]) -> bool:
Expand Down
31 changes: 27 additions & 4 deletions dlt/common/configuration/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,20 @@
import ast
import contextlib
import tomlkit
from typing import Any, Dict, Mapping, NamedTuple, Optional, Tuple, Type, Sequence
from typing import (
Any,
Dict,
Mapping,
NamedTuple,
Optional,
Tuple,
Type,
Sequence,
get_args,
Literal,
get_origin,
List,
)
from collections.abc import Mapping as C_Mapping

from dlt.common.json import json
Expand Down Expand Up @@ -51,25 +64,35 @@ def deserialize_value(key: str, value: Any, hint: Type[TAny]) -> TAny:
raise
return c # type: ignore

literal_values: Tuple[Any, ...] = ()
if get_origin(hint) is Literal:
# Literal fields are validated against the literal values
literal_values = get_args(hint)
hint_origin = type(literal_values[0])
else:
hint_origin = hint

# coerce value
hint_dt = py_type_to_sc_type(hint)
hint_dt = py_type_to_sc_type(hint_origin)
value_dt = py_type_to_sc_type(type(value))

# eval only if value is string and hint is "complex"
if value_dt == "text" and hint_dt == "complex":
if hint is tuple:
if hint_origin is tuple:
# use literal eval for tuples
value = ast.literal_eval(value)
else:
# use json for sequences and mappings
value = json.loads(value)
# exact types must match
if not isinstance(value, hint):
if not isinstance(value, hint_origin):
raise ValueError(value)
else:
# for types that are not complex, reuse schema coercion rules
if value_dt != hint_dt:
value = coerce_value(hint_dt, value_dt, value)
if literal_values and value not in literal_values:
raise ConfigValueCannotBeCoercedException(key, value, hint)
return value # type: ignore
except ConfigValueCannotBeCoercedException:
raise
Expand Down
12 changes: 7 additions & 5 deletions dlt/common/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,9 @@ def is_dict_generic_type(t: Type[Any]) -> bool:
return False


def extract_inner_type(hint: Type[Any], preserve_new_types: bool = False) -> Type[Any]:
def extract_inner_type(
hint: Type[Any], preserve_new_types: bool = False, preserve_literal: bool = False
) -> Type[Any]:
"""Gets the inner type from Literal, Optional, Final and NewType
Args:
Expand All @@ -226,15 +228,15 @@ def extract_inner_type(hint: Type[Any], preserve_new_types: bool = False) -> Typ
Type[Any]: Inner type if hint was Literal, Optional or NewType, otherwise hint
"""
if maybe_modified := extract_type_if_modifier(hint):
return extract_inner_type(maybe_modified, preserve_new_types)
return extract_inner_type(maybe_modified, preserve_new_types, preserve_literal)
if is_optional_type(hint):
return extract_inner_type(get_args(hint)[0], preserve_new_types)
if is_literal_type(hint):
return extract_inner_type(get_args(hint)[0], preserve_new_types, preserve_literal)
if is_literal_type(hint) and not preserve_literal:
# assume that all literals are of the same type
return type(get_args(hint)[0])
if is_newtype_type(hint) and not preserve_new_types:
# descend into supertypes of NewType
return extract_inner_type(hint.__supertype__, preserve_new_types)
return extract_inner_type(hint.__supertype__, preserve_new_types, preserve_literal)
return hint


Expand Down
23 changes: 23 additions & 0 deletions tests/common/configuration/test_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
add_config_dict_to_env,
add_config_to_env,
)
from dlt.common.pipeline import TRefreshMode

from tests.utils import preserve_environ
from tests.common.configuration.utils import (
Expand Down Expand Up @@ -239,6 +240,11 @@ def resolve_dynamic_type_field(self) -> Type[Union[int, str]]:
return str


@configspec
class ConfigWithLiteralField(BaseConfiguration):
refresh: TRefreshMode = None


LongInteger = NewType("LongInteger", int)
FirstOrderStr = NewType("FirstOrderStr", str)
SecondOrderStr = NewType("SecondOrderStr", FirstOrderStr)
Expand Down Expand Up @@ -1255,3 +1261,20 @@ class EmbeddedConfigurationWithDefaults(BaseConfiguration):
c_resolved = resolve.resolve_configuration(c_instance)
assert c_resolved.is_resolved()
assert c_resolved.conn_str.is_resolved()


def test_configuration_with_literal_field(environment: Dict[str, str]) -> None:
"""Literal type fields only allow values from the literal"""
environment["REFRESH"] = "not_a_refresh_mode"

with pytest.raises(ConfigValueCannotBeCoercedException) as einfo:
resolve.resolve_configuration(ConfigWithLiteralField())

assert einfo.value.field_name == "refresh"
assert einfo.value.field_value == "not_a_refresh_mode"
assert einfo.value.hint == TRefreshMode

environment["REFRESH"] = "drop_data"

spec = resolve.resolve_configuration(ConfigWithLiteralField())
assert spec.refresh == "drop_data"

0 comments on commit 7120835

Please sign in to comment.