Skip to content

Commit

Permalink
rewrite with dataclasses and typehints
Browse files Browse the repository at this point in the history
  • Loading branch information
glorpen committed Apr 27, 2022
1 parent 31f7b65 commit 0d4dc60
Showing 1 changed file with 174 additions and 87 deletions.
261 changes: 174 additions & 87 deletions src/glorpen/config/config.py
Original file line number Diff line number Diff line change
@@ -1,97 +1,184 @@
'''
.. moduleauthor:: Arkadiusz Dzięgiel <arkadiusz.dziegiel@glorpen.pl>
'''
# from glorpen.config.fields import path_validation_error
from collections import OrderedDict
from glorpen.config import exceptions
import abc
import dataclasses
import itertools
import textwrap
import types
import typing

class Config(object):

class ConfigType(abc.ABC):
def __init__(self, config):
super(ConfigType, self).__init__()
self.config = config

@abc.abstractmethod
def as_model(self, data: typing.Any, type, args: typing.Tuple, metadata: dict, path: str):
pass


ValueErrorItems = typing.Union[dict, typing.Sequence]


class ConfigValueError(ValueError):
def __init__(self, error):
super(ConfigValueError, self).__init__(f"Found validation errors:\n{error}")


class DictValueError(ValueError):
def __init__(self, items: ValueErrorItems):
# msg = "Invalid fields.\n" + self._format_row(items)
msg = self._format_row(items)
super(DictValueError, self).__init__(msg)

def _format_row(self, items: ValueErrorItems):
return textwrap.indent("\n".join(self._format_items(items)), "")

def _format_items(self, items: ValueErrorItems):
if hasattr(items, "keys"):
key_max_len = max(len(str(i)) for i in items.keys())
item_sets = items.items()
key_suffix = ": "
else:
key_max_len = 1
item_sets = [("-", v) for v in items]
key_suffix = " "

msg_offset = key_max_len + len(key_suffix)

for k, e in item_sets:
f_key = str(k).rjust(key_max_len)
f_msg = textwrap.indent(str(e), " " * msg_offset)[msg_offset:]
yield f"{f_key}{key_suffix}{f_msg}"


class Config:
"""Config validator and normalizer."""

def __init__(self, spec):

_registered_types: typing.List[ConfigType]

def __init__(self):
super(Config, self).__init__()

self.spec = spec

def walk(self, normalized_tree, path=[]):
if hasattr(normalized_tree, "values"):
for k,v in normalized_tree.values.items():
lpath = tuple(list(path) + [k])
yield from self.walk(v, lpath)
yield lpath, v

def get(self, raw_value):
if not self.spec.is_value_supported(raw_value):
raise exceptions.ConfigException("value is not supported")

self._registered_types = []

@classmethod
def _handle_optional_values(cls, type, default_factory):
if (type is types.NoneType) or (
typing.get_origin(type) is typing.Union and types.NoneType in typing.get_args(type)):
return None
if default_factory:
return default_factory()

raise ValueError("No value provided")

def as_model(self, data: typing.Any, type, metadata=None, default_factory=None, path=""):
if data is None:
return self._handle_optional_values(type, default_factory)

if dataclasses.is_dataclass(type):
return self._from_dataclass(data, type, path=path)
else:
origin = typing.get_origin(type)
if origin is None:
return self._from_type(data=data, type=type, args=(), metadata=metadata, path=path)
else:
return self._from_type(data=data, type=origin, args=typing.get_args(type), metadata=metadata, path=path)

def to_model(self, data, cls):
try:
normalized_value = self.spec.normalize(raw_value)
except exceptions.ConfigException as e:
raise exceptions.TraceableConfigException(e)

index, required_deps_by_path = self._find_dependencies(normalized_value)
self._resolve_dependencies(index, required_deps_by_path)

packed_tree = self.spec.pack(normalized_value)
self._validate(index, packed_tree)

return packed_tree

def _find_dependencies(self, normalized_value):
index = {}
required_deps_by_path = {}

for path, i in self.walk(normalized_value):
index[path] = i
deps = i.field.get_dependencies(i)
if deps:
required_deps_by_path[path] = deps

return index, required_deps_by_path

def _resolve_dependencies(self, index, required_deps_by_path):
resolved_paths = {}
something_was_done = True

# iterate over required deps until we cannot resolve anything
# it should happen in two cases:
# - all dependencies were resolved
# - only circular deps are left
while something_was_done:
something_was_done = False
for req_path, req_deps in required_deps_by_path.items():
if req_path in resolved_paths:
continue

unknown_deps = set(req_deps).intersection(set(required_deps_by_path).difference(resolved_paths))
if unknown_deps:
# nested deps - skipping
continue

something_was_done = True

values = []

for i in req_deps:
values.append(resolved_paths[i] if i in resolved_paths else index[i].value)
if req_path not in index:
raise Exception("Path %r was not found in config" % (req_path,) )
resolved_paths[req_path] = index[req_path].field.interpolate(index[req_path], values)

unsolvable_deps = set(required_deps_by_path).difference(resolved_paths)
if unsolvable_deps:
raise Exception("Paths could not be solved: %r", unsolvable_deps)

def _validate(self, index, packed_tree):
return self.as_model(data, cls)
except ValueError as e:
raise ConfigValueError(e) from None

@classmethod
def _get_default_factory(cls, field: dataclasses.Field):
if field.default is not dataclasses.MISSING:
return lambda: field.default
elif field.default_factory is not dataclasses.MISSING:
return field.default_factory
else:
return None

def _from_dataclass(self, data: typing.Dict, cls, path: str):
kwargs = {}
errors = {}
for path, f in index.items():
for field in dataclasses.fields(cls):
try:
f.field.validate(f.packed, packed_tree)
except Exception as e:
errors[path] = e
kwargs[field.name] = self.as_model(data.get(field.name), field.type, metadata=field.metadata,
default_factory=self._get_default_factory(field), path=f"{path}.{field.name}")
except ValueError as e:
errors[field.name] = e

if errors:
raise Exception(errors)

def help(self):
return self.spec.help_config
raise DictValueError(errors)

return cls(**kwargs)

def _from_type(self, data: typing.Any, type, args: typing.Tuple, metadata: dict, path: str):
for reg_type in self._registered_types:
value = reg_type.as_model(data=data, type=type, args=args, metadata=metadata, path=path)
if value is not None:
return value

raise ValueError(f"Could not convert to {type}")

def register(self, type: ConfigType):
self._registered_types.append(type(self))


class UnionType(ConfigType):
def as_model(self, data: typing.Any, type, args: typing.Tuple, metadata: dict, path: str):
if type is typing.Union:
return self._try_each_type(data, args, metadata=metadata, path=path)

def _try_each_type(self, data, types, path: str, metadata=None):
errors = []
for tp in types:
try:
return self.config.as_model(data, tp, metadata=metadata, path=path)
except ValueError as e:
errors.append(e)

raise DictValueError(errors)


class SimpleTypes(ConfigType):
@classmethod
def _try_convert(cls, data, conv):
try:
return conv(data)
except Exception as e:
raise ValueError(e)

def as_model(self, data: typing.Any, type, args: typing.Tuple, metadata: dict, path: str):
if type in (int, str, bool, float):
return self._try_convert(data, type)


class SequenceTypes(ConfigType):
def as_model(self, data: typing.Any, type, args: typing.Tuple, metadata: dict, path: str):
if type is tuple:
errors = {}
ret = []

for index, (tp, value) in enumerate(itertools.zip_longest(args, data)):
try:
ret.append(self.config.as_model(value, tp, path=f"{path}.{index}"))
except Exception as e:
errors[index] = e

if errors:
raise DictValueError(errors)

return tuple(ret)


def default():
c = Config()
c.register(UnionType)
c.register(SimpleTypes)
c.register(SequenceTypes)

return c

0 comments on commit 0d4dc60

Please sign in to comment.