Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions pydantic_xml/element/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@ class XmlElementReader(abc.ABC):
Provides an interface for extracting element text, attributes and sub-elements.
"""

@property
@abc.abstractmethod
def tag(self) -> str:
"""
Xml element tag.
"""

@abc.abstractmethod
def is_empty(self) -> bool:
"""
Expand Down Expand Up @@ -45,6 +52,14 @@ def find_element(
:return: xml element
"""

@abc.abstractmethod
def get_text(self) -> Optional[str]:
"""
Returns the element text.

:return: element text
"""

@abc.abstractmethod
def pop_text(self) -> Optional[str]:
"""
Expand All @@ -63,6 +78,14 @@ def pop_attrib(self, name: str) -> Optional[str]:
:return: element attribute
"""

@abc.abstractmethod
def get_attributes(self) -> Optional[Dict[str, str]]:
"""
Returns the element attributes.

:return: element attributes
"""

@abc.abstractmethod
def pop_attributes(self) -> Optional[Dict[str, str]]:
"""
Expand Down Expand Up @@ -92,6 +115,14 @@ def find_sub_element(self, path: Sequence[str], search_mode: 'SearchMode') -> Op
:return: found element or `None`
"""

@abc.abstractmethod
def get_elements(self) -> Optional[List['XmlElement[Any]']]:
"""
Returns the element sub-elements.

:return: sub-element
"""

@abc.abstractmethod
def create_snapshot(self) -> 'XmlElement[Any]':
"""
Expand Down Expand Up @@ -306,6 +337,9 @@ def append_element(self, element: 'XmlElement[NativeElement]') -> None:
def get_attrib(self, name: str) -> Optional[str]:
return self._state.attrib.get(name, None) if self._state.attrib else None

def get_text(self) -> Optional[str]:
return self._state.text

def pop_text(self) -> Optional[str]:
result, self._state.text = self._state.text, None

Expand All @@ -314,6 +348,9 @@ def pop_text(self) -> Optional[str]:
def pop_attrib(self, name: str) -> Optional[str]:
return self._state.attrib.pop(name, None) if self._state.attrib else None

def get_attributes(self) -> Optional[Dict[str, str]]:
return self._state.attrib

def pop_attributes(self) -> Optional[Dict[str, str]]:
result, self._state.attrib = self._state.attrib, None

Expand Down Expand Up @@ -358,6 +395,9 @@ def find_element(

return searcher(self._state, tag, look_behind, step_forward)

def get_elements(self) -> Optional[List['XmlElement[NativeElement]']]:
return self._state.elements[self._state.next_element_idx:]


class SearchMode(str, Enum):
"""
Expand Down
45 changes: 44 additions & 1 deletion pydantic_xml/serializers/factories/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import abc
import typing
from typing import Any, Dict, Mapping, Optional, Set, Type
from typing import Any, Dict, List, Mapping, Optional, Set, Type

import pydantic as pd
import pydantic_core as pdc
from pydantic_core import core_schema as pcs

import pydantic_xml as pxml
Expand All @@ -25,6 +27,41 @@ def element_name(self) -> str: ...
@abc.abstractmethod
def nsmap(self) -> Optional[NsMap]: ...

@classmethod
def _check_extra(cls, error_title: str, element: XmlElementReader) -> None:
line_errors: List[pdc.InitErrorDetails] = []

if (text := element.get_text()) is not None:
if text := text.strip():
line_errors.append(
pdc.InitErrorDetails(
type='extra_forbidden',
loc=('<text>',),
input=text,
),
)
if extra_attrs := element.get_attributes():
for name, value in extra_attrs.items():
line_errors.append(
pdc.InitErrorDetails(
type='extra_forbidden',
loc=(f'<attr> {name}',),
input=value,
),
)
if extra_elements := element.get_elements():
for extra_element in extra_elements:
line_errors.append(
pdc.InitErrorDetails(
type='extra_forbidden',
loc=(f'<element> {extra_element.tag}',),
input=extra_element.get_text(),
),
)

if line_errors:
raise pd.ValidationError.from_exception_data(title=error_title, line_errors=line_errors)


class ModelSerializer(BaseModelSerializer):
@classmethod
Expand Down Expand Up @@ -157,6 +194,9 @@ def deserialize(
if (field_value := field_serializer.deserialize(element, context=context)) is not None
}

if self._model.model_config.get('extra', 'ignore') == 'forbid':
self._check_extra(self._model.__name__, element)

return self._model.model_validate(result, strict=False, context=context)


Expand Down Expand Up @@ -239,6 +279,9 @@ def deserialize(

result = self._root_serializer.deserialize(element, context=context)

if self._model.model_config.get('extra', 'ignore') == 'forbid':
self._check_extra(self._model.__name__, element)

return self._model.model_validate(result, strict=False, context=context)


Expand Down
49 changes: 49 additions & 0 deletions tests/test_misc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Dict, List, Optional, Tuple, Union
from unittest.mock import ANY

import pydantic as pd
import pytest
Expand Down Expand Up @@ -253,3 +254,51 @@ def validate_field(cls, v: str, info: pd.FieldValidationInfo):
'''

TestModel.from_xml(xml, validation_context)


@pytest.mark.parametrize('search_mode', ['strict', 'ordered', 'unordered'])
def test_extra_forbid(search_mode: str):
class Model(BaseXmlModel, tag='model', extra='forbid', search_mode=search_mode):
attr1: str = attr()
field1: str = element()
field2: str = wrapped('wrapper', element())

xml = '''
<model attr1="attr value 1" attr2="attr value 2">text value
<field1>field value 1</field1>
<wrapper>
<field2>field value 2</field2>
</wrapper>
<field3>field value 3</field3>
</model>
'''

with pytest.raises(pd.ValidationError) as exc:
Model.from_xml(xml)

err = exc.value
assert err.title == 'Model'
assert err.error_count() == 3
assert err.errors() == [
{
'input': 'text value',
'loc': ('<text>',),
'msg': 'Extra inputs are not permitted',
'type': 'extra_forbidden',
'url': ANY,
},
{
'input': 'attr value 2',
'loc': ('<attr> attr2',),
'msg': 'Extra inputs are not permitted',
'type': 'extra_forbidden',
'url': ANY,
},
{
'input': 'field value 3',
'loc': ('<element> field3',),
'msg': 'Extra inputs are not permitted',
'type': 'extra_forbidden',
'url': ANY,
},
]