diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 31d72a7..9c4744a 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,13 @@ Changelog ========= + +2.4.0 (2023-11-06) +------------------ + +- attributes with default namespace bug fixed. See https://github.com/dapper91/pydantic-xml/issues/137. + + 2.3.0 (2023-10-22) ------------------ diff --git a/pydantic_xml/serializers/factories/mapping.py b/pydantic_xml/serializers/factories/mapping.py index a26c1a7..d42a97d 100644 --- a/pydantic_xml/serializers/factories/mapping.py +++ b/pydantic_xml/serializers/factories/mapping.py @@ -6,13 +6,13 @@ from pydantic_xml.element import XmlElementReader, XmlElementWriter from pydantic_xml.serializers.serializer import TYPE_FAMILY, SchemaTypeFamily, SearchMode, Serializer from pydantic_xml.typedefs import EntityLocation, NsMap -from pydantic_xml.utils import QName, merge_nsmaps +from pydantic_xml.utils import QName, merge_nsmaps, select_ns class AttributesSerializer(Serializer): @classmethod def from_core_schema(cls, schema: pcs.CoreSchema, ctx: Serializer.Context) -> 'AttributesSerializer': - ns = ctx.entity_ns or ctx.parent_ns + ns = select_ns(ctx.entity_ns, ctx.parent_ns) nsmap = merge_nsmaps(ctx.entity_nsmap, ctx.parent_nsmap) namespaced_attrs = ctx.namespaced_attrs computed = ctx.field_computed @@ -66,7 +66,7 @@ class ElementSerializer(AttributesSerializer): @classmethod def from_core_schema(cls, schema: pcs.CoreSchema, ctx: Serializer.Context) -> 'ElementSerializer': name = ctx.entity_path or ctx.field_alias or ctx.field_name - ns = ctx.entity_ns or ctx.parent_ns + ns = select_ns(ctx.entity_ns, ctx.parent_ns) nsmap = merge_nsmaps(ctx.entity_nsmap, ctx.parent_nsmap) namespaced_attrs = ctx.namespaced_attrs search_mode = ctx.search_mode diff --git a/pydantic_xml/serializers/factories/model.py b/pydantic_xml/serializers/factories/model.py index ef95d53..a2383f2 100644 --- a/pydantic_xml/serializers/factories/model.py +++ b/pydantic_xml/serializers/factories/model.py @@ -11,7 +11,7 @@ from pydantic_xml.element import XmlElementReader, XmlElementWriter from pydantic_xml.serializers.serializer import SearchMode, Serializer, XmlEntityInfoP from pydantic_xml.typedefs import EntityLocation, NsMap -from pydantic_xml.utils import QName, merge_nsmaps +from pydantic_xml.utils import QName, merge_nsmaps, select_ns class BaseModelSerializer(Serializer, abc.ABC): @@ -283,7 +283,7 @@ def from_core_schema(cls, schema: pcs.ModelSchema, ctx: Serializer.Context) -> ' assert issubclass(model_cls, pxml.BaseXmlModel), "unexpected model type" name = ctx.entity_path or model_cls.__xml_tag__ or ctx.field_alias or ctx.field_name or model_cls.__name__ - ns = ctx.entity_ns or model_cls.__xml_ns__ or ctx.parent_ns + ns = select_ns(ctx.entity_ns, model_cls.__xml_ns__, ctx.parent_ns) nsmap = merge_nsmaps(ctx.entity_nsmap, model_cls.__xml_nsmap__, ctx.parent_nsmap) search_mode = ctx.search_mode computed = ctx.field_computed diff --git a/pydantic_xml/serializers/factories/primitive.py b/pydantic_xml/serializers/factories/primitive.py index 43e1af5..d137e3f 100644 --- a/pydantic_xml/serializers/factories/primitive.py +++ b/pydantic_xml/serializers/factories/primitive.py @@ -6,7 +6,7 @@ from pydantic_xml.element import XmlElementReader, XmlElementWriter from pydantic_xml.serializers.serializer import SearchMode, Serializer, encode_primitive from pydantic_xml.typedefs import EntityLocation, NsMap -from pydantic_xml.utils import QName, merge_nsmaps +from pydantic_xml.utils import QName, merge_nsmaps, select_ns PrimitiveTypeSchema = Union[ pcs.NoneSchema, @@ -67,10 +67,16 @@ class AttributeSerializer(Serializer): def from_core_schema(cls, schema: PrimitiveTypeSchema, ctx: Serializer.Context) -> 'AttributeSerializer': namespaced_attrs = ctx.namespaced_attrs name = ctx.entity_path or ctx.field_alias or ctx.field_name - ns = ctx.entity_ns or (ctx.parent_ns if namespaced_attrs else None) + ns = select_ns(ctx.entity_ns, ctx.parent_ns if namespaced_attrs else None) nsmap = merge_nsmaps(ctx.entity_nsmap, ctx.parent_nsmap) computed = ctx.field_computed + if ns == '': + raise errors.ModelFieldError( + ctx.model_name, + ctx.field_name, + "attributes with default namespace are forbidden", + ) if name is None: raise errors.ModelFieldError(ctx.model_name, ctx.field_name, "entity name is not provided") @@ -113,7 +119,7 @@ class ElementSerializer(TextSerializer): @classmethod def from_core_schema(cls, schema: PrimitiveTypeSchema, ctx: Serializer.Context) -> 'ElementSerializer': name = ctx.entity_path or ctx.field_alias or ctx.field_name - ns = ctx.entity_ns or ctx.parent_ns + ns = select_ns(ctx.entity_ns, ctx.parent_ns) nsmap = merge_nsmaps(ctx.entity_nsmap, ctx.parent_nsmap) search_mode = ctx.search_mode computed = ctx.field_computed diff --git a/pydantic_xml/serializers/factories/raw.py b/pydantic_xml/serializers/factories/raw.py index acd1d97..37a86f5 100644 --- a/pydantic_xml/serializers/factories/raw.py +++ b/pydantic_xml/serializers/factories/raw.py @@ -6,14 +6,14 @@ from pydantic_xml.element import XmlElementReader, XmlElementWriter from pydantic_xml.serializers.serializer import SearchMode, Serializer from pydantic_xml.typedefs import EntityLocation, NsMap -from pydantic_xml.utils import QName, merge_nsmaps +from pydantic_xml.utils import QName, merge_nsmaps, select_ns class ElementSerializer(Serializer): @classmethod def from_core_schema(cls, schema: pcs.IsInstanceSchema, ctx: Serializer.Context) -> 'ElementSerializer': name = ctx.entity_path or ctx.field_alias or ctx.field_name - ns = ctx.entity_ns or ctx.parent_ns + ns = select_ns(ctx.entity_ns, ctx.parent_ns) nsmap = merge_nsmaps(ctx.entity_nsmap, ctx.parent_nsmap) search_mode = ctx.search_mode computed = ctx.field_computed diff --git a/pydantic_xml/serializers/factories/wrapper.py b/pydantic_xml/serializers/factories/wrapper.py index e0628d7..217e76b 100644 --- a/pydantic_xml/serializers/factories/wrapper.py +++ b/pydantic_xml/serializers/factories/wrapper.py @@ -5,14 +5,14 @@ from pydantic_xml.element import XmlElementReader, XmlElementWriter from pydantic_xml.serializers.serializer import SearchMode, Serializer from pydantic_xml.typedefs import NsMap -from pydantic_xml.utils import QName, merge_nsmaps +from pydantic_xml.utils import QName, merge_nsmaps, select_ns class ElementPathSerializer(Serializer): @classmethod def from_core_schema(cls, schema: pcs.CoreSchema, ctx: Serializer.Context) -> 'ElementPathSerializer': path = ctx.entity_path - ns = ctx.entity_ns or ctx.parent_ns + ns = select_ns(ctx.entity_ns, ctx.parent_ns) nsmap = merge_nsmaps(ctx.entity_nsmap, ctx.parent_nsmap) search_mode = ctx.search_mode computed = ctx.field_computed diff --git a/pydantic_xml/serializers/serializer.py b/pydantic_xml/serializers/serializer.py index 8989967..0b33517 100644 --- a/pydantic_xml/serializers/serializer.py +++ b/pydantic_xml/serializers/serializer.py @@ -11,6 +11,7 @@ from pydantic_xml.element import SearchMode, XmlElementReader, XmlElementWriter from pydantic_xml.errors import ModelError from pydantic_xml.typedefs import EntityLocation, NsMap +from pydantic_xml.utils import select_ns from . import factories @@ -143,7 +144,8 @@ def entity_wrapped(self) -> Optional['XmlEntityInfoP']: @cached_property def parent_ns(self) -> Optional[str]: if parent_ctx := self.parent_ctx: - return parent_ctx.entity_ns or parent_ctx.parent_ns + ns = select_ns(parent_ctx.entity_ns, parent_ctx.parent_ns) + return ns return None diff --git a/pydantic_xml/utils.py b/pydantic_xml/utils.py index 94a3abf..a001470 100644 --- a/pydantic_xml/utils.py +++ b/pydantic_xml/utils.py @@ -84,3 +84,11 @@ def register_nsmap(nsmap: NsMap) -> None: def get_slots(o: object) -> Iterable[str]: return it.chain.from_iterable(getattr(cls, '__slots__', []) for cls in o.__class__.__mro__) + + +def select_ns(*nss: Optional[str]) -> Optional[str]: + for ns in nss: + if ns is not None: + return ns + + return None diff --git a/pyproject.toml b/pyproject.toml index c531307..2ab26c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "pydantic-xml" -version = "2.3.0" +version = "2.4.0" description = "pydantic xml extension" authors = ["Dmitry Pershin "] license = "Unlicense" diff --git a/tests/test_namespaces.py b/tests/test_namespaces.py index 77d057a..d9d0b0b 100644 --- a/tests/test_namespaces.py +++ b/tests/test_namespaces.py @@ -42,7 +42,7 @@ class TestModel(BaseXmlModel, tag='model'): @pytest.mark.skipif(not is_lxml_native(), reason='not lxml used') def test_lxml_default_namespace_serialisation(): class TestSubModel(BaseXmlModel, tag='submodel', ns='', nsmap={'': 'http://test3.org', 'tst': 'http://test4.org'}): - attr1: int = attr(ns='') + attr1: int = attr() attr2: int = attr(ns='tst') element1: str = element(ns='') @@ -357,3 +357,31 @@ class TestModel(BaseTestModel, tag='model', ns='tst', nsmap={'tst': 'http://test actual_xml = actual_obj.to_xml() assert_xml_equal(actual_xml, xml1) + + +def test_submodel_namespaces_default_namespace_inheritance(): + class TestSubModel(BaseXmlModel, tag='submodel', ns='', nsmap={'': 'http://test2.org'}): + attr1: int = attr() + attr2: int = attr() + element1: str = element() + + class TestModel(BaseXmlModel, tag='model', ns='tst', nsmap={'tst': 'http://test1.org'}): + submodel: TestSubModel + + xml = ''' + + + value + + + ''' + + actual_obj = TestModel.from_xml(xml) + expected_obj = TestModel( + submodel=TestSubModel(element1='value', attr1=1, attr2=2), + ) + + assert actual_obj == expected_obj + + actual_xml = actual_obj.to_xml() + assert_xml_equal(actual_xml, xml)