-
Notifications
You must be signed in to change notification settings - Fork 72
/
base.py
361 lines (297 loc) · 12.6 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
import datetime
import decimal
import uuid
from enum import Enum
from inspect import isclass
import typing
from marshmallow import fields, missing, Schema, validate
from marshmallow.class_registry import get_class
from marshmallow.decorators import post_dump
from marshmallow.utils import _Missing
from marshmallow import INCLUDE, EXCLUDE, RAISE
try:
from marshmallow_union import Union
ALLOW_UNIONS = True
except ImportError:
ALLOW_UNIONS = False
try:
from marshmallow_enum import EnumField, LoadDumpOptions
ALLOW_ENUMS = True
except ImportError:
ALLOW_ENUMS = False
from .exceptions import UnsupportedValueError
from .validation import (
handle_equal,
handle_length,
handle_one_of,
handle_range,
handle_regexp,
)
__all__ = ("JSONSchema",)
PY_TO_JSON_TYPES_MAP = {
dict: {"type": "object"},
list: {"type": "array"},
datetime.time: {"type": "string", "format": "time"},
datetime.timedelta: {
# TODO explore using 'range'?
"type": "string"
},
datetime.datetime: {"type": "string", "format": "date-time"},
datetime.date: {"type": "string", "format": "date"},
uuid.UUID: {"type": "string", "format": "uuid"},
str: {"type": "string"},
bytes: {"type": "string"},
decimal.Decimal: {"type": "number", "format": "decimal"},
set: {"type": "array"},
tuple: {"type": "array"},
float: {"type": "number", "format": "float"},
int: {"type": "integer"},
bool: {"type": "boolean"},
Enum: {"type": "string"},
}
# We use these pairs to get proper python type from marshmallow type.
# We can't use mapping as earlier Python versions might shuffle dict contents
# and then `fields.Number` might end up before `fields.Integer`.
# As we perform sequential subclass check to determine proper Python type,
# we can't let that happen.
MARSHMALLOW_TO_PY_TYPES_PAIRS = [
# This part of a mapping is carefully selected from marshmallow source code,
# see marshmallow.BaseSchema.TYPE_MAPPING.
(fields.UUID, uuid.UUID),
(fields.String, str),
(fields.Float, float),
(fields.Raw, str),
(fields.Boolean, bool),
(fields.Integer, int),
(fields.Time, datetime.time),
(fields.Date, datetime.date),
(fields.TimeDelta, datetime.timedelta),
(fields.DateTime, datetime.datetime),
(fields.Decimal, decimal.Decimal),
# These are some mappings that generally make sense for the rest
# of marshmallow fields.
(fields.Email, str),
(fields.Dict, dict),
(fields.Url, str),
(fields.List, list),
(fields.Number, decimal.Decimal),
(fields.IP, str),
(fields.IPInterface, str),
# This one is here just for completeness sake and to check for
# unknown marshmallow fields more cleanly.
(fields.Nested, dict),
]
if ALLOW_ENUMS:
# We currently only support loading enum's from their names. So the possible
# values will always map to string in the JSONSchema
MARSHMALLOW_TO_PY_TYPES_PAIRS.append((EnumField, Enum))
FIELD_VALIDATORS = {
validate.Equal: handle_equal,
validate.Length: handle_length,
validate.OneOf: handle_one_of,
validate.Range: handle_range,
validate.Regexp: handle_regexp,
}
def _resolve_additional_properties(cls) -> bool:
meta = cls.Meta
additional_properties = getattr(meta, "additional_properties", None)
if additional_properties is not None:
if additional_properties in (True, False):
return additional_properties
else:
raise UnsupportedValueError(
"`additional_properties` must be either True or False"
)
unknown = getattr(meta, "unknown", None)
if unknown is None:
return False
elif unknown in (RAISE, EXCLUDE):
return False
elif unknown == INCLUDE:
return True
else:
raise UnsupportedValueError("Unknown value %s for `unknown`" % unknown)
class JSONSchema(Schema):
"""Converts to JSONSchema as defined by http://json-schema.org/."""
properties = fields.Method("get_properties")
type = fields.Constant("object")
required = fields.Method("get_required")
def __init__(self, *args, **kwargs) -> None:
"""Setup internal cache of nested fields, to prevent recursion.
:param bool props_ordered: if `True` order of properties will be save as declare in class,
else will using sorting, default is `False`.
Note: For the marshmallow scheme, also need to enable
ordering of fields too (via `class Meta`, attribute `ordered`).
"""
self._nested_schema_classes: typing.Dict[str, typing.Dict[str, typing.Any]] = {}
self.nested = kwargs.pop("nested", False)
self.props_ordered = kwargs.pop("props_ordered", False)
setattr(self.opts, "ordered", self.props_ordered)
super().__init__(*args, **kwargs)
def get_properties(self, obj) -> typing.Dict[str, typing.Dict[str, typing.Any]]:
"""Fill out properties field."""
properties = self.dict_class()
if self.props_ordered:
fields_items_sequence = obj.fields.items()
else:
fields_items_sequence = sorted(obj.fields.items())
for field_name, field in fields_items_sequence:
schema = self._get_schema_for_field(obj, field)
properties[
field.metadata.get("name") or field.data_key or field.name
] = schema
return properties
def get_required(self, obj) -> typing.Union[typing.List[str], _Missing]:
"""Fill out required field."""
required = []
for field_name, field in sorted(obj.fields.items()):
if field.required:
required.append(field.data_key or field.name)
return required or missing
def _from_python_type(self, obj, field, pytype) -> typing.Dict[str, typing.Any]:
"""Get schema definition from python type."""
json_schema = {"title": field.attribute or field.name or ""}
for key, val in PY_TO_JSON_TYPES_MAP[pytype].items():
json_schema[key] = val
if field.dump_only:
json_schema["readOnly"] = True
if field.default is not missing and not callable(field.default):
json_schema["default"] = field.default
if ALLOW_ENUMS and isinstance(field, EnumField):
json_schema["enum"] = self._get_enum_values(field)
if field.allow_none:
previous_type = json_schema["type"]
json_schema["type"] = [previous_type, "null"]
# NOTE: doubled up to maintain backwards compatibility
metadata = field.metadata.get("metadata", {})
metadata.update(field.metadata)
for md_key, md_val in metadata.items():
if md_key in ("metadata", "name"):
continue
json_schema[md_key] = md_val
if isinstance(field, fields.List):
json_schema["items"] = self._get_schema_for_field(obj, field.inner)
if isinstance(field, fields.Dict):
json_schema["additionalProperties"] = (
self._get_schema_for_field(obj, field.value_field)
if field.value_field
else {}
)
return json_schema
def _get_enum_values(self, field) -> typing.List[str]:
assert ALLOW_ENUMS and isinstance(field, EnumField)
if field.load_by == LoadDumpOptions.value:
# Python allows enum values to be almost anything, so it's easier to just load from the
# names of the enum's which will have to be strings.
raise NotImplementedError(
"Currently do not support JSON schema for enums loaded by value"
)
return [value.name for value in field.enum]
def _from_union_schema(
self, obj, field
) -> typing.Dict[str, typing.List[typing.Any]]:
"""Get a union type schema. Uses anyOf to allow the value to be any of the provided sub fields"""
assert ALLOW_UNIONS and isinstance(field, Union)
return {
"anyOf": [
self._get_schema_for_field(obj, sub_field)
for sub_field in field._candidate_fields
]
}
def _get_python_type(self, field):
"""Get python type based on field subclass"""
for map_class, pytype in MARSHMALLOW_TO_PY_TYPES_PAIRS:
if issubclass(field.__class__, map_class):
return pytype
raise UnsupportedValueError("unsupported field type %s" % field)
def _get_schema_for_field(self, obj, field):
"""Get schema and validators for field."""
if hasattr(field, "_jsonschema_type_mapping"):
schema = field._jsonschema_type_mapping()
elif "_jsonschema_type_mapping" in field.metadata:
schema = field.metadata["_jsonschema_type_mapping"]
else:
if isinstance(field, fields.Nested):
# Special treatment for nested fields.
schema = self._from_nested_schema(obj, field)
elif ALLOW_UNIONS and isinstance(field, Union):
schema = self._from_union_schema(obj, field)
else:
pytype = self._get_python_type(field)
schema = self._from_python_type(obj, field, pytype)
# Apply any and all validators that field may have
for validator in field.validators:
if validator.__class__ in FIELD_VALIDATORS:
schema = FIELD_VALIDATORS[validator.__class__](
schema, field, validator, obj
)
else:
base_class = getattr(
validator, "_jsonschema_base_validator_class", None
)
if base_class is not None and base_class in FIELD_VALIDATORS:
schema = FIELD_VALIDATORS[base_class](schema, field, validator, obj)
return schema
def _from_nested_schema(self, obj, field):
"""Support nested field."""
if isinstance(field.nested, (str, bytes)):
nested = get_class(field.nested)
else:
nested = field.nested
if isclass(nested) and issubclass(nested, Schema):
name = nested.__name__
only = field.only
exclude = field.exclude
nested_cls = nested
nested_instance = nested(only=only, exclude=exclude)
else:
nested_cls = nested.__class__
name = nested_cls.__name__
nested_instance = nested
outer_name = obj.__class__.__name__
# If this is not a schema we've seen, and it's not this schema (checking this for recursive schemas),
# put it in our list of schema defs
if name not in self._nested_schema_classes and name != outer_name:
wrapped_nested = self.__class__(nested=True)
wrapped_dumped = wrapped_nested.dump(nested_instance)
wrapped_dumped["additionalProperties"] = _resolve_additional_properties(
nested_cls
)
self._nested_schema_classes[name] = wrapped_dumped
self._nested_schema_classes.update(wrapped_nested._nested_schema_classes)
# and the schema is just a reference to the def
schema = {"type": "object", "$ref": "#/definitions/{}".format(name)}
# NOTE: doubled up to maintain backwards compatibility
metadata = field.metadata.get("metadata", {})
metadata.update(field.metadata)
for md_key, md_val in metadata.items():
if md_key in ("metadata", "name"):
continue
schema[md_key] = md_val
if field.default is not missing and not callable(field.default):
schema["default"] = nested_instance.dump(field.default)
if field.many:
schema = {
"type": "array" if field.required else ["array", "null"],
"items": schema,
}
return schema
def dump(self, obj, **kwargs):
"""Take obj for later use: using class name to namespace definition."""
self.obj = obj
return super().dump(obj, **kwargs)
@post_dump
def wrap(self, data, **_) -> typing.Dict[str, typing.Any]:
"""Wrap this with the root schema definitions."""
if self.nested: # no need to wrap, will be in outer defs
return data
cls = self.obj.__class__
name = cls.__name__
data["additionalProperties"] = _resolve_additional_properties(cls)
self._nested_schema_classes[name] = data
root = {
"$schema": "http://json-schema.org/draft-07/schema#",
"definitions": self._nested_schema_classes,
"$ref": "#/definitions/{name}".format(name=name),
}
return root