-
Notifications
You must be signed in to change notification settings - Fork 14
/
primitive.py
194 lines (156 loc) · 6.26 KB
/
primitive.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
from typing import Any, Dict, Optional, Union
from pydantic_core import core_schema as pcs
from pydantic_xml import errors
from pydantic_xml.element import XmlElementReader, XmlElementWriter, is_element_nill, make_element_nill
from pydantic_xml.serializers.serializer import SearchMode, Serializer, encode_primitive
from pydantic_xml.typedefs import EntityLocation, NsMap
from pydantic_xml.utils import QName, merge_nsmaps, select_ns
PrimitiveTypeSchema = Union[
pcs.NoneSchema,
pcs.BoolSchema,
pcs.IntSchema,
pcs.FloatSchema,
pcs.StringSchema,
pcs.BytesSchema,
pcs.DateSchema,
pcs.TimeSchema,
pcs.DatetimeSchema,
pcs.TimedeltaSchema,
pcs.UrlSchema,
pcs.MultiHostUrlSchema,
pcs.JsonSchema,
pcs.LiteralSchema,
pcs.LaxOrStrictSchema,
pcs.IsInstanceSchema,
]
class TextSerializer(Serializer):
@classmethod
def from_core_schema(cls, schema: PrimitiveTypeSchema, ctx: Serializer.Context) -> 'TextSerializer':
computed = ctx.field_computed
nillable = ctx.nillable
return cls(computed, nillable)
def __init__(self, computed: bool, nillable: bool):
self._computed = computed
self._nillable = nillable
def serialize(
self, element: XmlElementWriter, value: Any, encoded: Any, *, skip_empty: bool = False,
) -> Optional[XmlElementWriter]:
if value is None and skip_empty:
return element
if self._nillable and value is None:
make_element_nill(element)
element.set_text(encode_primitive(encoded))
return element
def deserialize(
self,
element: Optional[XmlElementReader],
*,
context: Optional[Dict[str, Any]],
) -> Optional[str]:
if self._computed:
return None
if element is None:
return None
if self._nillable and is_element_nill(element):
return None
return element.pop_text() or None
class AttributeSerializer(Serializer):
@classmethod
def from_core_schema(cls, schema: PrimitiveTypeSchema, ctx: Serializer.Context) -> 'AttributeSerializer':
namespaced_attrs = ctx.namespaced_attrs
name = ctx.entity_path or ctx.field_alias or ctx.field_name
ns = select_ns(ctx.entity_ns, ctx.parent_ns if namespaced_attrs else None)
nsmap = merge_nsmaps(ctx.entity_nsmap, ctx.parent_nsmap)
computed = ctx.field_computed
if ns == '':
raise errors.ModelFieldError(
ctx.model_name,
ctx.field_name,
"attributes with default namespace are forbidden",
)
if name is None:
raise errors.ModelFieldError(ctx.model_name, ctx.field_name, "entity name is not provided")
return cls(name, ns, nsmap, computed)
def __init__(self, name: str, ns: Optional[str], nsmap: Optional[NsMap], computed: bool):
self._attr_name = QName.from_alias(tag=name, ns=ns, nsmap=nsmap, is_attr=True).uri
self._computed = computed
@property
def attr_name(self) -> str:
return self._attr_name
def serialize(
self, element: XmlElementWriter, value: Any, encoded: Any, *, skip_empty: bool = False,
) -> Optional[XmlElementWriter]:
if value is None and skip_empty:
return element
element.set_attribute(self._attr_name, encode_primitive(encoded))
return element
def deserialize(
self,
element: Optional[XmlElementReader],
*,
context: Optional[Dict[str, Any]],
) -> Optional[str]:
if self._computed:
return None
if element is None:
return None
return element.pop_attrib(self._attr_name)
class ElementSerializer(TextSerializer):
@classmethod
def from_core_schema(cls, schema: PrimitiveTypeSchema, ctx: Serializer.Context) -> 'ElementSerializer':
name = ctx.entity_path or ctx.field_alias or ctx.field_name
ns = select_ns(ctx.entity_ns, ctx.parent_ns)
nsmap = merge_nsmaps(ctx.entity_nsmap, ctx.parent_nsmap)
search_mode = ctx.search_mode
computed = ctx.field_computed
nillable = ctx.nillable
if name is None:
raise errors.ModelFieldError(ctx.model_name, ctx.field_name, "entity name is not provided")
return cls(name, ns, nsmap, search_mode, computed, nillable)
def __init__(
self,
name: str,
ns: Optional[str],
nsmap: Optional[NsMap],
search_mode: SearchMode,
computed: bool,
nillable: bool,
):
super().__init__(computed, nillable)
self._nsmap = nsmap
self._search_mode = search_mode
self._element_name = QName.from_alias(tag=name, ns=ns, nsmap=nsmap).uri
def serialize(
self, element: XmlElementWriter, value: Any, encoded: Any, *, skip_empty: bool = False,
) -> Optional[XmlElementWriter]:
if value is None and skip_empty:
return element
sub_element = element.make_element(self._element_name, nsmap=self._nsmap)
super().serialize(sub_element, value, encoded, skip_empty=skip_empty)
if skip_empty and sub_element.is_empty():
return None
else:
element.append_element(sub_element)
return sub_element
def deserialize(
self,
element: Optional[XmlElementReader],
*,
context: Optional[Dict[str, Any]],
) -> Optional[str]:
if self._computed:
return None
if element is not None and \
(sub_element := element.pop_element(self._element_name, self._search_mode)) is not None:
return super().deserialize(sub_element, context=context)
else:
return None
def from_core_schema(schema: PrimitiveTypeSchema, ctx: Serializer.Context) -> Serializer:
if ctx.entity_location is EntityLocation.ELEMENT:
return ElementSerializer.from_core_schema(schema, ctx)
elif ctx.entity_location is EntityLocation.ATTRIBUTE:
return AttributeSerializer.from_core_schema(schema, ctx)
elif ctx.entity_location is None:
return TextSerializer.from_core_schema(schema, ctx)
else:
raise AssertionError("unreachable")