Skip to content

Commit

Permalink
Merge pull request #61 from dapper91/forward-refs
Browse files Browse the repository at this point in the history
forward refs support added.
  • Loading branch information
dapper91 authored Jun 19, 2023
2 parents f6580e6 + 22808c9 commit b81c46c
Show file tree
Hide file tree
Showing 10 changed files with 534 additions and 32 deletions.
7 changes: 7 additions & 0 deletions pydantic_xml/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,13 @@ def __init_serializer__(cls) -> None:

cls.__xml_serializer__ = ModelSerializerFactory.build_root(cls)

@classmethod
def update_forward_refs(cls, **kwargs: Any) -> None:
super().update_forward_refs(**kwargs)

if cls.__xml_serializer__ is not None:
cls.__xml_serializer__.resolve_forward_refs()

@classmethod
def from_xml_tree(cls, root: etree.Element) -> Optional['BaseXmlModel']:
"""
Expand Down
1 change: 1 addition & 0 deletions pydantic_xml/serializers/factories/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .forwardref import ForwardRefSerializerFactory
from .heterogeneous import HeterogeneousSerializerFactory
from .homogeneous import HomogeneousSerializerFactory
from .mapping import MappingSerializerFactory
Expand Down
53 changes: 53 additions & 0 deletions pydantic_xml/serializers/factories/forwardref.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import Any, Optional, Type

import pydantic as pd

import pydantic_xml as pxml
from pydantic_xml.element import XmlElementReader, XmlElementWriter
from pydantic_xml.serializers.encoder import XmlEncoder
from pydantic_xml.serializers.serializer import Serializer


class ForwardRefSerializerFactory:
"""
Primitive type serializer factory.
"""

class ForwardRefSerializer(Serializer):
def __init__(
self,
model: Type['pxml.BaseXmlModel'],
model_field: pd.fields.ModelField,
ctx: Serializer.Context,
):
self._model = model
self._model_field = model_field
self._ctx = ctx

def serialize(
self, element: XmlElementWriter, value: Any, *, encoder: XmlEncoder, skip_empty: bool = False,
) -> Optional[XmlElementWriter]:
raise pxml.errors.ModelFieldError(
self._model.__name__,
self._model_field.name,
"field is not yet prepared so type is still a ForwardRef, you might need to call update_forward_refs()",
)

def deserialize(self, element: Optional[XmlElementReader]) -> Any:
raise pxml.errors.ModelFieldError(
self._model.__name__,
self._model_field.name,
"field is not yet prepared so type is still a ForwardRef, you might need to call update_forward_refs()",
)

def resolve_forward_refs(self) -> Serializer:
return self._build_field_serializer(self._model, self._model_field, self._ctx)

@classmethod
def build(
cls,
model: Type['pxml.BaseXmlModel'],
model_field: pd.fields.ModelField,
ctx: Serializer.Context,
) -> 'Serializer':
return cls.ForwardRefSerializer(model, model_field, ctx)
25 changes: 19 additions & 6 deletions pydantic_xml/serializers/factories/heterogeneous.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import dataclasses as dc
from copy import deepcopy
import typing
from typing import Any, List, Optional, Type

import pydantic as pd
Expand All @@ -8,7 +8,7 @@
from pydantic_xml import errors
from pydantic_xml.element import XmlElementReader, XmlElementWriter
from pydantic_xml.serializers.encoder import XmlEncoder
from pydantic_xml.serializers.serializer import Location, PydanticShapeType, Serializer
from pydantic_xml.serializers.serializer import Location, PydanticShapeType, Serializer, SubFieldWrapper
from pydantic_xml.utils import QName, merge_nsmaps


Expand All @@ -32,10 +32,15 @@ def __init__(

self._inner_serializers = []
for sub_field in model_field.sub_fields:
sub_field = deepcopy(sub_field)
sub_field.name = model_field.name
sub_field.alias = model_field.alias
sub_field.field_info = model_field.field_info
sub_field = typing.cast(
pd.fields.ModelField,
SubFieldWrapper(
model_field.name,
model_field.alias,
model_field.field_info,
sub_field,
),
)

self._inner_serializers.append(
self._build_field_serializer(
Expand All @@ -50,6 +55,14 @@ def __init__(
),
)

def resolve_forward_refs(self) -> 'Serializer':
self._inner_serializers = [
serializer.resolve_forward_refs()
for serializer in self._inner_serializers
]

return self

def serialize(
self, element: XmlElementWriter, value: List[Any], *, encoder: XmlEncoder, skip_empty: bool = False,
) -> Optional[XmlElementWriter]:
Expand Down
23 changes: 16 additions & 7 deletions pydantic_xml/serializers/factories/homogeneous.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import dataclasses as dc
from copy import deepcopy
import typing
from typing import Any, List, Optional, Type

import pydantic as pd
Expand All @@ -8,7 +8,7 @@
from pydantic_xml import errors
from pydantic_xml.element import XmlElementReader, XmlElementWriter
from pydantic_xml.serializers.encoder import XmlEncoder
from pydantic_xml.serializers.serializer import Location, PydanticShapeType, Serializer
from pydantic_xml.serializers.serializer import Location, PydanticShapeType, Serializer, SubFieldWrapper
from pydantic_xml.utils import QName, merge_nsmaps


Expand All @@ -29,11 +29,15 @@ def __init__(
nsmap = merge_nsmaps(nsmap, ctx.parent_nsmap)

self._element_name = QName.from_alias(tag=name, ns=ns, nsmap=nsmap).uri

item_field = deepcopy(model_field.sub_fields[0])
item_field.name = model_field.name
item_field.alias = model_field.alias
item_field.field_info = model_field.field_info
item_field = typing.cast(
pd.fields.ModelField,
SubFieldWrapper(
model_field.name,
model_field.alias,
model_field.field_info,
model_field.sub_fields[0],
),
)

self._inner_serializer = self._build_field_serializer(
model,
Expand Down Expand Up @@ -73,6 +77,11 @@ def deserialize(self, element: Optional[XmlElementReader]) -> Optional[List[Any]

return result or None

def resolve_forward_refs(self) -> 'Serializer':
self._inner_serializer = self._inner_serializer.resolve_forward_refs()

return self

@classmethod
def build(
cls,
Expand Down
8 changes: 8 additions & 0 deletions pydantic_xml/serializers/factories/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,14 @@ def deserialize(self, element: Optional[XmlElementReader]) -> Optional['pxml.Bas

return self._model.parse_obj(obj)

def resolve_forward_refs(self) -> 'Serializer':
self._field_serializers = {
field_name: serializer.resolve_forward_refs()
for field_name, serializer in self._field_serializers.items()
}

return self

class DeferredSerializer(ModelSerializer):

def __init__(self, model_field: pd.fields.ModelField):
Expand Down
38 changes: 28 additions & 10 deletions pydantic_xml/serializers/factories/union.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from copy import deepcopy
import typing
from typing import Any, List, Optional, Type

import pydantic as pd
Expand All @@ -8,7 +8,7 @@
from pydantic_xml.element import XmlElementReader, XmlElementWriter
from pydantic_xml.serializers.encoder import XmlEncoder
from pydantic_xml.serializers.factories.model import ModelSerializerFactory
from pydantic_xml.serializers.serializer import Location, PydanticShapeType, Serializer, is_xml_model
from pydantic_xml.serializers.serializer import Location, PydanticShapeType, Serializer, SubFieldWrapper, is_xml_model


class UnionSerializerFactory:
Expand All @@ -25,10 +25,15 @@ def __init__(

inner_serializers: List[Serializer] = []
for sub_field in model_field.sub_fields:
sub_field = deepcopy(sub_field)
sub_field.name = model_field.name
sub_field.alias = model_field.alias
sub_field.field_info = model_field.field_info
sub_field = typing.cast(
pd.fields.ModelField,
SubFieldWrapper(
model_field.name,
model_field.alias,
model_field.field_info,
sub_field,
),
)

inner_serializers.append(self._build_field_serializer(model, sub_field, ctx))

Expand All @@ -55,10 +60,15 @@ def __init__(

inner_serializers: List[ModelSerializerFactory.ModelSerializer] = []
for sub_field in model_field.sub_fields:
sub_field = deepcopy(sub_field)
sub_field.name = model_field.name
sub_field.alias = model_field.alias
sub_field.field_info = model_field.field_info
sub_field = typing.cast(
pd.fields.ModelField,
SubFieldWrapper(
model_field.name,
model_field.alias,
model_field.field_info,
sub_field,
),
)

serializer = self._build_field_serializer(model, sub_field, ctx)
assert isinstance(serializer, ModelSerializerFactory.ModelSerializer), "unexpected serializer type"
Expand Down Expand Up @@ -103,6 +113,14 @@ def deserialize(self, element: Optional[XmlElementReader]) -> Optional[List[Any]

return result

def resolve_forward_refs(self) -> 'Serializer':
self._inner_serializers = [
typing.cast(ModelSerializerFactory.ModelSerializer, serializer.resolve_forward_refs())
for serializer in self._inner_serializers
]

return self

@classmethod
def build(
cls,
Expand Down
20 changes: 16 additions & 4 deletions pydantic_xml/serializers/factories/wrapper.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import dataclasses as dc
from copy import deepcopy
import typing
from typing import Any, Optional, Sized, Type

import pydantic as pd

import pydantic_xml as pxml
from pydantic_xml.element import XmlElementReader, XmlElementWriter
from pydantic_xml.serializers.encoder import XmlEncoder
from pydantic_xml.serializers.serializer import Serializer
from pydantic_xml.serializers.serializer import Serializer, SubFieldWrapper
from pydantic_xml.utils import QName, merge_nsmaps


Expand All @@ -25,14 +25,21 @@ def __init__(
self._nsmap = nsmap = merge_nsmaps(nsmap, ctx.parent_nsmap)
self._search_mode = ctx.search_mode

model_field = deepcopy(model_field)
field_info = model_field.field_info

assert path is not None, "path is not provided"
assert isinstance(field_info, pxml.XmlWrapperInfo), "unexpected field info type"

# copy field_info from wrapped entity
model_field.field_info = field_info.entity or pd.fields.FieldInfo()
model_field = typing.cast(
pd.fields.ModelField,
SubFieldWrapper(
model_field.name,
model_field.alias,
field_info.entity or pd.fields.FieldInfo(),
model_field,
),
)

self._path = tuple(QName.from_alias(tag=part, ns=ns, nsmap=nsmap).uri for part in path.split('/'))
self._inner_serializer = self._build_field_serializer(
Expand Down Expand Up @@ -69,6 +76,11 @@ def deserialize(self, element: Optional[XmlElementReader]) -> Optional[Any]:
else:
return None

def resolve_forward_refs(self) -> 'Serializer':
self._inner_serializer = self._inner_serializer.resolve_forward_refs()

return self

@classmethod
def build(
cls,
Expand Down
41 changes: 36 additions & 5 deletions pydantic_xml/serializers/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import typing
from enum import IntEnum
from inspect import isclass
from typing import Any, Dict, Optional, Tuple, Type, Union
from typing import Any, Dict, ForwardRef, Optional, Tuple, Type, Union

if sys.version_info < (3, 10):
UnionTypes = (Union,)
Expand All @@ -22,6 +22,17 @@
from .encoder import XmlEncoder


class SubFieldWrapper:
def __init__(self, name: str, alias: str, field_info: pd.fields.FieldInfo, model_field: pd.fields.ModelField):
self.model_field = model_field
self.name = name
self.alias = alias
self.field_info = field_info

def __getattr__(self, item: str) -> Any:
return getattr(self.model_field, item)


class Location(IntEnum):
"""
Field data location.
Expand Down Expand Up @@ -120,6 +131,15 @@ def deserialize(self, element: Optional[XmlElementReader]) -> Any:
:return: deserialized value
"""

def resolve_forward_refs(self) -> 'Serializer':
"""
Resolve forward references if exist
:return: resolved serializer
"""

return self

@classmethod
def _get_field_location(cls, field_info: pd.fields.FieldInfo) -> Location:
if isinstance(field_info, pxml.XmlElementInfo):
Expand Down Expand Up @@ -156,21 +176,21 @@ def _build_field_serializer(
model_field: pd.fields.ModelField,
ctx: Context,
) -> 'Serializer':
field_type = model_field.type_
field_info = model_field.field_info
if cls._has_forward_refs(model_field):
return factories.ForwardRefSerializerFactory.build(model, model_field, ctx)

shape_type = PydanticShapeType.from_shape(model_field.shape)
if shape_type is PydanticShapeType.UNKNOWN:
raise TypeError(f"fields of type {model_field.type_} are not supported")

if is_xml_model(field_type):
if is_xml_model(model_field.type_):
is_model_field = True
else:
is_model_field = False

is_union_type = is_union(model_field.outer_type_) and not is_optional(model_field.outer_type_)

field_location = cls._get_field_location(field_info)
field_location = cls._get_field_location(model_field.field_info)

if field_location is Location.WRAPPED:
return factories.WrappedSerializerFactory.build(model, model_field, ctx)
Expand All @@ -188,3 +208,14 @@ def _build_field_serializer(
return factories.HeterogeneousSerializerFactory.build(model, model_field, field_location, ctx)
else:
raise AssertionError("unreachable")

@classmethod
def _has_forward_refs(cls, model_field: pd.fields.ModelField) -> bool:
if isinstance(model_field.type_, ForwardRef):
return True

for field in model_field.sub_fields or []:
if isinstance(field.type_, ForwardRef):
return True

return False
Loading

0 comments on commit b81c46c

Please sign in to comment.