Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/source/pages/data-binding/elements.rst
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ choice because it works in predictable time since it doesn't require any look-ah

.. grid-item-card:: Model

.. literalinclude:: ../../../../examples/snippets/model_mode_strict.py
.. literalinclude:: ../../../../examples/snippets/lxml/model_mode_strict.py
:language: python
:start-after: model-start
:end-before: model-end
Expand All @@ -220,15 +220,15 @@ choice because it works in predictable time since it doesn't require any look-ah

.. tab-item:: XML

.. literalinclude:: ../../../../examples/snippets/model_mode_strict.py
.. literalinclude:: ../../../../examples/snippets/lxml/model_mode_strict.py
:language: xml
:lines: 2-
:start-after: xml-start
:end-before: xml-end

.. tab-item:: JSON

.. literalinclude:: ../../../../examples/snippets/model_mode_strict.py
.. literalinclude:: ../../../../examples/snippets/lxml/model_mode_strict.py
:language: json
:lines: 2-
:start-after: json-start
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ class Company(
error = e.errors()[0]
assert error == {
'loc': ('founded',),
'msg': 'Field required',
'msg': '[line 2]: Field required',
'ctx': {'orig': 'Field required', 'sourceline': 2},
'type': 'missing',
'input': ANY,
'url': ANY,
}
else:
raise AssertionError('exception not raised')
52 changes: 37 additions & 15 deletions pydantic_xml/element/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

from pydantic_xml.typedefs import NsMap

PathElementT = TypeVar('PathElementT')
PathT = Tuple[PathElementT, ...]


class XmlElementReader(abc.ABC):
"""
Expand Down Expand Up @@ -90,7 +93,7 @@ def pop_element(self, tag: str, search_mode: 'SearchMode') -> Optional['XmlEleme
"""

@abc.abstractmethod
def find_sub_element(self, path: Sequence[str], search_mode: 'SearchMode') -> Optional['XmlElementReader']:
def find_sub_element(self, path: Sequence[str], search_mode: 'SearchMode') -> PathT['XmlElementReader']:
"""
Searches for an element at the provided path. If the element is not found returns `None`.

Expand Down Expand Up @@ -122,13 +125,21 @@ def to_native(self) -> Any:
"""

@abc.abstractmethod
def get_unbound(self) -> List[Tuple[Tuple[str, ...], str]]:
def get_unbound(self) -> List[Tuple[PathT['XmlElementReader'], Optional[str], str]]:
"""
Returns unbound entities.

:return: list of unbound entities
"""

@abc.abstractmethod
def get_sourceline(self) -> int:
"""
Returns source line of the element in the xml document.

:return: source line
"""


class XmlElementWriter(abc.ABC):
"""
Expand Down Expand Up @@ -265,6 +276,7 @@ def __init__(
attributes: Optional[Dict[str, str]] = None,
elements: Optional[List['XmlElement[NativeElement]']] = None,
nsmap: Optional[NsMap] = None,
sourceline: int = -1,
):
self._tag = tag
self._nsmap = nsmap
Expand All @@ -275,6 +287,11 @@ def __init__(
elements=elements or [],
next_element_idx=0,
)
self._sourceline = sourceline

@abc.abstractmethod
def get_sourceline(self) -> int:
return self._sourceline

@property
def tag(self) -> str:
Expand Down Expand Up @@ -345,15 +362,17 @@ def pop_element(self, tag: str, search_mode: 'SearchMode') -> Optional['XmlEleme

return searcher(self._state, tag, False, True)

def find_sub_element(self, path: Sequence[str], search_mode: 'SearchMode') -> Optional['XmlElement[NativeElement]']:
def find_sub_element(self, path: Sequence[str], search_mode: 'SearchMode') -> PathT['XmlElement[NativeElement]']:
assert len(path) > 0, "path can't be empty"

root, path = path[0], path[1:]
element = self.find_element(root, search_mode)
if element and path:
return element.find_sub_element(path, search_mode)

return element
root, *path = path
if (element := self.find_element(root, search_mode)) is not None:
if path:
return (element,) + element.find_sub_element(path, search_mode)
else:
return (element,)
else:
return ()

def find_element_or_create(
self,
Expand All @@ -379,21 +398,24 @@ def find_element(

return searcher(self._state, tag, look_behind, step_forward)

def get_unbound(self, path: Tuple[str, ...] = ()) -> List[Tuple[Tuple[str, ...], str]]:
result: List[Tuple[Tuple[str, ...], str]] = []
def get_unbound(
self,
path: PathT[XmlElementReader] = (),
) -> List[Tuple[PathT[XmlElementReader], Optional[str], str]]:
result: List[Tuple[PathT[XmlElementReader], Optional[str], str]] = []

if self._state.text and (text := self._state.text.strip()):
result.append((path, text))
result.append((path, None, text))

if self._state.tail and (tail := self._state.tail.strip()):
result.append((path, tail))
result.append((path, None, tail))

if attrs := self._state.attrib:
for name, value in attrs.items():
result.append((path + (f'@{name}',), value))
result.append((path, name, value))

for sub_element in self._state.elements:
result.extend(sub_element.get_unbound(path + (sub_element.tag,)))
result.extend(sub_element.get_unbound(path + (sub_element,)))

return result

Expand Down
5 changes: 5 additions & 0 deletions pydantic_xml/element/native/lxml.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import typing
from typing import Optional, Union

from lxml import etree
Expand Down Expand Up @@ -30,6 +31,7 @@ def from_native(cls, element: ElementT) -> 'XmlElement':
for sub_element in element
if not is_xml_comment(sub_element)
],
sourceline=typing.cast(int, element.sourceline) if element.sourceline is not None else -1,
)

def to_native(self) -> ElementT:
Expand All @@ -48,6 +50,9 @@ def to_native(self) -> ElementT:
def make_element(self, tag: str, nsmap: Optional[NsMap]) -> 'XmlElement':
return XmlElement(tag, nsmap=nsmap)

def get_sourceline(self) -> int:
return self._sourceline


def force_str(val: Union[str, bytes]) -> str:
if isinstance(val, bytes):
Expand Down
3 changes: 3 additions & 0 deletions pydantic_xml/element/native/std.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ def to_native(self) -> ElementT:
def make_element(self, tag: str, nsmap: Optional[NsMap]) -> 'XmlElement':
return XmlElement(tag)

def get_sourceline(self) -> int:
return -1


def is_xml_comment(element: ElementT) -> bool:
return element.tag is etree.Comment # type: ignore[comparison-overlap]
9 changes: 8 additions & 1 deletion pydantic_xml/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,14 @@ def from_xml_tree(cls: Type[ModelT], root: etree.Element, context: Optional[Dict
assert cls.__xml_serializer__ is not None, f"model {cls.__name__} is partially initialized"

if root.tag == cls.__xml_serializer__.element_name:
obj = typing.cast(ModelT, cls.__xml_serializer__.deserialize(XmlElement.from_native(root), context=context))
obj = typing.cast(
ModelT, cls.__xml_serializer__.deserialize(
XmlElement.from_native(root),
context=context,
sourcemap={},
loc=(),
),
)
return obj
else:
raise errors.ParsingError(
Expand Down
30 changes: 21 additions & 9 deletions pydantic_xml/serializers/factories/heterogeneous.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,27 @@
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union

import pydantic as pd
from pydantic_core import core_schema as pcs

from pydantic_xml import errors
from pydantic_xml import errors, utils
from pydantic_xml.element import XmlElementReader, XmlElementWriter
from pydantic_xml.serializers.serializer import TYPE_FAMILY, SchemaTypeFamily, Serializer
from pydantic_xml.typedefs import EntityLocation
from pydantic_xml.typedefs import EntityLocation, Location


class ElementSerializer(Serializer):
@classmethod
def from_core_schema(cls, schema: pcs.TuplePositionalSchema, ctx: Serializer.Context) -> 'ElementSerializer':
model_name = ctx.model_name
computed = ctx.field_computed
inner_serializers: List[Serializer] = []
for item_schema in schema['items_schema']:
inner_serializers.append(Serializer.parse_core_schema(item_schema, ctx))

return cls(computed, tuple(inner_serializers))
return cls(model_name, computed, tuple(inner_serializers))

def __init__(self, computed: bool, inner_serializers: Tuple[Serializer, ...]):
def __init__(self, model_name: str, computed: bool, inner_serializers: Tuple[Serializer, ...]):
self._model_name = model_name
self._computed = computed
self._inner_serializers = inner_serializers

Expand All @@ -44,17 +47,26 @@ def deserialize(
element: Optional[XmlElementReader],
*,
context: Optional[Dict[str, Any]],
sourcemap: Dict[Location, int],
loc: Location,
) -> Optional[List[Any]]:
if self._computed:
return None

if element is None:
return None

result = [
serializer.deserialize(element, context=context)
for serializer in self._inner_serializers
]
result: List[Any] = []
item_errors: Dict[Union[None, str, int], pd.ValidationError] = {}
for idx, serializer in enumerate(self._inner_serializers):
try:
result.append(serializer.deserialize(element, context=context, sourcemap=sourcemap, loc=loc + (idx,)))
except pd.ValidationError as err:
item_errors[idx] = err

if item_errors:
raise utils.build_validation_error(title=self._model_name, errors_map=item_errors)

if all((value is None for value in result)):
return None
else:
Expand Down
32 changes: 25 additions & 7 deletions pydantic_xml/serializers/factories/homogeneous.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import itertools as it
from typing import Any, Dict, List, Optional, Union

import pydantic as pd
from pydantic_core import core_schema as pcs

from pydantic_xml import errors
from pydantic_xml import errors, utils
from pydantic_xml.element import XmlElementReader, XmlElementWriter
from pydantic_xml.serializers.serializer import TYPE_FAMILY, SchemaTypeFamily, Serializer
from pydantic_xml.typedefs import EntityLocation
from pydantic_xml.typedefs import EntityLocation, Location

HomogeneousCollectionTypeSchema = Union[
pcs.TupleVariableSchema,
Expand All @@ -18,12 +20,14 @@
class ElementSerializer(Serializer):
@classmethod
def from_core_schema(cls, schema: HomogeneousCollectionTypeSchema, ctx: Serializer.Context) -> 'ElementSerializer':
model_name = ctx.model_name
computed = ctx.field_computed
inner_serializer = Serializer.parse_core_schema(schema['items_schema'], ctx)

return cls(computed, inner_serializer)
return cls(model_name, computed, inner_serializer)

def __init__(self, computed: bool, inner_serializer: Serializer):
def __init__(self, model_name: str, computed: bool, inner_serializer: Serializer):
self._model_name = model_name
self._computed = computed
self._inner_serializer = inner_serializer

Expand All @@ -49,16 +53,30 @@ def deserialize(
element: Optional[XmlElementReader],
*,
context: Optional[Dict[str, Any]],
sourcemap: Dict[Location, int],
loc: Location,
) -> Optional[List[Any]]:
if self._computed:
return None

if element is None:
return None

result = []
while (value := self._inner_serializer.deserialize(element, context=context)) is not None:
result.append(value)
serializer = self._inner_serializer
result: List[Any] = []
item_errors: Dict[Union[None, str, int], pd.ValidationError] = {}
for idx in it.count():
try:
value = serializer.deserialize(element, context=context, sourcemap=sourcemap, loc=loc + (idx,))
if value is None:
break
except pd.ValidationError as err:
item_errors[idx] = err
else:
result.append(value)

if item_errors:
raise utils.build_validation_error(title=self._model_name, errors_map=item_errors)

return result or None

Expand Down
9 changes: 7 additions & 2 deletions pydantic_xml/serializers/factories/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pydantic_xml import errors
from pydantic_xml.element import XmlElementReader, XmlElementWriter
from pydantic_xml.serializers.serializer import TYPE_FAMILY, SchemaTypeFamily, SearchMode, Serializer
from pydantic_xml.typedefs import EntityLocation, NsMap
from pydantic_xml.typedefs import EntityLocation, Location, NsMap
from pydantic_xml.utils import QName, merge_nsmaps, select_ns


Expand Down Expand Up @@ -49,6 +49,8 @@ def deserialize(
element: Optional[XmlElementReader],
*,
context: Optional[Dict[str, Any]],
sourcemap: Dict[Location, int],
loc: Location,
) -> Optional[Dict[str, str]]:
if self._computed:
return None
Expand Down Expand Up @@ -115,12 +117,15 @@ def deserialize(
element: Optional[XmlElementReader],
*,
context: Optional[Dict[str, Any]],
sourcemap: Dict[Location, int],
loc: Location,
) -> Optional[Dict[str, str]]:
if self._computed:
return None

if element and (sub_element := element.pop_element(self._element_name, self._search_mode)) is not None:
return super().deserialize(sub_element, context=context)
sourcemap[loc] = sub_element.get_sourceline()
return super().deserialize(sub_element, context=context, sourcemap=sourcemap, loc=loc)
else:
return None

Expand Down
Loading