Skip to content

Commit

Permalink
Merge pull request #193 from dapper91/dev
Browse files Browse the repository at this point in the history
- named tuple support added.
  • Loading branch information
dapper91 committed May 11, 2024
2 parents b8348a9 + e928d4b commit ce20508
Show file tree
Hide file tree
Showing 10 changed files with 278 additions and 3 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
Changelog
=========

2.11.0 (2024-05-11)
------------------

- named tuple support added. See https://github.com/dapper91/pydantic-xml/issues/172


2.10.0 (2024-05-09)
------------------

Expand Down
1 change: 1 addition & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ What is not supported?
______________________

- `dataclasses <https://docs.pydantic.dev/usage/dataclasses/>`_
- `callable discriminators <https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator>`_

Getting started
---------------
Expand Down
4 changes: 2 additions & 2 deletions pydantic_xml/serializers/factories/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from . import heterogeneous, homogeneous, is_instance, mapping, model, primitive, raw, tagged_union, tuple
from . import typed_mapping, union, wrapper
from . import call, heterogeneous, homogeneous, is_instance, mapping, model, named_tuple, primitive, raw, tagged_union
from . import tuple, typed_mapping, union, wrapper
16 changes: 16 additions & 0 deletions pydantic_xml/serializers/factories/call.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import inspect

from pydantic_core import core_schema as pcs

from pydantic_xml import errors
from pydantic_xml.serializers.factories import named_tuple
from pydantic_xml.serializers.serializer import Serializer


def from_core_schema(schema: pcs.CallSchema, ctx: Serializer.Context) -> Serializer:
func = schema['function']

if inspect.isclass(func) and issubclass(func, tuple):
return named_tuple.from_core_schema(schema, ctx)
else:
raise errors.ModelError("type call is not supported")
1 change: 1 addition & 0 deletions pydantic_xml/serializers/factories/heterogeneous.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def from_core_schema(schema: pcs.TupleSchema, ctx: Serializer.Context) -> Serial
SchemaTypeFamily.TYPED_MAPPING,
SchemaTypeFamily.UNION,
SchemaTypeFamily.IS_INSTANCE,
SchemaTypeFamily.CALL,
):
raise errors.ModelFieldError(
ctx.model_name, ctx.field_name, "collection item must be of primitive, model, mapping or union type",
Expand Down
2 changes: 2 additions & 0 deletions pydantic_xml/serializers/factories/homogeneous.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def from_core_schema(schema: HomogeneousCollectionTypeSchema, ctx: Serializer.Co
SchemaTypeFamily.TYPED_MAPPING,
SchemaTypeFamily.UNION,
SchemaTypeFamily.IS_INSTANCE,
SchemaTypeFamily.CALL,
SchemaTypeFamily.TUPLE,
):
raise errors.ModelFieldError(
Expand All @@ -113,6 +114,7 @@ def from_core_schema(schema: HomogeneousCollectionTypeSchema, ctx: Serializer.Co
SchemaTypeFamily.MODEL,
SchemaTypeFamily.UNION,
SchemaTypeFamily.TUPLE,
SchemaTypeFamily.CALL,
) and ctx.entity_location is None:
raise errors.ModelFieldError(ctx.model_name, ctx.field_name, "entity name is not provided")

Expand Down
74 changes: 74 additions & 0 deletions pydantic_xml/serializers/factories/named_tuple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import typing
from typing import Any, Dict, List, Optional, Tuple

from pydantic_core import core_schema as pcs

from pydantic_xml import errors
from pydantic_xml.element import XmlElementReader, XmlElementWriter
from pydantic_xml.serializers.factories import heterogeneous
from pydantic_xml.serializers.serializer import TYPE_FAMILY, SchemaTypeFamily, Serializer
from pydantic_xml.typedefs import EntityLocation, Location


class ElementSerializer(Serializer):
@classmethod
def from_core_schema(cls, schema: pcs.ArgumentsSchema, ctx: Serializer.Context) -> 'ElementSerializer':
model_name = ctx.model_name
computed = ctx.field_computed
inner_serializers: List[Serializer] = []
for argument_schema in schema['arguments_schema']:
param_schema = argument_schema['schema']
inner_serializers.append(Serializer.parse_core_schema(param_schema, ctx))

return cls(model_name, computed, tuple(inner_serializers))

def __init__(self, model_name: str, computed: bool, inner_serializers: Tuple[Serializer, ...]):
self._inner_serializer = heterogeneous.ElementSerializer(model_name, computed, inner_serializers)

def serialize(
self, element: XmlElementWriter, value: List[Any], encoded: List[Any], *, skip_empty: bool = False,
) -> Optional[XmlElementWriter]:
return self._inner_serializer.serialize(element, value, encoded, skip_empty=skip_empty)

def deserialize(
self,
element: Optional[XmlElementReader],
*,
context: Optional[Dict[str, Any]],
sourcemap: Dict[Location, int],
loc: Location,
) -> Optional[List[Any]]:
return self._inner_serializer.deserialize(element, context=context, sourcemap=sourcemap, loc=loc)


def from_core_schema(schema: pcs.CallSchema, ctx: Serializer.Context) -> Serializer:
arguments_schema = typing.cast(pcs.ArgumentsSchema, schema['arguments_schema'])
for argument_schema in arguments_schema['arguments_schema']:
param_schema = argument_schema['schema']
param_schema, ctx = Serializer.preprocess_schema(param_schema, ctx)

param_type_family = TYPE_FAMILY.get(param_schema['type'])
if param_type_family not in (
SchemaTypeFamily.PRIMITIVE,
SchemaTypeFamily.MODEL,
SchemaTypeFamily.MAPPING,
SchemaTypeFamily.TYPED_MAPPING,
SchemaTypeFamily.UNION,
SchemaTypeFamily.IS_INSTANCE,
SchemaTypeFamily.CALL,
):
raise errors.ModelFieldError(
ctx.model_name, ctx.field_name, "tuple item must be of primitive, model, mapping or union type",
)

if param_type_family not in (SchemaTypeFamily.MODEL, SchemaTypeFamily.UNION) and ctx.entity_location is None:
raise errors.ModelFieldError(ctx.model_name, ctx.field_name, "entity name is not provided")

if ctx.entity_location is EntityLocation.ELEMENT:
return ElementSerializer.from_core_schema(arguments_schema, ctx)
elif ctx.entity_location is None:
return ElementSerializer.from_core_schema(arguments_schema, ctx)
elif ctx.entity_location is EntityLocation.ATTRIBUTE:
raise errors.ModelFieldError(ctx.model_name, ctx.field_name, "attributes of tuple types are not supported")
else:
raise AssertionError("unreachable")
7 changes: 7 additions & 0 deletions pydantic_xml/serializers/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class SchemaTypeFamily(IntEnum):
DEFINITION_REF = 10
JSON_OR_PYTHON = 11
IS_INSTANCE = 12
CALL = 13


TYPE_FAMILY = {
Expand Down Expand Up @@ -87,6 +88,8 @@ class SchemaTypeFamily(IntEnum):
'definition-ref': SchemaTypeFamily.DEFINITION_REF,

'json-or-python': SchemaTypeFamily.JSON_OR_PYTHON,

'call': SchemaTypeFamily.CALL,
}


Expand Down Expand Up @@ -265,6 +268,10 @@ def select_serializer(cls, schema: pcs.CoreSchema, ctx: Context) -> 'Serializer'
schema = typing.cast(pcs.IsInstanceSchema, schema)
return factories.is_instance.from_core_schema(schema, ctx)

elif type_family is SchemaTypeFamily.CALL:
schema = typing.cast(pcs.CallSchema, schema)
return factories.call.from_core_schema(schema, ctx)

else:
raise AssertionError("unreachable")

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pydantic-xml"
version = "2.10.0"
version = "2.11.0"
description = "pydantic xml extension"
authors = ["Dmitry Pershin <dapper1291@gmail.com>"]
license = "Unlicense"
Expand Down
168 changes: 168 additions & 0 deletions tests/test_named_tuple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
from typing import List, NamedTuple, Optional, Union

from helpers import assert_xml_equal

from pydantic_xml import BaseXmlModel, RootXmlModel, attr, element


def test_named_tuple_of_primitives_extraction():
class TestTuple(NamedTuple):
field1: int
field2: float
field3: str
field4: Optional[str]

class TestModel(BaseXmlModel, tag='model1'):
elements: TestTuple = element(tag='element')

xml = '''
<model1>
<element>1</element>
<element>2.2</element>
<element>string3</element>
</model1>
'''

actual_obj = TestModel.from_xml(xml)
expected_obj = TestModel(elements=(1, 2.2, "string3", None))

assert actual_obj == expected_obj

actual_xml = actual_obj.to_xml(skip_empty=True)
assert_xml_equal(actual_xml, xml)


def test_named_tuple_of_mixed_types_extraction():
class TestSubModel1(BaseXmlModel):
attr1: int = attr()
element1: float = element()

class TestTuple(NamedTuple):
field1: TestSubModel1
field2: int

class TestModel(BaseXmlModel, tag='model1'):
submodels: TestTuple = element(tag='submodel')

xml = '''
<model1>
<submodel attr1="1">
<element1>2.2</element1>
</submodel>
<submodel>1</submodel>
</model1>
'''

actual_obj = TestModel.from_xml(xml)
expected_obj = TestModel(
submodels=[
TestSubModel1(attr1=1, element1=2.2),
1,
],
)

assert actual_obj == expected_obj

actual_xml = actual_obj.to_xml()
assert_xml_equal(actual_xml, xml)


def test_list_of_named_tuples_extraction():
class TestTuple(NamedTuple):
field1: int
field2: Optional[float] = None

class RootModel(BaseXmlModel, tag='model'):
elements: List[TestTuple] = element(tag='element')

xml = '''
<model>
<element>1</element>
<element>1.1</element>
<element>2</element>
<element></element>
<element>3</element>
<element>3.3</element>
</model>
'''

actual_obj = RootModel.from_xml(xml)
expected_obj = RootModel(
elements=[
(1, 1.1),
(2, None),
(3, 3.3),
],
)

assert actual_obj == expected_obj

actual_xml = actual_obj.to_xml()
assert_xml_equal(actual_xml, xml)


def test_list_of_named_tuples_of_models_extraction():
class SubModel1(RootXmlModel[str], tag='text'):
pass

class SubModel2(RootXmlModel[int], tag='number'):
pass

class TestTuple(NamedTuple):
field1: SubModel1
field2: Optional[SubModel2] = None

class RootModel(BaseXmlModel, tag='model'):
elements: List[TestTuple]

xml = '''
<model>
<text>text1</text>
<number>1</number>
<text>text2</text>
<text>text3</text>
<number>3</number>
</model>
'''

actual_obj = RootModel.from_xml(xml)
expected_obj = RootModel(
elements=[
(SubModel1('text1'), SubModel2(1)),
(SubModel1('text2'), None),
(SubModel1('text3'), SubModel2(3)),
],
)

assert actual_obj == expected_obj

actual_xml = actual_obj.to_xml()
assert_xml_equal(actual_xml, xml)


def test_primitive_union_named_tuple():
class TestTuple(NamedTuple):
field1: Union[int, float]
field2: str
field3: Union[int, float]

class TestModel(BaseXmlModel, tag='model'):
sublements: TestTuple = element(tag='model1')

xml = '''
<model>
<model1>1.1</model1>
<model1>text</model1>
<model1>1</model1>
</model>
'''

actual_obj = TestModel.from_xml(xml)
expected_obj = TestModel(
sublements=(float('1.1'), 'text', 1),
)

assert actual_obj == expected_obj

actual_xml = actual_obj.to_xml()
assert_xml_equal(actual_xml, xml)

0 comments on commit ce20508

Please sign in to comment.