-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
57c6909
commit 043d9e0
Showing
5 changed files
with
571 additions
and
2 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,282 @@ | ||
""" | ||
Class that accepts an xmltodict-style dictionary, builds a generic Element | ||
hierarchy and can reproduce the dictionary. | ||
The class can be overridden to create classes dealing with specialized Elements. | ||
A Visitor allows for accessing the entire Element-hierarchy. | ||
""" | ||
|
||
import logging | ||
|
||
from dataclasses import dataclass, field, fields | ||
from typing import Optional, Union, List, Dict, ForwardRef, get_origin, get_args | ||
|
||
import random | ||
import string | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
@dataclass | ||
class Element(): | ||
text : Optional[str] = None | ||
children : List["Element"] = field(default_factory=list) | ||
_children : List["Element"] = field(init=False, repr=False) | ||
_tag : str = field(init=False, default="Element") | ||
_parent : Optional["Element"] = field(init=False, default=None) | ||
_attributes : Dict[str,str] = field(init=False, default_factory=dict) | ||
|
||
_catch_all_children = True | ||
_no_validation = [] | ||
|
||
def __post_init__(self): | ||
self._validate_fields() | ||
|
||
def _validate_fields(self): | ||
def _validate(name, instance, field_type): | ||
logger.debug(f"post init validating {name} : {field_type} = {instance}") | ||
if isinstance(field_type, ForwardRef): | ||
field_type = field_type._evaluate(globals(), locals()) | ||
if not isinstance(instance, field_type): | ||
raise TypeError(f"on {self} field `{name}` should be `{field_type}` not `{type(instance)}`") | ||
|
||
for fld in fields(self): | ||
# skip our own known private attributes and additionally defined ones | ||
if fld.name in ["_tag", "_parent", "_attributes"] + self._no_validation: | ||
continue | ||
# ensure default value | ||
if fld.name not in self.__dict__: | ||
setattr(self, fld.name, fld.default) | ||
else: | ||
base_type = get_origin(fld.type) | ||
type_args = get_args(fld.type) | ||
if base_type is list: | ||
_validate(fld.name, self.__dict__[fld.name], list) | ||
list_type = type_args[0] | ||
for index, item in enumerate(self.__dict__[fld.name]): | ||
_validate(f"{fld.name}[{index}]", item, list_type) | ||
elif base_type is Union: | ||
# Optiona[...] == Union[..., NoneType] | ||
if len(type_args) == 2 and type_args[1] is None.__class__: | ||
if self.__dict__[fld.name] is not None: | ||
_validate(fld.name, self.__dict__[fld.name], type_args[0]) | ||
else: | ||
logger.warning("type checking for Union is not (yet) implemented") | ||
elif base_type is None: | ||
_validate(fld.name, self.__dict__[fld.name], fld.type) | ||
else: | ||
logger.warning(f"type checking for {base_type} is not (yet) implemented") | ||
|
||
@property | ||
def children(self): # readonly tuple | ||
return tuple(self._children) + tuple(self.specialized_children) | ||
|
||
@children.setter | ||
def children(self, new_children): | ||
if type(new_children) is property: | ||
new_children = [] | ||
# TODO: validation | ||
self._children = new_children | ||
|
||
@property | ||
def specialized_children(self): | ||
return [ | ||
child | ||
for fld in self.specialized_children_fields | ||
for child in self.__dict__[fld.name] | ||
] | ||
|
||
@property | ||
def specializations(self): | ||
return { | ||
get_args(fld.type)[0] : self.__dict__[fld.name] | ||
for fld in self.specialized_children_fields | ||
} | ||
|
||
@property | ||
def specialized_children_fields(self): | ||
for fld in fields(self): | ||
if fld.metadata.get("child", False): | ||
yield fld | ||
|
||
def append(self, child): | ||
# find specialization | ||
for fld_type, fld in self.specializations.items(): | ||
if isinstance(child, fld_type): | ||
fld.append(child) | ||
return | ||
# default | ||
if self._catch_all_children: | ||
self._children.append(child) | ||
else: | ||
raise ValueError(f"{self} doesn't allow for children of type {type(child)}") | ||
|
||
def __setitem__(self, name, value): | ||
self._attributes[name] = value | ||
|
||
def __getitem__(self, name): | ||
try: | ||
return self._attributes[name] | ||
except KeyError: | ||
return None | ||
|
||
def __getattr__(self, name): | ||
return self[name] | ||
|
||
@property | ||
def root(self): | ||
if self._parent: | ||
return self._parent.root | ||
else: | ||
return self | ||
|
||
def find(self, key, value, skip=None, stack=None): | ||
if stack is None: | ||
stack = [] | ||
|
||
if self in stack: | ||
logger.warn("avoided recursion") | ||
return None | ||
|
||
if self is skip: | ||
return None | ||
|
||
# do I have the key=value attribute? | ||
try: | ||
if self._attributes[key] == value: | ||
return self | ||
except KeyError: | ||
pass | ||
|
||
# recurse down children | ||
for child in self.children: | ||
match = child.find(key, value, skip=skip, stack=stack+[self]) | ||
if match: | ||
return match | ||
|
||
return None | ||
|
||
# def append(self, child): | ||
# if not child: | ||
# raise ValueError(f"invalid child: {child}") | ||
# self.children.append(child) | ||
# child._parent = self | ||
# return self | ||
|
||
def extend(self, children): | ||
for child in children: | ||
self.append(child) | ||
return self | ||
|
||
@property | ||
def attributes(self): | ||
return self._attributes | ||
|
||
def children_oftype(self, cls, recurse=False): | ||
children = [] | ||
for child in self.children: | ||
if isinstance(child, cls) or isinstance(child.wrapped, cls): | ||
children.append(child) | ||
if recurse: | ||
children.extend(child.children_oftype(cls, recurse=True)) | ||
return children | ||
|
||
def as_dict(self, with_tag=False): | ||
# collect attributes | ||
definition = { | ||
f"@{key}" : value for key, value in self.attributes.items() if value is not None | ||
} | ||
|
||
# collect text | ||
if self.text: | ||
definition["#text"] = self.text | ||
|
||
# collect children | ||
for child in self.children: | ||
d = child.as_dict() | ||
try: | ||
definition[child._tag].append(d) | ||
except KeyError: | ||
definition[child._tag] = d | ||
except AttributeError: | ||
definition[child._tag] = [ definition[child._tag] , d ] | ||
|
||
# prune text-only tag | ||
if list(definition.keys()) == [ "#text" ]: | ||
definition = definition["#text"] | ||
|
||
# prune empty tag | ||
if definition == {}: | ||
definition = None | ||
|
||
if with_tag: | ||
return { self._tag : definition } | ||
else: | ||
return definition | ||
|
||
@classmethod | ||
def mapped_class(cls, tag, classes): | ||
if classes: | ||
for clazz in classes: | ||
try: | ||
if clazz._tag == tag: | ||
return clazz | ||
except AttributeError: | ||
pass | ||
return cls | ||
|
||
@classmethod | ||
def from_dict(cls, d, classes=None, depth=0, raise_unmapped=False): | ||
element_type, element_definition = list(d.items())[0] | ||
element_class = cls.mapped_class(element_type, classes) | ||
if element_class == cls and classes: | ||
if raise_unmapped: | ||
raise ValueError(f"unmapped element: {element_type}") | ||
else: | ||
logger.warning(f"unmapped element: {element_type}") | ||
element = element_class() | ||
element._tag = element_type | ||
|
||
if isinstance(element_definition, str): | ||
element_definition = { "#text" : element_definition } | ||
|
||
for key, defintions in element_definition.items(): | ||
if key[0] == "@": | ||
element._attributes[key[1:]] = defintions | ||
elif key == "#text": | ||
element.text = defintions | ||
else: | ||
if not isinstance(defintions, list): | ||
defintions = [ defintions ] | ||
for definition in defintions: | ||
if definition is None: | ||
definition = {} | ||
elif isinstance(definition, str): | ||
definition = { "#text" : definition } | ||
child = cls.from_dict( | ||
{ key : definition }, classes=classes, depth=depth+1, | ||
raise_unmapped=raise_unmapped | ||
) | ||
assert child is not element | ||
element.append(child) | ||
|
||
return element | ||
|
||
def accept(self, visitor): | ||
with visitor: | ||
visitor.visit(self) | ||
for child in self.children: | ||
try: | ||
child.accept(visitor) | ||
except TypeError: | ||
raise ValueError(f"accept() on {child} is missing argument") | ||
|
||
class IdentifiedElement(Element): | ||
def __init__(self, id=None, **kwargs): | ||
super().__init__(**kwargs) | ||
if id is None: | ||
random_str = ''.join(random.choices(string.ascii_uppercase + string.digits, k=8)) | ||
id = f"{self.__class__.__name__.lower()}_{random_str}" | ||
self["id"] = id |
Oops, something went wrong.