From 1b0a6d193f82dc5b707e5a332927e45c3e446fe9 Mon Sep 17 00:00:00 2001 From: Dmitry Pershin Date: Sun, 3 Dec 2023 23:02:17 +0500 Subject: [PATCH] nillable element support added. --- docs/source/pages/misc.rst | 23 +++++--- examples/snippets/serialization-nillable.py | 20 +++++++ pydantic_xml/element/__init__.py | 1 + pydantic_xml/element/utils.py | 14 +++++ pydantic_xml/model.py | 24 +++++++-- pydantic_xml/serializers/factories/model.py | 18 +++++-- .../serializers/factories/primitive.py | 29 ++++++++--- pydantic_xml/serializers/serializer.py | 5 ++ tests/test_computed_fields.py | 46 ++++++++++++++++ tests/test_heterogeneous_collections.py | 21 ++++++++ tests/test_primitives.py | 52 ++++++++++++++++++- tests/test_submodels.py | 38 ++++++++++++++ 12 files changed, 269 insertions(+), 22 deletions(-) create mode 100644 examples/snippets/serialization-nillable.py create mode 100644 pydantic_xml/element/utils.py diff --git a/docs/source/pages/misc.rst b/docs/source/pages/misc.rst index d94882f..7251d7f 100644 --- a/docs/source/pages/misc.rst +++ b/docs/source/pages/misc.rst @@ -42,30 +42,39 @@ The following example illustrate how to encode :py:class:`bytes` typed fields as :language: xml -None type encoding -__________________ +Optional type encoding +~~~~~~~~~~~~~~~~~~~~~~ Since xml format doesn't support ``null`` type natively it is not obvious how to encode ``None`` fields -(ignore it, encode it as an empty string or mark it as ``xsi:nil``). The library encodes ``None`` typed fields -as empty strings by default but you can define your own encoding format: +(ignore it, encode it as an empty string or mark it as ``xsi:nil``). +The library encodes ``None`` values as empty strings by default. +There are some alternative ways: + +- Define your own encoding format for ``None`` values: .. literalinclude:: ../../../examples/snippets/py3.9/serialization.py :language: python -or drop ``None`` fields at all: +- Mark an empty elements as `nillable `_: + +.. literalinclude:: ../../../examples/snippets/serialization-nillable.py + :language: python + + +- Drop empty elements: .. code-block:: python from typing import Optional from pydantic_xml import BaseXmlModel, element - class Company(BaseXmlModel): + class Company(BaseXmlModel, skip_empty=True): title: Optional[str] = element(default=None) company = Company() - assert company.to_xml(skip_empty=True) == b'' + assert company.to_xml() == b'' Empty entities exclusion diff --git a/examples/snippets/serialization-nillable.py b/examples/snippets/serialization-nillable.py new file mode 100644 index 0000000..1356b26 --- /dev/null +++ b/examples/snippets/serialization-nillable.py @@ -0,0 +1,20 @@ +from typing import Optional +from xml.etree.ElementTree import canonicalize + +from pydantic_xml import BaseXmlModel, element + + +class Company(BaseXmlModel): + title: Optional[str] = element(default=None, nillable=True) + + +xml_doc = ''' + + +</Company> +''' + +company = Company.from_xml(xml_doc) + +assert company.title is None +assert canonicalize(company.to_xml(), strip_text=True) == canonicalize(xml_doc, strip_text=True) diff --git a/pydantic_xml/element/__init__.py b/pydantic_xml/element/__init__.py index a700c42..6c8bd7d 100644 --- a/pydantic_xml/element/__init__.py +++ b/pydantic_xml/element/__init__.py @@ -1 +1,2 @@ from .element import SearchMode, XmlElement, XmlElementReader, XmlElementWriter +from .utils import is_element_nill, make_element_nill diff --git a/pydantic_xml/element/utils.py b/pydantic_xml/element/utils.py new file mode 100644 index 0000000..7d49827 --- /dev/null +++ b/pydantic_xml/element/utils.py @@ -0,0 +1,14 @@ +from .element import XmlElementReader, XmlElementWriter + +XSI_NS = 'http://www.w3.org/2001/XMLSchema-instance' + + +def is_element_nill(element: XmlElementReader) -> bool: + if (is_nil := element.pop_attrib('{%s}nil' % XSI_NS)) and is_nil == 'true': + return True + else: + return False + + +def make_element_nill(element: XmlElementWriter) -> None: + element.set_attribute('{%s}nil' % XSI_NS, 'true') diff --git a/pydantic_xml/model.py b/pydantic_xml/model.py index fd9408e..4bd3fb6 100644 --- a/pydantic_xml/model.py +++ b/pydantic_xml/model.py @@ -32,12 +32,13 @@ class ComputedXmlEntityInfo(pd.fields.ComputedFieldInfo): Computed field xml meta-information. """ - __slots__ = ('location', 'path', 'ns', 'nsmap', 'wrapped') + __slots__ = ('location', 'path', 'ns', 'nsmap', 'nillable', 'wrapped') location: Optional[EntityLocation] path: Optional[str] ns: Optional[str] nsmap: Optional[NsMap] + nillable: bool wrapped: Optional[XmlEntityInfoP] # to be compliant with XmlEntityInfoP protocol def __post_init__(self) -> None: @@ -57,6 +58,7 @@ def decorator(prop: Any) -> Any: path = kwargs.pop('path', None) ns = kwargs.pop('ns', None) nsmap = kwargs.pop('nsmap', None) + nillable = kwargs.pop('nillable', False) descriptor_proxy = pd.computed_field(**kwargs)(prop) descriptor_proxy.decorator_info = ComputedXmlEntityInfo( @@ -64,6 +66,7 @@ def decorator(prop: Any) -> Any: path=path, ns=ns, nsmap=nsmap, + nillable=nillable, wrapped=None, **dc.asdict(descriptor_proxy.decorator_info), ) @@ -101,6 +104,7 @@ def computed_element( tag: Optional[str] = None, ns: Optional[str] = None, nsmap: Optional[NsMap] = None, + nillable: bool = False, **kwargs: Any, ) -> Union[PropertyT, Callable[[PropertyT], PropertyT]]: """ @@ -110,10 +114,11 @@ def computed_element( :param tag: element tag :param ns: element xml namespace :param nsmap: element xml namespace map + :param nillable: is element nillable. See https://www.w3.org/TR/xmlschema-1/#xsi_nil. :param kwargs: pydantic computed field arguments. See :py:class:`pydantic.computed_field` """ - return computed_entity(EntityLocation.ELEMENT, prop, path=tag, ns=ns, nsmap=nsmap, **kwargs) + return computed_entity(EntityLocation.ELEMENT, prop, path=tag, ns=ns, nsmap=nsmap, nillable=nillable, **kwargs) class XmlEntityInfo(pd.fields.FieldInfo): @@ -121,7 +126,7 @@ class XmlEntityInfo(pd.fields.FieldInfo): Field xml meta-information. """ - __slots__ = ('location', 'path', 'ns', 'nsmap', 'wrapped') + __slots__ = ('location', 'path', 'ns', 'nsmap', 'nillable', 'wrapped') def __init__( self, @@ -130,6 +135,7 @@ def __init__( path: Optional[str] = None, ns: Optional[str] = None, nsmap: Optional[NsMap] = None, + nillable: bool = False, wrapped: Optional[pd.fields.FieldInfo] = None, **kwargs: Any, ): @@ -149,6 +155,7 @@ def __init__( self.path = path self.ns = ns self.nsmap = nsmap + self.nillable = nillable self.wrapped: Optional[XmlEntityInfoP] = wrapped if isinstance(wrapped, XmlEntityInfo) else None if config.REGISTER_NS_PREFIXES and nsmap: @@ -167,17 +174,24 @@ def attr(name: Optional[str] = None, ns: Optional[str] = None, **kwargs: Any) -> return XmlEntityInfo(EntityLocation.ATTRIBUTE, path=name, ns=ns, **kwargs) -def element(tag: Optional[str] = None, ns: Optional[str] = None, nsmap: Optional[NsMap] = None, **kwargs: Any) -> Any: +def element( + tag: Optional[str] = None, + ns: Optional[str] = None, + nsmap: Optional[NsMap] = None, + nillable: bool = False, + **kwargs: Any, +) -> Any: """ Marks a pydantic field as an xml element. :param tag: element tag :param ns: element xml namespace :param nsmap: element xml namespace map + :param nillable: is element nillable. See https://www.w3.org/TR/xmlschema-1/#xsi_nil. :param kwargs: pydantic field arguments. See :py:class:`pydantic.Field` """ - return XmlEntityInfo(EntityLocation.ELEMENT, path=tag, ns=ns, nsmap=nsmap, **kwargs) + return XmlEntityInfo(EntityLocation.ELEMENT, path=tag, ns=ns, nsmap=nsmap, nillable=nillable, **kwargs) def wrapped( diff --git a/pydantic_xml/serializers/factories/model.py b/pydantic_xml/serializers/factories/model.py index a2383f2..fc8d1e9 100644 --- a/pydantic_xml/serializers/factories/model.py +++ b/pydantic_xml/serializers/factories/model.py @@ -8,7 +8,7 @@ import pydantic_xml as pxml from pydantic_xml import errors -from pydantic_xml.element import XmlElementReader, XmlElementWriter +from pydantic_xml.element import XmlElementReader, XmlElementWriter, is_element_nill, make_element_nill from pydantic_xml.serializers.serializer import SearchMode, Serializer, XmlEntityInfoP from pydantic_xml.typedefs import EntityLocation, NsMap from pydantic_xml.utils import QName, merge_nsmaps, select_ns @@ -287,8 +287,9 @@ def from_core_schema(cls, schema: pcs.ModelSchema, ctx: Serializer.Context) -> ' nsmap = merge_nsmaps(ctx.entity_nsmap, model_cls.__xml_nsmap__, ctx.parent_nsmap) search_mode = ctx.search_mode computed = ctx.field_computed + nillable = ctx.nillable - return cls(model_cls, name, ns, nsmap, search_mode, computed) + return cls(model_cls, name, ns, nsmap, search_mode, computed, nillable) def __init__( self, @@ -298,12 +299,14 @@ def __init__( nsmap: Optional[NsMap], search_mode: SearchMode, computed: bool, + nillable: bool, ): self._model = model self._element_name = QName.from_alias(tag=name, ns=ns, nsmap=nsmap).uri self._nsmap = nsmap self._search_mode = search_mode self._computed = computed + self._nillable = nillable @property def model(self) -> Type['pxml.BaseXmlModel']: @@ -331,6 +334,12 @@ def serialize( ) -> Optional[XmlElementWriter]: assert self._model.__xml_serializer__ is not None, f"model {self._model.__name__} is partially initialized" + if self._nillable and value is None: + sub_element = element.make_element(self._element_name, nsmap=self._nsmap) + make_element_nill(sub_element) + element.append_element(sub_element) + return sub_element + if value is None: return None @@ -355,7 +364,10 @@ def deserialize( if element is not None and \ (sub_element := element.pop_element(self._element_name, self._search_mode)) is not None: - return self._model.__xml_serializer__.deserialize(sub_element, context=context) + if is_element_nill(sub_element): + return None + else: + return self._model.__xml_serializer__.deserialize(sub_element, context=context) else: return None diff --git a/pydantic_xml/serializers/factories/primitive.py b/pydantic_xml/serializers/factories/primitive.py index d137e3f..addfcff 100644 --- a/pydantic_xml/serializers/factories/primitive.py +++ b/pydantic_xml/serializers/factories/primitive.py @@ -3,7 +3,7 @@ from pydantic_core import core_schema as pcs from pydantic_xml import errors -from pydantic_xml.element import XmlElementReader, XmlElementWriter +from pydantic_xml.element import XmlElementReader, XmlElementWriter, is_element_nill, make_element_nill from pydantic_xml.serializers.serializer import SearchMode, Serializer, encode_primitive from pydantic_xml.typedefs import EntityLocation, NsMap from pydantic_xml.utils import QName, merge_nsmaps, select_ns @@ -32,11 +32,13 @@ class TextSerializer(Serializer): @classmethod def from_core_schema(cls, schema: PrimitiveTypeSchema, ctx: Serializer.Context) -> 'TextSerializer': computed = ctx.field_computed + nillable = ctx.nillable - return cls(computed) + return cls(computed, nillable) - def __init__(self, computed: bool): + def __init__(self, computed: bool, nillable: bool): self._computed = computed + self._nillable = nillable def serialize( self, element: XmlElementWriter, value: Any, encoded: Any, *, skip_empty: bool = False, @@ -44,6 +46,9 @@ def serialize( if value is None and skip_empty: return element + if self._nillable and value is None: + make_element_nill(element) + element.set_text(encode_primitive(encoded)) return element @@ -59,6 +64,9 @@ def deserialize( if element is None: return None + if self._nillable and is_element_nill(element): + return None + return element.pop_text() or None @@ -123,14 +131,23 @@ def from_core_schema(cls, schema: PrimitiveTypeSchema, ctx: Serializer.Context) nsmap = merge_nsmaps(ctx.entity_nsmap, ctx.parent_nsmap) search_mode = ctx.search_mode computed = ctx.field_computed + nillable = ctx.nillable if name is None: raise errors.ModelFieldError(ctx.model_name, ctx.field_name, "entity name is not provided") - return cls(name, ns, nsmap, search_mode, computed) + return cls(name, ns, nsmap, search_mode, computed, nillable) - def __init__(self, name: str, ns: Optional[str], nsmap: Optional[NsMap], search_mode: SearchMode, computed: bool): - super().__init__(computed) + def __init__( + self, + name: str, + ns: Optional[str], + nsmap: Optional[NsMap], + search_mode: SearchMode, + computed: bool, + nillable: bool, + ): + super().__init__(computed, nillable) self._nsmap = nsmap self._search_mode = search_mode diff --git a/pydantic_xml/serializers/serializer.py b/pydantic_xml/serializers/serializer.py index 0b33517..7932651 100644 --- a/pydantic_xml/serializers/serializer.py +++ b/pydantic_xml/serializers/serializer.py @@ -96,6 +96,7 @@ class XmlEntityInfoP(typing.Protocol): path: Optional[str] ns: Optional[str] nsmap: Optional[NsMap] + nillable: bool wrapped: Optional['XmlEntityInfoP'] @@ -137,6 +138,10 @@ def entity_ns(self) -> Optional[str]: def entity_nsmap(self) -> Optional[NsMap]: return self.entity_info.nsmap if self.entity_info is not None else None + @property + def nillable(self) -> bool: + return self.entity_info.nillable if self.entity_info is not None else False + @property def entity_wrapped(self) -> Optional['XmlEntityInfoP']: return self.entity_info.wrapped if self.entity_info is not None else None diff --git a/tests/test_computed_fields.py b/tests/test_computed_fields.py index ea3d34a..6332fd9 100644 --- a/tests/test_computed_fields.py +++ b/tests/test_computed_fields.py @@ -1,3 +1,5 @@ +from typing import Optional + from helpers import assert_xml_equal from pydantic import computed_field @@ -63,6 +65,29 @@ def element2(self) -> str: assert_xml_equal(actual_xml, xml) +def test_computed_nillable_elements(): + class TestModel(BaseXmlModel, tag='model'): + @computed_element(tag='element1', nillable=True) + def computed_element1(self) -> Optional[int]: + return None + + @computed_element(tag='element2', nillable=True) + def computed_element2(self) -> Optional[int]: + return 2 + + xml = ''' + <model> + <element1 xsi:nil="true" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" /> + <element2>2</element2> + </model> + ''' + + actual_obj = TestModel.from_xml(xml) + + actual_xml = actual_obj.to_xml() + assert_xml_equal(actual_xml, xml) + + def test_computed_submodel(): class TestSumModel(BaseXmlModel): text: str @@ -87,3 +112,24 @@ def submodel2(self) -> TestSumModel: actual_xml = actual_obj.to_xml() assert_xml_equal(actual_xml, xml) + + +def test_computed_nillable_submodel(): + class TestSumModel(BaseXmlModel): + text: str + + class TestModel(BaseXmlModel, tag='model'): + @computed_element(tag='submodel1', nillable=True) + def submodel1(self) -> Optional[TestSumModel]: + return None + + xml = ''' + <model> + <submodel1 xsi:nil="true" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" /> + </model> + ''' + + actual_obj = TestModel.from_xml(xml) + + actual_xml = actual_obj.to_xml() + assert_xml_equal(actual_xml, xml) diff --git a/tests/test_heterogeneous_collections.py b/tests/test_heterogeneous_collections.py index e7d4fc8..9ee7571 100644 --- a/tests/test_heterogeneous_collections.py +++ b/tests/test_heterogeneous_collections.py @@ -27,6 +27,27 @@ class TestModel(BaseXmlModel, tag='model1'): assert_xml_equal(actual_xml, xml) +def test_list_of_nillable_primitives_extraction(): + class TestModel(BaseXmlModel, tag='model1'): + elements: Tuple[Optional[int], Optional[float], Optional[str]] = element(tag='element', nillable=True) + + xml = ''' + <model1> + <element>1</element> + <element xsi:nil="true" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" /> + <element>string3</element> + </model1> + ''' + + actual_obj = TestModel.from_xml(xml) + expected_obj = TestModel(elements=(1, None, "string3")) + + assert actual_obj == expected_obj + + actual_xml = actual_obj.to_xml() + assert_xml_equal(actual_xml, xml) + + def test_tuple_of_submodel_extraction(): class TestSubModel1(BaseXmlModel): attr1: int = attr() diff --git a/tests/test_primitives.py b/tests/test_primitives.py index b67293b..97c12d5 100644 --- a/tests/test_primitives.py +++ b/tests/test_primitives.py @@ -1,5 +1,5 @@ import datetime as dt -from typing import Generic, TypeVar +from typing import Generic, Optional, TypeVar from helpers import assert_xml_equal @@ -54,6 +54,37 @@ class TestModel(BaseXmlModel, tag='model'): assert_xml_equal(actual_xml, xml) +def test_nillable_element_extraction(): + class TestModel(BaseXmlModel, tag='model'): + element1: Optional[int] = element(default=None, nillable=True) + element2: Optional[int] = element(default=None, nillable=True) + element3: Optional[int] = element(default=None, nillable=True) + + src_xml = ''' + <model> + <element1 xsi:nil="true" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" /> + <element2 xsi:nil="false" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance">2</element2> + <element3>3</element3> + </model> + ''' + + actual_obj = TestModel.from_xml(src_xml) + expected_obj = TestModel(element1=None, element2=2, element3=3) + + assert actual_obj == expected_obj + + actual_xml = actual_obj.to_xml() + expected_xml = ''' + <model> + <element1 xsi:nil="true" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" /> + <element2>2</element2> + <element3>3</element3> + </model> + ''' + + assert_xml_equal(actual_xml, expected_xml) + + def test_model_inheritance(): class TestModel4(BaseXmlModel): attr1: int = attr() @@ -195,6 +226,25 @@ class TestModel(RootXmlModel, tag='model'): assert_xml_equal(actual_xml, xml) +def test_root_model_nillable_element_extraction(): + class TestModel(RootXmlModel, tag='model'): + root: Optional[int] = element(tag="element1", default=None, nillable=True) + + xml = ''' + <model> + <element1 xsi:nil="true" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" /> + </model> + ''' + + actual_obj = TestModel.from_xml(xml) + expected_obj = TestModel() + + assert actual_obj == expected_obj + + actual_xml = actual_obj.to_xml() + assert_xml_equal(actual_xml, xml) + + def test_root_model_default(): class TestRootModel(RootXmlModel, tag='sub'): root: int = 1 diff --git a/tests/test_submodels.py b/tests/test_submodels.py index 02ac3fc..4d1d90a 100644 --- a/tests/test_submodels.py +++ b/tests/test_submodels.py @@ -54,6 +54,44 @@ class TestModel(BaseXmlModel, tag='model1'): assert_xml_equal(actual_xml, xml) +def test_nillable_submodel_element_extraction(): + class TestSubModel(BaseXmlModel): + text: int + + class TestModel(BaseXmlModel, tag='model1'): + model2: Optional[TestSubModel] = element(default=None, nillable=True) + model3: Optional[TestSubModel] = element(default=None, nillable=True) + model4: Optional[TestSubModel] = element(default=None, nillable=True) + + xml = ''' + <model1> + <model2 xsi:nil="true" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" /> + <model3 xsi:nil="false" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance">3</model3> + <model4>4</model4> + </model1> + ''' + + actual_obj = TestModel.from_xml(xml) + expected_obj = TestModel( + model2=None, + model3=TestSubModel(text=3), + model4=TestSubModel(text=4), + ) + + assert actual_obj == expected_obj + + actual_xml = actual_obj.to_xml() + expected_xml = ''' + <model1> + <model2 xsi:nil="true" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" /> + <model3>3</model3> + <model4>4</model4> + </model1> + ''' + + assert_xml_equal(actual_xml, expected_xml) + + def test_root_submodel_element_extraction(): class TestSubModel(RootXmlModel, tag='model2'): root: int