diff --git a/dlt/common/configuration/resolve.py b/dlt/common/configuration/resolve.py index ebfa7b6b89..57dcea7546 100644 --- a/dlt/common/configuration/resolve.py +++ b/dlt/common/configuration/resolve.py @@ -286,7 +286,7 @@ def _resolve_config_field( embedded_sections: Tuple[str, ...], accept_partial: bool, ) -> Tuple[Any, List[LookupTrace]]: - inner_hint = extract_inner_hint(hint) + inner_hint = extract_inner_hint(hint, preserve_literal=True) if explicit_value is not None: value = explicit_value diff --git a/dlt/common/configuration/specs/base_configuration.py b/dlt/common/configuration/specs/base_configuration.py index 06fb97fcdd..4b85316d96 100644 --- a/dlt/common/configuration/specs/base_configuration.py +++ b/dlt/common/configuration/specs/base_configuration.py @@ -19,6 +19,7 @@ overload, ClassVar, TypeVar, + Literal, ) from typing_extensions import get_args, get_origin, dataclass_transform from functools import wraps @@ -84,13 +85,18 @@ def is_valid_hint(hint: Type[Any]) -> bool: return False -def extract_inner_hint(hint: Type[Any], preserve_new_types: bool = False) -> Type[Any]: +def extract_inner_hint( + hint: Type[Any], preserve_new_types: bool = False, preserve_literal: bool = False +) -> Type[Any]: # extract hint from Optional / Literal / NewType hints - inner_hint = extract_inner_type(hint, preserve_new_types) + inner_hint = extract_inner_type(hint, preserve_new_types, preserve_literal) # get base configuration from union type inner_hint = get_config_if_union_hint(inner_hint) or inner_hint # extract origin from generic types (ie List[str] -> List) - return get_origin(inner_hint) or inner_hint + origin = get_origin(inner_hint) or inner_hint + if preserve_literal and origin is Literal: + return inner_hint + return origin or inner_hint def is_secret_hint(hint: Type[Any]) -> bool: diff --git a/dlt/common/configuration/utils.py b/dlt/common/configuration/utils.py index 51e6b5615a..8f3c1789ce 100644 --- a/dlt/common/configuration/utils.py +++ b/dlt/common/configuration/utils.py @@ -2,7 +2,20 @@ import ast import contextlib import tomlkit -from typing import Any, Dict, Mapping, NamedTuple, Optional, Tuple, Type, Sequence +from typing import ( + Any, + Dict, + Mapping, + NamedTuple, + Optional, + Tuple, + Type, + Sequence, + get_args, + Literal, + get_origin, + List, +) from collections.abc import Mapping as C_Mapping from dlt.common.json import json @@ -51,25 +64,35 @@ def deserialize_value(key: str, value: Any, hint: Type[TAny]) -> TAny: raise return c # type: ignore + literal_values: Tuple[Any, ...] = () + if get_origin(hint) is Literal: + # Literal fields are validated against the literal values + literal_values = get_args(hint) + hint_origin = type(literal_values[0]) + else: + hint_origin = hint + # coerce value - hint_dt = py_type_to_sc_type(hint) + hint_dt = py_type_to_sc_type(hint_origin) value_dt = py_type_to_sc_type(type(value)) # eval only if value is string and hint is "complex" if value_dt == "text" and hint_dt == "complex": - if hint is tuple: + if hint_origin is tuple: # use literal eval for tuples value = ast.literal_eval(value) else: # use json for sequences and mappings value = json.loads(value) # exact types must match - if not isinstance(value, hint): + if not isinstance(value, hint_origin): raise ValueError(value) else: # for types that are not complex, reuse schema coercion rules if value_dt != hint_dt: value = coerce_value(hint_dt, value_dt, value) + if literal_values and value not in literal_values: + raise ConfigValueCannotBeCoercedException(key, value, hint) return value # type: ignore except ConfigValueCannotBeCoercedException: raise diff --git a/dlt/common/typing.py b/dlt/common/typing.py index 99c2604cdf..1597c0054d 100644 --- a/dlt/common/typing.py +++ b/dlt/common/typing.py @@ -215,7 +215,9 @@ def is_dict_generic_type(t: Type[Any]) -> bool: return False -def extract_inner_type(hint: Type[Any], preserve_new_types: bool = False) -> Type[Any]: +def extract_inner_type( + hint: Type[Any], preserve_new_types: bool = False, preserve_literal: bool = False +) -> Type[Any]: """Gets the inner type from Literal, Optional, Final and NewType Args: @@ -226,15 +228,15 @@ def extract_inner_type(hint: Type[Any], preserve_new_types: bool = False) -> Typ Type[Any]: Inner type if hint was Literal, Optional or NewType, otherwise hint """ if maybe_modified := extract_type_if_modifier(hint): - return extract_inner_type(maybe_modified, preserve_new_types) + return extract_inner_type(maybe_modified, preserve_new_types, preserve_literal) if is_optional_type(hint): - return extract_inner_type(get_args(hint)[0], preserve_new_types) - if is_literal_type(hint): + return extract_inner_type(get_args(hint)[0], preserve_new_types, preserve_literal) + if is_literal_type(hint) and not preserve_literal: # assume that all literals are of the same type return type(get_args(hint)[0]) if is_newtype_type(hint) and not preserve_new_types: # descend into supertypes of NewType - return extract_inner_type(hint.__supertype__, preserve_new_types) + return extract_inner_type(hint.__supertype__, preserve_new_types, preserve_literal) return hint diff --git a/tests/common/configuration/test_configuration.py b/tests/common/configuration/test_configuration.py index 5fbcd86d92..561ab4506c 100644 --- a/tests/common/configuration/test_configuration.py +++ b/tests/common/configuration/test_configuration.py @@ -52,6 +52,7 @@ add_config_dict_to_env, add_config_to_env, ) +from dlt.common.pipeline import TRefreshMode from tests.utils import preserve_environ from tests.common.configuration.utils import ( @@ -239,6 +240,11 @@ def resolve_dynamic_type_field(self) -> Type[Union[int, str]]: return str +@configspec +class ConfigWithLiteralField(BaseConfiguration): + refresh: TRefreshMode = None + + LongInteger = NewType("LongInteger", int) FirstOrderStr = NewType("FirstOrderStr", str) SecondOrderStr = NewType("SecondOrderStr", FirstOrderStr) @@ -1255,3 +1261,20 @@ class EmbeddedConfigurationWithDefaults(BaseConfiguration): c_resolved = resolve.resolve_configuration(c_instance) assert c_resolved.is_resolved() assert c_resolved.conn_str.is_resolved() + + +def test_configuration_with_literal_field(environment: Dict[str, str]) -> None: + """Literal type fields only allow values from the literal""" + environment["REFRESH"] = "not_a_refresh_mode" + + with pytest.raises(ConfigValueCannotBeCoercedException) as einfo: + resolve.resolve_configuration(ConfigWithLiteralField()) + + assert einfo.value.field_name == "refresh" + assert einfo.value.field_value == "not_a_refresh_mode" + assert einfo.value.hint == TRefreshMode + + environment["REFRESH"] = "drop_data" + + spec = resolve.resolve_configuration(ConfigWithLiteralField()) + assert spec.refresh == "drop_data"