Skip to content

Commit

Permalink
started redesign of xml module
Browse files Browse the repository at this point in the history
  • Loading branch information
christophevg committed Jan 3, 2024
1 parent 57c6909 commit 043d9e0
Show file tree
Hide file tree
Showing 5 changed files with 571 additions and 2 deletions.
Empty file added bpmn_tools/future/__init__.py
Empty file.
282 changes: 282 additions & 0 deletions bpmn_tools/future/xml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
"""
Class that accepts an xmltodict-style dictionary, builds a generic Element
hierarchy and can reproduce the dictionary.
The class can be overridden to create classes dealing with specialized Elements.
A Visitor allows for accessing the entire Element-hierarchy.
"""

import logging

from dataclasses import dataclass, field, fields
from typing import Optional, Union, List, Dict, ForwardRef, get_origin, get_args

import random
import string

logger = logging.getLogger(__name__)

@dataclass
class Element():
text : Optional[str] = None
children : List["Element"] = field(default_factory=list)
_children : List["Element"] = field(init=False, repr=False)
_tag : str = field(init=False, default="Element")
_parent : Optional["Element"] = field(init=False, default=None)
_attributes : Dict[str,str] = field(init=False, default_factory=dict)

_catch_all_children = True
_no_validation = []

def __post_init__(self):
self._validate_fields()

def _validate_fields(self):
def _validate(name, instance, field_type):
logger.debug(f"post init validating {name} : {field_type} = {instance}")
if isinstance(field_type, ForwardRef):
field_type = field_type._evaluate(globals(), locals())
if not isinstance(instance, field_type):
raise TypeError(f"on {self} field `{name}` should be `{field_type}` not `{type(instance)}`")

for fld in fields(self):
# skip our own known private attributes and additionally defined ones
if fld.name in ["_tag", "_parent", "_attributes"] + self._no_validation:
continue
# ensure default value
if fld.name not in self.__dict__:
setattr(self, fld.name, fld.default)
else:
base_type = get_origin(fld.type)
type_args = get_args(fld.type)
if base_type is list:
_validate(fld.name, self.__dict__[fld.name], list)
list_type = type_args[0]
for index, item in enumerate(self.__dict__[fld.name]):
_validate(f"{fld.name}[{index}]", item, list_type)
elif base_type is Union:
# Optiona[...] == Union[..., NoneType]
if len(type_args) == 2 and type_args[1] is None.__class__:
if self.__dict__[fld.name] is not None:
_validate(fld.name, self.__dict__[fld.name], type_args[0])
else:
logger.warning("type checking for Union is not (yet) implemented")
elif base_type is None:
_validate(fld.name, self.__dict__[fld.name], fld.type)
else:
logger.warning(f"type checking for {base_type} is not (yet) implemented")

@property
def children(self): # readonly tuple
return tuple(self._children) + tuple(self.specialized_children)

@children.setter
def children(self, new_children):
if type(new_children) is property:
new_children = []
# TODO: validation
self._children = new_children

@property
def specialized_children(self):
return [
child
for fld in self.specialized_children_fields
for child in self.__dict__[fld.name]
]

@property
def specializations(self):
return {
get_args(fld.type)[0] : self.__dict__[fld.name]
for fld in self.specialized_children_fields
}

@property
def specialized_children_fields(self):
for fld in fields(self):
if fld.metadata.get("child", False):
yield fld

def append(self, child):
# find specialization
for fld_type, fld in self.specializations.items():
if isinstance(child, fld_type):
fld.append(child)
return
# default
if self._catch_all_children:
self._children.append(child)
else:
raise ValueError(f"{self} doesn't allow for children of type {type(child)}")

def __setitem__(self, name, value):
self._attributes[name] = value

def __getitem__(self, name):
try:
return self._attributes[name]
except KeyError:
return None

def __getattr__(self, name):
return self[name]

@property
def root(self):
if self._parent:
return self._parent.root
else:
return self

def find(self, key, value, skip=None, stack=None):
if stack is None:
stack = []

if self in stack:
logger.warn("avoided recursion")
return None

if self is skip:
return None

# do I have the key=value attribute?
try:
if self._attributes[key] == value:
return self
except KeyError:
pass

# recurse down children
for child in self.children:
match = child.find(key, value, skip=skip, stack=stack+[self])
if match:
return match

return None

# def append(self, child):
# if not child:
# raise ValueError(f"invalid child: {child}")
# self.children.append(child)
# child._parent = self
# return self

def extend(self, children):
for child in children:
self.append(child)
return self

@property
def attributes(self):
return self._attributes

def children_oftype(self, cls, recurse=False):
children = []
for child in self.children:
if isinstance(child, cls) or isinstance(child.wrapped, cls):
children.append(child)
if recurse:
children.extend(child.children_oftype(cls, recurse=True))
return children

def as_dict(self, with_tag=False):
# collect attributes
definition = {
f"@{key}" : value for key, value in self.attributes.items() if value is not None
}

# collect text
if self.text:
definition["#text"] = self.text

# collect children
for child in self.children:
d = child.as_dict()
try:
definition[child._tag].append(d)
except KeyError:
definition[child._tag] = d
except AttributeError:
definition[child._tag] = [ definition[child._tag] , d ]

# prune text-only tag
if list(definition.keys()) == [ "#text" ]:
definition = definition["#text"]

# prune empty tag
if definition == {}:
definition = None

if with_tag:
return { self._tag : definition }
else:
return definition

@classmethod
def mapped_class(cls, tag, classes):
if classes:
for clazz in classes:
try:
if clazz._tag == tag:
return clazz
except AttributeError:
pass
return cls

@classmethod
def from_dict(cls, d, classes=None, depth=0, raise_unmapped=False):
element_type, element_definition = list(d.items())[0]
element_class = cls.mapped_class(element_type, classes)
if element_class == cls and classes:
if raise_unmapped:
raise ValueError(f"unmapped element: {element_type}")
else:
logger.warning(f"unmapped element: {element_type}")
element = element_class()
element._tag = element_type

if isinstance(element_definition, str):
element_definition = { "#text" : element_definition }

for key, defintions in element_definition.items():
if key[0] == "@":
element._attributes[key[1:]] = defintions
elif key == "#text":
element.text = defintions
else:
if not isinstance(defintions, list):
defintions = [ defintions ]
for definition in defintions:
if definition is None:
definition = {}
elif isinstance(definition, str):
definition = { "#text" : definition }
child = cls.from_dict(
{ key : definition }, classes=classes, depth=depth+1,
raise_unmapped=raise_unmapped
)
assert child is not element
element.append(child)

return element

def accept(self, visitor):
with visitor:
visitor.visit(self)
for child in self.children:
try:
child.accept(visitor)
except TypeError:
raise ValueError(f"accept() on {child} is missing argument")

class IdentifiedElement(Element):
def __init__(self, id=None, **kwargs):
super().__init__(**kwargs)
if id is None:
random_str = ''.join(random.choices(string.ascii_uppercase + string.digits, k=8))
id = f"{self.__class__.__name__.lower()}_{random_str}"
self["id"] = id
Loading

0 comments on commit 043d9e0

Please sign in to comment.