Skip to content

Commit

Permalink
Merge 71f05ab into 55051de
Browse files Browse the repository at this point in the history
  • Loading branch information
konradhalas committed May 12, 2019
2 parents 55051de + 71f05ab commit 082af61
Show file tree
Hide file tree
Showing 8 changed files with 99 additions and 764 deletions.
98 changes: 3 additions & 95 deletions dacite/config.py
Original file line number Diff line number Diff line change
@@ -1,101 +1,9 @@
from dataclasses import dataclass, field as dc_field, fields, Field
from typing import Dict, Any, Callable, List, Optional, Type

from dacite.data import Data
from dacite.dataclasses import has_field_default_value
from dacite.exceptions import InvalidConfigurationError
from dacite.types import cast_value


class ValueNotFoundError(Exception):
pass
from dataclasses import dataclass, field as dc_field
from typing import Dict, Any, Callable, Optional, Type


@dataclass
class Config:
remap: Dict[str, str] = dc_field(default_factory=dict)
prefixed: Dict[str, str] = dc_field(default_factory=dict)
cast: List[str] = dc_field(default_factory=list)
transform: Dict[str, Callable[[Any], Any]] = dc_field(default_factory=dict)
flattened: List[str] = dc_field(default_factory=list)
type_hooks: Dict[Type, Callable[[Any], Any]] = dc_field(default_factory=dict)
forward_references: Optional[Dict[str, Any]] = None
check_types: bool = True

def validate(self, data_class: Type, data: Data) -> None:
self._validate_field_name(data_class, "remap")
self._validate_data_key(data_class, data, "remap")
self._validate_field_name(data_class, "prefixed")
self._validate_data_key(data_class, data, "prefixed", lambda v, c: any(n.startswith(v) for n in c))
self._validate_field_name(data_class, "cast")
self._validate_field_name(data_class, "transform")
self._validate_field_name(data_class, "flattened")

def make_inner(self, field: Field) -> "Config":
return Config(
remap=self._extract_nested_dict(field, self.remap),
prefixed=self._extract_nested_dict(field, self.prefixed),
cast=self._extract_nested_list(field, self.cast),
transform=self._extract_nested_dict(field, self.transform),
flattened=self._extract_nested_list(field, self.flattened),
check_types=self.check_types,
)

# pylint: disable=unsupported-membership-test,unsubscriptable-object,no-member
def get_value(self, field: Field, data: Data) -> Any:
if field.name in self.flattened or field.name in self.prefixed:
if field.name in self.flattened:
value = data
else:
value = self._extract_nested_dict_for_prefix(self.prefixed[field.name], data)
else:
try:
key_name = self.remap.get(field.name, field.name)
value = data[key_name]
except KeyError:
raise ValueNotFoundError()
if field.name in self.transform:
value = self.transform[field.name](value)
if field.name in self.cast:
value = cast_value(field.type, value)
return value

def _validate_field_name(self, data_class: Type, parameter: str) -> None:
data_class_fields = {field.name for field in fields(data_class)}
for data_class_field in getattr(self, parameter):
if "." not in data_class_field:
if data_class_field not in data_class_fields:
raise InvalidConfigurationError(
parameter=parameter, available_choices=data_class_fields, value=data_class_field
)

def _validate_data_key(self, data_class: Type, data: Data, parameter: str, validator=lambda v, c: v in c) -> None:
input_data_keys = set(data.keys())
data_class_fields = {field.name: field for field in fields(data_class)}
for field_name, input_data_field in getattr(self, parameter).items():
if "." not in field_name:
field = data_class_fields[field_name]
if not validator(input_data_field, input_data_keys) and not has_field_default_value(field):
raise InvalidConfigurationError(
parameter=parameter, available_choices=input_data_keys, value=input_data_field
)

def _extract_nested_dict(self, field: Field, params: Dict[str, Any]) -> Dict[str, Any]:
prefix = field.name + "."
return self._extract_nested_dict_for_prefix(prefix=prefix, data=params)

def _extract_nested_list(self, field: Field, params: List[str]) -> List[str]:
result = []
prefix = field.name + "."
prefix_len = len(prefix)
for name in params:
if name.startswith(prefix):
result.append(name[prefix_len:])
return result

def _extract_nested_dict_for_prefix(self, prefix: str, data: Dict[str, Any]) -> Dict[str, Any]:
result = {}
prefix_len = len(prefix)
for key, val in data.items():
if key.startswith(prefix):
result[key[prefix_len:]] = val
return result
12 changes: 7 additions & 5 deletions dacite/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dataclasses import fields, is_dataclass
from typing import TypeVar, Type, Optional, get_type_hints, Mapping, Any

from dacite.config import Config, ValueNotFoundError
from dacite.config import Config
from dacite.data import Data
from dacite.dataclasses import get_default_value_for_field, create_instance, DefaultValueNotFoundError
from dacite.exceptions import (
Expand All @@ -20,6 +20,7 @@
is_union,
extract_generic,
is_optional,
transform_value,
)

T = TypeVar("T")
Expand All @@ -36,7 +37,6 @@ def from_dict(data_class: Type[T], data: Data, config: Optional[Config] = None)
init_values: Data = {}
post_init_values: Data = {}
config = config or Config()
config.validate(data_class, data)
try:
data_class_hints = get_type_hints(data_class, globalns=config.forward_references)
except NameError as error:
Expand All @@ -46,15 +46,17 @@ def from_dict(data_class: Type[T], data: Data, config: Optional[Config] = None)
field.type = data_class_hints[field.name]
try:
try:
value = _build_value(
type_=field.type, data=config.get_value(field, data), config=config.make_inner(field)
field_data = data[field.name]
transformed_value = transform_value(
type_hooks=config.type_hooks, target_type=field.type, value=field_data
)
value = _build_value(type_=field.type, data=transformed_value, config=config)
except DaciteFieldError as error:
error.update_path(field.name)
raise
if config.check_types and not is_instance(value, field.type):
raise WrongTypeError(field_path=field.name, field_type=field.type, value=value)
except ValueNotFoundError:
except KeyError:
try:
value = get_default_value_for_field(field)
except DefaultValueNotFoundError:
Expand Down
4 changes: 0 additions & 4 deletions dacite/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@ class DefaultValueNotFoundError(Exception):
pass


def has_field_default_value(field: Field) -> bool:
return field.default != MISSING or field.default_factory != MISSING or is_optional(field.type) # type: ignore


def get_default_value_for_field(field: Field) -> Any:
if field.default != MISSING:
return field.default
Expand Down
16 changes: 1 addition & 15 deletions dacite/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Set, Type, Optional
from typing import Any, Type, Optional


def _name(type_: Type) -> str:
Expand Down Expand Up @@ -50,20 +50,6 @@ def __str__(self) -> str:
)


class InvalidConfigurationError(DaciteError):
def __init__(self, parameter: str, available_choices: Set[str], value: str) -> None:
super().__init__()
self.parameter = parameter
self.available_choices = available_choices
self.value = value

def __str__(self):
return (
f'invalid value in "{self.parameter}" configuration: "{self.value}". '
f'Choices are: {", ".join(self.available_choices)}'
)


class ForwardReferenceError(DaciteError):
def __init__(self, message: str) -> None:
super().__init__()
Expand Down
37 changes: 25 additions & 12 deletions dacite/types.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,29 @@
from typing import Type, Any, Optional, Union, Collection, TypeVar, cast
from typing import Type, Any, Optional, Union, Collection, TypeVar, Dict, Callable

T = TypeVar("T", bound=Any)


def cast_value(type_: Type[T], value: Any) -> T:
if is_optional(type_):
type_ = extract_optional(type_)
if is_generic_collection(type_):
collection_cls = extract_origin_collection(type_)
def transform_value(type_hooks: Dict[Type, Callable[[Any], Any]], target_type: Type, value: Any) -> Any:
if target_type in type_hooks:
value = type_hooks[target_type](value)
if is_optional(target_type):
if value is None:
return None
target_type = extract_optional(target_type)
return transform_value(type_hooks, target_type, value)
if is_generic_collection(target_type) and isinstance(value, extract_origin_collection(target_type)):
collection_cls = extract_origin_collection(target_type)
if issubclass(collection_cls, dict):
key_cls, item_cls = extract_generic(type_)
return cast(T, collection_cls({key_cls(key): item_cls(item) for key, item in value.items()}))
item_cls = extract_generic(type_)[0]
return collection_cls(item_cls(item) for item in value)
return type_(value)
key_cls, item_cls = extract_generic(target_type)
return collection_cls(
{
transform_value(type_hooks, key_cls, key): transform_value(type_hooks, item_cls, item)
for key, item in value.items()
}
)
item_cls = extract_generic(target_type)[0]
return collection_cls(transform_value(type_hooks, item_cls, item) for item in value)
return value


def extract_origin_collection(collection: Type) -> Type:
Expand Down Expand Up @@ -77,7 +87,10 @@ def is_generic_collection(type_: Type) -> bool:
if not is_generic(type_):
return False
origin = extract_origin_collection(type_)
return bool(origin and issubclass(origin, Collection))
try:
return bool(origin and issubclass(origin, Collection))
except (TypeError, AttributeError):
return False


def extract_generic(type_: Type) -> tuple:
Expand Down

0 comments on commit 082af61

Please sign in to comment.