/
fields.py
executable file
·278 lines (212 loc) · 7.75 KB
/
fields.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
"""
flask_marshmallow.fields
~~~~~~~~~~~~~~~~~~~~~~~~
Custom, Flask-specific fields.
See the `marshmallow.fields` module for the list of all fields available from the
marshmallow library.
"""
import re
import typing
from collections.abc import Sequence
from flask import current_app, url_for
from marshmallow import fields, missing
__all__ = [
"URLFor",
"UrlFor",
"AbsoluteURLFor",
"AbsoluteUrlFor",
"Hyperlinks",
"File",
"Config",
]
_tpl_pattern = re.compile(r"\s*<\s*(\S*)\s*>\s*")
def _tpl(val: str) -> typing.Optional[str]:
"""Return value within ``< >`` if possible, else return ``None``."""
match = _tpl_pattern.match(val)
if match:
return match.groups()[0]
return None
def _get_value(obj, key, default=missing):
"""Slightly-modified version of marshmallow.utils.get_value.
If a dot-delimited ``key`` is passed and any attribute in the
path is `None`, return `None`.
"""
if "." in key:
return _get_value_for_keys(obj, key.split("."), default)
else:
return _get_value_for_key(obj, key, default)
def _get_value_for_keys(obj, keys, default):
if len(keys) == 1:
return _get_value_for_key(obj, keys[0], default)
else:
value = _get_value_for_key(obj, keys[0], default)
# XXX This differs from the marshmallow implementation
if value is None:
return None
return _get_value_for_keys(value, keys[1:], default)
def _get_value_for_key(obj, key, default):
if not hasattr(obj, "__getitem__"):
return getattr(obj, key, default)
try:
return obj[key]
except (KeyError, IndexError, TypeError, AttributeError):
return getattr(obj, key, default)
class URLFor(fields.Field):
"""Field that outputs the URL for an endpoint. Acts identically to
Flask's ``url_for`` function, except that arguments can be pulled from the
object to be serialized, and ``**values`` should be passed to the ``values``
parameter.
Usage: ::
url = URLFor("author_get", values=dict(id="<id>"))
https_url = URLFor(
"author_get",
values=dict(id="<id>", _scheme="https", _external=True),
)
:param str endpoint: Flask endpoint name.
:param dict values: Same keyword arguments as Flask's url_for, except string
arguments enclosed in `< >` will be interpreted as attributes to pull
from the object.
:param kwargs: keyword arguments to pass to marshmallow field (e.g. ``required``).
"""
_CHECK_ATTRIBUTE = False
def __init__(
self,
endpoint: str,
values: typing.Optional[typing.Dict[str, typing.Any]] = None,
**kwargs,
):
self.endpoint = endpoint
self.values = values or {}
fields.Field.__init__(self, **kwargs)
def _serialize(self, value, key, obj):
"""Output the URL for the endpoint, given the kwargs passed to
``__init__``.
"""
param_values = {}
for name, attr_tpl in self.values.items():
attr_name = _tpl(str(attr_tpl))
if attr_name:
attribute_value = _get_value(obj, attr_name, default=missing)
if attribute_value is None:
return None
if attribute_value is not missing:
param_values[name] = attribute_value
else:
raise AttributeError(
f"{attr_name!r} is not a valid " f"attribute of {obj!r}"
)
else:
param_values[name] = attr_tpl
return url_for(self.endpoint, **param_values)
UrlFor = URLFor
class AbsoluteURLFor(URLFor):
"""Field that outputs the absolute URL for an endpoint."""
def __init__(
self,
endpoint: str,
values: typing.Optional[typing.Dict[str, typing.Any]] = None,
**kwargs,
):
if values:
values["_external"] = True
else:
values = {"_external": True}
URLFor.__init__(self, endpoint=endpoint, values=values, **kwargs)
AbsoluteUrlFor = AbsoluteURLFor
def _rapply(
d: typing.Union[dict, typing.Iterable], func: typing.Callable, *args, **kwargs
):
"""Apply a function to all values in a dictionary or
list of dictionaries, recursively.
"""
if isinstance(d, (tuple, list)):
return [_rapply(each, func, *args, **kwargs) for each in d]
if isinstance(d, dict):
return {key: _rapply(value, func, *args, **kwargs) for key, value in d.items()}
else:
return func(d, *args, **kwargs)
def _url_val(val: typing.Any, key: str, obj: typing.Any, **kwargs):
"""Function applied by `HyperlinksField` to get the correct value in the
schema.
"""
if isinstance(val, URLFor):
return val.serialize(key, obj, **kwargs)
else:
return val
class Hyperlinks(fields.Field):
"""Field that outputs a dictionary of hyperlinks,
given a dictionary schema with :class:`~flask_marshmallow.fields.URLFor`
objects as values.
Example: ::
_links = Hyperlinks(
{
"self": URLFor("author", values=dict(id="<id>")),
"collection": URLFor("author_list"),
}
)
`URLFor` objects can be nested within the dictionary. ::
_links = Hyperlinks(
{
"self": {
"href": URLFor("book", values=dict(id="<id>")),
"title": "book detail",
}
}
)
:param dict schema: A dict that maps names to
:class:`~flask_marshmallow.fields.URLFor` fields.
"""
_CHECK_ATTRIBUTE = False
def __init__(self, schema: typing.Dict[str, typing.Union[URLFor, str]], **kwargs):
self.schema = schema
fields.Field.__init__(self, **kwargs)
def _serialize(self, value, attr, obj):
return _rapply(self.schema, _url_val, key=attr, obj=obj)
class File(fields.Field):
"""A binary file field for uploaded files.
Examples: ::
class ImageSchema(Schema):
image = File(required=True)
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Metadata used by apispec
self.metadata["type"] = "string"
self.metadata["format"] = "binary"
default_error_messages = {"invalid": "Not a valid file."}
def deserialize(
self,
value: typing.Any,
attr: typing.Optional[str] = None,
data: typing.Optional[typing.Mapping[str, typing.Any]] = None,
**kwargs,
):
if isinstance(value, Sequence) and len(value) == 0:
value = missing
return super().deserialize(value, attr, data, **kwargs)
def _deserialize(self, value, attr, data, **kwargs):
from werkzeug.datastructures import FileStorage
if not isinstance(value, FileStorage):
raise self.make_error("invalid")
return value
class Config(fields.Field):
"""A field for Flask configuration values.
Examples: ::
from flask import Flask
app = Flask(__name__)
app.config["API_TITLE"] = "Pet API"
class FooSchema(Schema):
user = String()
title = Config("API_TITLE")
This field should only be used in an output schema. A ``ValueError`` will
be raised if the config key is not found in the app config.
:param str key: The key of the configuration value.
"""
_CHECK_ATTRIBUTE = False
def __init__(self, key: str, **kwargs):
fields.Field.__init__(self, **kwargs)
self.key = key
def _serialize(self, value, attr, obj, **kwargs):
if self.key not in current_app.config:
raise ValueError(f"The key {self.key!r} is not found in the app config.")
return current_app.config[self.key]