-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
rewrite with dataclasses and typehints
- Loading branch information
Showing
1 changed file
with
174 additions
and
87 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,97 +1,184 @@ | ||
''' | ||
.. moduleauthor:: Arkadiusz Dzięgiel <arkadiusz.dziegiel@glorpen.pl> | ||
''' | ||
# from glorpen.config.fields import path_validation_error | ||
from collections import OrderedDict | ||
from glorpen.config import exceptions | ||
import abc | ||
import dataclasses | ||
import itertools | ||
import textwrap | ||
import types | ||
import typing | ||
|
||
class Config(object): | ||
|
||
class ConfigType(abc.ABC): | ||
def __init__(self, config): | ||
super(ConfigType, self).__init__() | ||
self.config = config | ||
|
||
@abc.abstractmethod | ||
def as_model(self, data: typing.Any, type, args: typing.Tuple, metadata: dict, path: str): | ||
pass | ||
|
||
|
||
ValueErrorItems = typing.Union[dict, typing.Sequence] | ||
|
||
|
||
class ConfigValueError(ValueError): | ||
def __init__(self, error): | ||
super(ConfigValueError, self).__init__(f"Found validation errors:\n{error}") | ||
|
||
|
||
class DictValueError(ValueError): | ||
def __init__(self, items: ValueErrorItems): | ||
# msg = "Invalid fields.\n" + self._format_row(items) | ||
msg = self._format_row(items) | ||
super(DictValueError, self).__init__(msg) | ||
|
||
def _format_row(self, items: ValueErrorItems): | ||
return textwrap.indent("\n".join(self._format_items(items)), "") | ||
|
||
def _format_items(self, items: ValueErrorItems): | ||
if hasattr(items, "keys"): | ||
key_max_len = max(len(str(i)) for i in items.keys()) | ||
item_sets = items.items() | ||
key_suffix = ": " | ||
else: | ||
key_max_len = 1 | ||
item_sets = [("-", v) for v in items] | ||
key_suffix = " " | ||
|
||
msg_offset = key_max_len + len(key_suffix) | ||
|
||
for k, e in item_sets: | ||
f_key = str(k).rjust(key_max_len) | ||
f_msg = textwrap.indent(str(e), " " * msg_offset)[msg_offset:] | ||
yield f"{f_key}{key_suffix}{f_msg}" | ||
|
||
|
||
class Config: | ||
"""Config validator and normalizer.""" | ||
|
||
def __init__(self, spec): | ||
|
||
_registered_types: typing.List[ConfigType] | ||
|
||
def __init__(self): | ||
super(Config, self).__init__() | ||
|
||
self.spec = spec | ||
|
||
def walk(self, normalized_tree, path=[]): | ||
if hasattr(normalized_tree, "values"): | ||
for k,v in normalized_tree.values.items(): | ||
lpath = tuple(list(path) + [k]) | ||
yield from self.walk(v, lpath) | ||
yield lpath, v | ||
|
||
def get(self, raw_value): | ||
if not self.spec.is_value_supported(raw_value): | ||
raise exceptions.ConfigException("value is not supported") | ||
|
||
self._registered_types = [] | ||
|
||
@classmethod | ||
def _handle_optional_values(cls, type, default_factory): | ||
if (type is types.NoneType) or ( | ||
typing.get_origin(type) is typing.Union and types.NoneType in typing.get_args(type)): | ||
return None | ||
if default_factory: | ||
return default_factory() | ||
|
||
raise ValueError("No value provided") | ||
|
||
def as_model(self, data: typing.Any, type, metadata=None, default_factory=None, path=""): | ||
if data is None: | ||
return self._handle_optional_values(type, default_factory) | ||
|
||
if dataclasses.is_dataclass(type): | ||
return self._from_dataclass(data, type, path=path) | ||
else: | ||
origin = typing.get_origin(type) | ||
if origin is None: | ||
return self._from_type(data=data, type=type, args=(), metadata=metadata, path=path) | ||
else: | ||
return self._from_type(data=data, type=origin, args=typing.get_args(type), metadata=metadata, path=path) | ||
|
||
def to_model(self, data, cls): | ||
try: | ||
normalized_value = self.spec.normalize(raw_value) | ||
except exceptions.ConfigException as e: | ||
raise exceptions.TraceableConfigException(e) | ||
|
||
index, required_deps_by_path = self._find_dependencies(normalized_value) | ||
self._resolve_dependencies(index, required_deps_by_path) | ||
|
||
packed_tree = self.spec.pack(normalized_value) | ||
self._validate(index, packed_tree) | ||
|
||
return packed_tree | ||
|
||
def _find_dependencies(self, normalized_value): | ||
index = {} | ||
required_deps_by_path = {} | ||
|
||
for path, i in self.walk(normalized_value): | ||
index[path] = i | ||
deps = i.field.get_dependencies(i) | ||
if deps: | ||
required_deps_by_path[path] = deps | ||
|
||
return index, required_deps_by_path | ||
|
||
def _resolve_dependencies(self, index, required_deps_by_path): | ||
resolved_paths = {} | ||
something_was_done = True | ||
|
||
# iterate over required deps until we cannot resolve anything | ||
# it should happen in two cases: | ||
# - all dependencies were resolved | ||
# - only circular deps are left | ||
while something_was_done: | ||
something_was_done = False | ||
for req_path, req_deps in required_deps_by_path.items(): | ||
if req_path in resolved_paths: | ||
continue | ||
|
||
unknown_deps = set(req_deps).intersection(set(required_deps_by_path).difference(resolved_paths)) | ||
if unknown_deps: | ||
# nested deps - skipping | ||
continue | ||
|
||
something_was_done = True | ||
|
||
values = [] | ||
|
||
for i in req_deps: | ||
values.append(resolved_paths[i] if i in resolved_paths else index[i].value) | ||
if req_path not in index: | ||
raise Exception("Path %r was not found in config" % (req_path,) ) | ||
resolved_paths[req_path] = index[req_path].field.interpolate(index[req_path], values) | ||
|
||
unsolvable_deps = set(required_deps_by_path).difference(resolved_paths) | ||
if unsolvable_deps: | ||
raise Exception("Paths could not be solved: %r", unsolvable_deps) | ||
|
||
def _validate(self, index, packed_tree): | ||
return self.as_model(data, cls) | ||
except ValueError as e: | ||
raise ConfigValueError(e) from None | ||
|
||
@classmethod | ||
def _get_default_factory(cls, field: dataclasses.Field): | ||
if field.default is not dataclasses.MISSING: | ||
return lambda: field.default | ||
elif field.default_factory is not dataclasses.MISSING: | ||
return field.default_factory | ||
else: | ||
return None | ||
|
||
def _from_dataclass(self, data: typing.Dict, cls, path: str): | ||
kwargs = {} | ||
errors = {} | ||
for path, f in index.items(): | ||
for field in dataclasses.fields(cls): | ||
try: | ||
f.field.validate(f.packed, packed_tree) | ||
except Exception as e: | ||
errors[path] = e | ||
kwargs[field.name] = self.as_model(data.get(field.name), field.type, metadata=field.metadata, | ||
default_factory=self._get_default_factory(field), path=f"{path}.{field.name}") | ||
except ValueError as e: | ||
errors[field.name] = e | ||
|
||
if errors: | ||
raise Exception(errors) | ||
|
||
def help(self): | ||
return self.spec.help_config | ||
raise DictValueError(errors) | ||
|
||
return cls(**kwargs) | ||
|
||
def _from_type(self, data: typing.Any, type, args: typing.Tuple, metadata: dict, path: str): | ||
for reg_type in self._registered_types: | ||
value = reg_type.as_model(data=data, type=type, args=args, metadata=metadata, path=path) | ||
if value is not None: | ||
return value | ||
|
||
raise ValueError(f"Could not convert to {type}") | ||
|
||
def register(self, type: ConfigType): | ||
self._registered_types.append(type(self)) | ||
|
||
|
||
class UnionType(ConfigType): | ||
def as_model(self, data: typing.Any, type, args: typing.Tuple, metadata: dict, path: str): | ||
if type is typing.Union: | ||
return self._try_each_type(data, args, metadata=metadata, path=path) | ||
|
||
def _try_each_type(self, data, types, path: str, metadata=None): | ||
errors = [] | ||
for tp in types: | ||
try: | ||
return self.config.as_model(data, tp, metadata=metadata, path=path) | ||
except ValueError as e: | ||
errors.append(e) | ||
|
||
raise DictValueError(errors) | ||
|
||
|
||
class SimpleTypes(ConfigType): | ||
@classmethod | ||
def _try_convert(cls, data, conv): | ||
try: | ||
return conv(data) | ||
except Exception as e: | ||
raise ValueError(e) | ||
|
||
def as_model(self, data: typing.Any, type, args: typing.Tuple, metadata: dict, path: str): | ||
if type in (int, str, bool, float): | ||
return self._try_convert(data, type) | ||
|
||
|
||
class SequenceTypes(ConfigType): | ||
def as_model(self, data: typing.Any, type, args: typing.Tuple, metadata: dict, path: str): | ||
if type is tuple: | ||
errors = {} | ||
ret = [] | ||
|
||
for index, (tp, value) in enumerate(itertools.zip_longest(args, data)): | ||
try: | ||
ret.append(self.config.as_model(value, tp, path=f"{path}.{index}")) | ||
except Exception as e: | ||
errors[index] = e | ||
|
||
if errors: | ||
raise DictValueError(errors) | ||
|
||
return tuple(ret) | ||
|
||
|
||
def default(): | ||
c = Config() | ||
c.register(UnionType) | ||
c.register(SimpleTypes) | ||
c.register(SequenceTypes) | ||
|
||
return c | ||
|